├── .all-contributorsrc ├── .bumpversion.cfg ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── configs.example ├── callbacks │ ├── default.yaml │ ├── none.yaml │ └── wandb.yaml ├── config.yaml ├── datamodule │ ├── configuration │ │ └── example_configuration.yaml │ ├── premade_batches.yaml │ └── streamed_batches.yaml ├── experiment │ ├── baseline.yaml │ ├── conv3d_sat_nwp.yaml │ ├── example_simple.yaml │ └── test.yaml ├── hparams_search │ └── conv3d_optuna.yaml ├── hydra │ └── default.yaml ├── logger │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── baseline.yaml │ ├── multimodal.yaml │ ├── nwp_dwsrf_weighting.yaml │ ├── test.yaml │ └── wind_multimodal.yaml ├── readme.md └── trainer │ ├── all_params.yaml │ └── default.yaml ├── experiments ├── india │ ├── 001_v1 │ │ └── india_pv_wind.md │ ├── 002_wind_meteomatics │ │ └── india_windnet_v2.md │ ├── 003_wind_plevels │ │ ├── MAE.png │ │ ├── MAEvstimesteps.png │ │ ├── p10.png │ │ ├── p50.png │ │ └── plevel.md │ ├── 004_n_training_samples │ │ ├── log-plot.py │ │ ├── mae_samples.png │ │ ├── mae_step.png │ │ └── readme.md │ ├── 005_extra_nwp_variables │ │ ├── mae_steps.png │ │ ├── mae_steps_grouped.png │ │ └── readmd.md │ ├── 006_da_only │ │ ├── bad.png │ │ ├── da_only.md │ │ ├── good.png │ │ └── mae_steps.png │ ├── 007_different_seeds │ │ ├── mae_all_steps.png │ │ ├── mae_steps.png │ │ └── readme.md │ └── 008_coarse4 │ │ ├── mae_step.png │ │ ├── mae_step_smooth.png │ │ └── readme.md ├── mae_analysis.py └── uk │ └── 011 - Extending forecast to 36 hours (updated ECMWF data) │ ├── PVNEt_national_XG_comparison.png │ ├── PVNet_day_ahead.md │ └── PVNets_comparison.png ├── pvnet ├── __init__.py ├── callbacks.py ├── data │ ├── __init__.py │ ├── base_datamodule.py │ ├── site_datamodule.py │ └── uk_regional_datamodule.py ├── load_model.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── baseline │ │ ├── __init__.py │ │ ├── last_value.py │ │ ├── readme.md │ │ └── single_value.py │ ├── ensemble.py │ ├── model_cards │ │ ├── pv_india_model_card_template.md │ │ ├── pv_uk_regional_model_card_template.md │ │ └── wind_india_model_card_template.md │ ├── multimodal │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── basic_blocks.py │ │ │ ├── encoders2d.py │ │ │ ├── encoders3d.py │ │ │ └── encodersRNN.py │ │ ├── linear_networks │ │ │ ├── __init__.py │ │ │ ├── basic_blocks.py │ │ │ └── networks.py │ │ ├── multimodal.py │ │ ├── readme.md │ │ ├── site_encoders │ │ │ ├── __init__.py │ │ │ ├── basic_blocks.py │ │ │ └── encoders.py │ │ └── unimodal_teacher.py │ └── utils.py ├── optimizers.py ├── training.py └── utils.py ├── pyproject.toml ├── run.py ├── scripts ├── backtest_sites.py ├── backtest_uk_gsp.py ├── checkpoint_to_huggingface.py ├── save_concurrent_samples.py └── save_samples.py └── tests ├── __init__.py ├── conftest.py ├── data └── test_datamodule.py ├── models ├── baseline │ ├── test_last_value.py │ └── test_single_value.py ├── multimodal │ ├── encoders │ │ ├── test_encoders2d.py │ │ ├── test_encoders3d.py │ │ └── test_encodersRNN.py │ ├── linear_networks │ │ └── test_networks.py │ ├── site_encoders │ │ └── test_encoders.py │ ├── test_multimodal.py │ ├── test_save_load_pretrained.py │ └── test_unimodal_teacher.py └── test_ensemble.py ├── test_data └── sample_data │ ├── non_hrv_shell.zarr │ ├── .zattrs │ ├── .zgroup │ ├── .zmetadata │ ├── time │ │ ├── 0 │ │ ├── .zarray │ │ └── .zattrs │ ├── variable │ │ ├── 0 │ │ ├── .zarray │ │ └── .zattrs │ ├── x_geostationary │ │ ├── 0 │ │ ├── .zarray │ │ └── .zattrs │ └── y_geostationary │ │ ├── 0 │ │ ├── .zarray │ │ └── .zattrs │ └── nwp_shell.zarr │ ├── .zattrs │ ├── .zgroup │ ├── .zmetadata │ ├── init_time │ ├── 0 │ ├── 1 │ ├── .zarray │ └── .zattrs │ ├── step │ ├── 0 │ ├── .zarray │ └── .zattrs │ ├── variable │ ├── 0 │ ├── .zarray │ └── .zattrs │ ├── x │ ├── 0 │ ├── .zarray │ └── .zattrs │ └── y │ ├── 0 │ ├── .zarray │ └── .zattrs ├── test_end2end.py └── test_utils.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitType": "docs", 8 | "commitConvention": "angular", 9 | "contributors": [ 10 | { 11 | "login": "felix-e-h-p", 12 | "name": "Felix", 13 | "avatar_url": "https://avatars.githubusercontent.com/u/137530077?v=4", 14 | "profile": "https://github.com/felix-e-h-p", 15 | "contributions": [ 16 | "code" 17 | ] 18 | }, 19 | { 20 | "login": "Sukh-P", 21 | "name": "Sukhil Patel", 22 | "avatar_url": "https://avatars.githubusercontent.com/u/42407101?v=4", 23 | "profile": "https://github.com/Sukh-P", 24 | "contributions": [ 25 | "code" 26 | ] 27 | }, 28 | { 29 | "login": "dfulu", 30 | "name": "James Fulton", 31 | "avatar_url": "https://avatars.githubusercontent.com/u/41546094?v=4", 32 | "profile": "https://github.com/dfulu", 33 | "contributions": [ 34 | "code" 35 | ] 36 | }, 37 | { 38 | "login": "AUdaltsova", 39 | "name": "Alexandra Udaltsova", 40 | "avatar_url": "https://avatars.githubusercontent.com/u/43303448?v=4", 41 | "profile": "https://github.com/AUdaltsova", 42 | "contributions": [ 43 | "code", 44 | "review" 45 | ] 46 | }, 47 | { 48 | "login": "zakwatts", 49 | "name": "Megawattz", 50 | "avatar_url": "https://avatars.githubusercontent.com/u/47150349?v=4", 51 | "profile": "https://github.com/zakwatts", 52 | "contributions": [ 53 | "code" 54 | ] 55 | }, 56 | { 57 | "login": "peterdudfield", 58 | "name": "Peter Dudfield", 59 | "avatar_url": "https://avatars.githubusercontent.com/u/34686298?v=4", 60 | "profile": "https://github.com/peterdudfield", 61 | "contributions": [ 62 | "code" 63 | ] 64 | }, 65 | { 66 | "login": "mahdilamb", 67 | "name": "Mahdi Lamb", 68 | "avatar_url": "https://avatars.githubusercontent.com/u/4696915?v=4", 69 | "profile": "https://github.com/mahdilamb", 70 | "contributions": [ 71 | "infra" 72 | ] 73 | }, 74 | { 75 | "login": "jacobbieker", 76 | "name": "Jacob Prince-Bieker", 77 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 78 | "profile": "https://www.jacobbieker.com", 79 | "contributions": [ 80 | "code" 81 | ] 82 | }, 83 | { 84 | "login": "codderrrrr", 85 | "name": "codderrrrr", 86 | "avatar_url": "https://avatars.githubusercontent.com/u/149995852?v=4", 87 | "profile": "https://github.com/codderrrrr", 88 | "contributions": [ 89 | "code" 90 | ] 91 | }, 92 | { 93 | "login": "confusedmatrix", 94 | "name": "Chris Briggs", 95 | "avatar_url": "https://avatars.githubusercontent.com/u/617309?v=4", 96 | "profile": "https://chrisxbriggs.com", 97 | "contributions": [ 98 | "code" 99 | ] 100 | }, 101 | { 102 | "login": "tmi", 103 | "name": "tmi", 104 | "avatar_url": "https://avatars.githubusercontent.com/u/147159?v=4", 105 | "profile": "https://github.com/tmi", 106 | "contributions": [ 107 | "code" 108 | ] 109 | }, 110 | { 111 | "login": "carderne", 112 | "name": "Chris Arderne", 113 | "avatar_url": "https://avatars.githubusercontent.com/u/19817302?v=4", 114 | "profile": "https://rdrn.me/", 115 | "contributions": [ 116 | "code" 117 | ] 118 | }, 119 | { 120 | "login": "Dakshbir", 121 | "name": "Dakshbir", 122 | "avatar_url": "https://avatars.githubusercontent.com/u/144359831?v=4", 123 | "profile": "https://github.com/Dakshbir", 124 | "contributions": [ 125 | "code" 126 | ] 127 | }, 128 | { 129 | "login": "MAYANK12SHARMA", 130 | "name": "MAYANK SHARMA", 131 | "avatar_url": "https://avatars.githubusercontent.com/u/145884197?v=4", 132 | "profile": "https://github.com/MAYANK12SHARMA", 133 | "contributions": [ 134 | "code" 135 | ] 136 | }, 137 | { 138 | "login": "lambaaryan011", 139 | "name": "aryan lamba ", 140 | "avatar_url": "https://avatars.githubusercontent.com/u/153702847?v=4", 141 | "profile": "https://github.com/lambaaryan011", 142 | "contributions": [ 143 | "code" 144 | ] 145 | }, 146 | { 147 | "login": "michael-gendy", 148 | "name": "michael-gendy", 149 | "avatar_url": "https://avatars.githubusercontent.com/u/64384201?v=4", 150 | "profile": "https://github.com/michael-gendy", 151 | "contributions": [ 152 | "code" 153 | ] 154 | }, 155 | { 156 | "login": "adityasuthar", 157 | "name": "Aditya Suthar", 158 | "avatar_url": "https://avatars.githubusercontent.com/u/95685363?v=4", 159 | "profile": "https://adityasuthar.github.io/", 160 | "contributions": [ 161 | "code" 162 | ] 163 | }, 164 | { 165 | "login": "markus-kreft", 166 | "name": "Markus Kreft", 167 | "avatar_url": "https://avatars.githubusercontent.com/u/129367085?v=4", 168 | "profile": "https://github.com/markus-kreft", 169 | "contributions": [ 170 | "code" 171 | ] 172 | }, 173 | { 174 | "login": "JackKelly", 175 | "name": "Jack Kelly", 176 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4", 177 | "profile": "http://jack-kelly.com", 178 | "contributions": [ 179 | "ideas" 180 | ] 181 | } 182 | ], 183 | "contributorsPerLine": 7, 184 | "skipCi": true, 185 | "repoType": "github", 186 | "repoHost": "https://github.com", 187 | "projectName": "PVNet", 188 | "projectOwner": "openclimatefix" 189 | } 190 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | tag = True 4 | current_version = 4.1.18 5 | message = Bump version: {current_version} → {new_version} [skip ci] 6 | 7 | [bumpversion:file:pvnet/__init__.py] 8 | search = __version__ = "{current_version}" 9 | replace = __version__ = "{new_version}" 10 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Python Bump Version & release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths-ignore: 8 | - "configs.example/**" # ignores all files in configs.example 9 | - "**/README.md" # ignores all README files 10 | - "experiments/**" # ignores all files in experiments directory 11 | 12 | jobs: 13 | release: 14 | uses: openclimatefix/.github/.github/workflows/python-release.yml@main 15 | secrets: 16 | token: ${{ secrets.PYPI_API_TOKEN }} 17 | PAT_TOKEN: ${{ secrets.PAT_TOKEN }} 18 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python package tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | schedule: 8 | - cron: "0 12 * * 1" 9 | jobs: 10 | call-run-python-tests: 11 | uses: openclimatefix/.github/.github/workflows/python-test.yml@main 12 | with: 13 | # 0 means don't use pytest-xdist 14 | pytest_numcpus: "4" 15 | # pytest-cov looks at this folder 16 | pytest_cov_dir: "pvnet" 17 | # extra things to install 18 | sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" 19 | # brew_install: "proj geos librttopo" 20 | os_list: '["ubuntu-latest"]' 21 | python-version: "['3.10', '3.11']" 22 | extra_commands: "pip3 install -e '.[all]'" 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | config_tree.txt 3 | configs/ 4 | lightning_logs/ 5 | logs/ 6 | output/ 7 | checkpoints* 8 | csv/ 9 | notebooks/ 10 | *.html 11 | *.csv 12 | latest_logged_train_batch.png 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | .DS_Store 144 | 145 | # vim 146 | *swp 147 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: debug-statements 12 | - id: detect-private-key 13 | 14 | # python code formatting/linting 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | # Ruff version. 17 | rev: "v0.11.0" 18 | hooks: 19 | - id: ruff 20 | args: [--fix] 21 | # yaml formatting 22 | - repo: https://github.com/pre-commit/mirrors-prettier 23 | rev: v3.0.2 24 | hooks: 25 | - id: prettier 26 | types: [yaml] 27 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | configs.example 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Open Climate Fix Ltd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | recursive-include pvnet/models/model_cards *.md 3 | -------------------------------------------------------------------------------- /configs.example/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: pvnet.callbacks.MainEarlyStopping 3 | # name of the logged metric which determines when model is improving 4 | monitor: "${resolve_monitor_loss:${model.output_quantiles}}" 5 | mode: "min" # can be "max" or "min" 6 | patience: 10 # how many epochs (or val check periods) of not improving until training stops 7 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 8 | 9 | learning_rate_monitor: 10 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 11 | logging_interval: "epoch" 12 | 13 | model_summary: 14 | _target_: lightning.pytorch.callbacks.ModelSummary 15 | max_depth: 3 16 | 17 | model_checkpoint: 18 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 19 | # name of the logged metric which determines when model is improving 20 | monitor: "${resolve_monitor_loss:${model.output_quantiles}}" 21 | mode: "min" # can be "max" or "min" 22 | save_top_k: 1 # save k best models (determined by above metric) 23 | save_last: True # additionaly always save model from last epoch 24 | every_n_epochs: 1 25 | verbose: False 26 | filename: "epoch={epoch}-step={step}" 27 | # The path to where the model checkpoints will be stored 28 | dirpath: "PLACEHOLDER/${model_name}" #${..model_name} 29 | auto_insert_metric_name: False 30 | save_on_train_epoch_end: False 31 | -------------------------------------------------------------------------------- /configs.example/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/configs.example/callbacks/none.yaml -------------------------------------------------------------------------------- /configs.example/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /configs.example/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: multimodal.yaml 8 | - datamodule: premade_samples.yaml 9 | - callbacks: default.yaml # set this to null if you don't want to use callbacks 10 | - logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | - experiment: null 12 | - hparams_search: null 13 | - hydra: default.yaml 14 | 15 | renewable: "pv_uk" 16 | 17 | # enable color logging 18 | # - override hydra/hydra_logging: colorlog 19 | # - override hydra/job_logging: colorlog 20 | 21 | # path to original working directory 22 | # hydra hijacks working directory by changing it to the current log directory, 23 | # so it's useful to have this path as a special variable 24 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 25 | work_dir: ${hydra:runtime.cwd} 26 | 27 | model_name: "default" 28 | 29 | # use `python run.py debug=true` for easy debugging! 30 | # this will run 1 train, val and test loop with only 1 batch 31 | # equivalent to running `python run.py trainer.fast_dev_run=true` 32 | # (this is placed here just for easier access from command line) 33 | debug: False 34 | 35 | # pretty print config at the start of the run using Rich library 36 | print_config: True 37 | 38 | # disable python warnings if they annoy you 39 | ignore_warnings: True 40 | 41 | # check performance on test set, using the best model achieved during training 42 | # lightning chooses best model based on metric specified in checkpoint callback 43 | test_after_training: False 44 | 45 | seed: 2727831 46 | -------------------------------------------------------------------------------- /configs.example/datamodule/configuration/example_configuration.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Example config for producing PVNet samples 3 | name: example_config 4 | 5 | input_data: 6 | 7 | # Either use Site OR GSP configuration 8 | site: 9 | # Path to Site data in NetCDF format 10 | file_path: PLACEHOLDER.nc 11 | # Path to metadata in CSV format 12 | metadata_file_path: PLACEHOLDER.csv 13 | time_resolution_minutes: 15 14 | interval_start_minutes: -60 15 | # Specified for intraday currently 16 | interval_end_minutes: 480 17 | dropout_timedeltas_minutes: [] 18 | dropout_fraction: 0 # Fraction of samples with dropout 19 | 20 | gsp: 21 | # Path to GSP data in zarr format 22 | # e.g. gs://solar-pv-nowcasting-data/PV/GSP/v7/pv_gsp.zarr 23 | zarr_path: PLACEHOLDER.zarr 24 | interval_start_minutes: -60 25 | # Specified for intraday currently 26 | interval_end_minutes: 480 27 | time_resolution_minutes: 30 28 | # Random value from the list below will be chosen as the delay when dropout is used 29 | # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. 30 | # Values after t0 are assumed as targets and cannot be dropped. 31 | dropout_timedeltas_minutes: [] 32 | dropout_fraction: 0 # Fraction of samples with dropout 33 | 34 | nwp: 35 | 36 | ecmwf: 37 | provider: ecmwf 38 | # Path to ECMWF NWP data in zarr format 39 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 40 | zarr_path: PLACEHOLDER.zarr 41 | interval_start_minutes: -60 42 | # Specified for intraday currently 43 | interval_end_minutes: 480 44 | time_resolution_minutes: 60 45 | channels: 46 | - t2m # 2-metre temperature 47 | - dswrf # downwards short-wave radiation flux 48 | - dlwrf # downwards long-wave radiation flux 49 | - hcc # high cloud cover 50 | - mcc # medium cloud cover 51 | - lcc # low cloud cover 52 | - tcc # total cloud cover 53 | - sde # snow depth water equivalent 54 | - sr # direct solar radiation 55 | - duvrs # downwards UV radiation at surface 56 | - prate # precipitation rate 57 | - u10 # 10-metre U component of wind speed 58 | - u100 # 100-metre U component of wind speed 59 | - u200 # 200-metre U component of wind speed 60 | - v10 # 10-metre V component of wind speed 61 | - v100 # 100-metre V component of wind speed 62 | - v200 # 200-metre V component of wind speed 63 | # The following channels are accumulated and need to be diffed 64 | accum_channels: 65 | - dswrf # downwards short-wave radiation flux 66 | - dlwrf # downwards long-wave radiation flux 67 | - sr # direct solar radiation 68 | - duvrs # downwards UV radiation at surface 69 | image_size_pixels_height: 24 70 | image_size_pixels_width: 24 71 | dropout_timedeltas_minutes: [-360] 72 | dropout_fraction: 1.0 # Fraction of samples with dropout 73 | max_staleness_minutes: null 74 | normalisation_constants: 75 | t2m: 76 | mean: 283.48333740234375 77 | std: 3.692270040512085 78 | dswrf: 79 | mean: 11458988.0 80 | std: 13025427.0 81 | dlwrf: 82 | mean: 27187026.0 83 | std: 15855867.0 84 | hcc: 85 | mean: 0.3961029052734375 86 | std: 0.42244860529899597 87 | mcc: 88 | mean: 0.3288780450820923 89 | std: 0.38039860129356384 90 | lcc: 91 | mean: 0.44901806116104126 92 | std: 0.3791404366493225 93 | tcc: 94 | mean: 0.7049227356910706 95 | std: 0.37487083673477173 96 | sde: 97 | mean: 8.107526082312688e-05 98 | std: 0.000913831521756947 # Mapped from "sd" in the Python file 99 | sr: 100 | mean: 12905302.0 101 | std: 16294988.0 102 | duvrs: 103 | mean: 1305651.25 104 | std: 1445635.25 105 | prate: 106 | mean: 3.108070450252853e-05 107 | std: 9.81039775069803e-05 108 | u10: 109 | mean: 1.7677178382873535 110 | std: 5.531515598297119 111 | u100: 112 | mean: 2.393547296524048 113 | std: 7.2320556640625 114 | u200: 115 | mean: 2.7963004112243652 116 | std: 8.049470901489258 117 | v10: 118 | mean: 0.985887885093689 119 | std: 5.411230564117432 120 | v100: 121 | mean: 1.4244288206100464 122 | std: 6.944501876831055 123 | v200: 124 | mean: 1.6010299921035767 125 | std: 7.561611652374268 126 | # Added diff_ keys for the channels under accum_channels: 127 | diff_dlwrf: 128 | mean: 1136464.0 129 | std: 131942.03125 130 | diff_dswrf: 131 | mean: 420584.6875 132 | std: 715366.3125 133 | diff_duvrs: 134 | mean: 48265.4765625 135 | std: 81605.25 136 | diff_sr: 137 | mean: 469169.5 138 | std: 818950.6875 139 | 140 | ukv: 141 | provider: ukv 142 | # Path to UKV NWP data in zarr format 143 | # e.g. gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_7.zarr 144 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 145 | zarr_path: PLACEHOLDER.zarr 146 | interval_start_minutes: -60 147 | # Specified for intraday currently 148 | interval_end_minutes: 480 149 | time_resolution_minutes: 60 150 | channels: 151 | - t # 2-metre temperature 152 | - dswrf # downwards short-wave radiation flux 153 | - dlwrf # downwards long-wave radiation flux 154 | - hcc # high cloud cover 155 | - mcc # medium cloud cover 156 | - lcc # low cloud cover 157 | - sde # snow depth water equivalent 158 | - r # relative humidty 159 | - vis # visibility 160 | - si10 # 10-metre wind speed 161 | - wdir10 # 10-metre wind direction 162 | - prate # precipitation rate 163 | # These variables exist in CEDA training data but not in the live MetOffice live service 164 | - hcct # height of convective cloud top, meters above surface. NaN if no clouds 165 | - cdcb # height of lowest cloud base > 3 oktas 166 | - dpt # dew point temperature 167 | - prmsl # mean sea level pressure 168 | - h # geometrical? (maybe geopotential?) height 169 | image_size_pixels_height: 24 170 | image_size_pixels_width: 24 171 | dropout_timedeltas_minutes: [-360] 172 | dropout_fraction: 1.0 # Fraction of samples with dropout 173 | max_staleness_minutes: null 174 | normalisation_constants: 175 | t: 176 | mean: 283.64913206 177 | std: 4.38818501 178 | dswrf: 179 | mean: 111.28265039 180 | std: 190.47216887 181 | dlwrf: 182 | mean: 325.03130139 183 | std: 39.45988077 184 | hcc: 185 | mean: 29.11949682 186 | std: 38.07184418 187 | mcc: 188 | mean: 40.88984494 189 | std: 41.91144559 190 | lcc: 191 | mean: 50.08362643 192 | std: 39.33210726 193 | sde: 194 | mean: 0.00289545 195 | std: 0.1029753 196 | r: 197 | mean: 81.79229501 198 | std: 11.45012499 199 | vis: 200 | mean: 32262.03285118 201 | std: 21578.97975625 202 | si10: 203 | mean: 6.88348448 204 | std: 3.94718813 205 | wdir10: 206 | mean: 199.41891636 207 | std: 94.08407495 208 | prate: 209 | mean: 3.45793433e-05 210 | std: 0.00021497 211 | hcct: 212 | mean: -18345.97478167 213 | std: 18382.63958991 214 | cdcb: 215 | mean: 1412.26599062 216 | std: 2126.99350113 217 | dpt: 218 | mean: 280.54379901 219 | std: 4.57250482 220 | prmsl: 221 | mean: 101321.61574029 222 | std: 1252.71790539 223 | h: 224 | mean: 2096.51991356 225 | std: 1075.77812282 226 | 227 | satellite: 228 | # Path to Satellite data (non-HRV) in zarr format 229 | # e.g. gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr 230 | zarr_path: PLACEHOLDER.zarr 231 | interval_start_minutes: -30 232 | interval_end_minutes: 0 233 | time_resolution_minutes: 5 234 | channels: 235 | - IR_016 # Surface, cloud phase 236 | - IR_039 # Surface, clouds, wind fields 237 | - IR_087 # Surface, clouds, atmospheric instability 238 | - IR_097 # Ozone 239 | - IR_108 # Surface, clouds, wind fields, atmospheric instability 240 | - IR_120 # Surface, clouds, atmospheric instability 241 | - IR_134 # Cirrus cloud height, atmospheric instability 242 | - VIS006 # Surface, clouds, wind fields 243 | - VIS008 # Surface, clouds, wind fields 244 | - WV_062 # Water vapor, high level clouds, upper air analysis 245 | - WV_073 # Water vapor, atmospheric instability, upper-level dynamics 246 | image_size_pixels_height: 24 247 | image_size_pixels_width: 24 248 | dropout_timedeltas_minutes: [] 249 | dropout_fraction: 0 # Fraction of samples with dropout 250 | normalisation_constants: 251 | IR_016: 252 | mean: 0.17594202 253 | std: 0.21462157 254 | IR_039: 255 | mean: 0.86167645 256 | std: 0.04618041 257 | IR_087: 258 | mean: 0.7719318 259 | std: 0.06687243 260 | IR_097: 261 | mean: 0.8014212 262 | std: 0.0468558 263 | IR_108: 264 | mean: 0.71254843 265 | std: 0.17482725 266 | IR_120: 267 | mean: 0.89058584 268 | std: 0.06115861 269 | IR_134: 270 | mean: 0.944365 271 | std: 0.04492306 272 | VIS006: 273 | mean: 0.09633306 274 | std: 0.12184761 275 | VIS008: 276 | mean: 0.11426069 277 | std: 0.13090034 278 | WV_062: 279 | mean: 0.7359355 280 | std: 0.16111417 281 | WV_073: 282 | mean: 0.62479186 283 | std: 0.12924142 284 | 285 | solar_position: 286 | interval_start_minutes: -60 287 | interval_end_minutes: 480 288 | time_resolution_minutes: 30 289 | -------------------------------------------------------------------------------- /configs.example/datamodule/premade_batches.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.data.DataModule 2 | configuration: null 3 | 4 | # The sample_dir is the location batches were saved to using the save_batches.py script 5 | # The sample_dir should contain train and val subdirectories with batches 6 | 7 | sample_dir: "PLACEHOLDER" 8 | num_workers: 10 9 | prefetch_factor: 2 10 | batch_size: 8 11 | -------------------------------------------------------------------------------- /configs.example/datamodule/streamed_batches.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.data.DataModule 2 | # Path to the data configuration yaml file. You can find examples in the configuration subdirectory 3 | # in configs.example/datamodule/configuration 4 | # Use the full local path such as: /FULL/PATH/PVNet/configs/datamodule/configuration/gcp_configuration.yaml" 5 | 6 | configuration: "PLACEHOLDER.yaml" 7 | num_workers: 20 8 | prefetch_factor: 2 9 | batch_size: 8 10 | 11 | sample_output_dir: "PLACEHOLDER" 12 | num_train_samples: 2 13 | num_val_samples: 1 14 | 15 | train_period: 16 | - null 17 | - "2022-05-07" 18 | val_period: 19 | - "2022-05-08" 20 | - "2023-05-08" 21 | -------------------------------------------------------------------------------- /configs.example/experiment/baseline.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: baseline.yaml 9 | - override /datamodule: premade_samples.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: neptune.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 518 17 | validate_only: "1" # by putting this key in the config file, the model does not get trained. 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 1 22 | -------------------------------------------------------------------------------- /configs.example/experiment/conv3d_sat_nwp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: conv3d_sat_nwp.yaml 9 | - override /datamodule: premade_samples.yaml 10 | - override /callbacks: default.yaml 11 | # - override /logger: neptune.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 518 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 21 | 22 | model: 23 | conv3d_channels: 32 24 | -------------------------------------------------------------------------------- /configs.example/experiment/example_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: conv3d_sat_nwp.yaml 9 | - override /datamodule: premade_samples.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: tensorboard.yaml 12 | - override /hparams_search: null 13 | - override /hydra: default.yaml 14 | 15 | # all parameters below will be merged with parameters from default configurations set above 16 | # this allows you to overwrite only specified parameters 17 | 18 | seed: 518 19 | 20 | trainer: 21 | min_epochs: 1 22 | max_epochs: 2 23 | 24 | datamodule: 25 | batch_size: 16 26 | 27 | validate_only: "1" # by putting this key in the config file, the model does not get trained. 28 | -------------------------------------------------------------------------------- /configs.example/experiment/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=test.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: test.yaml 9 | - override /datamodule: premade_samples.yaml 10 | - override /callbacks: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | seed: 518 16 | 17 | trainer: 18 | min_epochs: 0 19 | max_epochs: 2 20 | reload_dataloaders_every_n_epochs: 0 21 | limit_train_batches: 2000 22 | limit_val_batches: 100 23 | limit_test_batches: 100 24 | val_check_interval: 100 25 | num_sanity_val_steps: 8 26 | accumulate_grad_batches: 4 27 | #fast_dev_run: 3 28 | 29 | datamodule: 30 | num_workers: 10 31 | prefetch_factor: 2 32 | batch_size: 8 33 | #validate_only: '1' # by putting this key in the config file, the model does not get trained. 34 | -------------------------------------------------------------------------------- /configs.example/hparams_search/conv3d_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=conv3d_optuna experiment=conv3d_sat_nwp 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | optimized_metric: "MSE/Validation_epoch" 11 | 12 | hydra: 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | sweeper: 17 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 18 | storage: null 19 | study_name: null 20 | n_jobs: 1 21 | 22 | # 'minimize' or 'maximize' the objective 23 | direction: minimize 24 | 25 | # number of experiments that will be executed 26 | n_trials: 20 27 | 28 | # choose Optuna hyperparameter sampler 29 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 30 | sampler: 31 | _target_: optuna.samplers.TPESampler 32 | seed: 12345 33 | consider_prior: true 34 | prior_weight: 1.0 35 | consider_magic_clip: true 36 | consider_endpoints: false 37 | n_startup_trials: 10 38 | n_ei_candidates: 24 39 | multivariate: false 40 | warn_independent_sampling: true 41 | 42 | # define range of hyperparameters 43 | search_space: 44 | model.include_pv_yield_history: 45 | type: categorical 46 | choices: [true, false] 47 | model.include_future_satellite: 48 | type: categorical 49 | choices: [true, false] 50 | -------------------------------------------------------------------------------- /configs.example/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | # Local log directory for hydra 4 | dir: PLACEHOLDER/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | sweep: 6 | # Local log directory for hydra 7 | dir: PLACEHOLDER/multiruns/${now:%Y-%m-%d_%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | 10 | # you can set here environment variables that are universal for all users 11 | # for system specific variables (like data paths) it's better to use .env file! 12 | job: 13 | env_set: 14 | EXAMPLE_VAR: "example_value" 15 | -------------------------------------------------------------------------------- /configs.example/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | # local path to log training process 6 | save_dir: "PLACEHOLDER" 7 | name: "csv/" 8 | version: null 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs.example/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | - csv.yaml 5 | # - neptune.yaml 6 | # - tensorboard.yaml 7 | - wandb.yaml 8 | -------------------------------------------------------------------------------- /configs.example/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | # Neptune project placeholder 7 | project: PLACEHOLDER 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs.example/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | # Path to use for tensorboard logs 6 | save_dir: "PLACEHOLDER" 7 | name: "default" 8 | version: "${model_name}" 9 | log_graph: False 10 | default_hp_metric: False 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /configs.example/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # wandb project to log to 6 | project: "PLACEHOLDER" 7 | name: "${model_name}" 8 | # location to store the wandb local logs 9 | save_dir: "PLACEHOLDER" 10 | offline: False # set True to store all logs only locally 11 | id: null # pass correct id to resume experiment! 12 | # entity: "" # set to name of your wandb team or just remove it 13 | log_model: False 14 | prefix: "" 15 | job_type: "train" 16 | group: "" 17 | tags: [] 18 | -------------------------------------------------------------------------------- /configs.example/model/baseline.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.models.baseline.last_value.Model 2 | 3 | forecast_minutes: 120 4 | history_minutes: 30 5 | -------------------------------------------------------------------------------- /configs.example/model/multimodal.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.models.multimodal.multimodal.Model 2 | 3 | output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] 4 | 5 | #-------------------------------------------- 6 | # NWP encoder 7 | #-------------------------------------------- 8 | 9 | nwp_encoders_dict: 10 | ukv: 11 | _target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet 12 | _partial_: True 13 | in_channels: 2 14 | out_features: 256 15 | number_of_conv3d_layers: 6 16 | conv3d_channels: 32 17 | image_size_pixels: 24 18 | ecmwf: 19 | _target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet 20 | _partial_: True 21 | in_channels: 12 22 | out_features: 256 23 | number_of_conv3d_layers: 4 24 | conv3d_channels: 32 25 | image_size_pixels: 12 26 | 27 | #-------------------------------------------- 28 | # Sat encoder settings 29 | #-------------------------------------------- 30 | 31 | sat_encoder: 32 | _target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet 33 | _partial_: True 34 | in_channels: 11 35 | out_features: 256 36 | number_of_conv3d_layers: 6 37 | conv3d_channels: 32 38 | image_size_pixels: 24 39 | 40 | add_image_embedding_channel: False 41 | 42 | #-------------------------------------------- 43 | # PV encoder settings 44 | #-------------------------------------------- 45 | 46 | pv_encoder: 47 | _target_: pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork 48 | _partial_: True 49 | num_sites: 349 50 | out_features: 40 51 | num_heads: 4 52 | kdim: 40 53 | id_embed_dim: 20 54 | 55 | #-------------------------------------------- 56 | # Tabular network settings 57 | #-------------------------------------------- 58 | 59 | output_network: 60 | _target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2 61 | _partial_: True 62 | fc_hidden_features: 128 63 | n_res_blocks: 6 64 | res_block_layers: 2 65 | dropout_frac: 0.0 66 | 67 | embedding_dim: 16 68 | include_sun: True 69 | include_gsp_yield_history: False 70 | include_site_yield_history: False 71 | 72 | # The mapping between the location IDs and their embedding indices 73 | location_id_mapping: 74 | 1: 1 75 | 5: 2 76 | 110: 3 77 | # ... 78 | 79 | #-------------------------------------------- 80 | # Times 81 | #-------------------------------------------- 82 | 83 | # Foreast and time settings 84 | forecast_minutes: 480 85 | history_minutes: 120 86 | 87 | min_sat_delay_minutes: 60 88 | 89 | # These must also be set even if identical to forecast_minutes and history_minutes 90 | sat_history_minutes: 90 91 | pv_history_minutes: 180 92 | 93 | # These must be set for each NWP encoder 94 | nwp_history_minutes: 95 | ukv: 120 96 | ecmwf: 120 97 | nwp_forecast_minutes: 98 | ukv: 480 99 | ecmwf: 480 100 | # Optional; defaults to 60, so must be set for data with different time resolution 101 | nwp_interval_minutes: 102 | ukv: 60 103 | ecmwf: 60 104 | 105 | # ---------------------------------------------- 106 | # Optimizer 107 | # ---------------------------------------------- 108 | optimizer: 109 | _target_: pvnet.optimizers.EmbAdamWReduceLROnPlateau 110 | lr: 0.0001 111 | weight_decay: 0.01 112 | amsgrad: True 113 | patience: 5 114 | factor: 0.1 115 | threshold: 0.002 116 | -------------------------------------------------------------------------------- /configs.example/model/nwp_dwsrf_weighting.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.models.multimodal.nwp_weighting.Model 2 | 3 | #-------------------------------------------- 4 | # Network settings 5 | #-------------------------------------------- 6 | 7 | # Foreast and time settings 8 | forecast_minutes: 480 9 | history_minutes: 120 10 | 11 | nwp_history_minutes: 120 12 | nwp_forecast_minutes: 480 13 | 14 | nwp_image_size_pixels: 24 15 | dwsrf_channel: 1 16 | 17 | # ---------------------------------------------- 18 | 19 | optimizer: 20 | _target_: pvnet.optimizers.AdamW 21 | lr: 0.0005 22 | -------------------------------------------------------------------------------- /configs.example/model/test.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.models.baseline.single_value.Model 2 | 3 | history_minutes: 120 4 | forecast_minutes: 360 5 | -------------------------------------------------------------------------------- /configs.example/model/wind_multimodal.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.models.multimodal.multimodal.Model 2 | 3 | output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] 4 | 5 | #-------------------------------------------- 6 | # NWP encoder 7 | #-------------------------------------------- 8 | nwp_encoders_dict: 9 | ecmwf: 10 | _target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet 11 | _partial_: True 12 | in_channels: 14 13 | out_features: 256 14 | number_of_conv3d_layers: 6 15 | conv3d_channels: 32 16 | image_size_pixels: 16 17 | 18 | #-------------------------------------------- 19 | # Sensor encoder settings 20 | #-------------------------------------------- 21 | 22 | wind_encoder: 23 | _target_: pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork 24 | _partial_: True 25 | num_sites: 19 26 | out_features: 40 27 | num_heads: 4 28 | kdim: 40 29 | id_embed_dim: 20 30 | 31 | #-------------------------------------------- 32 | # Tabular network settings 33 | #-------------------------------------------- 34 | 35 | output_network: 36 | _target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2 37 | _partial_: True 38 | fc_hidden_features: 128 39 | n_res_blocks: 6 40 | res_block_layers: 2 41 | dropout_frac: 0.0 42 | 43 | embedding_dim: 16 44 | include_sun: False 45 | include_gsp_yield_history: False 46 | 47 | # The mapping between the location IDs and their embedding indices 48 | location_id_mapping: 49 | 1: 1 50 | 5: 2 51 | 110: 3 52 | # ... 53 | 54 | #-------------------------------------------- 55 | # Times 56 | #-------------------------------------------- 57 | 58 | # Foreast and time settings 59 | forecast_minutes: 480 60 | history_minutes: 120 61 | 62 | min_sat_delay_minutes: 60 63 | 64 | # --- set to null if same as history_minutes --- 65 | sat_history_minutes: 90 66 | nwp_history_minutes: 60 67 | nwp_forecast_minutes: 2880 68 | pv_history_minutes: 180 69 | pv_interval_minutes: 15 70 | sat_interval_minutes: 15 71 | 72 | target_key: "sensor" 73 | # ---------------------------------------------- 74 | # Optimizer 75 | # ---------------------------------------------- 76 | optimizer: 77 | _target_: pvnet.optimizers.EmbAdamWReduceLROnPlateau 78 | lr: 0.0001 79 | weight_decay: 0.01 80 | amsgrad: True 81 | patience: 5 82 | factor: 0.1 83 | threshold: 0.002 84 | -------------------------------------------------------------------------------- /configs.example/readme.md: -------------------------------------------------------------------------------- 1 | This directory contains example configuration files for the PVNet project. Many paths will need to unique to each user. You can find these paths by searching for PLACEHOLDER within these logs. Not all of 2 | the values with a placeholder need to be set. For example in the logger subdirectory there are many different loggers with PLACEHOLDERS. If only one logger is used, then only that placeholder needs to be set. 3 | 4 | run experiments by: 5 | `python run.py experiment=example_simple ` 6 | -------------------------------------------------------------------------------- /configs.example/trainer/all_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_save_path: null 33 | num_sanity_val_steps: 2 34 | truncated_bptt_steps: null 35 | resume_from_checkpoint: null 36 | profiler: null 37 | benchmark: False 38 | deterministic: False 39 | reload_dataloaders_every_epoch: False 40 | auto_lr_find: False 41 | replace_sampler_ddp: True 42 | terminate_on_nan: False 43 | auto_scale_batch_size: False 44 | prepare_data_per_node: True 45 | plugins: null 46 | amp_backend: "native" 47 | amp_level: "O2" 48 | move_metrics_to_cpu: False 49 | -------------------------------------------------------------------------------- /configs.example/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.trainer.Trainer 2 | 3 | # set `1` to train on GPU, `0` to train on CPU only 4 | accelerator: auto 5 | devices: auto 6 | 7 | min_epochs: null 8 | max_epochs: null 9 | reload_dataloaders_every_n_epochs: 0 10 | num_sanity_val_steps: 8 11 | fast_dev_run: false 12 | 13 | accumulate_grad_batches: 4 14 | log_every_n_steps: 50 15 | -------------------------------------------------------------------------------- /experiments/india/001_v1/india_pv_wind.md: -------------------------------------------------------------------------------- 1 | # PVNet for Wind and PV Sites in India 2 | 3 | ## PVNet for sites 4 | 5 | ### Data 6 | 7 | We use PV generation data for India from April 2019-Nov 2022 for training 8 | and Dec 2022- Nov 2023 for validation. This is only with ECMWF data, and PV generation history. 9 | 10 | The forecast is every 15 minutes for 48 hours for PV generation. 11 | 12 | The input NWP data is hourly, and 32x32 pixels (corresponding to around 320kmx320km) around a central 13 | point in NW-India. 14 | 15 | [WandB Link](https://wandb.ai/openclimatefix/pvnet_india2.1/runs/o4xpvzrc) 16 | 17 | ### Results 18 | 19 | Overall MAE is 4.9% on the validation set, and forecasts look overall good. 20 | 21 | ![batch_idx_1_all_892_2ca7e12db5de2cf2e244](https://github.com/openclimatefix/PVNet/assets/7170359/07e8199a-11b5-4400-9897-37b7738a4f39) 22 | 23 | ![W B Chart 05_02_2024, 10_07_12_pvnet](https://github.com/openclimatefix/PVNet/assets/7170359/abaefdc1-dedd-4a12-8a26-afaf36d7786b) 24 | 25 | ## WindNet 26 | 27 | 28 | ### April-29-2024 WindNet v1 Production Model 29 | 30 | [WandB Link](https://wandb.ai/openclimatefix/india/runs/5llq8iw6) 31 | 32 | Improvements: Larger input size (64x64), 7 hour delay for ECMWF NWP inputs, to match productions. 33 | New, much more efficient encoder for NWP, allowing for more filters and layers, with less parameters. 34 | The 64x64 input size corresponds to 6.4 degrees x 6.4 degrees, which is around 700km x 700km. This allows for the 35 | model to see the wind over the wind generation sites, which seems to be the biggest reason for the improvement in the model. 36 | 37 | 38 | 39 | MAE is 7.6% with real improvements on the production side of things. 40 | 41 | 42 | There were other experiments with slightly different numbers of filters, model parameters and the like, but generally no 43 | improvements were seen. 44 | 45 | 46 | ## WindNet v1 Results 47 | 48 | ### Data 49 | 50 | We use Wind generation data for India from April 2019-Nov 2022 for training 51 | and Dec 2022- Nov 2023 for validation. This is only with ECMWF data, and Wind generation history. 52 | 53 | The forecast is every 15 minutes for 48 hours for Wind generation. 54 | 55 | The input NWP data is hourly, and 32x32 pixels (corresponding to around 320kmx320km) around a central 56 | point in NW-India. Note: The majority of the wind generation is likely not covered in the 320kmx320km area. 57 | 58 | 59 | [WandB Link](https://wandb.ai/openclimatefix/pvnet_india2.1/runs/otdx7axx) 60 | 61 | ### Results 62 | 63 | ![W B Chart 05_02_2024, 10_05_19](https://github.com/openclimatefix/PVNet/assets/7170359/6a8cd9c5-bdfe-41ab-996d-37fd1be2a07c) 64 | 65 | ![W B Chart 05_02_2024, 10_06_51_windnet](https://github.com/openclimatefix/PVNet/assets/7170359/77554ef0-4411-4432-af95-8530aef4a701) 66 | 67 | ![batch_idx_1_all_1730_379a9f881a7f01153f98](https://github.com/openclimatefix/PVNet/assets/7170359/243d9f3e-4cb9-405e-80c5-40c6c218c17f) 68 | 69 | MAE is around 10% overall, although it doesn't seem to do very well on the ramps up and down. 70 | -------------------------------------------------------------------------------- /experiments/india/002_wind_meteomatics/india_windnet_v2.md: -------------------------------------------------------------------------------- 1 | ### WindNet v2 Meteomatics + ECMWF Model 2 | 3 | [WandB Linl](https://wandb.ai/openclimatefix/india/runs/v3mja33d) 4 | 5 | This newest experiment uses Meteomatics data in addition to ECMWF data. The Meteomatics data is at specific locations corresponding 6 | to the gneeration sites we know about. It is smartly downscaled ECMWF data, down to 15 minutes and at a few height levels we are 7 | interested in, primarily 10m, 100m, and 200m. The Meteomatics data is a semi-reanalysis, with each block of 6 hours being from one forecast run. 8 | For example, in one day, hours 00-06 are from the same, 00 forecast run, and hours 06-12 are from the 06 forecast run. This is important to note 9 | as it is both not a real reanalysis, but we also can't have it exactly match the live data, as any forecast steps beyond 6 hours are thrown away. 10 | This does mean that these results should be taken as a best case or better than best case scenario, as every 6 hour, observations from the future 11 | are incorporated into the Meteomatics input data from the next NWP mode run. 12 | 13 | For the purposes of WindNet, Meteomatics data is treated as Sensor data that goes into the future. 14 | The model encodes the sensor information the same way as for the historical PV, Wind, and GSP generation, and has 15 | a simple, single attention head to encode the information. This is then concatenated along with the rest of the data, like in 16 | previous experiments. 17 | 18 | This model also has an even larger input size of ECMWF data, 81x81 pixels, corresponding to around 810kmx810km. 19 | ![Screenshot_20240430_082855](https://github.com/openclimatefix/PVNet/assets/7170359/6981a088-8664-474b-bfea-c94c777fc119) 20 | 21 | MAE is 7.0% on the validation set, showing a slight improvement over the previous model. 22 | 23 | Comperison with the production model: 24 | 25 | | Timestep | Prod MAE % | No Meteomatics MAE % | Meteomatics MAE % | 26 | | --- | --- | --- | --- | 27 | | 0-0 minutes | 7.586 | 5.920 | 2.475 | 28 | | 15-15 minutes | 8.021 | 5.809 | 2.968 | 29 | | 30-45 minutes | 7.233 | 5.742 | 3.472 | 30 | | 45-60 minutes | 7.187 | 5.698 | 3.804 | 31 | | 60-120 minutes | 7.231 | 5.816 | 4.650 | 32 | | 120-240 minutes | 7.287 | 6.080 | 6.028 | 33 | | 240-360 minutes | 7.319 | 6.375 | 6.738 | 34 | | 360-480 minutes | 7.285 | 6.638 | 6.964 | 35 | | 480-720 minutes | 7.143 | 6.747 | 6.906 | 36 | | 720-1440 minutes | 7.380 | 7.207 | 6.962 | 37 | | 1440-2880 minutes | 7.904 | 7.507 | 7.507 | 38 | 39 | ![mae_per_timestep](https://github.com/openclimatefix/PVNet/assets/7170359/e3c942e8-65c6-4b95-8c51-f25d43e7a082) 40 | 41 | 42 | 43 | 44 | Example plot 45 | 46 | ![Screenshot_20240430_082937](https://github.com/openclimatefix/PVNet/assets/7170359/88db342e-bf82-414e-8255-5ad4af659fb8) 47 | -------------------------------------------------------------------------------- /experiments/india/003_wind_plevels/MAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/003_wind_plevels/MAE.png -------------------------------------------------------------------------------- /experiments/india/003_wind_plevels/MAEvstimesteps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/003_wind_plevels/MAEvstimesteps.png -------------------------------------------------------------------------------- /experiments/india/003_wind_plevels/p10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/003_wind_plevels/p10.png -------------------------------------------------------------------------------- /experiments/india/003_wind_plevels/p50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/003_wind_plevels/p50.png -------------------------------------------------------------------------------- /experiments/india/003_wind_plevels/plevel.md: -------------------------------------------------------------------------------- 1 | # Running WindNet for RUVNL for diferent Plevels 2 | 3 | https://wandb.ai/openclimatefix/india/runs/5llq8iw6 is the current production one 4 | This has 7 plevels and a small patch size. 5 | 6 | ## Experiments 7 | 8 | 1. Only used plevel 50 (orange) 9 | https://wandb.ai/openclimatefix/india/runs/ziudzweq/ 10 | 11 | 2. Use plevels of [2, 10, 25, 50, 75, 90, 98]. This is what is already used. (green) 12 | https://wandb.ai/openclimatefix/india/runs/xdlew7ib 13 | 14 | 3. Use plevels of [1, 02, 10, 20, 25, 30, 40, 50, 60, 70, 75, 80 (brown) 15 | , 90, 98, 99] 16 | https://wandb.ai/openclimatefix/india/runs/pcr2zsrc 17 | 18 | 19 | ## Training 20 | 21 | Each epoch took about ~4 hours, so the training runs took several days. 22 | 23 | TODO add number of samples 24 | 25 | ## Results 26 | 27 | MAE results show that using the plevel of 50 only, gives better results 28 | ![](Mae.png "Mae") 29 | 30 | The p50 results are about the same 31 | ![](p50.png "p50") 32 | 33 | We can see that for p10 the results are not right, as they should converge to 0.1 34 | ![](p10.png "p10") 35 | 36 | Interestingly the more plevels you have the better the results are for before 4 hours 37 | but the less plevels you have the better the results for >= 8 hours. 38 | 39 | | Timestep | P50 only MAE % | 7 plevels MAE % | 15 plevel MAE % | 7 plevels small patch MAE % | 40 | | --- | --- | --- | --- | --- | 41 | | 0-0 minutes | 5.416 | 5.920 | 3.933 | 7.586 | 42 | | 15-15 minutes | 5.458 | 5.809 | 4.003 | 8.021 | 43 | | 30-45 minutes | 5.525 | 5.742 | 4.442 | 7.233 | 44 | | 45-60 minutes | 5.595 | 5.698 | 4.772 | 7.187 | 45 | | 60-120 minutes | 5.890 | 5.816 | 5.307 | 7.231 | 46 | | 120-240 minutes | 6.423 | 6.080 | 6.275 | 7.287 | 47 | | 240-360 minutes | 6.608 | 6.375 | 6.707 | 7.319 | 48 | | 360-480 minutes | 6.728 | 6.638 | 6.904 | 7.285 | 49 | | 480-720 minutes | 6.634 | 6.747 | 6.872 | 7.143 | 50 | | 720-1440 minutes | 6.940 | 7.207 | 7.176 | 7.380 | 51 | | 1440-2880 minutes | 7.446 | 7.507 | 7.735 | 7.904 | 52 | 53 | 54 | ![](MAEvstimesteps.png "MAEvstimesteps") 55 | -------------------------------------------------------------------------------- /experiments/india/004_n_training_samples/log-plot.py: -------------------------------------------------------------------------------- 1 | """ Small script to make MAE vs number of batches plot""" 2 | 3 | import pandas as df 4 | import plotly.graph_objects as go 5 | 6 | data = [[100, 7.779], [300, 7.441], [1000, 7.181], [3000, 7.180], [6711, 7.151]] 7 | df = df.DataFrame(data, columns=["n_samples", "MAE [%]"]) 8 | 9 | fig = go.Figure() 10 | fig.add_trace(go.Scatter(x=df["n_samples"], y=df["MAE [%]"], mode="lines+markers")) 11 | fig.update_layout(title="MAE % for N samples", xaxis_title="N Samples", yaxis_title="MAE %") 12 | # change to log log 13 | fig.update_xaxes(type="log") 14 | fig.show(renderer="browser") 15 | -------------------------------------------------------------------------------- /experiments/india/004_n_training_samples/mae_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/004_n_training_samples/mae_samples.png -------------------------------------------------------------------------------- /experiments/india/004_n_training_samples/mae_step.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/004_n_training_samples/mae_step.png -------------------------------------------------------------------------------- /experiments/india/004_n_training_samples/readme.md: -------------------------------------------------------------------------------- 1 | # N samples experiments 2 | 3 | Kicked off an experiment that uses N samples 4 | This is done by adding `limit_train_batches` to the `trainer/default.yaml`. 5 | 6 | I checked that when limiting the batches, the same batches are shown to model for each epoch. 7 | 8 | ## Experiments 9 | 10 | Original is 6711 batches 11 | 12 | - 100: 3p6scx2r 13 | - 300: am46tno1 14 | - 1000: u04xlb6p 15 | - 3000: p11lhreo 16 | 17 | ## Results 18 | 19 | Overall 20 | 21 | | Experiment | MAE % | 22 | |------------|-------| 23 | | 100 | 7.779 | 24 | | 300 | 7.441 | 25 | | 1000 | 7.181 | 26 | | 3000 | 7.180 | 27 | | 6711 | 7.151 | 28 | 29 | Results by timestamps 30 | 31 | 32 | | Timestep | 100 MAE % | 300 MAE % | 1000 MAE % | 3000 MAE % | 6711 MAE % | 33 | | --- | --- | --- | --- | --- | --- | 34 | | 0-0 minutes | 7.985 | 7.453 | 7.155 | 5.553 | 5.920 | 35 | | 15-15 minutes | 7.953 | 7.055 | 6.923 | 5.453 | 5.809 | 36 | | 30-45 minutes | 8.043 | 7.172 | 6.907 | 5.764 | 5.742 | 37 | | 45-60 minutes | 7.850 | 7.070 | 6.790 | 5.815 | 5.698 | 38 | | 60-120 minutes | 7.698 | 6.809 | 6.597 | 5.890 | 5.816 | 39 | | 120-240 minutes | 7.355 | 6.629 | 6.495 | 6.221 | 6.080 | 40 | | 240-360 minutes | 7.230 | 6.729 | 6.559 | 6.541 | 6.375 | 41 | | 360-480 minutes | 7.415 | 6.997 | 6.770 | 6.855 | 6.638 | 42 | | 480-720 minutes | 7.258 | 7.037 | 6.668 | 6.876 | 6.747 | 43 | | 720-1440 minutes | 7.659 | 7.362 | 7.038 | 7.142 | 7.207 | 44 | | 1440-2880 minutes | 8.027 | 7.745 | 7.518 | 7.535 | 7.507 | 45 | 46 | ![](mae_step.png "mae_steps") 47 | 48 | ![](mae_samples.png "mae_samples") 49 | -------------------------------------------------------------------------------- /experiments/india/005_extra_nwp_variables/mae_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/005_extra_nwp_variables/mae_steps.png -------------------------------------------------------------------------------- /experiments/india/005_extra_nwp_variables/mae_steps_grouped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/005_extra_nwp_variables/mae_steps_grouped.png -------------------------------------------------------------------------------- /experiments/india/005_extra_nwp_variables/readmd.md: -------------------------------------------------------------------------------- 1 | # Adding extra nwp variables 2 | 3 | I wanted to run Windnet but testing some new nwp variables from ecmwf 4 | 5 | General conclusion, although more experiments could be done. 6 | The current nwp variables are about right. 7 | If you add lots it makes it worse. 8 | If you take some away, it makes it worse. 9 | 10 | ## Bugs 11 | 12 | Ran into a problem where found that some xamples have 13 | `d.__getitem__('nwp-ecmwf__init_time_utc').values` had size 50, where it should be just one values. I removed these examples. This might 14 | 15 | ## Experiments 16 | 17 | The number of samples were 8000 when training. 18 | 19 | ### 15 variablles 20 | Run windnet with `'hcc', 'lcc', 'mcc', 'prate', 'sde', 'sr', 't2m', 'tcc', 'u10', 21 | 'v10', 'u100', 'v100', 'u200', 'v200', 'dlwrf', 'dswrf'`. 22 | 23 | The experiment on wandb is [here](https://wandb.ai/openclimatefix/india/runs/k91rdffo) 24 | 25 | ### 7 variables 26 | Run windnet with the original 7 variables. 27 | `t2m, u10, u100, u200, v10, v100, v200 ` 28 | 29 | The experiment on wandb is [here](https://wandb.ai/openclimatefix/india/runs/miszfep5) 30 | 31 | ### 3 variables 32 | Run windnet with only `t, u10, v100` 33 | 34 | The experiment on wandb is [here](https://wandb.ai/openclimatefix/india/runs/22v3a39g) 35 | 36 | ## Results 37 | 38 | | Timestep | 15 MAE % | 7 MAE % | 3 MAE % | 39 | | --- | --- | --- | --- | 40 | | 0-0 minutes | 7.450 | 6.623 | 7.529 | 41 | | 15-15 minutes | 7.348 | 6.441 | 7.408 | 42 | | 30-45 minutes | 7.242 | 6.544 | 7.294 | 43 | | 45-60 minutes | 7.134 | 6.567 | 7.185 | 44 | | 60-120 minutes | 7.058 | 6.295 | 7.009 | 45 | | 120-240 minutes | 6.965 | 6.290 | 6.800 | 46 | | 240-360 minutes | 6.807 | 6.374 | 6.580 | 47 | | 360-480 minutes | 6.749 | 6.482 | 6.548 | 48 | | 480-720 minutes | 6.892 | 6.686 | 6.685 | 49 | | 720-1440 minutes | 7.020 | 6.756 | 6.780 | 50 | | 1440-2880 minutes | 7.445 | 7.095 | 7.214 | 51 | 52 | ![](mae_steps_grouped.png "mae_steps") 53 | 54 | The raw data is here 55 | ![](mae_steps.png "mae_steps") 56 | -------------------------------------------------------------------------------- /experiments/india/006_da_only/bad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/006_da_only/bad.png -------------------------------------------------------------------------------- /experiments/india/006_da_only/da_only.md: -------------------------------------------------------------------------------- 1 | ## DA forecasts only 2 | 3 | The idea was to create a forecast for DA (day-ahead) only for Windnet. 4 | We hope this would bring down the DA MAE values. 5 | 6 | We do this by not forecasting the first X hours. 7 | 8 | Unfortunately, it doesnt not look like ignore X hours, make the DA forecast better. 9 | 10 | ## Experiments 11 | 12 | 1. Baseline - [here](https://wandb.ai/openclimatefix/india/runs/miszfep5) 13 | 2. Ignore first 6 hours - [here](https://wandb.ai/openclimatefix/india/runs/uosk0qug) 14 | 3. Ignore first 12 hours - [here](https://wandb.ai/openclimatefix/india/runs/s9cnn4ei) 15 | 16 | ## Results 17 | 18 | | Timestep | all MAE % | 6 MAE % | 12 MAE % | 19 | | --- | --- |---------|---------| 20 | | 0-0 minutes | nan | nan | nan | 21 | | 15-15 minutes | nan | nan | nan | 22 | | 30-45 minutes | 0.065 | nan | nan | 23 | | 45-60 minutes | 0.066 | nan | nan | 24 | | 60-120 minutes | 0.063 | nan | nan | 25 | | 120-240 minutes | 0.063 | nan | nan | 26 | | 240-360 minutes | 0.064 | nan | nan | 27 | | 360-480 minutes | 0.065 | 0.068 | nan | 28 | | 480-720 minutes | 0.067 | 0.065 | nan | 29 | | 720-1440 minutes | 0.068 | 0.065 | 0.065 | 30 | | 1440-2880 minutes | 0.071 | 0.071 | 0.071 | 31 | 32 | ![](mae_steps.png "mae_steps") 33 | 34 | Here's two examples from the 6 hour ignore model, one that forecated it well, one that didnt 35 | 36 | ![](bad.png "bad") 37 | ![](good.png "good") 38 | -------------------------------------------------------------------------------- /experiments/india/006_da_only/good.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/006_da_only/good.png -------------------------------------------------------------------------------- /experiments/india/006_da_only/mae_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/006_da_only/mae_steps.png -------------------------------------------------------------------------------- /experiments/india/007_different_seeds/mae_all_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/007_different_seeds/mae_all_steps.png -------------------------------------------------------------------------------- /experiments/india/007_different_seeds/mae_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/007_different_seeds/mae_steps.png -------------------------------------------------------------------------------- /experiments/india/007_different_seeds/readme.md: -------------------------------------------------------------------------------- 1 | # Training models with different seeds 2 | 3 | Want to see the effect or training a model with different seeds. 4 | 5 | We can see that the results for different seeds can vary by 0.5%, 6 | and some models being better at different time horizons than others 7 | 8 | ## Experiments 9 | - seed 1 - [miszfep5](https://wandb.ai/openclimatefix/india/runs/miszfep5) 10 | - seed 2 - [cxshv2q4](https://wandb.ai/openclimatefix/india/runs/cxshv2q4) 11 | - seed 3 - [m46wdrr7](https://wandb.ai/openclimatefix/india/runs/m46wdrr7) 12 | 13 | These were trained with 1000 batches, and 300 batches for validation 14 | 15 | ## Results 16 | 17 | | Timestep | s1 MAE % | s2 MAE % | s3 MAE % | 18 | | --- | --- | --- | --- | 19 | | 0-0 minutes | 0.066 | 0.061 | 0.066 | 20 | | 15-15 minutes | 0.064 | 0.058 | 0.064 | 21 | | 30-45 minutes | 0.065 | 0.060 | 0.063 | 22 | | 45-60 minutes | 0.066 | 0.060 | 0.063 | 23 | | 60-120 minutes | 0.063 | 0.060 | 0.063 | 24 | | 120-240 minutes | 0.063 | 0.063 | 0.065 | 25 | | 240-360 minutes | 0.064 | 0.066 | 0.065 | 26 | | 360-480 minutes | 0.065 | 0.066 | 0.066 | 27 | | 480-720 minutes | 0.067 | 0.066 | 0.065 | 28 | | 720-1440 minutes | 0.068 | 0.068 | 0.066 | 29 | | 1440-2880 minutes | 0.071 | 0.072 | 0.071 | 30 | 31 | ![](mae_steps.png "mae_steps") 32 | 33 | ![](mae_all_steps.png "mae_steps") 34 | -------------------------------------------------------------------------------- /experiments/india/008_coarse4/mae_step.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/008_coarse4/mae_step.png -------------------------------------------------------------------------------- /experiments/india/008_coarse4/mae_step_smooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/india/008_coarse4/mae_step_smooth.png -------------------------------------------------------------------------------- /experiments/india/008_coarse4/readme.md: -------------------------------------------------------------------------------- 1 | # Coarser data and more examples 2 | 3 | We downsampled the ECMWF data from 0.05 to 0.2. 4 | In previous experiments we used a 0.1 resolution, as this is the same as the live ECMWF data. 5 | 6 | By reducing the resolution we can increase the number of samples we have to train on. 7 | We used 41408 number of samples to train, and 10352 samples to validate 8 | This is approximately 5 times more samples than the previous experiments. 9 | 10 | ## Experiments 11 | 12 | 13 | ### b8_s1 14 | Batche size 8, with 0.2 degree NWP data. 15 | https://wandb.ai/openclimatefix/india/runs/w85hftb6 16 | 17 | 18 | ### b8_s2 19 | Batch size 8, different seed, with 0.2 degree NWP data. 20 | https://wandb.ai/openclimatefix/india/runs/k4x1tunj 21 | 22 | ### b32_s3 23 | Batch size 32, with 0.2 degree NWP data. Also kept the learning rate a bit higher 24 | https://wandb.ai/openclimatefix/india/runs/ktale7pa 25 | 26 | ### epochs 27 | We set the early stopping epochs from 10 to 15. This should mean model will train a bit more 28 | https://wandb.ai/openclimatefix/india/runs/8hfc83uv 29 | 30 | ### small model 31 | We made the model about 50% of the size by reduce the reducing the channels in the NWP encoder fomr 256 to 64 and reducing the hidden features in the output network fomr 1024 to 256 32 | https://wandb.ai/openclimatefix/india/runs/sk5ek3pk 33 | 34 | 35 | ### early stopping on MAE/val 36 | Changing from quantile_loss to MAE/val to stop early on. This should mean the model does more training epochs, and the results we are interested int. 37 | https://wandb.ai/openclimatefix/india/runs/a5nkkzj6 38 | 39 | 40 | ### old 41 | Old experiment with 0.1 degree NWP data. 42 | https://wandb.ai/openclimatefix/india/runs/m46wdrr7. 43 | Note the validation batches are different that the experiments above. 44 | 45 | Interesting the GPU memory did not increase much better experiments 2 and 3. 46 | Need to check that 32 batches were being passed through. 47 | 48 | ## Results 49 | 50 | The coarsening data does seem to improve the experiments results in the first 10 hours of the forecast. 51 | DA forecast looks very similar. Note the 0 hour forecast has a large amount of variation. 52 | 53 | 54 | 55 | Still spike results in the individual runs. 56 | 57 | | Timestep | b8_s1 MAE % | b8_s2 MAE % | b32_s3 MAE % | epochs MAE % | small MAE % | mae/val MAE % | old MAE % | 58 | | --- | --- | --- | --- | --- | --- | --- | --- | 59 | | 0-0 minutes | 0.052 | 0.047 | 0.027 | 0.030 | 0.041 | 0.041 | 0.066 | 60 | | 15-15 minutes | 0.052 | 0.049 | 0.031 | 0.033 | 0.041 | 0.041 | 0.064 | 61 | | 30-45 minutes | 0.052 | 0.051 | 0.037 | 0.039 | 0.043 | 0.043 | 0.063 | 62 | | 45-60 minutes | 0.053 | 0.052 | 0.040 | 0.043 | 0.044 | 0.044 | 0.063 | 63 | | 60-120 minutes | 0.056 | 0.054 | 0.048 | 0.052 | 0.048 | 0.048 | 0.063 | 64 | | 120-240 minutes | 0.061 | 0.060 | 0.060 | 0.064 | 0.057 | 0.057 | 0.065 | 65 | | 240-360 minutes | 0.061 | 0.062 | 0.063 | 0.065 | 0.061 | 0.061 | 0.065 | 66 | | 360-480 minutes | 0.062 | 0.062 | 0.062 | 0.063 | 0.063 | 0.063 | 0.066 | 67 | | 480-720 minutes | 0.063 | 0.063 | 0.062 | 0.064 | 0.064 | 0.064 | 0.065 | 68 | | 720-1440 minutes | 0.065 | 0.066 | 0.065 | 0.067 | 0.066 | 0.066 | 0.066 | 69 | | 1440-2880 minutes | 0.069 | 0.070 | 0.071 | 0.071 | 0.071 | 0.071 | 0.071 | 70 | 71 | 72 | ![](mae_step.png "mae_steps") 73 | 74 | ![](mae_step_smooth.png "mae_steps") 75 | 76 | I think its worth noting the model traing MAE is around `3`% and the validation MAE is about `7`%, so there is good reason to believe that the model is over fit to the trianing set. 77 | It would be good to plot some of the trainin examples, to see if they are less spiky. 78 | -------------------------------------------------------------------------------- /experiments/mae_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate analysis of MAE values for multiple model forecasts 3 | 4 | Does this for 48 hour horizon forecasts with 15 minute granularity 5 | 6 | """ 7 | 8 | import argparse 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | import wandb 15 | 16 | matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( 17 | color=[ 18 | "FFD053", # yellow 19 | "7BCDF3", # blue 20 | "63BCAF", # teal 21 | "086788", # dark blue 22 | "FF9736", # dark orange 23 | "E4E4E4", # grey 24 | "14120E", # black 25 | "FFAC5F", # orange 26 | "4C9A8E", # dark teal 27 | ] 28 | ) 29 | 30 | 31 | def main(project: str, runs: list[str], run_names: list[str]) -> None: 32 | """ 33 | Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity 34 | 35 | Args: 36 | project: name of W&B project 37 | runs: W&B ids of runs 38 | run_names: user specified names for runs 39 | 40 | """ 41 | api = wandb.Api() 42 | dfs = [] 43 | epoch_num = [] 44 | for run in runs: 45 | run = api.run(f"openclimatefix/{project}/{run}") 46 | 47 | df = run.history(samples=run.lastHistoryStep + 1) 48 | # Get the columns that are in the format 'MAE_horizon/step_/val` 49 | mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] 50 | # Sort them 51 | mae_cols.sort() 52 | df = df[mae_cols] 53 | # Get last non-NaN value 54 | # Drop all rows with all NaNs 55 | df = df.dropna(how="all") 56 | # Select the last row 57 | # Get average across entire row, and get the IDX for the one with the smallest values 58 | min_row_mean = np.inf 59 | for idx, (row_idx, row) in enumerate(df.iterrows()): 60 | if row.mean() < min_row_mean: 61 | min_row_mean = row.mean() 62 | min_row_idx = idx 63 | df = df.iloc[min_row_idx] 64 | # Calculate the timedelta for each group 65 | # Get the step from the column name 66 | column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] 67 | dfs.append(df) 68 | epoch_num.append(min_row_idx) 69 | # Get the timedelta for each group 70 | groupings = [ 71 | [0, 0], 72 | [15, 15], 73 | [30, 45], 74 | [45, 60], 75 | [60, 120], 76 | [120, 240], 77 | [240, 360], 78 | [360, 480], 79 | [480, 720], 80 | [720, 1440], 81 | [1440, 2880], 82 | ] 83 | 84 | groups_df = [] 85 | grouping_starts = [grouping[0] for grouping in groupings] 86 | header = "| Timestep |" 87 | separator = "| --- |" 88 | for run_name in run_names: 89 | header += f" {run_name} MAE % |" 90 | separator += " --- |" 91 | print(header) 92 | print(separator) 93 | for grouping in groupings: 94 | group_string = f"| {grouping[0]}-{grouping[1]} minutes |" 95 | # Select indicies from column_timesteps that are within the grouping, inclusive 96 | group_idx = [ 97 | idx 98 | for idx, timestep in enumerate(column_timesteps) 99 | if timestep >= grouping[0] and timestep <= grouping[1] 100 | ] 101 | data_one_group = [] 102 | for df in dfs: 103 | mean_row = df.iloc[group_idx].mean() 104 | group_string += f" {mean_row:0.3f} |" 105 | data_one_group.append(mean_row) 106 | print(group_string) 107 | 108 | groups_df.append(data_one_group) 109 | 110 | groups_df = pd.DataFrame(groups_df, columns=run_names, index=grouping_starts) 111 | 112 | for idx, df in enumerate(dfs): 113 | print(f"{run_names[idx]}: {df.mean()*100:0.3f}") 114 | 115 | # Plot the error per timestep 116 | plt.figure() 117 | for idx, df in enumerate(dfs): 118 | plt.plot( 119 | column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" 120 | ) 121 | plt.legend() 122 | plt.xlabel("Timestep (minutes)") 123 | plt.ylabel("MAE %") 124 | plt.title("MAE % for each timestep") 125 | plt.savefig("mae_per_timestep.png") 126 | plt.show() 127 | 128 | # Plot the error per grouped timestep 129 | plt.figure() 130 | for idx, run_name in enumerate(run_names): 131 | plt.plot( 132 | groups_df[run_name], 133 | label=f"{run_name}, epoch: {epoch_num[idx]}", 134 | marker="o", 135 | linestyle="-", 136 | ) 137 | plt.legend() 138 | plt.xlabel("Timestep (minutes)") 139 | plt.ylabel("MAE %") 140 | plt.title("MAE % for each grouped timestep") 141 | plt.savefig("mae_per_grouped_timestep.png") 142 | plt.show() 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument("--project", type=str, default="") 148 | # Add arguments that is a list of strings 149 | parser.add_argument("--list_of_runs", nargs="+") 150 | parser.add_argument("--run_names", nargs="+") 151 | args = parser.parse_args() 152 | main(args.project, args.list_of_runs, args.run_names) 153 | -------------------------------------------------------------------------------- /experiments/uk/011 - Extending forecast to 36 hours (updated ECMWF data)/PVNEt_national_XG_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/uk/011 - Extending forecast to 36 hours (updated ECMWF data)/PVNEt_national_XG_comparison.png -------------------------------------------------------------------------------- /experiments/uk/011 - Extending forecast to 36 hours (updated ECMWF data)/PVNet_day_ahead.md: -------------------------------------------------------------------------------- 1 | PVNet day ahead was retrained to produce a 36 hour forecast, it was given its [previous configuration](https://huggingface.co/openclimatefix/pvnet_uk_region/tree/main) and data except for being given ECMWF NWP data with a longer forecast horizon (max 85 hours but 37 hours given to the model). Longer horizon UKV NWP data was not available at time of training and will be a further addition in the future. 2 | 3 | **Results** \ 4 | [The training run](https://wandb.ai/openclimatefix/pvnet_day_ahead_36_hours/runs/m4d3wlft/overview) had 3.15% normalised mean absolute error (NMAE) on validation data (100,000 samples from May 2022 to May 2023), [previous training of PVNet day ahead](https://wandb.ai/openclimatefix/pvnet2.1/runs/2ghzwbxg/overview?) had similar results of 3.19% NMAE. 5 | 6 | 7 | ![](PVNets_comparison.png "PVNets comparison") 8 | 9 | When comparing the two versions of PVNet day ahead (the new version in green) by forecast accuracy at each step on the validation dataset samples we see some small differences in the model up to 33 hours, such as first the first few steps and between steps 5 and 10, which could be explained by differences in samples seen and evaluated on between the two versions. 10 | 11 | However the larger difference is an improvement toward the end of the forecast horizon, from 33 hours onwards which is likely due to ECMWF data now being available for this period, when previously no NWP data was given past 33 hours due to the NWP forecast horizon of previous data and factoring in NWP initialization times and production delays. 12 | 13 | UKV NWP data used in the model is currently up to 30 hours, we would expect a further reduction in error from 30+ hours when training with longer horizon UKV data which would cover up to 36 hours. 14 | 15 | 16 | A very rough comparison is also plotted between these two PVNet model versions and the National XG model which is currently used for day ahead predictions in production. 17 | 18 | ![](PVNEt_national_XG_comparison.png "PVNets national XG comparison") 19 | 20 | 21 | 22 | This comparison is rough and should not be seen as a fair comparison as the national XG numbers are just an estimate derived from backtest data on different time periods. However, it can show roughly what relative improvement could be achieved from replacing the National XG Day ahead model with a PVNet Day Ahead model. 23 | -------------------------------------------------------------------------------- /experiments/uk/011 - Extending forecast to 36 hours (updated ECMWF data)/PVNets_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/experiments/uk/011 - Extending forecast to 36 hours (updated ECMWF data)/PVNets_comparison.png -------------------------------------------------------------------------------- /pvnet/__init__.py: -------------------------------------------------------------------------------- 1 | """PVNet""" 2 | __version__ = "4.1.18" 3 | -------------------------------------------------------------------------------- /pvnet/callbacks.py: -------------------------------------------------------------------------------- 1 | """Custom callbacks 2 | """ 3 | from lightning.pytorch import Trainer 4 | from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping, LearningRateFinder 5 | from lightning.pytorch.trainer.states import TrainerFn 6 | 7 | 8 | class PhaseEarlyStopping(EarlyStopping): 9 | """Monitor a validation metric and stop training when it stops improving. 10 | 11 | Only functions in a specific phase of training. 12 | """ 13 | 14 | training_phase = None 15 | 16 | def switch_phase(self, phase: str): 17 | """Switch phase of callback""" 18 | if phase == self.training_phase: 19 | self.activate() 20 | else: 21 | self.deactivate() 22 | 23 | def deactivate(self): 24 | """Deactivate callback""" 25 | self.active = False 26 | 27 | def activate(self): 28 | """Activate callback""" 29 | self.active = True 30 | 31 | def _should_skip_check(self, trainer: Trainer) -> bool: 32 | return ( 33 | (trainer.state.fn != TrainerFn.FITTING) or (trainer.sanity_checking) or not self.active 34 | ) 35 | 36 | 37 | class PretrainEarlyStopping(EarlyStopping): 38 | """Monitor a validation metric and stop training when it stops improving. 39 | 40 | Only functions in the 'pretrain' phase of training. 41 | """ 42 | 43 | training_phase = "pretrain" 44 | 45 | 46 | class MainEarlyStopping(EarlyStopping): 47 | """Monitor a validation metric and stop training when it stops improving. 48 | 49 | Only functions in the 'main' phase of training. 50 | """ 51 | 52 | training_phase = "main" 53 | 54 | 55 | class PretrainFreeze(BaseFinetuning): 56 | """Freeze the satellite and NWP encoders during pretraining""" 57 | 58 | training_phase = "pretrain" 59 | 60 | def __init__(self): 61 | """Freeze the satellite and NWP encoders during pretraining""" 62 | super().__init__() 63 | 64 | def freeze_before_training(self, pl_module): 65 | """Freeze satellite and NWP encoders before training start""" 66 | # freeze any module you want 67 | modules = [] 68 | if pl_module.include_sat: 69 | modules += [pl_module.sat_encoder] 70 | if pl_module.include_nwp: 71 | modules += [pl_module.nwp_encoder] 72 | self.freeze(modules) 73 | 74 | def finetune_function(self, pl_module, current_epoch, optimizer): 75 | """Unfreeze satellite and NWP encoders""" 76 | if not self.active: 77 | modules = [] 78 | if pl_module.include_sat: 79 | modules += [pl_module.sat_encoder] 80 | if pl_module.include_nwp: 81 | modules += [pl_module.nwp_encoder] 82 | self.unfreeze_and_add_param_group( 83 | modules=modules, 84 | optimizer=optimizer, 85 | train_bn=True, 86 | ) 87 | 88 | def switch_phase(self, phase: str): 89 | """Switch phase of callback""" 90 | if phase == self.training_phase: 91 | self.activate() 92 | else: 93 | self.deactivate() 94 | 95 | def deactivate(self): 96 | """Deactivate callback""" 97 | self.active = False 98 | 99 | def activate(self): 100 | """Activate callback""" 101 | self.active = True 102 | 103 | 104 | class PhasedLearningRateFinder(LearningRateFinder): 105 | """Finds a learning rate at the start of each phase of learning""" 106 | 107 | active = True 108 | 109 | def on_fit_start(self, *args, **kwargs): 110 | """Do nothing""" 111 | return 112 | 113 | def on_train_epoch_start(self, trainer, pl_module): 114 | """Run learning rate finder on epoch start and then deactivate""" 115 | if self.active: 116 | self.lr_find(trainer, pl_module) 117 | self.deactivate() 118 | 119 | def switch_phase(self, phase: str): 120 | """Switch training phase""" 121 | self.activate() 122 | 123 | def deactivate(self): 124 | """Deactivate callback""" 125 | self.active = False 126 | 127 | def activate(self): 128 | """Activate callback""" 129 | self.active = True 130 | -------------------------------------------------------------------------------- /pvnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data parts""" 2 | from .site_datamodule import SiteDataModule 3 | from .uk_regional_datamodule import DataModule 4 | -------------------------------------------------------------------------------- /pvnet/data/base_datamodule.py: -------------------------------------------------------------------------------- 1 | """ Data module for pytorch lightning """ 2 | 3 | from glob import glob 4 | 5 | from lightning.pytorch import LightningDataModule 6 | from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch 7 | from ocf_data_sampler.torch_datasets.sample.base import ( 8 | NumpyBatch, 9 | SampleBase, 10 | TensorBatch, 11 | batch_to_tensor, 12 | ) 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | 16 | def collate_fn(samples: list[NumpyBatch]) -> TensorBatch: 17 | """Convert a list of NumpySample samples to a tensor batch""" 18 | return batch_to_tensor(stack_np_samples_into_batch(samples)) 19 | 20 | 21 | class PremadeSamplesDataset(Dataset): 22 | """Dataset to load samples from 23 | 24 | Args: 25 | sample_dir: Path to the directory of pre-saved samples. 26 | sample_class: sample class type to use for save/load/to_numpy 27 | """ 28 | 29 | def __init__(self, sample_dir: str, sample_class: SampleBase): 30 | """Initialise PremadeSamplesDataset""" 31 | self.sample_paths = glob(f"{sample_dir}/*") 32 | self.sample_class = sample_class 33 | 34 | def __len__(self): 35 | return len(self.sample_paths) 36 | 37 | def __getitem__(self, idx): 38 | sample = self.sample_class.load(self.sample_paths[idx]) 39 | return sample.to_numpy() 40 | 41 | 42 | class BaseDataModule(LightningDataModule): 43 | """Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler.""" 44 | 45 | def __init__( 46 | self, 47 | configuration: str | None = None, 48 | sample_dir: str | None = None, 49 | batch_size: int = 16, 50 | num_workers: int = 0, 51 | prefetch_factor: int | None = None, 52 | train_period: list[str | None] = [None, None], 53 | val_period: list[str | None] = [None, None], 54 | ): 55 | """Base Datamodule for training pvnet architecture. 56 | 57 | Can also be used with pre-made batches if `sample_dir` is set. 58 | 59 | Args: 60 | configuration: Path to ocf-data-sampler configuration file. 61 | sample_dir: Path to the directory of pre-saved samples. Cannot be used together with 62 | `configuration` or '[train/val]_period'. 63 | batch_size: Batch size. 64 | num_workers: Number of workers to use in multiprocess batch loading. 65 | prefetch_factor: Number of data will be prefetched at the end of each worker process. 66 | train_period: Date range filter for train dataloader. 67 | val_period: Date range filter for val dataloader. 68 | 69 | """ 70 | super().__init__() 71 | 72 | if not ((sample_dir is not None) ^ (configuration is not None)): 73 | raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.") 74 | 75 | if sample_dir is not None: 76 | if any([period != [None, None] for period in [train_period, val_period]]): 77 | raise ValueError("Cannot set `(train/val)_period` with presaved samples") 78 | 79 | self.configuration = configuration 80 | self.sample_dir = sample_dir 81 | self.train_period = train_period 82 | self.val_period = val_period 83 | 84 | self._common_dataloader_kwargs = dict( 85 | batch_size=batch_size, 86 | sampler=None, 87 | batch_sampler=None, 88 | num_workers=num_workers, 89 | collate_fn=collate_fn, 90 | pin_memory=False, 91 | drop_last=False, 92 | timeout=0, 93 | worker_init_fn=None, 94 | prefetch_factor=prefetch_factor, 95 | persistent_workers=False, 96 | ) 97 | 98 | def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: 99 | raise NotImplementedError 100 | 101 | def _get_premade_samples_dataset(self, subdir) -> Dataset: 102 | raise NotImplementedError 103 | 104 | def train_dataloader(self) -> DataLoader: 105 | """Construct train dataloader""" 106 | if self.sample_dir is not None: 107 | dataset = self._get_premade_samples_dataset("train") 108 | else: 109 | dataset = self._get_streamed_samples_dataset(*self.train_period) 110 | return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) 111 | 112 | def val_dataloader(self) -> DataLoader: 113 | """Construct val dataloader""" 114 | if self.sample_dir is not None: 115 | dataset = self._get_premade_samples_dataset("val") 116 | else: 117 | dataset = self._get_streamed_samples_dataset(*self.val_period) 118 | return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) 119 | -------------------------------------------------------------------------------- /pvnet/data/site_datamodule.py: -------------------------------------------------------------------------------- 1 | """ Data module for pytorch lightning """ 2 | 3 | from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset 4 | from ocf_data_sampler.torch_datasets.sample.site import SiteSample 5 | from torch.utils.data import Dataset 6 | 7 | from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset 8 | 9 | 10 | class SiteDataModule(BaseDataModule): 11 | """Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`.""" 12 | 13 | def __init__( 14 | self, 15 | configuration: str | None = None, 16 | sample_dir: str | None = None, 17 | batch_size: int = 16, 18 | num_workers: int = 0, 19 | prefetch_factor: int | None = None, 20 | train_period: list[str | None] = [None, None], 21 | val_period: list[str | None] = [None, None], 22 | ): 23 | """Datamodule for training pvnet architecture. 24 | 25 | Can also be used with pre-made batches if `sample_dir` is set. 26 | 27 | Args: 28 | configuration: Path to configuration file. 29 | sample_dir: Path to the directory of pre-saved samples. Cannot be used together with 30 | `configuration` or '[train/val]_period'. 31 | batch_size: Batch size. 32 | num_workers: Number of workers to use in multiprocess batch loading. 33 | prefetch_factor: Number of data will be prefetched at the end of each worker process. 34 | train_period: Date range filter for train dataloader. 35 | val_period: Date range filter for val dataloader. 36 | 37 | """ 38 | super().__init__( 39 | configuration=configuration, 40 | sample_dir=sample_dir, 41 | batch_size=batch_size, 42 | num_workers=num_workers, 43 | prefetch_factor=prefetch_factor, 44 | train_period=train_period, 45 | val_period=val_period, 46 | ) 47 | 48 | def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: 49 | return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) 50 | 51 | def _get_premade_samples_dataset(self, subdir) -> Dataset: 52 | split_dir = f"{self.sample_dir}/{subdir}" 53 | return PremadeSamplesDataset(split_dir, SiteSample) 54 | -------------------------------------------------------------------------------- /pvnet/data/uk_regional_datamodule.py: -------------------------------------------------------------------------------- 1 | """ Data module for pytorch lightning """ 2 | 3 | from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset 4 | from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample 5 | from torch.utils.data import Dataset 6 | 7 | from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset 8 | 9 | 10 | class DataModule(BaseDataModule): 11 | """Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`.""" 12 | 13 | def __init__( 14 | self, 15 | configuration: str | None = None, 16 | sample_dir: str | None = None, 17 | batch_size: int = 16, 18 | num_workers: int = 0, 19 | prefetch_factor: int | None = None, 20 | train_period: list[str | None] = [None, None], 21 | val_period: list[str | None] = [None, None], 22 | ): 23 | """Datamodule for training pvnet architecture. 24 | 25 | Can also be used with pre-made batches if `sample_dir` is set. 26 | 27 | Args: 28 | configuration: Path to configuration file. 29 | sample_dir: Path to the directory of pre-saved samples. Cannot be used together with 30 | `configuration` or '[train/val]_period'. 31 | batch_size: Batch size. 32 | num_workers: Number of workers to use in multiprocess batch loading. 33 | prefetch_factor: Number of data will be prefetched at the end of each worker process. 34 | train_period: Date range filter for train dataloader. 35 | val_period: Date range filter for val dataloader. 36 | 37 | """ 38 | super().__init__( 39 | configuration=configuration, 40 | sample_dir=sample_dir, 41 | batch_size=batch_size, 42 | num_workers=num_workers, 43 | prefetch_factor=prefetch_factor, 44 | train_period=train_period, 45 | val_period=val_period, 46 | ) 47 | 48 | def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: 49 | return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) 50 | 51 | def _get_premade_samples_dataset(self, subdir) -> Dataset: 52 | split_dir = f"{self.sample_dir}/{subdir}" 53 | # Returns a dict of np arrays 54 | return PremadeSamplesDataset(split_dir, UKRegionalSample) 55 | -------------------------------------------------------------------------------- /pvnet/load_model.py: -------------------------------------------------------------------------------- 1 | """ Load a model from its checkpoint directory """ 2 | import glob 3 | import os 4 | 5 | import hydra 6 | import torch 7 | from pyaml_env import parse_config 8 | 9 | from pvnet.models.ensemble import Ensemble 10 | from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel 11 | 12 | 13 | def get_model_from_checkpoints( 14 | checkpoint_dir_paths: list[str], 15 | val_best: bool = True, 16 | ): 17 | """Load a model from its checkpoint directory""" 18 | is_ensemble = len(checkpoint_dir_paths) > 1 19 | 20 | model_configs = [] 21 | models = [] 22 | data_configs = [] 23 | 24 | for path in checkpoint_dir_paths: 25 | # Load the model 26 | model_config = parse_config(f"{path}/model_config.yaml") 27 | 28 | model = hydra.utils.instantiate(model_config) 29 | 30 | if val_best: 31 | # Only one epoch (best) saved per model 32 | files = glob.glob(f"{path}/epoch*.ckpt") 33 | if len(files) != 1: 34 | raise ValueError( 35 | f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one." 36 | ) 37 | # TODO: Loading with weights_only=False is not recommended 38 | checkpoint = torch.load(files[0], map_location="cpu", weights_only=False) 39 | else: 40 | checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=False) 41 | 42 | model.load_state_dict(state_dict=checkpoint["state_dict"]) 43 | 44 | if isinstance(model, UMTModel): 45 | model, model_config = model.convert_to_multimodal_model(model_config) 46 | 47 | # Check for data config 48 | data_config = f"{path}/data_config.yaml" 49 | 50 | if os.path.isfile(data_config): 51 | data_configs.append(data_config) 52 | else: 53 | data_configs.append(None) 54 | 55 | model_configs.append(model_config) 56 | models.append(model) 57 | 58 | if is_ensemble: 59 | model_config = { 60 | "_target_": "pvnet.models.ensemble.Ensemble", 61 | "model_list": model_configs, 62 | } 63 | model = Ensemble(model_list=models) 64 | data_config = data_configs[0] 65 | 66 | else: 67 | model_config = model_configs[0] 68 | model = models[0] 69 | data_config = data_configs[0] 70 | 71 | return model, model_config, data_config 72 | -------------------------------------------------------------------------------- /pvnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models for PVNet""" 2 | -------------------------------------------------------------------------------- /pvnet/models/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | """Baselines""" 2 | -------------------------------------------------------------------------------- /pvnet/models/baseline/last_value.py: -------------------------------------------------------------------------------- 1 | """Persistence model""" 2 | 3 | 4 | import pvnet 5 | from pvnet.models.base_model import BaseModel 6 | from pvnet.optimizers import AbstractOptimizer 7 | 8 | 9 | class Model(BaseModel): 10 | """Simple baseline model that takes the last gsp yield value and copies it forward.""" 11 | 12 | name = "last_value" 13 | 14 | def __init__( 15 | self, 16 | forecast_minutes: int = 12, 17 | history_minutes: int = 6, 18 | optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), 19 | ): 20 | """Simple baseline model that takes the last gsp yield value and copies it forward. 21 | 22 | Args: 23 | history_minutes (int): Length of the GSP history period in minutes 24 | forecast_minutes (int): Length of the GSP forecast period in minutes 25 | optimizer (AbstractOptimizer): Optimizer 26 | """ 27 | 28 | super().__init__(history_minutes, forecast_minutes, optimizer) 29 | self.save_hyperparameters() 30 | 31 | def forward(self, x: dict): 32 | """Run model forward on dict batch of data""" 33 | # Shape: batch_size, seq_length, n_sites 34 | gsp_yield = x["gsp"] 35 | 36 | # take the last value non forecaster value and the first in the pv yeild 37 | # (this is the pv site we are preditcting for) 38 | y_hat = gsp_yield[:, -self.forecast_len - 1] 39 | 40 | # expand the last valid forward n predict steps 41 | out = y_hat.unsqueeze(1).repeat(1, self.forecast_len) 42 | return out 43 | -------------------------------------------------------------------------------- /pvnet/models/baseline/readme.md: -------------------------------------------------------------------------------- 1 | # Baseline Models 2 | 3 | - `last_value` - Forecast the sample last historical PV yeild for every forecast step 4 | - `single_value` - Learns a single value estimate and predicts this value for every input and every 5 | forecast step. 6 | -------------------------------------------------------------------------------- /pvnet/models/baseline/single_value.py: -------------------------------------------------------------------------------- 1 | """Average value model""" 2 | import torch 3 | from torch import nn 4 | 5 | import pvnet 6 | from pvnet.models.base_model import BaseModel 7 | from pvnet.optimizers import AbstractOptimizer 8 | 9 | 10 | class Model(BaseModel): 11 | """Simple baseline model that predicts always the same value.""" 12 | 13 | name = "single_value" 14 | 15 | def __init__( 16 | self, 17 | forecast_minutes: int = 120, 18 | history_minutes: int = 60, 19 | optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), 20 | ): 21 | """Simple baseline model that predicts always the same value. 22 | 23 | Args: 24 | history_minutes (int): Length of the GSP history period in minutes 25 | forecast_minutes (int): Length of the GSP forecast period in minutes 26 | optimizer (AbstractOptimizer): Optimizer 27 | """ 28 | super().__init__(history_minutes, forecast_minutes, optimizer) 29 | self._value = nn.Parameter(torch.zeros(1), requires_grad=True) 30 | self.save_hyperparameters() 31 | 32 | def forward(self, x: dict): 33 | """Run model forward on dict batch of data""" 34 | # Returns a single value at all steps 35 | y_hat = torch.zeros_like(x["gsp"][:, : self.forecast_len]) + self._value 36 | return y_hat 37 | -------------------------------------------------------------------------------- /pvnet/models/ensemble.py: -------------------------------------------------------------------------------- 1 | """Model which uses mutliple prediction heads""" 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.base_model import BaseModel 8 | 9 | 10 | class Ensemble(BaseModel): 11 | """Ensemble of PVNet models""" 12 | 13 | def __init__( 14 | self, 15 | model_list: list[BaseModel], 16 | weights: Optional[list[float]] = None, 17 | ): 18 | """Ensemble of PVNet models 19 | 20 | Args: 21 | model_list: A list of PVNet models to ensemble 22 | weights: A list of weighting to apply to each model. If None, the models are weighted 23 | equally. 24 | """ 25 | 26 | # Surface check all the models are compatible 27 | output_quantiles = [] 28 | history_minutes = [] 29 | forecast_minutes = [] 30 | target_key = [] 31 | interval_minutes = [] 32 | 33 | # Get some model properties from each model 34 | for model in model_list: 35 | output_quantiles.append(model.output_quantiles) 36 | history_minutes.append(model.history_minutes) 37 | forecast_minutes.append(model.forecast_minutes) 38 | target_key.append(model._target_key) 39 | interval_minutes.append(model.interval_minutes) 40 | 41 | # Check these properties are all the same 42 | for param_list in [ 43 | output_quantiles, 44 | history_minutes, 45 | forecast_minutes, 46 | target_key, 47 | interval_minutes, 48 | ]: 49 | assert all([p == param_list[0] for p in param_list]), param_list 50 | 51 | super().__init__( 52 | history_minutes=history_minutes[0], 53 | forecast_minutes=forecast_minutes[0], 54 | optimizer=None, 55 | output_quantiles=output_quantiles[0], 56 | target_key=target_key[0], 57 | interval_minutes=interval_minutes[0], 58 | ) 59 | 60 | self.model_list = nn.ModuleList(model_list) 61 | 62 | if weights is None: 63 | weights = torch.ones(len(model_list)) / len(model_list) 64 | else: 65 | assert len(weights) == len(model_list) 66 | weights = torch.Tensor(weights) / sum(weights) 67 | self.weights = nn.Parameter(weights, requires_grad=False) 68 | 69 | def forward(self, batch): 70 | """Run the model forward""" 71 | y_hat = 0 72 | for weight, model in zip(self.weights, self.model_list): 73 | y_hat = model(batch) * weight + y_hat 74 | return y_hat 75 | -------------------------------------------------------------------------------- /pvnet/models/model_cards/pv_india_model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 6 | 7 | 8 | 9 | 10 | # PVNet India 11 | 12 | ## Model Description 13 | 14 | 15 | This model class uses numerical weather predictions from providers such as ECMWF to forecast the PV power in North West India over the next 48 hours. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india). 16 | 17 | 18 | - **Developed by:** openclimatefix 19 | - **Model type:** Fusion model 20 | - **Language(s) (NLP):** en 21 | - **License:** mit 22 | 23 | 24 | # Training Details 25 | 26 | ## Data 27 | 28 | 29 | 30 | The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india) 31 | 32 | 33 | ### Preprocessing 34 | 35 | Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2]. 36 | 37 | 38 | ## Results 39 | 40 | The training logs for the current model can be found here: 41 | {{ wandb_links }} 42 | 43 | 44 | ### Hardware 45 | 46 | Trained on a single NVIDIA Tesla T4 47 | 48 | ### Software 49 | 50 | This model was trained using the following Open Climate Fix packages: 51 | 52 | - [1] https://github.com/openclimatefix/PVNet 53 | - [2] https://github.com/openclimatefix/ocf-data-sampler 54 | 55 | The versions of these packages can be found below: 56 | {{ package_versions }} 57 | -------------------------------------------------------------------------------- /pvnet/models/model_cards/pv_uk_regional_model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 6 | 7 | 8 | 9 | 10 | # PVNet2 11 | 12 | ## Model Description 13 | 14 | 15 | This model class uses satellite data, numerical weather predictions, and recent Grid Service Point( GSP) PV power output to forecast the near-term (~8 hours) PV power output at all GSPs. More information can be found in the model repo [1] and experimental notes in [this google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing). 16 | 17 | - **Developed by:** openclimatefix 18 | - **Model type:** Fusion model 19 | - **Language(s) (NLP):** en 20 | - **License:** mit 21 | 22 | 23 | # Training Details 24 | 25 | ## Data 26 | 27 | 28 | 29 | The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes in the [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) for more details. 30 | 31 | 32 | ### Preprocessing 33 | 34 | Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/pvnet_uk` Dataset [2]. 35 | 36 | 37 | ## Results 38 | 39 | The training logs for the current model can be found here: 40 | {{ wandb_links }} 41 | 42 | The training logs for all model runs of PVNet2 can be found [here](https://wandb.ai/openclimatefix/pvnet2.1). 43 | 44 | Some experimental notes can be found at in [the google doc](https://docs.google.com/document/d/1fbkfkBzp16WbnCg7RDuRDvgzInA6XQu3xh4NCjV-WDA/edit?usp=sharing) 45 | 46 | 47 | ### Hardware 48 | 49 | Trained on a single NVIDIA Tesla T4 50 | 51 | ### Software 52 | 53 | This model was trained using the following Open Climate Fix packages: 54 | 55 | - [1] https://github.com/openclimatefix/PVNet 56 | - [2] https://github.com/openclimatefix/ocf-data-sampler 57 | 58 | The versions of these packages can be found below: 59 | {{ package_versions }} 60 | -------------------------------------------------------------------------------- /pvnet/models/model_cards/wind_india_model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 6 | 7 | 8 | 9 | 10 | # WindNet 11 | 12 | ## Model Description 13 | 14 | 15 | This model class uses numerical weather predictions from providers such as ECMWF to forecast the wind power in North West India over the next 48 hours at 15 minute granularity. More information can be found in the model repo [1] and experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india). 16 | 17 | 18 | - **Developed by:** openclimatefix 19 | - **Model type:** Fusion model 20 | - **Language(s) (NLP):** en 21 | - **License:** mit 22 | 23 | 24 | # Training Details 25 | 26 | ## Data 27 | 28 | 29 | 30 | The model is trained on data from 2019-2022 and validated on data from 2022-2023. See experimental notes [here](https://github.com/openclimatefix/PVNet/tree/main/experiments/india) 31 | 32 | 33 | ### Preprocessing 34 | 35 | Data is prepared with the `ocf_data_sampler/torch_datasets/datasets/site` Dataset [2]. 36 | 37 | 38 | ## Results 39 | 40 | The training logs for the current model can be found here: 41 | {{ wandb_links }} 42 | 43 | 44 | ### Hardware 45 | 46 | Trained on a single NVIDIA Tesla T4 47 | 48 | ### Software 49 | 50 | This model was trained using the following Open Climate Fix packages: 51 | 52 | - [1] https://github.com/openclimatefix/PVNet 53 | - [2] https://github.com/openclimatefix/ocf-data-sampler 54 | 55 | The versions of these packages can be found below: 56 | {{ package_versions }} 57 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | """Multimodal Models""" 2 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic layers for composite models""" 2 | 3 | import warnings 4 | 5 | import torch 6 | from torch import _VF, nn 7 | 8 | 9 | class ImageEmbedding(nn.Module): 10 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.""" 11 | 12 | def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs): 13 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs. 14 | 15 | The embedding is a single 2D image and is appended at each step in the 1st dimension 16 | (assumed to be time). 17 | 18 | Args: 19 | num_embeddings: Size of the dictionary of embeddings 20 | sequence_length: The time sequence length of the data. 21 | image_size_pixels: The spatial size of the image. Assumed square. 22 | **kwargs: See `torch.nn.Embedding` for more possible arguments. 23 | """ 24 | super().__init__() 25 | self.image_size_pixels = image_size_pixels 26 | self.sequence_length = sequence_length 27 | self._embed = nn.Embedding( 28 | num_embeddings=num_embeddings, 29 | embedding_dim=image_size_pixels * image_size_pixels, 30 | **kwargs, 31 | ) 32 | 33 | def forward(self, x, id): 34 | """Append ID embedding to image""" 35 | emb = self._embed(id) 36 | emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels)) 37 | emb = emb.repeat(1, 1, self.sequence_length, 1, 1) 38 | x = torch.cat((x, emb), dim=1) 39 | return x 40 | 41 | 42 | class CompleteDropoutNd(nn.Module): 43 | """A layer used to completely drop out all elements of a N-dimensional sample. 44 | 45 | Each sample will be zeroed out independently on every forward call with probability `p` using 46 | samples from a Bernoulli distribution. 47 | 48 | """ 49 | 50 | __constants__ = ["p", "inplace", "n_dim"] 51 | p: float 52 | inplace: bool 53 | n_dim: int 54 | 55 | def __init__(self, n_dim, p=0.5, inplace=False): 56 | """A layer used to completely drop out all elements of a N-dimensional sample. 57 | 58 | Args: 59 | n_dim: Number of dimensions of each sample not including channels. E.g. a sample with 60 | shape (channel, time, height, width) would use `n_dim=3`. 61 | p: probability of a channel to be zeroed. Default: 0.5 62 | training: apply dropout if is `True`. Default: `True` 63 | inplace: If set to `True`, will do this operation in-place. Default: `False` 64 | """ 65 | super().__init__() 66 | if p < 0 or p > 1: 67 | raise ValueError( 68 | "dropout probability has to be between 0 and 1, " "but got {}".format(p) 69 | ) 70 | self.p = p 71 | self.inplace = inplace 72 | self.n_dim = n_dim 73 | 74 | def forward(self, input: torch.Tensor) -> torch.Tensor: 75 | """Run dropout""" 76 | p = self.p 77 | inp_dim = input.dim() 78 | 79 | if inp_dim not in (self.n_dim + 1, self.n_dim + 2): 80 | warn_msg = ( 81 | f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample" 82 | f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}" 83 | " dimensions." 84 | ) 85 | warnings.warn(warn_msg) 86 | 87 | is_batched = inp_dim == self.n_dim + 2 88 | if not is_batched: 89 | input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0) 90 | 91 | input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1) 92 | 93 | result = ( 94 | _VF.feature_dropout_(input, p, self.training) 95 | if self.inplace 96 | else _VF.feature_dropout(input, p, self.training) 97 | ) 98 | 99 | result = result.squeeze_(1) if self.inplace else result.squeeze(1) 100 | 101 | if not is_batched: 102 | result = result.squeeze_(0) if self.inplace else result.squeeze(0) 103 | 104 | return result 105 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode satellite and NWP inputs""" 2 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for image sequence encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta): 9 | """Abstract class for NWP/satellite encoder. 10 | 11 | The encoder will take an input of shape (batch_size, sequence_length, channels, height, width) 12 | and return an output of shape (batch_size, out_features). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sequence_length: int, 18 | image_size_pixels: int, 19 | in_channels: int, 20 | out_features: int, 21 | ): 22 | """Abstract class for NWP/satellite encoder. 23 | 24 | Args: 25 | sequence_length: The time sequence length of the data. 26 | image_size_pixels: The spatial size of the image. Assumed square. 27 | in_channels: Number of input channels. 28 | out_features: Number of output features. 29 | """ 30 | super().__init__() 31 | self.out_features = out_features 32 | self.image_size_pixels = image_size_pixels 33 | self.sequence_length = sequence_length 34 | 35 | @abstractmethod 36 | def forward(self): 37 | """Run model forward""" 38 | pass 39 | 40 | 41 | class ResidualConv3dBlock(nn.Module): 42 | """Fully-connected deep network based on ResNet architecture. 43 | 44 | Internally, this network uses ELU activations throughout the residual blocks. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | in_channels, 50 | n_layers: int = 2, 51 | dropout_frac: float = 0.0, 52 | ): 53 | """Fully-connected deep network based on ResNet architecture. 54 | 55 | Args: 56 | in_channels: Number of input channels. 57 | n_layers: Number of layers in residual pathway. 58 | dropout_frac: Probability of an element to be zeroed. 59 | """ 60 | super().__init__() 61 | 62 | layers = [] 63 | for i in range(n_layers): 64 | layers += [ 65 | nn.ELU(), 66 | nn.Conv3d( 67 | in_channels=in_channels, 68 | out_channels=in_channels, 69 | kernel_size=(3, 3, 3), 70 | padding=(1, 1, 1), 71 | ), 72 | nn.Dropout3d(p=dropout_frac), 73 | ] 74 | 75 | self.model = nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | """Run residual connection""" 79 | return self.model(x) + x 80 | 81 | 82 | class ResidualConv3dBlock2(nn.Module): 83 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 84 | 85 | This was the best performing residual block tested in the study. This implementation differs 86 | from that block just by using LeakyReLU activation to avoid dead neurons, and by including 87 | optional dropout in the residual branch. This is also a 3D fully connected layer residual block 88 | rather than a 2D convolutional block. 89 | 90 | Sources: 91 | [1] https://arxiv.org/pdf/1603.05027.pdf 92 | """ 93 | 94 | def __init__( 95 | self, 96 | in_channels: int, 97 | n_layers: int = 2, 98 | dropout_frac: float = 0.0, 99 | batch_norm: bool = True, 100 | ): 101 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 102 | 103 | Sources: 104 | [1] https://arxiv.org/pdf/1603.05027.pdf 105 | 106 | Args: 107 | in_channels: Number of input channels. 108 | n_layers: Number of layers in residual pathway. 109 | dropout_frac: Probability of an element to be zeroed. 110 | batch_norm: Whether to use batchnorm 111 | """ 112 | super().__init__() 113 | 114 | layers = [] 115 | for i in range(n_layers): 116 | if batch_norm: 117 | layers.append(nn.BatchNorm3d(in_channels)) 118 | layers.extend( 119 | [ 120 | nn.Dropout3d(p=dropout_frac), 121 | nn.LeakyReLU(), 122 | nn.Conv3d( 123 | in_channels=in_channels, 124 | out_channels=in_channels, 125 | kernel_size=(3, 3, 3), 126 | padding=(1, 1, 1), 127 | ), 128 | ] 129 | ) 130 | 131 | self.model = nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | """Run model forward""" 135 | return self.model(x) + x 136 | 137 | 138 | class ImageSequenceEncoder(nn.Module): 139 | """Simple network which independently encodes each image in a sequence into 1D features""" 140 | 141 | def __init__( 142 | self, 143 | image_size_pixels: int, 144 | in_channels: int, 145 | number_of_conv2d_layers: int = 4, 146 | conv2d_channels: int = 32, 147 | fc_features: int = 128, 148 | ): 149 | """Simple network which independently encodes each image in a sequence into 1D features. 150 | 151 | For input image with shape [N, C, L, H, W] the output is of shape [N, L, fc_features] where 152 | N is number of samples in batch, C is the number of input channels, L is the length of the 153 | sequence, and H and W are the height and width. 154 | 155 | Args: 156 | image_size_pixels: The spatial size of the image. Assumed square. 157 | in_channels: Number of input channels. 158 | number_of_conv2d_layers: Number of convolution 2D layers that are used. 159 | conv2d_channels: Number of channels used in each conv2d layer. 160 | fc_features: Number of output nodes for each image in each sequence. 161 | """ 162 | super().__init__() 163 | 164 | # Check that the output shape of the convolutional layers will be at least 1x1 165 | cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv2d_layers 166 | if not (cnn_spatial_output_size >= 1): 167 | raise ValueError( 168 | f"cannot use this many conv2d layers ({number_of_conv2d_layers}) with this input " 169 | f"spatial size ({image_size_pixels})" 170 | ) 171 | 172 | conv_layers = [] 173 | 174 | conv_layers += [ 175 | nn.Conv2d( 176 | in_channels=in_channels, 177 | out_channels=conv2d_channels, 178 | kernel_size=3, 179 | padding=0, 180 | ), 181 | nn.ELU(), 182 | ] 183 | for i in range(0, number_of_conv2d_layers - 1): 184 | conv_layers += [ 185 | nn.Conv2d( 186 | in_channels=conv2d_channels, 187 | out_channels=conv2d_channels, 188 | kernel_size=3, 189 | padding=0, 190 | ), 191 | nn.ELU(), 192 | ] 193 | 194 | self.conv_layers = nn.Sequential(*conv_layers) 195 | 196 | self.final_block = nn.Sequential( 197 | nn.Linear( 198 | in_features=(cnn_spatial_output_size**2) * conv2d_channels, 199 | out_features=fc_features, 200 | ), 201 | nn.ELU(), 202 | ) 203 | 204 | def forward(self, x): 205 | """Run model forward""" 206 | batch_size, channel, seq_len, height, width = x.shape 207 | 208 | x = torch.swapaxes(x, 1, 2) 209 | x = x.reshape(batch_size * seq_len, channel, height, width) 210 | 211 | out = self.conv_layers(x) 212 | out = out.reshape(batch_size * seq_len, -1) 213 | 214 | out = self.final_block(out) 215 | out = out.reshape(batch_size, seq_len, -1) 216 | 217 | return out 218 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/encoders/encodersRNN.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for the satellite/NWP data based on recursive and 2D convolutional layers. 2 | """ 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.multimodal.encoders.basic_blocks import ( 8 | AbstractNWPSatelliteEncoder, 9 | ImageSequenceEncoder, 10 | ) 11 | 12 | 13 | class ConvLSTM(AbstractNWPSatelliteEncoder): 14 | """Convolutional LSTM block from MetNet.""" 15 | 16 | def __init__( 17 | self, 18 | sequence_length: int, 19 | image_size_pixels: int, 20 | in_channels: int, 21 | out_features: int, 22 | hidden_channels: int = 32, 23 | num_layers: int = 2, 24 | kernel_size: int = 3, 25 | bias: bool = True, 26 | activation=torch.tanh, 27 | batchnorm=False, 28 | ): 29 | """Convolutional LSTM block from MetNet. 30 | 31 | Args: 32 | sequence_length: The time sequence length of the data. 33 | image_size_pixels: The spatial size of the image. Assumed square. 34 | in_channels: Number of input channels. 35 | out_features: Number of output features. 36 | hidden_channels: Hidden dimension size. 37 | num_layers: Depth of ConvLSTM cells. 38 | kernel_size: Kernel size. 39 | bias: Whether to add bias. 40 | activation: Activation function for ConvLSTM cells. 41 | batchnorm: Whether to use batch norm. 42 | """ 43 | from metnet.layers.ConvLSTM import ConvLSTM as _ConvLSTM 44 | 45 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 46 | 47 | self.conv_lstm = _ConvLSTM( 48 | input_dim=in_channels, 49 | hidden_dim=hidden_channels, 50 | kernel_size=kernel_size, 51 | num_layers=num_layers, 52 | bias=bias, 53 | activation=activation, 54 | batchnorm=batchnorm, 55 | ) 56 | 57 | # Calculate the size of the output of the ConvLSTM network 58 | convlstm_output_size = hidden_channels * image_size_pixels**2 59 | 60 | self.final_block = nn.Sequential( 61 | nn.Linear(in_features=convlstm_output_size, out_features=out_features), 62 | nn.ELU(), 63 | ) 64 | 65 | def forward(self, x): 66 | """Run model forward""" 67 | 68 | batch_size, channel, seq_len, height, width = x.shape 69 | x = torch.swapaxes(x, 1, 2) 70 | 71 | res, _ = self.conv_lstm(x) 72 | 73 | # Select last state only 74 | out = res[:, -1] 75 | 76 | # Flatten and fully connected layer 77 | out = out.reshape(batch_size, -1) 78 | out = self.final_block(out) 79 | 80 | return out 81 | 82 | 83 | class FlattenLSTM(AbstractNWPSatelliteEncoder): 84 | """Convolutional blocks followed by LSTM.""" 85 | 86 | def __init__( 87 | self, 88 | sequence_length: int, 89 | image_size_pixels: int, 90 | in_channels: int, 91 | out_features: int, 92 | num_layers: int = 2, 93 | number_of_conv2d_layers: int = 4, 94 | conv2d_channels: int = 32, 95 | ): 96 | """Network consisting of 2D spatial convolutional and LSTM sequence encoder. 97 | 98 | Args: 99 | sequence_length: The time sequence length of the data. 100 | image_size_pixels: The spatial size of the image. Assumed square. 101 | in_channels: Number of input channels. 102 | out_features: Number of output features. Also used for LSTM hidden dimension. 103 | num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking 104 | two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of 105 | the first LSTM and computing the final results. 106 | number_of_conv2d_layers: Number of convolution 2D layers that are used. 107 | conv2d_channels: Number of channels used in each conv2d layer. 108 | """ 109 | 110 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 111 | 112 | self.lstm = nn.LSTM( 113 | input_size=out_features, 114 | hidden_size=out_features, 115 | num_layers=num_layers, 116 | batch_first=True, 117 | ) 118 | 119 | self.encode_image_sequence = ImageSequenceEncoder( 120 | image_size_pixels=image_size_pixels, 121 | in_channels=in_channels, 122 | number_of_conv2d_layers=number_of_conv2d_layers, 123 | conv2d_channels=conv2d_channels, 124 | fc_features=out_features, 125 | ) 126 | 127 | self.final_block = nn.Sequential( 128 | nn.Linear(in_features=out_features, out_features=out_features), 129 | nn.ELU(), 130 | ) 131 | 132 | def forward(self, x): 133 | """Run model forward""" 134 | encoded_images = self.encode_image_sequence(x) 135 | 136 | _, (_, c_n) = self.lstm(encoded_images) 137 | 138 | # Take only the deepest level hidden cell state 139 | out = self.final_block(c_n[-1]) 140 | 141 | return out 142 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/linear_networks/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to combine 1D feature vectors from different sources and make final predictions""" 2 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/linear_networks/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for the lienar networks""" 2 | from abc import ABCMeta, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class AbstractLinearNetwork(nn.Module, metaclass=ABCMeta): 10 | """Abstract class for a network to combine the features from all the inputs.""" 11 | 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | ): 17 | """Abstract class for a network to combine the features from all the inputs. 18 | 19 | Args: 20 | in_features: Number of input features. 21 | out_features: Number of output features. 22 | """ 23 | super().__init__() 24 | 25 | def cat_modes(self, x): 26 | """Concatenate modes of input data into 1D feature vector""" 27 | if isinstance(x, OrderedDict): 28 | return torch.cat([value for key, value in x.items()], dim=1) 29 | elif isinstance(x, torch.Tensor): 30 | return x 31 | else: 32 | raise ValueError(f"Input of unexpected type {type(x)}") 33 | 34 | @abstractmethod 35 | def forward(self): 36 | """Run model forward""" 37 | pass 38 | 39 | 40 | class ResidualLinearBlock(nn.Module): 41 | """A 1D fully-connected residual block using ELU activations and including optional dropout.""" 42 | 43 | def __init__( 44 | self, 45 | in_features: int, 46 | n_layers: int = 2, 47 | dropout_frac: float = 0.0, 48 | ): 49 | """A 1D fully-connected residual block using ELU activations and including optional dropout. 50 | 51 | Args: 52 | in_features: Number of input features. 53 | n_layers: Number of layers in residual pathway. 54 | dropout_frac: Probability of an element to be zeroed. 55 | """ 56 | super().__init__() 57 | 58 | layers = [] 59 | for i in range(n_layers): 60 | layers += [ 61 | nn.ELU(), 62 | nn.Linear( 63 | in_features=in_features, 64 | out_features=in_features, 65 | ), 66 | nn.Dropout(p=dropout_frac), 67 | ] 68 | self.model = nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | """Run model forward""" 72 | return self.model(x) + x 73 | 74 | 75 | class ResidualLinearBlock2(nn.Module): 76 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 77 | 78 | This was the best performing residual block tested in the study. This implementation differs 79 | from that block just by using LeakyReLU activation to avoid dead neuron, and by including 80 | optional dropout in the residual branch. This is also a 1D fully connected layer residual block 81 | rather than a 2D convolutional block. 82 | 83 | Sources: 84 | [1] https://arxiv.org/pdf/1603.05027.pdf 85 | """ 86 | 87 | def __init__( 88 | self, 89 | in_features: int, 90 | n_layers: int = 2, 91 | dropout_frac: float = 0.0, 92 | ): 93 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 94 | 95 | Sources: 96 | [1] https://arxiv.org/pdf/1603.05027.pdf 97 | 98 | Args: 99 | in_features: Number of input features. 100 | n_layers: Number of layers in residual pathway. 101 | dropout_frac: Probability of an element to be zeroed. 102 | """ 103 | super().__init__() 104 | 105 | layers = [] 106 | for i in range(n_layers): 107 | layers += [ 108 | nn.BatchNorm1d(in_features), 109 | nn.Dropout(p=dropout_frac), 110 | nn.LeakyReLU(), 111 | nn.Linear( 112 | in_features=in_features, 113 | out_features=in_features, 114 | ), 115 | ] 116 | 117 | self.model = nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | """Run model forward""" 121 | return self.model(x) + x 122 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/readme.md: -------------------------------------------------------------------------------- 1 | ## Multimodal model architecture 2 | 3 | These models fusion models to predict GSP power output based on NWP, non-HRV satellite, GSP output history, solor coordinates, and GSP ID. 4 | 5 | The core model is `multimodel.Model`, and its architecture is shown in the diagram below. 6 | 7 | ![multimodal_model_diagram](https://github.com/openclimatefix/PVNet/assets/41546094/118393fa-52ec-4bfe-a0a3-268c94c25f1e) 8 | 9 | This model uses encoders which take 4D (time, channel, x, y) inputs of NWP and satellite and encode them into 1D feature vectors. Different encoders are contained inside `encoders`. 10 | 11 | Different choices for the fusion model are contained inside `linear_networks`. 12 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/site_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode site-level PV data""" 2 | -------------------------------------------------------------------------------- /pvnet/models/multimodal/site_encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for PV-site encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from torch import nn 5 | 6 | 7 | class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta): 8 | """Abstract class for encoder for output data from multiple PV sites. 9 | 10 | The encoder will take an input of shape (batch_size, sequence_length, num_sites) 11 | and return an output of shape (batch_size, out_features). 12 | """ 13 | 14 | def __init__( 15 | self, 16 | sequence_length: int, 17 | num_sites: int, 18 | out_features: int, 19 | ): 20 | """Abstract class for PV site-level encoder. 21 | 22 | Args: 23 | sequence_length: The time sequence length of the data. 24 | num_sites: Number of PV sites in the input data. 25 | out_features: Number of output features. 26 | """ 27 | super().__init__() 28 | self.sequence_length = sequence_length 29 | self.num_sites = num_sites 30 | self.out_features = out_features 31 | 32 | @abstractmethod 33 | def forward(self): 34 | """Run model forward""" 35 | pass 36 | -------------------------------------------------------------------------------- /pvnet/models/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions""" 2 | 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class PredAccumulator: 14 | """A class for accumulating y-predictions using grad accumulation and small batch size. 15 | 16 | Attributes: 17 | _y_hats (list[torch.Tensor]): List of prediction tensors 18 | """ 19 | 20 | def __init__(self): 21 | """Prediction accumulator""" 22 | self._y_hats = [] 23 | 24 | def __bool__(self): 25 | return len(self._y_hats) > 0 26 | 27 | def append(self, y_hat: torch.Tensor): 28 | """Append a sub-batch of predictions""" 29 | self._y_hats.append(y_hat) 30 | 31 | def flush(self) -> torch.Tensor: 32 | """Return all appended predictions as single tensor and remove from accumulated store.""" 33 | y_hat = torch.cat(self._y_hats, dim=0) 34 | self._y_hats = [] 35 | return y_hat 36 | 37 | 38 | class DictListAccumulator: 39 | """Abstract class for accumulating dictionaries of lists""" 40 | 41 | @staticmethod 42 | def _dict_list_append(d1, d2): 43 | for k, v in d2.items(): 44 | d1[k].append(v) 45 | 46 | @staticmethod 47 | def _dict_init_list(d): 48 | return {k: [v] for k, v in d.items()} 49 | 50 | 51 | class MetricAccumulator(DictListAccumulator): 52 | """Dictionary of metrics accumulator. 53 | 54 | A class for accumulating, and finding the mean of logging metrics when using grad 55 | accumulation and the batch size is small. 56 | 57 | Attributes: 58 | _metrics (Dict[str, list[float]]): Dictionary containing lists of metrics. 59 | """ 60 | 61 | def __init__(self): 62 | """Dictionary of metrics accumulator.""" 63 | self._metrics = {} 64 | 65 | def __bool__(self): 66 | return self._metrics != {} 67 | 68 | def append(self, loss_dict: dict[str, float]): 69 | """Append lictionary of metrics to self""" 70 | if not self: 71 | self._metrics = self._dict_init_list(loss_dict) 72 | else: 73 | self._dict_list_append(self._metrics, loss_dict) 74 | 75 | def flush(self) -> dict[str, float]: 76 | """Calculate mean of all accumulated metrics and clear""" 77 | mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()} 78 | self._metrics = {} 79 | return mean_metrics 80 | 81 | 82 | class BatchAccumulator(DictListAccumulator): 83 | """A class for accumulating batches when using grad accumulation and the batch size is small. 84 | 85 | Attributes: 86 | _batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics. 87 | """ 88 | 89 | def __init__(self, key_to_keep: str = "gsp"): 90 | """Batch accumulator""" 91 | self._batches = {} 92 | self.key_to_keep = key_to_keep 93 | 94 | def __bool__(self): 95 | return self._batches != {} 96 | 97 | # @staticmethod 98 | def _filter_batch_dict(self, d): 99 | keep_keys = [ 100 | self.key_to_keep, 101 | f"{self.key_to_keep}_id", 102 | f"{self.key_to_keep}_t0_idx", 103 | f"{self.key_to_keep}_time_utc", 104 | ] 105 | return {k: v for k, v in d.items() if k in keep_keys} 106 | 107 | def append(self, batch: dict[str, list[torch.Tensor]]): 108 | """Append batch to self""" 109 | if not self: 110 | self._batches = self._dict_init_list(self._filter_batch_dict(batch)) 111 | else: 112 | self._dict_list_append(self._batches, self._filter_batch_dict(batch)) 113 | 114 | def flush(self) -> dict[str, list[torch.Tensor]]: 115 | """Concatenate all accumulated batches, return, and clear self""" 116 | batch = {} 117 | for k, v in self._batches.items(): 118 | if k == f"{self.key_to_keep}_t0_idx": 119 | batch[k] = v[0] 120 | else: 121 | batch[k] = torch.cat(v, dim=0) 122 | self._batches = {} 123 | return batch 124 | -------------------------------------------------------------------------------- /pvnet/optimizers.py: -------------------------------------------------------------------------------- 1 | """Optimizer factory-function classes. 2 | """ 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | 8 | 9 | class AbstractOptimizer(ABC): 10 | """Abstract class for optimizer 11 | 12 | Optimizer classes will be used by model like: 13 | > OptimizerGenerator = AbstractOptimizer() 14 | > optimizer = OptimizerGenerator(model) 15 | The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s 16 | `configure_optimizers()` method. 17 | See : 18 | https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers 19 | 20 | """ 21 | 22 | @abstractmethod 23 | def __call__(self): 24 | """Abstract call""" 25 | pass 26 | 27 | 28 | class Adam(AbstractOptimizer): 29 | """Adam optimizer""" 30 | 31 | def __init__(self, lr=0.0005, **kwargs): 32 | """Adam optimizer""" 33 | self.lr = lr 34 | self.kwargs = kwargs 35 | 36 | def __call__(self, model): 37 | """Return optimizer""" 38 | return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs) 39 | 40 | 41 | class AdamW(AbstractOptimizer): 42 | """AdamW optimizer""" 43 | 44 | def __init__(self, lr=0.0005, **kwargs): 45 | """AdamW optimizer""" 46 | self.lr = lr 47 | self.kwargs = kwargs 48 | 49 | def __call__(self, model): 50 | """Return optimizer""" 51 | return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) 52 | 53 | 54 | def find_submodule_parameters(model, search_modules): 55 | """Finds all parameters within given submodule types 56 | 57 | Args: 58 | model: torch Module to search through 59 | search_modules: List of submodule types to search for 60 | """ 61 | if isinstance(model, search_modules): 62 | return model.parameters() 63 | 64 | children = list(model.children()) 65 | if len(children) == 0: 66 | return [] 67 | else: 68 | params = [] 69 | for c in children: 70 | params += find_submodule_parameters(c, search_modules) 71 | return params 72 | 73 | 74 | def find_other_than_submodule_parameters(model, ignore_modules): 75 | """Finds all parameters not with given submodule types 76 | 77 | Args: 78 | model: torch Module to search through 79 | ignore_modules: List of submodule types to ignore 80 | """ 81 | if isinstance(model, ignore_modules): 82 | return [] 83 | 84 | children = list(model.children()) 85 | if len(children) == 0: 86 | return model.parameters() 87 | else: 88 | params = [] 89 | for c in children: 90 | params += find_other_than_submodule_parameters(c, ignore_modules) 91 | return params 92 | 93 | 94 | class EmbAdamWReduceLROnPlateau(AbstractOptimizer): 95 | """AdamW optimizer and reduce on plateau scheduler""" 96 | 97 | def __init__( 98 | self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs 99 | ): 100 | """AdamW optimizer and reduce on plateau scheduler""" 101 | self.lr = lr 102 | self.weight_decay = weight_decay 103 | self.patience = patience 104 | self.factor = factor 105 | self.threshold = threshold 106 | self.opt_kwargs = opt_kwargs 107 | 108 | def __call__(self, model): 109 | """Return optimizer""" 110 | 111 | search_modules = (torch.nn.Embedding,) 112 | 113 | no_decay = find_submodule_parameters(model, search_modules) 114 | decay = find_other_than_submodule_parameters(model, search_modules) 115 | 116 | optim_groups = [ 117 | {"params": decay, "weight_decay": self.weight_decay}, 118 | {"params": no_decay, "weight_decay": 0.0}, 119 | ] 120 | opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs) 121 | 122 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 123 | opt, 124 | factor=self.factor, 125 | patience=self.patience, 126 | threshold=self.threshold, 127 | ) 128 | sch = { 129 | "scheduler": sch, 130 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 131 | } 132 | return [opt], [sch] 133 | 134 | 135 | class AdamWReduceLROnPlateau(AbstractOptimizer): 136 | """AdamW optimizer and reduce on plateau scheduler""" 137 | 138 | def __init__( 139 | self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs 140 | ): 141 | """AdamW optimizer and reduce on plateau scheduler""" 142 | self._lr = lr 143 | self.patience = patience 144 | self.factor = factor 145 | self.threshold = threshold 146 | self.step_freq = step_freq 147 | self.opt_kwargs = opt_kwargs 148 | 149 | def _call_multi(self, model): 150 | remaining_params = {k: p for k, p in model.named_parameters()} 151 | 152 | group_args = [] 153 | 154 | for key in self._lr.keys(): 155 | if key == "default": 156 | continue 157 | 158 | submodule_params = [] 159 | for param_name in list(remaining_params.keys()): 160 | if param_name.startswith(key): 161 | submodule_params += [remaining_params.pop(param_name)] 162 | 163 | group_args += [{"params": submodule_params, "lr": self._lr[key]}] 164 | 165 | remaining_params = [p for k, p in remaining_params.items()] 166 | group_args += [{"params": remaining_params}] 167 | 168 | opt = torch.optim.AdamW( 169 | group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs 170 | ) 171 | sch = { 172 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 173 | opt, 174 | factor=self.factor, 175 | patience=self.patience, 176 | threshold=self.threshold, 177 | ), 178 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 179 | } 180 | 181 | return [opt], [sch] 182 | 183 | def __call__(self, model): 184 | """Return optimizer""" 185 | if not isinstance(self._lr, float): 186 | return self._call_multi(model) 187 | else: 188 | default_lr = self._lr if model.lr is None else model.lr 189 | opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs) 190 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 191 | opt, 192 | factor=self.factor, 193 | patience=self.patience, 194 | threshold=self.threshold, 195 | ) 196 | sch = { 197 | "scheduler": sch, 198 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 199 | } 200 | return [opt], [sch] 201 | -------------------------------------------------------------------------------- /pvnet/training.py: -------------------------------------------------------------------------------- 1 | """Training""" 2 | import os 3 | import shutil 4 | from typing import Optional 5 | 6 | import hydra 7 | import torch 8 | from lightning.pytorch import ( 9 | Callback, 10 | LightningDataModule, 11 | LightningModule, 12 | Trainer, 13 | seed_everything, 14 | ) 15 | from lightning.pytorch.callbacks import ModelCheckpoint 16 | from lightning.pytorch.loggers import Logger 17 | from lightning.pytorch.loggers.wandb import WandbLogger 18 | from omegaconf import DictConfig, OmegaConf 19 | 20 | from pvnet import utils 21 | 22 | log = utils.get_logger(__name__) 23 | 24 | torch.set_default_dtype(torch.float32) 25 | 26 | 27 | def _callbacks_to_phase(callbacks, phase): 28 | for c in callbacks: 29 | if hasattr(c, "switch_phase"): 30 | c.switch_phase(phase) 31 | 32 | 33 | def resolve_monitor_loss(output_quantiles): 34 | """Return the desired metric to monitor based on whether quantile regression is being used. 35 | 36 | The adds the option to use something like: 37 | monitor: "${resolve_monitor_loss:${model.output_quantiles}}" 38 | 39 | in early stopping and model checkpoint callbacks so the callbacks config does not need to be 40 | modified depending on whether quantile regression is being used or not. 41 | """ 42 | if output_quantiles is None: 43 | return "MAE/val" 44 | else: 45 | return "quantile_loss/val" 46 | 47 | 48 | OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss) 49 | 50 | 51 | def train(config: DictConfig) -> Optional[float]: 52 | """Contains training pipeline. 53 | 54 | Instantiates all PyTorch Lightning objects from config. 55 | 56 | Args: 57 | config (DictConfig): Configuration composed by Hydra. 58 | 59 | Returns: 60 | Optional[float]: Metric score for hyperparameter optimization. 61 | """ 62 | 63 | # Set seed for random number generators in pytorch, numpy and python.random 64 | if "seed" in config: 65 | seed_everything(config.seed, workers=True) 66 | 67 | # Init lightning datamodule 68 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 69 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 70 | 71 | # Init lightning model 72 | log.info(f"Instantiating model <{config.model._target_}>") 73 | model: LightningModule = hydra.utils.instantiate(config.model) 74 | 75 | # Init lightning loggers 76 | loggers: list[Logger] = [] 77 | if "logger" in config: 78 | for _, lg_conf in config.logger.items(): 79 | if "_target_" in lg_conf: 80 | log.info(f"Instantiating logger <{lg_conf._target_}>") 81 | loggers.append(hydra.utils.instantiate(lg_conf)) 82 | 83 | # Init lightning callbacks 84 | callbacks: list[Callback] = [] 85 | if "callbacks" in config: 86 | for _, cb_conf in config.callbacks.items(): 87 | if "_target_" in cb_conf: 88 | log.info(f"Instantiating callback <{cb_conf._target_}>") 89 | callbacks.append(hydra.utils.instantiate(cb_conf)) 90 | 91 | # Align the wandb id with the checkpoint path 92 | # - only works if wandb logger and model checkpoint used 93 | # - this makes it easy to push the model to huggingface 94 | use_wandb_logger = False 95 | for logger in loggers: 96 | log.info(f"{logger}") 97 | if isinstance(logger, WandbLogger): 98 | use_wandb_logger = True 99 | wandb_logger = logger 100 | break 101 | 102 | if use_wandb_logger: 103 | for callback in callbacks: 104 | log.info(f"{callback}") 105 | if isinstance(callback, ModelCheckpoint): 106 | # Need to call the .experiment property to initialise the logger 107 | wandb_logger.experiment 108 | callback.dirpath = "/".join( 109 | callback.dirpath.split("/")[:-1] + [wandb_logger.version] 110 | ) 111 | # Also save model config here - this makes for easy model push to huggingface 112 | os.makedirs(callback.dirpath, exist_ok=True) 113 | OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml") 114 | 115 | # Similarly save the data config 116 | data_config = config.datamodule.configuration 117 | if data_config is None: 118 | # Data config can be none if using presaved batches. We go to the presaved 119 | # batches to get the data config 120 | data_config = f"{config.datamodule.sample_dir}/data_configuration.yaml" 121 | 122 | assert os.path.isfile(data_config), f"Data config file not found: {data_config}" 123 | shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml") 124 | 125 | # upload configuration up to wandb 126 | OmegaConf.save(config, "./experiment_config.yaml") 127 | wandb_logger.experiment.save( 128 | f"{callback.dirpath}/data_config.yaml", callback.dirpath 129 | ) 130 | wandb_logger.experiment.save("./experiment_config.yaml") 131 | 132 | break 133 | 134 | should_pretrain = False 135 | for c in callbacks: 136 | should_pretrain |= hasattr(c, "training_phase") and c.training_phase == "pretrain" 137 | 138 | if should_pretrain: 139 | _callbacks_to_phase(callbacks, "pretrain") 140 | 141 | trainer: Trainer = hydra.utils.instantiate( 142 | config.trainer, 143 | logger=loggers, 144 | _convert_="partial", 145 | callbacks=callbacks, 146 | ) 147 | 148 | # TODO: remove this option 149 | if should_pretrain: 150 | # Pre-train the model 151 | raise NotImplementedError("Pre-training is not yet supported") 152 | # The parameter `block_nwp_and_sat` is not available in data-sampler 153 | # If pretraining is re-supported in the future it is likely any pre-training logic should 154 | # go here or perhaps in the callbacks 155 | # datamodule.block_nwp_and_sat = True 156 | 157 | trainer.fit(model=model, datamodule=datamodule) 158 | 159 | _callbacks_to_phase(callbacks, "main") 160 | 161 | trainer.should_stop = False 162 | 163 | # Train the model completely 164 | trainer.fit(model=model, datamodule=datamodule) 165 | 166 | # Make sure everything closed properly 167 | log.info("Finalizing!") 168 | utils.finish( 169 | config=config, 170 | model=model, 171 | datamodule=datamodule, 172 | trainer=trainer, 173 | callbacks=callbacks, 174 | loggers=loggers, 175 | ) 176 | 177 | # Print path to best checkpoint 178 | log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") 179 | 180 | # Return metric score for hyperparameter optimization 181 | optimized_metric = config.get("optimized_metric") 182 | if optimized_metric: 183 | return trainer.callback_metrics[optimized_metric] 184 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name="PVNet" 3 | description = "PVNet" 4 | authors = [{name="Peter Dudfield", email="info@openclimatefix.org"}] 5 | dynamic = ["version", "readme"] 6 | license={file="LICENCE"} 7 | 8 | dependencies = [ 9 | "ocf-data-sampler>=0.2.10", 10 | "numpy", 11 | "pandas", 12 | "matplotlib", 13 | "xarray", 14 | "h5netcdf", 15 | "torch>=2.0.0", 16 | "lightning", 17 | "torchvision", 18 | "pytest", 19 | "pytest-cov", 20 | "typer", 21 | "sqlalchemy", 22 | "fsspec[s3]", 23 | "wandb", 24 | "huggingface-hub", 25 | "tqdm", 26 | "omegaconf", 27 | "hydra-core", 28 | "rich", 29 | "einops", 30 | ] 31 | 32 | [tool.setuptools.dynamic] 33 | version = {attr = "pvnet.__version__"} 34 | readme = {file = "README.md", content-type = "text/markdown"} 35 | 36 | [tool.setuptools.package-dir] 37 | "pvnet" = "pvnet" 38 | 39 | [project.optional-dependencies] 40 | dev=[ 41 | "pvlive-api", 42 | "ruff", 43 | "mypy", 44 | "pre-commit", 45 | "pytest", 46 | "pytest-cov", 47 | ] 48 | all_models=[ 49 | "pytorch-tabnet", 50 | "efficientnet_pytorch" 51 | ] 52 | all=["PVNet[dev,all_models]"] 53 | 54 | [tool.mypy] 55 | exclude = [ 56 | "^tests/", 57 | ] 58 | disallow_untyped_defs = true 59 | disallow_any_unimported = true 60 | no_implicit_optional = true 61 | check_untyped_defs = true 62 | warn_return_any = true 63 | warn_unused_ignores = true 64 | show_error_codes = true 65 | warn_unreachable = true 66 | 67 | [[tool.mypy.overrides]] 68 | module = [ 69 | ] 70 | ignore_missing_imports = true 71 | 72 | [tool.pytest.ini_options] 73 | minversion = "6.0" 74 | addopts = "-ra -q" 75 | testpaths = [ 76 | "tests", 77 | ] 78 | 79 | [tool.ruff] 80 | line-length = 100 81 | exclude = [ 82 | ".ipynb_checkpoints", 83 | "configs.example", 84 | ".bzr", 85 | ".direnv", 86 | ".eggs", 87 | ".git", 88 | ".hg", 89 | ".mypy_cache", 90 | ".nox", 91 | ".pants.d", 92 | ".pytype", 93 | ".ruff_cache", 94 | ".svn", 95 | ".tox", 96 | ".venv", 97 | "__pypackages__", 98 | "_build", 99 | "buck-out", 100 | "build", 101 | "dist", 102 | "node_modules", 103 | "venv", 104 | "tests", 105 | ] 106 | 107 | # Assume Python 3.10. 108 | target-version = "py310" 109 | fix = false 110 | 111 | [tool.ruff.lint] 112 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 113 | # Allow autofix for all enabled rules (when `--fix`) is provided. 114 | fixable = ["A", "B", "C", "D", "E", "F", "I"] 115 | unfixable = [] 116 | select = ["E", "F", "D", "I"] 117 | ignore-init-module-imports = true 118 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 119 | ignore = ["D200","D202","D210","D212","D415","D105",] 120 | 121 | [tool.ruff.lint.mccabe] 122 | # Unlike Flake8, default to a complexity level of 10. 123 | max-complexity = 10 124 | 125 | [tool.ruff.lint.pydocstyle] 126 | # Use Google-style docstrings. 127 | convention = "google" 128 | 129 | [tool.ruff.lint.per-file-ignores] 130 | "__init__.py" = ["F401", "E402"] 131 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Run training 2 | """ 3 | 4 | import os 5 | 6 | import torch 7 | 8 | try: 9 | torch.multiprocessing.set_start_method("spawn") 10 | import torch.multiprocessing as mp 11 | 12 | mp.set_start_method("spawn") 13 | except RuntimeError: 14 | pass 15 | 16 | import logging 17 | import sys 18 | 19 | # Tired of seeing these warnings 20 | import warnings 21 | 22 | import hydra 23 | from omegaconf import DictConfig 24 | from sqlalchemy import exc as sa_exc 25 | 26 | warnings.filterwarnings("ignore", category=sa_exc.SAWarning) 27 | 28 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 29 | 30 | os.environ["HYDRA_FULL_ERROR"] = "1" 31 | 32 | 33 | # this file can be run for example using 34 | # python run.py experiment=example_simple 35 | 36 | 37 | @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.2") 38 | def main(config: DictConfig): 39 | """Runs training""" 40 | # Imports should be nested inside @hydra.main to optimize tab completion 41 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 42 | from pvnet.training import train 43 | from pvnet.utils import extras, print_config 44 | 45 | # A couple of optional utilities: 46 | # - disabling python warnings 47 | # - easier access to debug mode 48 | # - forcing debug friendly configuration 49 | # - forcing multi-gpu friendly configuration 50 | # You can safely get rid of this line if you don't want those 51 | extras(config) 52 | 53 | # Pretty print config using Rich library 54 | if config.get("print_config"): 55 | print_config(config, resolve=True) 56 | 57 | # Train model 58 | return train(config) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /scripts/checkpoint_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """Command line tool to push locally save model checkpoints to huggingface 2 | 3 | use: 4 | python checkpoint_to_huggingface.py "path/to/model/checkpoints" \ 5 | --huggingface-repo="openclimatefix/pvnet_uk_region" \ 6 | --wandb-repo="openclimatefix/pvnet2.1" \ 7 | --local-path="~/tmp/this_model" \ 8 | --no-push-to-hub 9 | """ 10 | 11 | import tempfile 12 | 13 | import typer 14 | import wandb 15 | 16 | from pvnet.load_model import get_model_from_checkpoints 17 | 18 | app = typer.Typer(pretty_exceptions_show_locals=False) 19 | 20 | @app.command() 21 | def push_to_huggingface( 22 | checkpoint_dir_paths: list[str], 23 | huggingface_repo: str = "openclimatefix/pvnet_uk_region", # e.g. openclimatefix/windnet_india 24 | wandb_repo: str = "openclimatefix/pvnet2.1", 25 | val_best: bool = True, 26 | wandb_ids: list[str] = [], 27 | local_path: str = None, 28 | push_to_hub: bool = True, 29 | ): 30 | """Push a local model to a huggingface model repo 31 | 32 | Args: 33 | checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) 34 | huggingface_repo: Name of the HuggingFace repo to push the model to 35 | wandb_repo: Name of the wandb repo which has training logs 36 | val_best: Use best model according to val loss, else last saved model 37 | wandb_ids: The wandb ID code(s) 38 | local_path: Where to save the local copy of the model 39 | push_to_hub: Whether to push the model to the hub or just create local version. 40 | """ 41 | 42 | assert push_to_hub or local_path is not None 43 | 44 | is_ensemble = len(checkpoint_dir_paths) > 1 45 | 46 | # Check if checkpoint dir name is wandb run ID 47 | if wandb_ids == []: 48 | all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)] 49 | for path in checkpoint_dir_paths: 50 | dirname = path.split("/")[-1] 51 | if dirname in all_wandb_ids: 52 | wandb_ids.append(dirname) 53 | else: 54 | wandb_ids.append(None) 55 | 56 | model, model_config, data_config = get_model_from_checkpoints(checkpoint_dir_paths, val_best) 57 | 58 | if not is_ensemble: 59 | wandb_ids = wandb_ids[0] 60 | 61 | # Push to hub 62 | if local_path is None: 63 | temp_dir = tempfile.TemporaryDirectory() 64 | model_output_dir = temp_dir.name 65 | else: 66 | model_output_dir = local_path 67 | 68 | model.save_pretrained( 69 | model_output_dir, 70 | config=model_config, 71 | data_config=data_config, 72 | wandb_repo=wandb_repo, 73 | wandb_ids=wandb_ids, 74 | push_to_hub=push_to_hub, 75 | repo_id=huggingface_repo if push_to_hub else None, 76 | ) 77 | 78 | if local_path is None: 79 | temp_dir.cleanup() 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /scripts/save_concurrent_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs batches where each batch includes all GSPs and only a single timestamp. 3 | 4 | Currently a slightly hacky implementation due to the way the configs are done. This script will use 5 | the same config file currently set to train the model. In the datamodule config it is possible 6 | to set the batch_output_dir and number of train/val batches, they can also be overriden in the 7 | command as shown in the example below. 8 | 9 | use: 10 | ``` 11 | python save_concurrent_samples.py \ 12 | +datamodule.sample_output_dir="/mnt/disks/concurrent_batches/concurrent_samples_sat_pred_test" \ 13 | +datamodule.num_train_samples=20 \ 14 | +datamodule.num_val_samples=20 15 | ``` 16 | 17 | """ 18 | # Ensure this block of code runs only in the main process to avoid issues with worker processes. 19 | if __name__ == "__main__": 20 | import torch.multiprocessing as mp 21 | 22 | # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be 23 | # compatible with dask's multiprocessing. 24 | mp.set_start_method("forkserver") 25 | 26 | # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is 27 | # important because libraries like Zarr may open many files, which can exhaust the file 28 | # descriptor limit if too many workers are used. 29 | mp.set_sharing_strategy("file_system") 30 | 31 | 32 | import logging 33 | import os 34 | import shutil 35 | import sys 36 | import warnings 37 | 38 | import hydra 39 | import numpy as np 40 | import torch 41 | from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset 42 | from omegaconf import DictConfig, OmegaConf 43 | from sqlalchemy import exc as sa_exc 44 | from torch.utils.data import DataLoader, Dataset 45 | from tqdm import tqdm 46 | 47 | from pvnet.utils import print_config 48 | 49 | # ------- filter warning and set up config ------- 50 | 51 | warnings.filterwarnings("ignore", category=sa_exc.SAWarning) 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 56 | 57 | # ------------------------------------------------- 58 | 59 | 60 | class SaveFuncFactory: 61 | """Factory for creating a function to save a sample to disk.""" 62 | 63 | def __init__(self, save_dir: str): 64 | """Factory for creating a function to save a sample to disk.""" 65 | self.save_dir = save_dir 66 | 67 | def __call__(self, sample, sample_num: int): 68 | """Save a sample to disk""" 69 | torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") 70 | 71 | 72 | def save_samples_with_dataloader( 73 | dataset: Dataset, 74 | save_dir: str, 75 | num_samples: int, 76 | dataloader_kwargs: dict, 77 | ) -> None: 78 | """Save samples from a dataset using a dataloader.""" 79 | save_func = SaveFuncFactory(save_dir) 80 | 81 | gsp_ids = np.array([loc.id for loc in dataset.locations]) 82 | 83 | dataloader = DataLoader(dataset, **dataloader_kwargs) 84 | 85 | pbar = tqdm(total=num_samples) 86 | for i, sample in zip(range(num_samples), dataloader): 87 | check_sample(sample, gsp_ids) 88 | save_func(sample, i) 89 | pbar.update() 90 | pbar.close() 91 | 92 | 93 | def check_sample(sample, gsp_ids): 94 | """Check if sample is valid concurrent batch for all GSPs""" 95 | # Check all GSP IDs are included and in correct order 96 | assert (sample["gsp_id"].flatten().numpy() == gsp_ids).all() 97 | # Check all times are the same 98 | assert len(np.unique(sample["gsp_time_utc"][:, 0].numpy())) == 1 99 | 100 | 101 | @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") 102 | def main(config: DictConfig) -> None: 103 | """Constructs and saves validation and training samples.""" 104 | config_dm = config.datamodule 105 | 106 | print_config(config, resolve=False) 107 | 108 | # Set up directory 109 | os.makedirs(config_dm.sample_output_dir, exist_ok=False) 110 | 111 | # Copy across configs which define the samples into the new sample directory 112 | with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f: 113 | f.write(OmegaConf.to_yaml(config_dm)) 114 | 115 | shutil.copyfile( 116 | config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml" 117 | ) 118 | 119 | # Define the keywargs going into the train and val dataloaders 120 | dataloader_kwargs = dict( 121 | shuffle=True, 122 | batch_size=None, 123 | sampler=None, 124 | batch_sampler=None, 125 | num_workers=config_dm.num_workers, 126 | collate_fn=None, 127 | pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial 128 | drop_last=False, 129 | timeout=0, 130 | worker_init_fn=None, 131 | prefetch_factor=config_dm.prefetch_factor, 132 | persistent_workers=False, # Not needed since we only enter the dataloader loop once 133 | ) 134 | 135 | if config_dm.num_val_samples > 0: 136 | print("----- Saving val samples -----") 137 | 138 | val_output_dir = f"{config_dm.sample_output_dir}/val" 139 | 140 | # Make directory for val samples 141 | os.mkdir(val_output_dir) 142 | 143 | # Get the dataset 144 | val_dataset = PVNetUKConcurrentDataset( 145 | config_dm.configuration, 146 | start_time=config_dm.val_period[0], 147 | end_time=config_dm.val_period[1], 148 | ) 149 | 150 | # Save samples 151 | save_samples_with_dataloader( 152 | dataset=val_dataset, 153 | save_dir=val_output_dir, 154 | num_samples=config_dm.num_val_samples, 155 | dataloader_kwargs=dataloader_kwargs, 156 | ) 157 | 158 | del val_dataset 159 | 160 | if config_dm.num_train_samples > 0: 161 | print("----- Saving train samples -----") 162 | 163 | train_output_dir = f"{config_dm.sample_output_dir}/train" 164 | 165 | # Make directory for train samples 166 | os.mkdir(train_output_dir) 167 | 168 | # Get the dataset 169 | train_dataset = PVNetUKConcurrentDataset( 170 | config_dm.configuration, 171 | start_time=config_dm.train_period[0], 172 | end_time=config_dm.train_period[1], 173 | ) 174 | 175 | # Save samples 176 | save_samples_with_dataloader( 177 | dataset=train_dataset, 178 | save_dir=train_output_dir, 179 | num_samples=config_dm.num_train_samples, 180 | dataloader_kwargs=dataloader_kwargs, 181 | ) 182 | 183 | del train_dataset 184 | 185 | print("----- Saving complete -----") 186 | 187 | 188 | if __name__ == "__main__": 189 | main() 190 | -------------------------------------------------------------------------------- /scripts/save_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs samples and saves them to disk. 3 | 4 | Currently a slightly hacky implementation due to the way the configs are done. This script will use 5 | the same config file currently set to train the model. 6 | 7 | use: 8 | ``` 9 | python save_samples.py 10 | ``` 11 | if setting all values in the datamodule config file, or 12 | 13 | ``` 14 | python save_samples.py \ 15 | +datamodule.sample_output_dir="/mnt/disks/bigbatches/samples_v0" \ 16 | +datamodule.num_train_samples=0 \ 17 | +datamodule.num_val_samples=2 \ 18 | datamodule.num_workers=2 \ 19 | datamodule.prefetch_factor=2 20 | ``` 21 | if wanting to override these values for example 22 | """ 23 | 24 | # Ensure this block of code runs only in the main process to avoid issues with worker processes. 25 | if __name__ == "__main__": 26 | import torch.multiprocessing as mp 27 | 28 | # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be 29 | # compatible with dask's multiprocessing. 30 | mp.set_start_method("forkserver") 31 | 32 | # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is 33 | # important because libraries like Zarr may open many files, which can exhaust the file 34 | # descriptor limit if too many workers are used. 35 | mp.set_sharing_strategy("file_system") 36 | 37 | 38 | import logging 39 | import os 40 | import shutil 41 | import sys 42 | import warnings 43 | 44 | import dask 45 | import hydra 46 | from ocf_data_sampler.torch_datasets.datasets import PVNetUKRegionalDataset, SitesDataset 47 | from ocf_data_sampler.torch_datasets.sample.site import SiteSample 48 | from ocf_data_sampler.torch_datasets.sample.uk_regional import UKRegionalSample 49 | from omegaconf import DictConfig, OmegaConf 50 | from sqlalchemy import exc as sa_exc 51 | from torch.utils.data import DataLoader, Dataset 52 | from tqdm import tqdm 53 | 54 | from pvnet.utils import print_config 55 | 56 | dask.config.set(scheduler="threads", num_workers=4) 57 | 58 | 59 | # ------- filter warning and set up config ------- 60 | 61 | warnings.filterwarnings("ignore", category=sa_exc.SAWarning) 62 | 63 | logger = logging.getLogger(__name__) 64 | 65 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 66 | 67 | # ------------------------------------------------- 68 | 69 | 70 | class SaveFuncFactory: 71 | """Factory for creating a function to save a sample to disk.""" 72 | 73 | def __init__(self, save_dir: str, renewable: str = "pv_uk"): 74 | """Factory for creating a function to save a sample to disk.""" 75 | self.save_dir = save_dir 76 | self.renewable = renewable 77 | 78 | def __call__(self, sample, sample_num: int): 79 | """Save a sample to disk""" 80 | save_path = f"{self.save_dir}/{sample_num:08}" 81 | 82 | if self.renewable == "pv_uk": 83 | sample_class = UKRegionalSample(sample) 84 | filename = f"{save_path}.pt" 85 | elif self.renewable == "site": 86 | sample_class = SiteSample(sample) 87 | filename = f"{save_path}.nc" 88 | else: 89 | raise ValueError(f"Unknown renewable: {self.renewable}") 90 | # Assign data and save 91 | sample_class._data = sample 92 | sample_class.save(filename) 93 | 94 | 95 | def get_dataset( 96 | config_path: str, start_time: str, end_time: str, renewable: str = "pv_uk" 97 | ) -> Dataset: 98 | """Get the dataset for the given renewable type.""" 99 | if renewable == "pv_uk": 100 | dataset_cls = PVNetUKRegionalDataset 101 | elif renewable == "site": 102 | dataset_cls = SitesDataset 103 | else: 104 | raise ValueError(f"Unknown renewable: {renewable}") 105 | 106 | return dataset_cls(config_path, start_time=start_time, end_time=end_time) 107 | 108 | 109 | def save_samples_with_dataloader( 110 | dataset: Dataset, 111 | save_dir: str, 112 | num_samples: int, 113 | dataloader_kwargs: dict, 114 | renewable: str = "pv_uk", 115 | ) -> None: 116 | """Save samples from a dataset using a dataloader.""" 117 | save_func = SaveFuncFactory(save_dir, renewable=renewable) 118 | 119 | dataloader = DataLoader(dataset, **dataloader_kwargs) 120 | 121 | pbar = tqdm(total=num_samples) 122 | for i, sample in zip(range(num_samples), dataloader): 123 | save_func(sample, i) 124 | pbar.update() 125 | pbar.close() 126 | 127 | 128 | @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") 129 | def main(config: DictConfig) -> None: 130 | """Constructs and saves validation and training samples.""" 131 | config_dm = config.datamodule 132 | 133 | print_config(config, resolve=False) 134 | 135 | # Set up directory 136 | os.makedirs(config_dm.sample_output_dir, exist_ok=False) 137 | 138 | # Copy across configs which define the samples into the new sample directory 139 | with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f: 140 | f.write(OmegaConf.to_yaml(config_dm)) 141 | 142 | shutil.copyfile( 143 | config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml" 144 | ) 145 | 146 | # Define the keywargs going into the train and val dataloaders 147 | dataloader_kwargs = dict( 148 | shuffle=True, 149 | batch_size=None, 150 | sampler=None, 151 | batch_sampler=None, 152 | num_workers=config_dm.num_workers, 153 | collate_fn=None, 154 | pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial 155 | drop_last=False, 156 | timeout=0, 157 | worker_init_fn=None, 158 | prefetch_factor=config_dm.prefetch_factor, 159 | persistent_workers=False, # Not needed since we only enter the dataloader loop once 160 | ) 161 | 162 | if config_dm.num_val_samples > 0: 163 | print("----- Saving val samples -----") 164 | 165 | val_output_dir = f"{config_dm.sample_output_dir}/val" 166 | 167 | # Make directory for val samples 168 | os.mkdir(val_output_dir) 169 | 170 | # Get the dataset 171 | val_dataset = get_dataset( 172 | config_dm.configuration, 173 | *config_dm.val_period, 174 | renewable=config.renewable, 175 | ) 176 | 177 | # Save samples 178 | save_samples_with_dataloader( 179 | dataset=val_dataset, 180 | save_dir=val_output_dir, 181 | num_samples=config_dm.num_val_samples, 182 | dataloader_kwargs=dataloader_kwargs, 183 | renewable=config.renewable, 184 | ) 185 | 186 | del val_dataset 187 | 188 | if config_dm.num_train_samples > 0: 189 | print("----- Saving train samples -----") 190 | 191 | train_output_dir = f"{config_dm.sample_output_dir}/train" 192 | 193 | # Make directory for train samples 194 | os.mkdir(train_output_dir) 195 | 196 | # Get the dataset 197 | train_dataset = get_dataset( 198 | config_dm.configuration, 199 | *config_dm.train_period, 200 | renewable=config.renewable, 201 | ) 202 | 203 | # Save samples 204 | save_samples_with_dataloader( 205 | dataset=train_dataset, 206 | save_dir=train_output_dir, 207 | num_samples=config_dm.num_train_samples, 208 | dataloader_kwargs=dataloader_kwargs, 209 | renewable=config.renewable, 210 | ) 211 | 212 | del train_dataset 213 | 214 | print("----- Saving complete -----") 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/PVNet/a241a2b9bd09f02f91c06559737d8e8dd77194b0/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/test_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import pytest 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import xarray as xr 8 | from pvnet.data import DataModule, SiteDataModule 9 | 10 | @pytest.fixture 11 | def temp_pt_sample_dir(): 12 | """Create temporary directory with synthetic PT samples""" 13 | with tempfile.TemporaryDirectory() as tmpdirname: 14 | # Create train and val directories 15 | os.makedirs(f"{tmpdirname}/train", exist_ok=True) 16 | os.makedirs(f"{tmpdirname}/val", exist_ok=True) 17 | 18 | # Generate and save synthetic samples 19 | for i in range(5): 20 | sample = { 21 | "gsp": torch.rand(21), 22 | "gsp_time_utc": torch.tensor(list(range(21))), 23 | "gsp_nominal_capacity_mwp": torch.tensor(100.0), 24 | "gsp_id": 12 25 | } 26 | torch.save(sample, f"{tmpdirname}/train/{i:08d}.pt") 27 | torch.save(sample, f"{tmpdirname}/val/{i:08d}.pt") 28 | 29 | yield tmpdirname 30 | 31 | 32 | @pytest.fixture 33 | def temp_nc_sample_dir(): 34 | """Create temporary directory with synthetic NC site samples""" 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | # Create train and val directories 37 | os.makedirs(f"{tmpdirname}/train", exist_ok=True) 38 | os.makedirs(f"{tmpdirname}/val", exist_ok=True) 39 | 40 | # Create config file 41 | config_path = f"{tmpdirname}/data_configuration.yaml" 42 | with open(config_path, "w") as f: 43 | f.write(f"sample_dir: {tmpdirname}\n") 44 | 45 | # Generate and save synthetic site samples 46 | for i in range(5): 47 | site_time = pd.date_range("2023-01-01", periods=10, freq="15min") 48 | ds = xr.Dataset( 49 | data_vars={ 50 | "site": (["site__time_utc"], np.random.rand(10)), 51 | }, 52 | coords={ 53 | "site__time_utc": site_time, 54 | "site__site_id": np.int32(i % 3 + 1), 55 | "site__latitude": 52.5, 56 | "site__longitude": -1.5, 57 | "site__capacity_kwp": 10000.0, 58 | } 59 | ) 60 | 61 | ds.to_netcdf(f"{tmpdirname}/train/{i:08d}.nc", mode="w", engine="h5netcdf") 62 | ds.to_netcdf(f"{tmpdirname}/val/{i:08d}.nc", mode="w", engine="h5netcdf") 63 | 64 | yield tmpdirname 65 | 66 | 67 | def test_init(temp_pt_sample_dir): 68 | """Test DataModule initialization""" 69 | dm = DataModule( 70 | configuration=None, 71 | sample_dir=temp_pt_sample_dir, 72 | batch_size=2, 73 | num_workers=0, 74 | prefetch_factor=None, 75 | train_period=[None, None], 76 | val_period=[None, None], 77 | ) 78 | 79 | # Verify datamodule initialisation 80 | assert dm is not None 81 | assert hasattr(dm, "train_dataloader") 82 | 83 | 84 | def test_iter(temp_pt_sample_dir): 85 | """Test iteration through DataModule""" 86 | dm = DataModule( 87 | configuration=None, 88 | sample_dir=temp_pt_sample_dir, 89 | batch_size=2, 90 | num_workers=0, 91 | prefetch_factor=None, 92 | train_period=[None, None], 93 | val_period=[None, None], 94 | ) 95 | 96 | # Verify existing keys 97 | batch = next(iter(dm.train_dataloader())) 98 | assert batch is not None 99 | assert "gsp" in batch 100 | 101 | 102 | def test_iter_multiprocessing(temp_pt_sample_dir): 103 | """Test DataModule with multiple workers""" 104 | dm = DataModule( 105 | configuration=None, 106 | sample_dir=temp_pt_sample_dir, 107 | batch_size=1, 108 | num_workers=2, 109 | prefetch_factor=1, 110 | train_period=[None, None], 111 | val_period=[None, None], 112 | ) 113 | 114 | served_batches = 0 115 | for batch in dm.train_dataloader(): 116 | served_batches += 1 117 | 118 | if served_batches == 2: 119 | break 120 | 121 | # Batch verification 122 | assert served_batches == 2 123 | 124 | 125 | def test_site_init_sample_dir(temp_nc_sample_dir): 126 | """Test SiteDataModule initialization with sample dir""" 127 | dm = SiteDataModule( 128 | configuration=None, 129 | sample_dir=temp_nc_sample_dir, 130 | batch_size=2, 131 | num_workers=0, 132 | prefetch_factor=None, 133 | train_period=[None, None], 134 | val_period=[None, None], 135 | ) 136 | 137 | # Verify datamodule initialisation 138 | assert dm is not None 139 | assert hasattr(dm, "train_dataloader") 140 | 141 | 142 | def test_site_init_config(temp_nc_sample_dir): 143 | """Test SiteDataModule initialization with config file""" 144 | config_path = f"{temp_nc_sample_dir}/data_configuration.yaml" 145 | 146 | dm = SiteDataModule( 147 | configuration=config_path, 148 | batch_size=2, 149 | num_workers=0, 150 | prefetch_factor=None, 151 | train_period=[None, None], 152 | val_period=[None, None], 153 | sample_dir=None, 154 | ) 155 | 156 | # Verify datamodule initialisation w/ config 157 | assert dm is not None 158 | assert hasattr(dm, "train_dataloader") 159 | -------------------------------------------------------------------------------- /tests/models/baseline/test_last_value.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.baseline.last_value import Model 2 | import pytest 3 | 4 | 5 | @pytest.fixture() 6 | def last_value_model(model_minutes_kwargs): 7 | model = Model(**model_minutes_kwargs) 8 | return model 9 | 10 | 11 | def test_model_forward(last_value_model, sample_batch): 12 | y = last_value_model(sample_batch) 13 | 14 | # check output is the correct shape 15 | # batch size=2, forecast_len=15 16 | assert tuple(y.shape) == (2, 16), y.shape 17 | -------------------------------------------------------------------------------- /tests/models/baseline/test_single_value.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.baseline.single_value import Model 2 | import pytest 3 | 4 | 5 | @pytest.fixture() 6 | def single_value_model(model_minutes_kwargs): 7 | model = Model(**model_minutes_kwargs) 8 | return model 9 | 10 | 11 | def test_model_forward(single_value_model, sample_batch): 12 | y = single_value_model(sample_batch) 13 | 14 | # check output is the correct shape 15 | # batch size=2, forecast_len=15 16 | assert tuple(y.shape) == (2, 16), y.shape 17 | -------------------------------------------------------------------------------- /tests/models/multimodal/encoders/test_encoders2d.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.multimodal.encoders.encoders2d import ( 2 | NaiveEfficientNet, 3 | NaiveResNet, 4 | NaiveConvNeXt, 5 | CNBlockConfig, 6 | ) 7 | import pytest 8 | 9 | 10 | @pytest.fixture() 11 | def convnext_model_kwargs(encoder_model_kwargs): 12 | model_kwargs = {k: v for k, v in encoder_model_kwargs.items()} 13 | model_kwargs["block_setting"] = [ 14 | CNBlockConfig(96, 192, 3), 15 | CNBlockConfig(192, 384, 3), 16 | CNBlockConfig(384, 768, 9), 17 | CNBlockConfig(768, None, 3), 18 | ] 19 | return model_kwargs 20 | 21 | 22 | def _test_model_forward(batch, model_class, model_kwargs): 23 | model = model_class(**model_kwargs) 24 | y = model(batch) 25 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 26 | 27 | 28 | def _test_model_backward(batch, model_class, model_kwargs): 29 | model = model_class(**model_kwargs) 30 | y = model(batch) 31 | # Backwards on sum drives sum to zero 32 | y.sum().backward() 33 | 34 | 35 | # Test model forward on all models 36 | def test_naiveefficientnet_forward(sample_satellite_batch, encoder_model_kwargs): 37 | # Skip if optional dependency not installed 38 | pytest.importorskip("efficientnet_pytorch") 39 | _test_model_forward(sample_satellite_batch, NaiveEfficientNet, encoder_model_kwargs) 40 | 41 | 42 | def test_naiveresnet_forward(sample_satellite_batch, encoder_model_kwargs): 43 | _test_model_forward(sample_satellite_batch, NaiveResNet, encoder_model_kwargs) 44 | 45 | 46 | def test_convnext_forward(sample_satellite_batch, convnext_model_kwargs): 47 | _test_model_forward(sample_satellite_batch, NaiveConvNeXt, convnext_model_kwargs) 48 | 49 | 50 | # Test model backward on all models 51 | def test_naiveefficientnet_backward(sample_satellite_batch, encoder_model_kwargs): 52 | # Skip if optional dependency not installed 53 | pytest.importorskip("efficientnet_pytorch") 54 | _test_model_backward(sample_satellite_batch, NaiveEfficientNet, encoder_model_kwargs) 55 | 56 | 57 | def test_naiveresnet_backward(sample_satellite_batch, encoder_model_kwargs): 58 | _test_model_backward(sample_satellite_batch, NaiveResNet, encoder_model_kwargs) 59 | 60 | 61 | def test_convnext_backward(sample_satellite_batch, convnext_model_kwargs): 62 | _test_model_backward(sample_satellite_batch, NaiveConvNeXt, convnext_model_kwargs) 63 | -------------------------------------------------------------------------------- /tests/models/multimodal/encoders/test_encoders3d.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.multimodal.encoders.encoders3d import ( 2 | DefaultPVNet, 3 | DefaultPVNet2, 4 | ResConv3DNet2, 5 | EncoderUNET, 6 | ) 7 | import pytest 8 | 9 | 10 | def _test_model_forward(batch, model_class, model_kwargs): 11 | model = model_class(**model_kwargs) 12 | y = model(batch) 13 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 14 | 15 | 16 | def _test_model_backward(batch, model_class, model_kwargs): 17 | model = model_class(**model_kwargs) 18 | y = model(batch) 19 | # Backwards on sum drives sum to zero 20 | y.sum().backward() 21 | 22 | 23 | # Test model forward on all models 24 | def test_defaultpvnet_forward(sample_satellite_batch, encoder_model_kwargs): 25 | _test_model_forward(sample_satellite_batch, DefaultPVNet, encoder_model_kwargs) 26 | 27 | 28 | def test_defaultpvnet2_forward(sample_satellite_batch, encoder_model_kwargs): 29 | _test_model_forward(sample_satellite_batch, DefaultPVNet2, encoder_model_kwargs) 30 | 31 | 32 | def test_resconv3dnet2_forward(sample_satellite_batch, encoder_model_kwargs): 33 | _test_model_forward(sample_satellite_batch, ResConv3DNet2, encoder_model_kwargs) 34 | 35 | 36 | def test_encoderunet_forward(sample_satellite_batch, encoder_model_kwargs): 37 | _test_model_forward(sample_satellite_batch, EncoderUNET, encoder_model_kwargs) 38 | 39 | 40 | # Test model backward on all models 41 | def test_defaultpvnet_backward(sample_satellite_batch, encoder_model_kwargs): 42 | _test_model_backward(sample_satellite_batch, DefaultPVNet, encoder_model_kwargs) 43 | 44 | 45 | def test_defaultpvnet2_backward(sample_satellite_batch, encoder_model_kwargs): 46 | _test_model_backward(sample_satellite_batch, DefaultPVNet2, encoder_model_kwargs) 47 | 48 | 49 | def test_resconv3dnet2_backward(sample_satellite_batch, encoder_model_kwargs): 50 | _test_model_backward(sample_satellite_batch, ResConv3DNet2, encoder_model_kwargs) 51 | 52 | 53 | def test_encoderunet_backward(sample_satellite_batch, encoder_model_kwargs): 54 | _test_model_backward(sample_satellite_batch, EncoderUNET, encoder_model_kwargs) 55 | -------------------------------------------------------------------------------- /tests/models/multimodal/encoders/test_encodersRNN.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.multimodal.encoders.encodersRNN import ( 2 | ConvLSTM, 3 | FlattenLSTM, 4 | ) 5 | import pytest 6 | 7 | 8 | def _test_model_forward(batch, model_class, model_kwargs): 9 | model = model_class(**model_kwargs) 10 | y = model(batch) 11 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 12 | 13 | 14 | def _test_model_backward(batch, model_class, model_kwargs): 15 | model = model_class(**model_kwargs) 16 | y = model(batch) 17 | # Backwards on sum drives sum to zero 18 | y.sum().backward() 19 | 20 | 21 | # Test model forward on all models 22 | def test_convlstm_forward(sample_satellite_batch, encoder_model_kwargs): 23 | # Skip if optional dependency not installed 24 | pytest.importorskip("metnet") 25 | _test_model_forward(sample_satellite_batch, ConvLSTM, encoder_model_kwargs) 26 | 27 | 28 | def test_flattenlstm_forward(sample_satellite_batch, encoder_model_kwargs): 29 | _test_model_forward(sample_satellite_batch, FlattenLSTM, encoder_model_kwargs) 30 | 31 | 32 | # Test model backward on all models 33 | def test_convlstm_backward(sample_satellite_batch, encoder_model_kwargs): 34 | # Skip if optional dependency not installed 35 | pytest.importorskip("metnet") 36 | _test_model_backward(sample_satellite_batch, ConvLSTM, encoder_model_kwargs) 37 | 38 | 39 | def test_flattenlstm_backward(sample_satellite_batch, encoder_model_kwargs): 40 | _test_model_backward(sample_satellite_batch, FlattenLSTM, encoder_model_kwargs) 41 | -------------------------------------------------------------------------------- /tests/models/multimodal/linear_networks/test_networks.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.multimodal.linear_networks.networks import ( 2 | DefaultFCNet, 3 | ResFCNet, 4 | ResFCNet2, 5 | SNN, 6 | TabNet, 7 | ) 8 | import pytest 9 | import torch 10 | from collections import OrderedDict 11 | 12 | 13 | @pytest.fixture() 14 | def simple_linear_batch(): 15 | return torch.rand(2, 100) 16 | 17 | 18 | @pytest.fixture() 19 | def multimodal_linear_batch(): 20 | return OrderedDict(nwp=torch.rand(2, 50), sat=torch.rand(2, 40), sun=torch.rand(2, 10)) 21 | 22 | 23 | @pytest.fixture() 24 | def multiple_batch_types(simple_linear_batch, multimodal_linear_batch): 25 | return [simple_linear_batch, multimodal_linear_batch] 26 | 27 | 28 | @pytest.fixture() 29 | def fc_batch_batch(): 30 | return torch.rand(2, 100) 31 | 32 | 33 | @pytest.fixture() 34 | def linear_network_kwargs(): 35 | kwargs = dict(in_features=100, out_features=10) 36 | return kwargs 37 | 38 | 39 | def _test_model_forward(batches, model_class, model_kwargs): 40 | for batch in batches: 41 | model = model_class(**model_kwargs) 42 | y = model(batch) 43 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 44 | 45 | 46 | def _test_model_backward(batch, model_class, model_kwargs): 47 | model = model_class(**model_kwargs) 48 | y = model(batch) 49 | # Backwards on sum drives sum to zero 50 | y.sum().backward() 51 | 52 | 53 | # Test model forward on all models 54 | def test_defaultfcnet_forward(multiple_batch_types, linear_network_kwargs): 55 | _test_model_forward(multiple_batch_types, DefaultFCNet, linear_network_kwargs) 56 | 57 | 58 | def test_resfcnet_forward(multiple_batch_types, linear_network_kwargs): 59 | _test_model_forward(multiple_batch_types, ResFCNet, linear_network_kwargs) 60 | 61 | 62 | def test_resfcnet2_forward(multiple_batch_types, linear_network_kwargs): 63 | _test_model_forward(multiple_batch_types, ResFCNet2, linear_network_kwargs) 64 | 65 | 66 | def test_snn_forward(multiple_batch_types, linear_network_kwargs): 67 | _test_model_forward(multiple_batch_types, SNN, linear_network_kwargs) 68 | 69 | 70 | def test_tabnet_forward(multiple_batch_types, linear_network_kwargs): 71 | # Skip if optional dependency not installed 72 | pytest.importorskip("pytorch_tabnet") 73 | _test_model_forward(multiple_batch_types, TabNet, linear_network_kwargs) 74 | 75 | 76 | # Test model backward on all models 77 | def test_defaultfcnet_backward(simple_linear_batch, linear_network_kwargs): 78 | _test_model_backward(simple_linear_batch, DefaultFCNet, linear_network_kwargs) 79 | 80 | 81 | def test_resfcnet_backward(simple_linear_batch, linear_network_kwargs): 82 | _test_model_backward(simple_linear_batch, ResFCNet, linear_network_kwargs) 83 | 84 | 85 | def test_resfcnet2_backward(simple_linear_batch, linear_network_kwargs): 86 | _test_model_backward(simple_linear_batch, ResFCNet2, linear_network_kwargs) 87 | 88 | 89 | def test_snn_backward(simple_linear_batch, linear_network_kwargs): 90 | _test_model_backward(simple_linear_batch, SNN, linear_network_kwargs) 91 | 92 | 93 | def test_tabnet_backward(simple_linear_batch, linear_network_kwargs): 94 | # Skip if optional dependency not installed 95 | pytest.importorskip("pytorch_tabnet") 96 | _test_model_backward(simple_linear_batch, TabNet, linear_network_kwargs) 97 | -------------------------------------------------------------------------------- /tests/models/multimodal/site_encoders/test_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from pvnet.models.multimodal.site_encoders.encoders import ( 5 | SimpleLearnedAggregator, 6 | SingleAttentionNetwork, 7 | ) 8 | 9 | import pytest 10 | 11 | 12 | def _test_model_forward(batch, model_class, kwargs, batch_size): 13 | model = model_class(**kwargs) 14 | y = model(batch) 15 | assert tuple(y.shape) == (batch_size, kwargs["out_features"]), y.shape 16 | 17 | 18 | def _test_model_backward(batch, model_class, kwargs): 19 | model = model_class(**kwargs) 20 | y = model(batch) 21 | # Backwards on sum drives sum to zero 22 | y.sum().backward() 23 | 24 | 25 | # Test model forward on all models 26 | def test_simplelearnedaggregator_forward(sample_pv_batch, site_encoder_model_kwargs): 27 | _test_model_forward( 28 | sample_pv_batch, 29 | SimpleLearnedAggregator, 30 | site_encoder_model_kwargs, 31 | batch_size=8, 32 | ) 33 | 34 | 35 | def test_singleattentionnetwork_forward(sample_site_batch, site_encoder_model_kwargs_dsampler): 36 | _test_model_forward( 37 | sample_site_batch, 38 | SingleAttentionNetwork, 39 | site_encoder_model_kwargs_dsampler, 40 | batch_size=2, 41 | ) 42 | 43 | 44 | # TODO once we have test data which inludes sensor data with sites include this test 45 | # def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): 46 | # _test_model_forward( 47 | # sample_wind_batch, 48 | # SingleAttentionNetwork, 49 | # site_encoder_sensor_model_kwargs, 50 | # batch_size=2, 51 | # ) 52 | 53 | 54 | # Test model backward on all models 55 | def test_simplelearnedaggregator_backward(sample_pv_batch, site_encoder_model_kwargs): 56 | _test_model_backward(sample_pv_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) 57 | 58 | 59 | def test_singleattentionnetwork_backward(sample_site_batch, site_encoder_model_kwargs_dsampler): 60 | _test_model_backward( 61 | sample_site_batch, SingleAttentionNetwork, site_encoder_model_kwargs_dsampler 62 | ) 63 | -------------------------------------------------------------------------------- /tests/models/multimodal/test_multimodal.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD 2 | 3 | 4 | def test_model_forward(multimodal_model, sample_batch): 5 | y = multimodal_model(sample_batch) 6 | 7 | # check output is the correct shape 8 | # batch size=2, forecast_len=15 9 | assert tuple(y.shape) == (2, 16), y.shape 10 | 11 | def test_model_forward_site_history(multimodal_model_site_history, sample_site_batch): 12 | 13 | y = multimodal_model_site_history(sample_site_batch) 14 | 15 | # check output is the correct shape 16 | # batch size=2, forecast_len=15 17 | assert tuple(y.shape) == (2, 16), y.shape 18 | 19 | 20 | def test_model_backward(multimodal_model, sample_batch): 21 | opt = SGD(multimodal_model.parameters(), lr=0.001) 22 | 23 | y = multimodal_model(sample_batch) 24 | 25 | # Backwards on sum drives sum to zero 26 | y.sum().backward() 27 | 28 | 29 | def test_quantile_model_forward(multimodal_quantile_model, sample_batch): 30 | y_quantiles = multimodal_quantile_model(sample_batch) 31 | 32 | # check output is the correct shape 33 | # batch size=2, forecast_len=15, num_quantiles=3 34 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 35 | 36 | 37 | def test_quantile_model_backward(multimodal_quantile_model, sample_batch): 38 | opt = SGD(multimodal_quantile_model.parameters(), lr=0.001) 39 | 40 | y_quantiles = multimodal_quantile_model(sample_batch) 41 | 42 | # Backwards on sum drives sum to zero 43 | y_quantiles.sum().backward() 44 | 45 | 46 | def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minutes, sample_batch): 47 | y_quantiles = multimodal_quantile_model_ignore_minutes(sample_batch) 48 | 49 | # check output is the correct shape 50 | # batch size=2, forecast_len=8, num_quantiles=3 51 | assert tuple(y_quantiles.shape) == (2, 8, 3), y_quantiles.shape 52 | 53 | # Backwards on sum drives sum to zero 54 | y_quantiles.sum().backward() 55 | -------------------------------------------------------------------------------- /tests/models/multimodal/test_save_load_pretrained.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import re 3 | import pkg_resources 4 | from pvnet.models.base_model import BaseModel 5 | import yaml 6 | import tempfile 7 | 8 | 9 | 10 | def test_save_pretrained(tmp_path, multimodal_model, raw_multimodal_model_kwargs, sample_datamodule): 11 | with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as temp_file: 12 | # Get sample directory from the datamodule 13 | sample_dir = sample_datamodule.sample_dir 14 | 15 | # Create config with matching structure 16 | data_config = { 17 | "general": { 18 | "description": "Config for training the saved PVNet model", 19 | "name": "test_pvnet" 20 | }, 21 | "input_data": { 22 | "gsp": { 23 | "zarr_path": sample_dir, 24 | "interval_start_minutes": -120, 25 | "interval_end_minutes": 480, 26 | "time_resolution_minutes": 30, 27 | "dropout_timedeltas_minutes": None, 28 | "dropout_fraction": 0 29 | }, 30 | "nwp": { 31 | "ukv": { 32 | "provider": "ukv", 33 | "zarr_path": sample_dir, 34 | "interval_start_minutes": -120, 35 | "interval_end_minutes": 480, 36 | "time_resolution_minutes": 60, 37 | "channels": ["t", "dswrf", "dlwrf"], 38 | "image_size_pixels_height": 24, 39 | "image_size_pixels_width": 24, 40 | "dropout_timedeltas_minutes": None, 41 | "dropout_fraction": 0, 42 | "max_staleness_minutes": None 43 | } 44 | }, 45 | "satellite": { 46 | "zarr_path": sample_dir, 47 | "interval_start_minutes": -30, 48 | "interval_end_minutes": 0, 49 | "time_resolution_minutes": 5, 50 | "channels": ["IR_016", "IR_039", "IR_087"], 51 | "image_size_pixels_height": 24, 52 | "image_size_pixels_width": 24, 53 | "dropout_timedeltas_minutes": None, 54 | "dropout_fraction": 0 55 | } 56 | }, 57 | "sample_dir": sample_dir, 58 | "train_period": [None, None], 59 | "val_period": [None, None], 60 | "test_period": [None, None] 61 | } 62 | 63 | yaml.dump(data_config, temp_file) 64 | data_config_path = temp_file.name 65 | 66 | # Construct the model config 67 | model_config = {"_target_": "pvnet.models.multimodal.multimodal.Model"} 68 | model_config.update(raw_multimodal_model_kwargs) 69 | 70 | # Save the model 71 | model_output_dir = f"{tmp_path}/model" 72 | multimodal_model.save_pretrained( 73 | model_output_dir, 74 | config=model_config, 75 | data_config=data_config_path, 76 | wandb_repo=None, 77 | wandb_ids="excluded-for-text", 78 | push_to_hub=False, 79 | repo_id="openclimatefix/pvnet_uk_region", 80 | ) 81 | 82 | # Load the model 83 | _ = BaseModel.from_pretrained( 84 | model_id=model_output_dir, 85 | revision=None, 86 | ) 87 | 88 | @pytest.mark.parametrize( 89 | "repo_id, wandb_repo, wandb_ids", 90 | [ 91 | ( 92 | "openclimatefix/pvnet_uk_region", 93 | "None", 94 | "excluded-for-text" 95 | ), 96 | ], 97 | ) 98 | def test_create_hugging_face_model_card(repo_id, wandb_repo, wandb_ids): 99 | 100 | # Create Hugging Face ModelCard 101 | card = BaseModel.create_hugging_face_model_card( 102 | repo_id=repo_id, 103 | wandb_repo=wandb_repo, 104 | wandb_ids=wandb_ids 105 | ) 106 | 107 | # Extract the card markdown 108 | card_markdown = card.content 109 | 110 | # Regex to find if the pvnet and ocf-data-sampler versions are present 111 | pvnet_version = pkg_resources.get_distribution("pvnet").version 112 | has_pvnet = f"pvnet=={pvnet_version}" in card_markdown 113 | 114 | ocf_sampler_version = pkg_resources.get_distribution("ocf-data-sampler").version 115 | has_ocf_data_sampler= f"ocf-data-sampler=={ocf_sampler_version}" in card_markdown 116 | 117 | assert has_pvnet, f"The hugging face card created does not display the PVNet package version" 118 | assert has_ocf_data_sampler, f"The hugging face card created does not display the ocf-data-sampler package version" 119 | -------------------------------------------------------------------------------- /tests/models/multimodal/test_unimodal_teacher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tempfile 4 | import yaml 5 | 6 | import hydra 7 | import pytest 8 | 9 | from torch.optim import SGD 10 | import torch 11 | 12 | import pvnet 13 | from pvnet.models.multimodal.unimodal_teacher import Model 14 | 15 | 16 | @pytest.fixture 17 | def teacher_dir(multimodal_model, raw_multimodal_model_kwargs): 18 | raw_multimodal_model_kwargs["_target_"] = "pvnet.models.multimodal.multimodal.Model" 19 | 20 | with tempfile.TemporaryDirectory() as tmpdirname: 21 | # Save teachers for these modes 22 | for mode in ["sat", "nwp_ukv"]: 23 | mode_dir = f"{tmpdirname}/{mode}" 24 | os.mkdir(mode_dir) 25 | 26 | # Checkpoint paths would be like: epoch={X}-step={N}.ckpt or last.ckpt 27 | path = os.path.join(mode_dir, "epoch=2-step=35002.ckpt") 28 | path = f"{mode_dir}/epoch=2-step=35002.ckpt" 29 | 30 | # Save out themodel config file 31 | with open(os.path.join(mode_dir, "model_config.yaml"), "w") as outfile: 32 | yaml.dump(raw_multimodal_model_kwargs, outfile) 33 | 34 | # Save the weights 35 | torch.save({"model_state_dict": multimodal_model.state_dict()}, path) 36 | 37 | yield tempfile 38 | 39 | 40 | @pytest.fixture 41 | def unimodal_model_kwargs(teacher_dir, model_minutes_kwargs): 42 | # Configure the fusion network 43 | kwargs = dict( 44 | output_network=dict( 45 | _target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2, 46 | _partial_=True, 47 | fc_hidden_features=128, 48 | n_res_blocks=6, 49 | res_block_layers=2, 50 | dropout_frac=0.0, 51 | ), 52 | cold_start=True, 53 | location_id_mapping={i:i for i in range(1, 318)}, 54 | ) 55 | 56 | # Get the teacher model save directories 57 | mode_dirs = glob.glob(f"{teacher_dir}/*") 58 | mode_teacher_dict = dict() 59 | for mode_dir in mode_dirs: 60 | mode_name = mode_dir.split("/")[-1].replace("nwp_", "nwp/") 61 | mode_teacher_dict[mode_name] = mode_dir 62 | kwargs["mode_teacher_dict"] = mode_teacher_dict 63 | 64 | # Add the forecast and history minutes to be compatible with the sample batch 65 | kwargs.update(model_minutes_kwargs) 66 | 67 | yield hydra.utils.instantiate(kwargs) 68 | 69 | 70 | @pytest.fixture 71 | def unimodal_teacher_model(unimodal_model_kwargs): 72 | return Model(**unimodal_model_kwargs) 73 | 74 | 75 | def test_model_init(unimodal_model_kwargs): 76 | Model(**unimodal_model_kwargs) 77 | 78 | 79 | def test_model_forward(unimodal_teacher_model, sample_batch): 80 | # assert False 81 | y, _ = unimodal_teacher_model(sample_batch, return_modes=True) 82 | 83 | # check output is the correct shape 84 | # batch size=2, forecast_len=15 85 | assert tuple(y.shape) == (2, 16), y.shape 86 | 87 | 88 | def test_model_backward(unimodal_teacher_model, sample_batch): 89 | opt = SGD(unimodal_teacher_model.parameters(), lr=0.001) 90 | 91 | y = unimodal_teacher_model(sample_batch) 92 | 93 | # Backwards on sum drives sum to zero 94 | y.sum().backward() 95 | 96 | 97 | def test_model_conversion(unimodal_model_kwargs, sample_batch): 98 | # Create the unimodal model 99 | um_model = Model(**unimodal_model_kwargs) 100 | # Convert to the equivalent multimodel model 101 | mm_model, _ = um_model.convert_to_multimodal_model(unimodal_model_kwargs) 102 | 103 | # If the model has been successfully converted the predictions should be identical 104 | y_um = um_model(sample_batch, return_modes=False) 105 | y_mm = mm_model(sample_batch) 106 | 107 | assert (y_um == y_mm).all() 108 | -------------------------------------------------------------------------------- /tests/models/test_ensemble.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.ensemble import Ensemble 2 | 3 | 4 | def test_model_init(multimodal_model): 5 | ensemble_model = Ensemble( 6 | model_list=[multimodal_model] * 3, 7 | weights=None, 8 | ) 9 | 10 | ensemble_model = Ensemble( 11 | model_list=[multimodal_model] * 3, 12 | weights=[1, 2, 3], 13 | ) 14 | 15 | 16 | def test_model_forward(multimodal_model, sample_batch): 17 | ensemble_model = Ensemble( 18 | model_list=[multimodal_model] * 3, 19 | ) 20 | 21 | y = ensemble_model(sample_batch) 22 | 23 | # check output is the correct shape 24 | # batch size=2, forecast_len=15 25 | assert tuple(y.shape) == (2, 16), y.shape 26 | 27 | 28 | def test_quantile_model_forward(multimodal_quantile_model, sample_batch): 29 | ensemble_model = Ensemble( 30 | model_list=[multimodal_quantile_model] * 3, 31 | ) 32 | 33 | y_quantiles = ensemble_model(sample_batch) 34 | 35 | # check output is the correct shape 36 | # batch size=2, forecast_len=15, num_quantiles=3 37 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 38 | -------------------------------------------------------------------------------- /tests/test_data/sample_data/non_hrv_shell.zarr/.zgroup: -------------------------------------------------------------------------------- 1 | { 2 | "zarr_format": 2 3 | } 4 | -------------------------------------------------------------------------------- /tests/test_data/sample_data/non_hrv_shell.zarr/time/.zarray: -------------------------------------------------------------------------------- 1 | { 2 | "chunks": [ 3 | 24 4 | ], 5 | "compressor": { 6 | "blocksize": 0, 7 | "clevel": 5, 8 | "cname": "lz4", 9 | "id": "blosc", 10 | "shuffle": 1 11 | }, 12 | "dtype": "