├── pvnet ├── __init__.py ├── models │ ├── late_fusion │ │ ├── __init__.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── basic_blocks.py │ │ │ └── encoders3d.py │ │ ├── site_encoders │ │ │ ├── __init__.py │ │ │ ├── basic_blocks.py │ │ │ └── encoders.py │ │ ├── linear_networks │ │ │ ├── __init__.py │ │ │ ├── networks.py │ │ │ └── basic_blocks.py │ │ ├── README.md │ │ └── basic_blocks.py │ ├── __init__.py │ └── ensemble.py ├── training │ ├── __init__.py │ ├── plots.py │ ├── train.py │ └── lightning_module.py ├── model_cards │ └── empty_model_card_template.md ├── load_model.py ├── datamodule.py ├── utils.py └── optimizers.py ├── tests ├── __init__.py ├── test_datamodule.py ├── models │ ├── late_fusion │ │ ├── site_encoders │ │ │ └── test_encoders.py │ │ ├── test_late_fusion.py │ │ ├── encoders │ │ │ └── test_encoders3d.py │ │ ├── linear_networks │ │ │ └── test_networks.py │ │ └── test_save_load_pretrained.py │ ├── test_ensemble.py │ └── test_validation.py ├── test_end2end.py ├── test_data │ └── data_config.yaml ├── training │ └── test_train.py └── conftest.py ├── .github └── workflows │ ├── merged_ci.yml │ ├── tagged_ci.yml │ ├── branch_ci.yml │ └── pull_ci.yml ├── configs.example ├── trainer │ ├── default.yaml │ └── all_params.yaml ├── readme.md ├── hydra │ └── default.yaml ├── logger │ └── wandb.yaml ├── datamodule │ ├── streamed_samples.yaml │ └── configuration │ │ └── example_configuration.yaml ├── experiment │ └── example_simple.yaml ├── config.yaml ├── callbacks │ └── default.yaml └── model │ └── late_fusion.yaml ├── run.py ├── LICENSE ├── pyproject.toml ├── .gitignore ├── scripts ├── checkpoint_to_huggingface.py ├── mae_analysis.py ├── migrate_old_model.py ├── backtest_sites.py └── backtest_uk_gsp.py ├── .all-contributorsrc └── README.md /pvnet/__init__.py: -------------------------------------------------------------------------------- 1 | """PVNet source code.""" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Late fusion models""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | """Late fusion models""" 2 | -------------------------------------------------------------------------------- /pvnet/training/__init__.py: -------------------------------------------------------------------------------- 1 | """Training submodule""" 2 | from .train import train -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode satellite and NWP inputs""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode site-level PV data""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to combine 1D feature vectors from different sources and make final predictions""" 2 | -------------------------------------------------------------------------------- /pvnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models for PVNet""" 2 | from .base_model import BaseModel 3 | from .ensemble import Ensemble 4 | from .late_fusion.late_fusion import LateFusionModel 5 | -------------------------------------------------------------------------------- /.github/workflows/merged_ci.yml: -------------------------------------------------------------------------------- 1 | name: Merged CI 2 | run-name: 'Bump tag with merge #${{ github.event.number }} "${{ github.event.pull_request.title }}"' 3 | 4 | on: 5 | pull_request_target: 6 | types: ["closed"] 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | bump-tag: 11 | uses: openclimatefix/.github/.github/workflows/bump_tag.yml@main 12 | secrets: inherit 13 | -------------------------------------------------------------------------------- /configs.example/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.Trainer 2 | 3 | # set `gpu` to train on GPU, `cpu` to train on CPU only 4 | accelerator: auto 5 | devices: auto 6 | 7 | min_epochs: null 8 | max_epochs: null 9 | reload_dataloaders_every_n_epochs: 0 10 | num_sanity_val_steps: 8 11 | fast_dev_run: false 12 | 13 | accumulate_grad_batches: 4 14 | log_every_n_steps: 50 15 | -------------------------------------------------------------------------------- /.github/workflows/tagged_ci.yml: -------------------------------------------------------------------------------- 1 | name: Tagged CI 2 | run-name: 'Tagged CI for ${{ github.ref_name }} by ${{ github.actor }}' 3 | 4 | on: 5 | push: 6 | tags: ["v*.*.*"] 7 | 8 | jobs: 9 | tagged-ci: 10 | uses: openclimatefix/.github/.github/workflows/tagged_ci.yml@main 11 | secrets: inherit 12 | with: 13 | containerfile: 'None' 14 | enable_pypi: true 15 | -------------------------------------------------------------------------------- /tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | from pvnet.datamodule import PVNetDataModule 2 | 3 | 4 | 5 | def test_data_module(data_config_path): 6 | """Test PVNetDataModule initialization""" 7 | 8 | _ = PVNetDataModule( 9 | configuration=data_config_path, 10 | batch_size=2, 11 | num_workers=0, 12 | prefetch_factor=None, 13 | train_period=[None, None], 14 | val_period=[None, None], 15 | ) -------------------------------------------------------------------------------- /configs.example/readme.md: -------------------------------------------------------------------------------- 1 | This directory contains example configuration files for the PVNet project. Many paths will need to unique to each user. You can find these paths by searching for PLACEHOLDER within these logs. Not all of 2 | the values with a placeholder need to be set. For example in the logger subdirectory there are many different loggers with PLACEHOLDERS. If only one logger is used, then only that placeholder needs to be set. 3 | 4 | run experiments by: 5 | `python run.py experiment=example_simple ` 6 | -------------------------------------------------------------------------------- /configs.example/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | # Local log directory for hydra 4 | dir: PLACEHOLDER/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | sweep: 6 | # Local log directory for hydra 7 | dir: PLACEHOLDER/multiruns/${now:%Y-%m-%d_%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | 10 | # you can set here environment variables that are universal for all users 11 | # for system specific variables (like data paths) it's better to use .env file! 12 | job: 13 | env_set: 14 | EXAMPLE_VAR: "example_value" 15 | -------------------------------------------------------------------------------- /configs.example/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # wandb project to log to 6 | project: "PLACEHOLDER" 7 | name: "${model_name}" 8 | # location to store the wandb local logs 9 | save_dir: "PLACEHOLDER" 10 | offline: False # set True to store all logs only locally 11 | id: null # pass correct id to resume experiment! 12 | # entity: "" # set to name of your wandb team or just remove it 13 | log_model: False 14 | prefix: "" 15 | job_type: "train" 16 | group: "" 17 | tags: [] 18 | -------------------------------------------------------------------------------- /.github/workflows/branch_ci.yml: -------------------------------------------------------------------------------- 1 | name: Branch CI (Python) 2 | run-name: 'Test branch commit "${{ github.event.head_commit.message }}"' 3 | 4 | on: 5 | push: 6 | branches-ignore: [ "main" ] 7 | paths-ignore: ['README.md'] 8 | 9 | jobs: 10 | branch-ci: 11 | uses: openclimatefix/.github/.github/workflows/branch_ci.yml@main 12 | secrets: inherit 13 | with: 14 | enable_linting: true 15 | enable_typechecking: false 16 | containerfile: 'None' 17 | tests_folder: 'tests' 18 | tests_matrix: true 19 | test_python_versions: '["3.11", "3.12"]' 20 | -------------------------------------------------------------------------------- /.github/workflows/pull_ci.yml: -------------------------------------------------------------------------------- 1 | name: Pull CI (Python) 2 | run-name: 'Test PR edit #${{ github.event.number }} "${{ github.event.pull_request.title }}"' 3 | 4 | on: 5 | pull_request: 6 | paths-ignore: ['README.md'] 7 | 8 | jobs: 9 | 10 | pull-ci: 11 | uses: openclimatefix/.github/.github/workflows/branch_ci.yml@main 12 | if: ${{ github.event.pull_request.head.repo.fork }} 13 | with: 14 | enable_linting: true 15 | enable_typechecking: false 16 | containerfile: 'None' 17 | tests_folder: 'tests' 18 | tests_matrix: true 19 | test_python_versions: '["3.11", "3.12"]' 20 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/README.md: -------------------------------------------------------------------------------- 1 | ## Multimodal model architecture 2 | 3 | These models fusion models to predict GSP power output based on NWP, non-HRV satellite, GSP output history, solor coordinates, and GSP ID. 4 | 5 | The core model is `late_fusion.LateFusionModel`, and its architecture is shown in the diagram below. 6 | 7 | ![multimodal_model_diagram](https://github.com/openclimatefix/PVNet/assets/41546094/118393fa-52ec-4bfe-a0a3-268c94c25f1e) 8 | 9 | This model uses encoders which take 4D (time, channel, x, y) inputs of NWP and satellite and encode them into 1D feature vectors. Different encoders are contained inside `encoders`. 10 | 11 | Different choices for the fusion model are contained inside `linear_networks`. 12 | -------------------------------------------------------------------------------- /configs.example/datamodule/streamed_samples.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.datamodule.PVNetDataModule 2 | # Path to the data configuration yaml file. You can find examples in the configuration subdirectory 3 | # in configs.example/datamodule/configuration 4 | # Use the full local path such as: /FULL/PATH/PVNet/configs/datamodule/configuration/example_configuration.yaml" 5 | configuration: "PLACEHOLDER.yaml" 6 | num_workers: 20 7 | prefetch_factor: 2 8 | persistent_workers: false 9 | batch_size: 8 10 | 11 | train_period: 12 | - null 13 | - "2022-05-07" 14 | val_period: 15 | - "2022-05-08" 16 | - "2023-05-08" 17 | 18 | seed: "${seed}" 19 | 20 | # Setting the dataset pickle dir will speed up initiation of multiple workers 21 | dataset_pickle_dir: null 22 | -------------------------------------------------------------------------------- /configs.example/experiment/example_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: multimodal.yaml 9 | - override /datamodule: streamed_samples.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | - override /hydra: default.yaml 13 | 14 | # all parameters below will be merged with parameters from default configurations set above 15 | # this allows you to overwrite only specified parameters 16 | 17 | seed: 518 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 2 22 | 23 | datamodule: 24 | batch_size: 16 25 | -------------------------------------------------------------------------------- /tests/models/late_fusion/site_encoders/test_encoders.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.site_encoders.encoders import SingleAttentionNetwork 2 | 3 | 4 | def _test_model_forward(batch, model_class, kwargs, batch_size): 5 | model = model_class(**kwargs) 6 | y = model(batch) 7 | assert tuple(y.shape) == (batch_size, kwargs["out_features"]), y.shape 8 | 9 | 10 | def _test_model_backward(batch, model_class, kwargs): 11 | model = model_class(**kwargs) 12 | y = model(batch) 13 | # Backwards on sum drives sum to zero 14 | y.sum().backward() 15 | 16 | 17 | def test_singleattentionnetwork_forward(batch, site_encoder_model_kwargs): 18 | _test_model_forward( 19 | batch, 20 | SingleAttentionNetwork, 21 | site_encoder_model_kwargs, 22 | batch_size=2, 23 | ) 24 | 25 | 26 | def test_singleattentionnetwork_backward(batch, site_encoder_model_kwargs): 27 | _test_model_backward(batch, SingleAttentionNetwork, site_encoder_model_kwargs) 28 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Run training. 2 | 3 | This file can be run for example using 4 | >> python run.py experiment=example_simple 5 | """ 6 | 7 | import logging 8 | import sys 9 | 10 | import hydra 11 | from omegaconf import DictConfig 12 | 13 | from pvnet.training import train 14 | from pvnet.utils import print_config, run_config_utilities, validate_gpu_config 15 | 16 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 17 | 18 | 19 | 20 | @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.2") 21 | def main(config: DictConfig) -> None: 22 | """Runs training""" 23 | 24 | # A couple of optional utilities: 25 | # - disabling python warnings 26 | # - forcing debug friendly configuration 27 | # - forcing multi-gpu friendly configuration 28 | run_config_utilities(config) 29 | validate_gpu_config(config) 30 | print_config(config, resolve=True) 31 | 32 | return train(config) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /configs.example/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: late_fusion.yaml 8 | - datamodule: streamed_samples.yaml 9 | - callbacks: default.yaml # set this to null if you don't want to use callbacks 10 | - logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | - experiment: null 12 | - hparams_search: null 13 | - hydra: default.yaml 14 | 15 | renewable: "pv_uk" 16 | 17 | # enable color logging 18 | # - override hydra/hydra_logging: colorlog 19 | # - override hydra/job_logging: colorlog 20 | 21 | # path to original working directory 22 | # hydra hijacks working directory by changing it to the current log directory, 23 | # so it's useful to have this path as a special variable 24 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 25 | work_dir: ${hydra:runtime.cwd} 26 | 27 | model_name: "default" 28 | 29 | seed: 2727831 30 | -------------------------------------------------------------------------------- /tests/models/test_ensemble.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.ensemble import Ensemble 2 | 3 | 4 | def test_model_init(late_fusion_model): 5 | # Without weighting 6 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3, weights=None) 7 | 8 | # With weighting 9 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3, weights=[1, 2, 3]) 10 | 11 | 12 | def test_model_forward(late_fusion_model, batch): 13 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3) 14 | 15 | y = ensemble_model(batch) 16 | 17 | # Check output is the correct shape: [batch size=2, forecast_len=16] 18 | assert tuple(y.shape) == (2, 16), y.shape 19 | 20 | 21 | def test_quantile_model_forward(late_fusion_quantile_model, batch): 22 | ensemble_model = Ensemble(model_list=[late_fusion_quantile_model] * 3) 23 | 24 | y_quantiles = ensemble_model(batch) 25 | 26 | # Check output is the correct shape: [batch size=2, forecast_len=16, num_quantiles=3] 27 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 28 | -------------------------------------------------------------------------------- /tests/test_end2end.py: -------------------------------------------------------------------------------- 1 | import lightning 2 | 3 | from pvnet.datamodule import PVNetDataModule 4 | from pvnet.optimizers import EmbAdamWReduceLROnPlateau 5 | from pvnet.training.lightning_module import PVNetLightningModule 6 | 7 | 8 | def test_model_trainer_fit(session_tmp_path, data_config_path, late_fusion_model): 9 | """Test end-to-end training.""" 10 | 11 | datamodule = PVNetDataModule( 12 | configuration=data_config_path, 13 | batch_size=2, 14 | num_workers=2, 15 | prefetch_factor=None, 16 | dataset_pickle_dir=f"{session_tmp_path}/dataset_pickles" 17 | ) 18 | 19 | lightning_model = PVNetLightningModule( 20 | model=late_fusion_model, 21 | optimizer=EmbAdamWReduceLROnPlateau(), 22 | ) 23 | 24 | # Train the model for two batches 25 | trainer = lightning.Trainer( 26 | max_epochs=2, 27 | limit_val_batches=2, 28 | limit_train_batches=2, 29 | accelerator="cpu", 30 | logger=False, 31 | enable_checkpointing=False, 32 | ) 33 | trainer.fit(model=lightning_model, datamodule=datamodule) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Open Climate Fix Ltd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for PV-site encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta): 9 | """Abstract class for encoder for output data from multiple PV sites. 10 | 11 | The encoder will take an input of shape (batch_size, sequence_length, num_sites) 12 | and return an output of shape (batch_size, out_features). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sequence_length: int, 18 | num_sites: int, 19 | out_features: int, 20 | ): 21 | """Abstract class for PV site-level encoder. 22 | 23 | Args: 24 | sequence_length: The time sequence length of the data. 25 | num_sites: Number of PV sites in the input data. 26 | out_features: Number of output features. 27 | """ 28 | super().__init__() 29 | self.sequence_length = sequence_length 30 | self.num_sites = num_sites 31 | self.out_features = out_features 32 | 33 | @abstractmethod 34 | def forward(self) -> torch.Tensor: 35 | """Run model forward""" 36 | pass 37 | -------------------------------------------------------------------------------- /tests/models/late_fusion/test_late_fusion.py: -------------------------------------------------------------------------------- 1 | def test_model_forward(late_fusion_model, batch): 2 | y = late_fusion_model(batch) 3 | 4 | # Check output is the correct shape: [batch size=2, forecast_len=16] 5 | assert tuple(y.shape) == (2, 16), y.shape 6 | 7 | def test_model_forward_generation_history(late_fusion_model_generation_history, batch): 8 | 9 | y = late_fusion_model_generation_history(batch) 10 | 11 | # Check output is the correct shape: [batch size=2, forecast_len=32] 12 | assert tuple(y.shape) == (2, 16), y.shape 13 | 14 | 15 | def test_model_backward(late_fusion_model, batch): 16 | y = late_fusion_model(batch) 17 | 18 | # Backwards on sum drives sum to zero 19 | y.sum().backward() 20 | 21 | 22 | def test_quantile_model_forward(late_fusion_quantile_model, batch): 23 | y_quantiles = late_fusion_quantile_model(batch) 24 | 25 | # Check output is the correct shape: [batch size=2, forecast_len=16, num_quantiles=3] 26 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 27 | 28 | 29 | def test_quantile_model_backward(late_fusion_quantile_model, batch): 30 | 31 | y_quantiles = late_fusion_quantile_model(batch) 32 | 33 | # Backwards on sum drives sum to zero 34 | y_quantiles.sum().backward() 35 | -------------------------------------------------------------------------------- /configs.example/trainer/all_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_save_path: null 33 | num_sanity_val_steps: 2 34 | truncated_bptt_steps: null 35 | resume_from_checkpoint: null 36 | profiler: null 37 | benchmark: False 38 | deterministic: False 39 | reload_dataloaders_every_epoch: False 40 | auto_lr_find: False 41 | replace_sampler_ddp: True 42 | terminate_on_nan: False 43 | auto_scale_batch_size: False 44 | prepare_data_per_node: True 45 | plugins: null 46 | amp_backend: "native" 47 | amp_level: "O2" 48 | move_metrics_to_cpu: False 49 | -------------------------------------------------------------------------------- /tests/models/late_fusion/encoders/test_encoders3d.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.encoders.encoders3d import DefaultPVNet, ResConv3DNet 2 | 3 | 4 | def _test_model_forward(batch, model_class, model_kwargs): 5 | model = model_class(**model_kwargs) 6 | y = model(batch) 7 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 8 | 9 | 10 | def _test_model_backward(batch, model_class, model_kwargs): 11 | model = model_class(**model_kwargs) 12 | y = model(batch) 13 | # Backwards on sum drives sum to zero 14 | y.sum().backward() 15 | 16 | 17 | # Test model forward on all models 18 | def test_defaultpvnet_forward(satellite_batch_component, encoder_model_kwargs): 19 | _test_model_forward(satellite_batch_component, DefaultPVNet, encoder_model_kwargs) 20 | 21 | 22 | def test_resconv3dnet_forward(satellite_batch_component, encoder_model_kwargs): 23 | _test_model_forward(satellite_batch_component, ResConv3DNet, encoder_model_kwargs) 24 | 25 | 26 | # Test model backward on all models 27 | def test_defaultpvnet_backward(satellite_batch_component, encoder_model_kwargs): 28 | _test_model_backward(satellite_batch_component, DefaultPVNet, encoder_model_kwargs) 29 | 30 | 31 | def test_resconv3dnet_backward(satellite_batch_component, encoder_model_kwargs): 32 | _test_model_backward(satellite_batch_component, ResConv3DNet, encoder_model_kwargs) 33 | -------------------------------------------------------------------------------- /configs.example/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: lightning.pytorch.callbacks.EarlyStopping 3 | # name of the logged metric which determines when model is improving 4 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 5 | mode: "min" # can be "max" or "min" 6 | patience: 10 # how many epochs (or val check periods) of not improving until training stops 7 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 8 | 9 | learning_rate_monitor: 10 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 11 | logging_interval: "epoch" 12 | 13 | model_summary: 14 | _target_: lightning.pytorch.callbacks.ModelSummary 15 | max_depth: 3 16 | 17 | model_checkpoint: 18 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 19 | # name of the logged metric which determines when model is improving 20 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 21 | mode: "min" # can be "max" or "min" 22 | save_top_k: 1 # save k best models (determined by above metric) 23 | save_last: True # additionaly always save model from last epoch 24 | every_n_epochs: 1 25 | verbose: False 26 | filename: "epoch={epoch}-step={step}" 27 | # The path to where the model checkpoints will be stored 28 | dirpath: "PLACEHOLDER/${model_name}" #${..model_name} 29 | auto_insert_metric_name: False 30 | save_on_train_epoch_end: False 31 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic layers for composite models""" 2 | 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class ImageEmbedding(nn.Module): 9 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.""" 10 | 11 | def __init__(self, num_embeddings: int, sequence_length: int, image_size_pixels: int, **kwargs): 12 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs. 13 | 14 | The embedding is a single 2D image and is appended at each step in the 1st dimension 15 | (assumed to be time). 16 | 17 | Args: 18 | num_embeddings: Size of the dictionary of embeddings 19 | sequence_length: The time sequence length of the data. 20 | image_size_pixels: The spatial size of the image. Assumed square. 21 | **kwargs: See `torch.nn.Embedding` for more possible arguments. 22 | """ 23 | super().__init__() 24 | self.image_size_pixels = image_size_pixels 25 | self.sequence_length = sequence_length 26 | self._embed = nn.Embedding( 27 | num_embeddings=num_embeddings, 28 | embedding_dim=image_size_pixels * image_size_pixels, 29 | **kwargs, 30 | ) 31 | 32 | def forward(self, x: torch.Tensor, id: torch.Tensor) -> torch.Tensor: 33 | """Append ID embedding to image""" 34 | emb = self._embed(id) 35 | emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels)) 36 | emb = emb.repeat(1, 1, self.sequence_length, 1, 1) 37 | return torch.cat((x, emb), dim=1) 38 | -------------------------------------------------------------------------------- /tests/models/late_fusion/linear_networks/test_networks.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.linear_networks.networks import ResFCNet 2 | import pytest 3 | import torch 4 | from collections import OrderedDict 5 | 6 | 7 | @pytest.fixture() 8 | def simple_linear_batch(): 9 | return torch.rand(2, 100) 10 | 11 | 12 | @pytest.fixture() 13 | def late_fusion_linear_batch(): 14 | return OrderedDict(nwp=torch.rand(2, 50), sat=torch.rand(2, 40), sun=torch.rand(2, 10)) 15 | 16 | 17 | @pytest.fixture() 18 | def multiple_batch_types(simple_linear_batch, late_fusion_linear_batch): 19 | return [simple_linear_batch, late_fusion_linear_batch] 20 | 21 | 22 | @pytest.fixture() 23 | def fc_batch_batch(): 24 | return torch.rand(2, 100) 25 | 26 | 27 | @pytest.fixture() 28 | def linear_network_kwargs(): 29 | return dict(in_features=100, out_features=10) 30 | 31 | 32 | def _test_model_forward(batches, model_class, model_kwargs): 33 | for batch in batches: 34 | model = model_class(**model_kwargs) 35 | y = model(batch) 36 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 37 | 38 | 39 | def _test_model_backward(batch, model_class, model_kwargs): 40 | model = model_class(**model_kwargs) 41 | y = model(batch) 42 | # Backwards on sum drives sum to zero 43 | y.sum().backward() 44 | 45 | 46 | # Test model forward on all models 47 | def test_resfcnet_forward(multiple_batch_types, linear_network_kwargs): 48 | _test_model_forward(multiple_batch_types, ResFCNet, linear_network_kwargs) 49 | 50 | 51 | def test_resfcnet_backward(simple_linear_batch, linear_network_kwargs): 52 | _test_model_backward(simple_linear_batch, ResFCNet, linear_network_kwargs) 53 | -------------------------------------------------------------------------------- /tests/models/test_validation.py: -------------------------------------------------------------------------------- 1 | """Tests for model and trainer configuration validation utilities.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from pvnet.utils import validate_batch_against_config, validate_gpu_config 7 | 8 | 9 | def test_validate_batch_against_config( 10 | batch: dict, 11 | late_fusion_model, 12 | ): 13 | """Test batch validation utility function.""" 14 | # This should pass as full uk_batch is valid 15 | validate_batch_against_config(batch=batch, model=late_fusion_model) 16 | 17 | 18 | def test_validate_batch_against_config_raises_error(late_fusion_model): 19 | """Test that the validation raises an error for a mismatched batch.""" 20 | # Create batch that is missing required NWP data 21 | minimal_batch = {"generation": torch.randn(2, 17)} 22 | with pytest.raises( 23 | ValueError, 24 | match="Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch.", 25 | ): 26 | validate_batch_against_config(batch=minimal_batch, model=late_fusion_model) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "trainer", 31 | [ 32 | {"devices": 1}, 33 | {"devices": [0]}, 34 | {"accelerator": "cpu"}, 35 | ], 36 | ids=["devices=1", "devices=[0]", "accelerator=cpu"], 37 | ) 38 | def test_validate_gpu_config_single_device(trainer_cfg, trainer): 39 | """Accept single GPU or explicit CPU configurations.""" 40 | validate_gpu_config(trainer_cfg(trainer)) 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "trainer", 45 | [ 46 | {"devices": 2}, 47 | ], 48 | ids=["devices=2"], 49 | ) 50 | def test_validate_gpu_config_multiple_devices(trainer_cfg, trainer): 51 | """Reject accidental multi-GPU setups.""" 52 | with pytest.raises(ValueError, match="Parallel training not supported"): 53 | validate_gpu_config(trainer_cfg(trainer)) 54 | -------------------------------------------------------------------------------- /tests/models/late_fusion/test_save_load_pretrained.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | from pvnet.models import BaseModel 3 | import pvnet.model_cards 4 | 5 | 6 | card_path = f"{pvnet.model_cards.__path__[0]}/empty_model_card_template.md" 7 | 8 | 9 | def test_save_pretrained( 10 | tmp_path, 11 | late_fusion_model, 12 | raw_late_fusion_model_kwargs, 13 | data_config_path 14 | ): 15 | 16 | # Construct the model config 17 | model_config = { 18 | "_target_": "pvnet.models.LateFusionModel", 19 | **raw_late_fusion_model_kwargs, 20 | } 21 | 22 | # Save the model 23 | model_output_dir = f"{tmp_path}/saved_model" 24 | late_fusion_model.save_pretrained( 25 | save_directory=model_output_dir, 26 | model_config=model_config, 27 | data_config_path=data_config_path, 28 | wandb_repo="test", 29 | wandb_ids="abc", 30 | card_template_path=card_path, 31 | push_to_hub=False, 32 | ) 33 | 34 | # Load the model 35 | _ = BaseModel.from_pretrained(model_id=model_output_dir, revision=None) 36 | 37 | 38 | def test_create_hugging_face_model_card(): 39 | 40 | # Create Hugging Face ModelCard 41 | card = BaseModel.create_hugging_face_model_card(card_path, wandb_repo="test", wandb_ids="abc") 42 | 43 | # Extract the card markdown 44 | card_markdown = card.content 45 | 46 | # Regex to find if the pvnet and ocf-data-sampler versions are present 47 | pvnet_version = version("pvnet") 48 | has_pvnet = f"pvnet=={pvnet_version}" in card_markdown 49 | 50 | ocf_sampler_version = version("ocf-data-sampler") 51 | has_ocf_data_sampler= f"ocf-data-sampler=={ocf_sampler_version}" in card_markdown 52 | 53 | assert has_pvnet, f"The hugging face card created does not display the PVNet package version" 54 | assert has_ocf_data_sampler, f"The hugging face card created does not display the ocf-data-sampler package version" 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=67", "wheel", "setuptools-git-versioning>=2.0,<3"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="PVNet" 7 | description = "PVNet" 8 | authors = [{name="Peter Dudfield", email="info@openclimatefix.org"}] 9 | dynamic = ["version"] 10 | license={file="LICENCE"} 11 | readme = {file="README.md", content-type="text/markdown"} 12 | requires-python = ">=3.11,<3.14" 13 | 14 | dependencies = [ 15 | "ocf-data-sampler>=0.6.0", 16 | "numpy", 17 | "pandas", 18 | "matplotlib", 19 | "xarray", 20 | "h5netcdf", 21 | "torch>=2.0.0", 22 | "lightning", 23 | "typer", 24 | "sqlalchemy", 25 | "fsspec[s3]", 26 | "wandb", 27 | "huggingface-hub", 28 | "tqdm", 29 | "omegaconf", 30 | "hydra-core", 31 | "rich", 32 | "einops", 33 | "safetensors", 34 | ] 35 | 36 | [dependency-groups] 37 | dev=[ 38 | "ruff", 39 | "mypy", 40 | "pytest", 41 | "pytest-cov", 42 | ] 43 | 44 | [tool.setuptools-git-versioning] 45 | enabled = true 46 | 47 | [tool.setuptools.package-dir] 48 | "pvnet" = "pvnet" 49 | 50 | [tool.mypy] 51 | exclude = [ 52 | "^tests/", 53 | ] 54 | disallow_untyped_defs = true 55 | disallow_any_unimported = true 56 | no_implicit_optional = true 57 | check_untyped_defs = true 58 | warn_return_any = true 59 | warn_unused_ignores = true 60 | show_error_codes = true 61 | warn_unreachable = true 62 | 63 | [[tool.mypy.overrides]] 64 | module = [] 65 | ignore_missing_imports = true 66 | 67 | [tool.pytest.ini_options] 68 | minversion = "6.0" 69 | addopts = "-ra -q" 70 | testpaths = [ 71 | "tests", 72 | ] 73 | 74 | [tool.ruff] 75 | line-length = 100 76 | exclude = ["tests"] 77 | target-version = "py310" 78 | 79 | [tool.ruff.lint] 80 | extend-select = ["E", "D", "I"] 81 | ignore = ["D200","D202","D210","D212","D415","D105"] 82 | 83 | [tool.ruff.lint.mccabe] 84 | # Unlike Flake8, default to a complexity level of 10. 85 | max-complexity = 10 86 | 87 | [tool.ruff.lint.pydocstyle] 88 | # Use Google-style docstrings. 89 | convention = "google" 90 | 91 | [tool.ruff.lint.per-file-ignores] 92 | "__init__.py" = ["F401", "E402"] 93 | -------------------------------------------------------------------------------- /pvnet/model_cards/empty_model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 7 | 8 | 9 | # TEMPLATE 10 | 11 | 12 | ## Model Description 13 | 14 | 17 | 18 | - **Developed by:** openclimatefix 19 | - **Model type:** Fusion model 20 | - **Language(s) (NLP):** en 21 | - **License:** mit 22 | 23 | # Training Details 24 | 25 | ## Data 26 | 27 | 32 | 33 | 34 | ### Preprocessing 35 | 36 | 39 | 40 | ## Results 41 | 42 | 43 | The training logs for this model commit can be found here: 44 | {{ wandb_links }} 45 | 46 | 47 | ### Hardware 48 | Trained on a single NVIDIA Tesla T4 49 | 50 | 51 | ### Software 52 | 53 | This model was trained using the following Open Climate Fix packages: 54 | 55 | - [1] https://github.com/openclimatefix/PVNet 56 | - [2] https://github.com/openclimatefix/ocf-data-sampler 57 | 58 | 59 | The versions of these packages can be found below: 60 | {{ package_versions }} 61 | -------------------------------------------------------------------------------- /pvnet/models/ensemble.py: -------------------------------------------------------------------------------- 1 | """Model which uses mutliple prediction heads""" 2 | import torch 3 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 4 | from torch import nn 5 | 6 | from pvnet.models.base_model import BaseModel 7 | 8 | 9 | class Ensemble(BaseModel): 10 | """Ensemble of PVNet models""" 11 | 12 | def __init__( 13 | self, 14 | model_list: list[BaseModel], 15 | weights: list[float] | None = None, 16 | ): 17 | """Ensemble of PVNet models 18 | 19 | Args: 20 | model_list: A list of PVNet models to ensemble 21 | weights: A list of weighting to apply to each model. If None, the models are weighted 22 | equally. 23 | """ 24 | 25 | # Surface check all the models are compatible 26 | output_quantiles = [] 27 | history_minutes = [] 28 | forecast_minutes = [] 29 | interval_minutes = [] 30 | 31 | # Get some model properties from each model 32 | for model in model_list: 33 | output_quantiles.append(model.output_quantiles) 34 | history_minutes.append(model.history_minutes) 35 | forecast_minutes.append(model.forecast_minutes) 36 | interval_minutes.append(model.interval_minutes) 37 | 38 | # Check these properties are all the same 39 | for param_list in [ 40 | output_quantiles, 41 | history_minutes, 42 | forecast_minutes, 43 | interval_minutes, 44 | ]: 45 | assert all([p == param_list[0] for p in param_list]), param_list 46 | 47 | super().__init__( 48 | history_minutes=history_minutes[0], 49 | forecast_minutes=forecast_minutes[0], 50 | output_quantiles=output_quantiles[0], 51 | interval_minutes=interval_minutes[0], 52 | ) 53 | 54 | self.model_list = nn.ModuleList(model_list) 55 | 56 | if weights is None: 57 | weights = torch.ones(len(model_list)) / len(model_list) 58 | else: 59 | assert len(weights) == len(model_list) 60 | weights = torch.Tensor(weights) / sum(weights) 61 | self.weights = nn.Parameter(weights, requires_grad=False) 62 | 63 | def forward(self, x: TensorBatch) -> torch.Tensor: 64 | """Run the model forward""" 65 | y_hat = 0 66 | for weight, model in zip(self.weights, self.model_list): 67 | y_hat = model(x) * weight + y_hat 68 | return y_hat 69 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/networks.py: -------------------------------------------------------------------------------- 1 | """Linear networks used for the fusion model""" 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.late_fusion.linear_networks.basic_blocks import ( 8 | AbstractLinearNetwork, 9 | ResidualLinearBlock, 10 | ) 11 | 12 | 13 | class ResFCNet(AbstractLinearNetwork): 14 | """Fully connected deep network based on ResNet architecture. 15 | 16 | This architecture is similar to 17 | `ResFCNet`, except that it uses LeakyReLU activations internally, and batchnorm in the residual 18 | branches. The residual blocks are implemented based on the best performing block in [1]. 19 | 20 | Sources: 21 | [1] https://arxiv.org/pdf/1603.05027.pdf 22 | """ 23 | 24 | def __init__( 25 | self, 26 | in_features: int, 27 | out_features: int, 28 | fc_hidden_features: int = 128, 29 | n_res_blocks: int = 4, 30 | res_block_layers: int = 2, 31 | dropout_frac: float = 0.0, 32 | ): 33 | """Fully connected deep network based on ResNet architecture. 34 | 35 | Args: 36 | in_features: Number of input features. 37 | out_features: Number of output features. 38 | fc_hidden_features: Number of features in middle hidden layers. 39 | n_res_blocks: Number of residual blocks to use. 40 | res_block_layers: Number of fully-connected layers used in each residual block. 41 | dropout_frac: Probability of an element to be zeroed in the residual pathways. 42 | """ 43 | super().__init__(in_features, out_features) 44 | 45 | model = [nn.Linear(in_features=in_features, out_features=fc_hidden_features)] 46 | 47 | for i in range(n_res_blocks): 48 | model += [ 49 | ResidualLinearBlock( 50 | in_features=fc_hidden_features, 51 | n_layers=res_block_layers, 52 | dropout_frac=dropout_frac, 53 | ) 54 | ] 55 | 56 | model += [ 57 | nn.LeakyReLU(), 58 | nn.Linear(in_features=fc_hidden_features, out_features=out_features), 59 | nn.LeakyReLU(negative_slope=0.01), 60 | ] 61 | 62 | self.model = nn.Sequential(*model) 63 | 64 | def forward(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 65 | """Run model forward""" 66 | x = self.cat_modes(x) 67 | return self.model(x) 68 | -------------------------------------------------------------------------------- /pvnet/training/plots.py: -------------------------------------------------------------------------------- 1 | """Plots logged during training""" 2 | from collections.abc import Sequence 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import pylab 7 | import torch 8 | import wandb 9 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 10 | 11 | 12 | def wandb_line_plot( 13 | x: Sequence[float], 14 | y: Sequence[float], 15 | xlabel: str, 16 | ylabel: str, 17 | title: str | None = None 18 | ) -> wandb.plot.CustomChart: 19 | """Make a wandb line plot""" 20 | data = [[xi, yi] for (xi, yi) in zip(x, y)] 21 | table = wandb.Table(data=data, columns=[xlabel, ylabel]) 22 | return wandb.plot.line(table, xlabel, ylabel, title=title) 23 | 24 | 25 | def plot_sample_forecasts( 26 | batch: TensorBatch, 27 | y_hat: torch.Tensor, 28 | quantiles: list[float] | None, 29 | key_to_plot: str, 30 | ) -> plt.Figure: 31 | """Plot a batch of data and the forecast from that batch""" 32 | 33 | y = batch[key_to_plot].cpu().numpy() 34 | y_hat = y_hat.cpu().numpy() 35 | ids = batch["location_id"].cpu().numpy().squeeze() 36 | times_utc = pd.to_datetime( 37 | batch["time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]") 38 | ) 39 | batch_size = y.shape[0] 40 | 41 | fig, axes = plt.subplots(4, 4, figsize=(16, 16)) 42 | 43 | for i, ax in enumerate(axes.ravel()[:batch_size]): 44 | 45 | ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$") 46 | 47 | if quantiles is None: 48 | ax.plot( 49 | times_utc[i][-len(y_hat[i]) :], 50 | y_hat[i], 51 | marker=".", 52 | color="r", 53 | label=r"$\hat{y}$", 54 | ) 55 | else: 56 | cm = pylab.get_cmap("twilight") 57 | for nq, q in enumerate(quantiles): 58 | ax.plot( 59 | times_utc[i][-len(y_hat[i]) :], 60 | y_hat[i, :, nq], 61 | color=cm(q), 62 | label=r"$\hat{y}$" + f"({q})", 63 | alpha=0.7, 64 | ) 65 | 66 | ax.set_title(f"ID: {ids[i]} | {times_utc[i][0].date()}", fontsize="small") 67 | 68 | xticks = [t for t in times_utc[i] if t.minute == 0][::2] 69 | ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90) 70 | ax.grid() 71 | 72 | axes[0, 0].legend(loc="best") 73 | 74 | if batch_size<16: 75 | for ax in axes.ravel()[batch_size:]: 76 | ax.axis("off") 77 | 78 | for ax in axes[-1, :]: 79 | ax.set_xlabel("Time (hour of day)") 80 | 81 | title = f"Normed {key_to_plot.upper()} output" 82 | 83 | plt.suptitle(title) 84 | plt.tight_layout() 85 | 86 | return fig 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | config_tree.txt 3 | configs/ 4 | lightning_logs/ 5 | logs/ 6 | output/ 7 | checkpoints* 8 | csv/ 9 | notebooks/ 10 | *.html 11 | *.csv 12 | latest_logged_train_batch.png 13 | 14 | # Ignore all model cards... 15 | pvnet/model_cards/* 16 | 17 | # ...except for the empty template. 18 | !pvnet/model_cards/empty_model_card_template.md 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | .DS_Store 150 | 151 | # vim 152 | *swp 153 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for the lienar networks""" 2 | from abc import ABCMeta, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class AbstractLinearNetwork(nn.Module, metaclass=ABCMeta): 10 | """Abstract class for a network to combine the features from all the inputs.""" 11 | 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | ): 17 | """Abstract class for a network to combine the features from all the inputs. 18 | 19 | Args: 20 | in_features: Number of input features. 21 | out_features: Number of output features. 22 | """ 23 | super().__init__() 24 | 25 | def cat_modes(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 26 | """Concatenate modes of input data into 1D feature vector""" 27 | if isinstance(x, OrderedDict): 28 | return torch.cat([value for key, value in x.items()], dim=1) 29 | elif isinstance(x, torch.Tensor): 30 | return x 31 | else: 32 | raise ValueError(f"Input of unexpected type {type(x)}") 33 | 34 | @abstractmethod 35 | def forward(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 36 | """Run model forward""" 37 | pass 38 | 39 | 40 | class ResidualLinearBlock(nn.Module): 41 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 42 | 43 | This was the best performing residual block tested in the study. This implementation differs 44 | from that block just by using LeakyReLU activation to avoid dead neuron, and by including 45 | optional dropout in the residual branch. This is also a 1D fully connected layer residual block 46 | rather than a 2D convolutional block. 47 | 48 | Sources: 49 | [1] https://arxiv.org/pdf/1603.05027.pdf 50 | """ 51 | 52 | def __init__( 53 | self, 54 | in_features: int, 55 | n_layers: int = 2, 56 | dropout_frac: float = 0.0, 57 | ): 58 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 59 | 60 | Sources: 61 | [1] https://arxiv.org/pdf/1603.05027.pdf 62 | 63 | Args: 64 | in_features: Number of input features. 65 | n_layers: Number of layers in residual pathway. 66 | dropout_frac: Probability of an element to be zeroed. 67 | """ 68 | super().__init__() 69 | 70 | layers = [] 71 | for i in range(n_layers): 72 | layers += [ 73 | nn.BatchNorm1d(in_features), 74 | nn.Dropout(p=dropout_frac), 75 | nn.LeakyReLU(), 76 | nn.Linear( 77 | in_features=in_features, 78 | out_features=in_features, 79 | ), 80 | ] 81 | 82 | self.model = nn.Sequential(*layers) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | """Run model forward""" 86 | return self.model(x) + x 87 | -------------------------------------------------------------------------------- /tests/test_data/data_config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Test config for PVNet 3 | name: pvnet_test 4 | 5 | input_data: 6 | 7 | generation: 8 | zarr_path: set_in_temp_file 9 | interval_start_minutes: -60 10 | interval_end_minutes: 480 11 | time_resolution_minutes: 30 12 | dropout_timedeltas_minutes: [] 13 | dropout_fraction: 0 14 | 15 | nwp: 16 | ukv: 17 | provider: ukv 18 | zarr_path: set_in_temp_file 19 | interval_start_minutes: -120 20 | interval_end_minutes: 480 21 | time_resolution_minutes: 60 22 | channels: ["si10", "dswrf", "t", "prate"] 23 | image_size_pixels_height: 24 24 | image_size_pixels_width: 24 25 | dropout_timedeltas_minutes: [-180] 26 | dropout_fraction: 1.0 27 | max_staleness_minutes: null 28 | normalisation_constants: 29 | si10: 30 | mean: 1 31 | std: 1 32 | dswrf: 33 | mean: 1 34 | std: 1 35 | t: 36 | mean: 1 37 | std: 1 38 | prate: 39 | mean: 1 40 | std: 1 41 | 42 | ecmwf: 43 | provider: ecmwf 44 | zarr_path: set_in_temp_file 45 | interval_start_minutes: -120 46 | interval_end_minutes: 480 47 | time_resolution_minutes: 60 48 | channels: ["t2m", "dswrf", "mcc"] 49 | image_size_pixels_height: 12 50 | image_size_pixels_width: 12 51 | dropout_timedeltas_minutes: [-180] 52 | dropout_fraction: 1.0 53 | max_staleness_minutes: null 54 | normalisation_constants: 55 | t2m: 56 | mean: 1 57 | std: 1 58 | dswrf: 59 | mean: 1 60 | std: 1 61 | mcc: 62 | mean: 1 63 | std: 1 64 | 65 | satellite: 66 | zarr_path: set_in_temp_file 67 | interval_start_minutes: -30 68 | interval_end_minutes: 0 69 | time_resolution_minutes: 5 70 | 71 | image_size_pixels_height: 24 72 | image_size_pixels_width: 24 73 | dropout_timedeltas_minutes: [] 74 | dropout_fraction: 0 75 | 76 | channels: 77 | - IR_016 78 | - IR_039 79 | - IR_087 80 | - IR_097 81 | - IR_108 82 | - IR_120 83 | - IR_134 84 | - VIS006 85 | - VIS008 86 | - WV_062 87 | - WV_073 88 | 89 | normalisation_constants: 90 | IR_016: 91 | mean: 1 92 | std: 1 93 | IR_039: 94 | mean: 1 95 | std: 1 96 | IR_087: 97 | mean: 1 98 | std: 1 99 | IR_097: 100 | mean: 1 101 | std: 1 102 | IR_108: 103 | mean: 1 104 | std: 1 105 | IR_120: 106 | mean: 1 107 | std: 1 108 | IR_134: 109 | mean: 1 110 | std: 1 111 | VIS006: 112 | mean: 1 113 | std: 1 114 | VIS008: 115 | mean: 1 116 | std: 1 117 | WV_062: 118 | mean: 1 119 | std: 1 120 | WV_073: 121 | mean: 1 122 | std: 1 123 | 124 | solar_position: 125 | interval_start_minutes: -60 126 | interval_end_minutes: 480 127 | time_resolution_minutes: 30 128 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for image sequence encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta): 9 | """Abstract class for NWP/satellite encoder. 10 | 11 | The encoder will take an input of shape (batch_size, sequence_length, channels, height, width) 12 | and return an output of shape (batch_size, out_features). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sequence_length: int, 18 | image_size_pixels: int, 19 | in_channels: int, 20 | out_features: int, 21 | ): 22 | """Abstract class for NWP/satellite encoder. 23 | 24 | Args: 25 | sequence_length: The time sequence length of the data. 26 | image_size_pixels: The spatial size of the image. Assumed square. 27 | in_channels: Number of input channels. 28 | out_features: Number of output features. 29 | """ 30 | super().__init__() 31 | self.out_features = out_features 32 | self.image_size_pixels = image_size_pixels 33 | self.sequence_length = sequence_length 34 | self.in_channels = in_channels 35 | 36 | @abstractmethod 37 | def forward(self): 38 | """Run model forward""" 39 | pass 40 | 41 | 42 | class ResidualConv3dBlock(nn.Module): 43 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 44 | 45 | This was the best performing residual block tested in the study. This implementation differs 46 | from that block just by using LeakyReLU activation to avoid dead neurons, and by including 47 | optional dropout in the residual branch. This is also a 3D fully connected layer residual block 48 | rather than a 2D convolutional block. 49 | 50 | Sources: 51 | [1] https://arxiv.org/pdf/1603.05027.pdf 52 | """ 53 | 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | n_layers: int = 2, 58 | dropout_frac: float = 0.0, 59 | batch_norm: bool = True, 60 | ): 61 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 62 | 63 | Sources: 64 | [1] https://arxiv.org/pdf/1603.05027.pdf 65 | 66 | Args: 67 | in_channels: Number of input channels. 68 | n_layers: Number of layers in residual pathway. 69 | dropout_frac: Probability of an element to be zeroed. 70 | batch_norm: Whether to use batchnorm 71 | """ 72 | super().__init__() 73 | 74 | layers = [] 75 | for i in range(n_layers): 76 | if batch_norm: 77 | layers.append(nn.BatchNorm3d(in_channels)) 78 | layers.extend( 79 | [ 80 | nn.Dropout3d(p=dropout_frac), 81 | nn.LeakyReLU(), 82 | nn.Conv3d( 83 | in_channels=in_channels, 84 | out_channels=in_channels, 85 | kernel_size=(3, 3, 3), 86 | padding=(1, 1, 1), 87 | ), 88 | ] 89 | ) 90 | 91 | self.model = nn.Sequential(*layers) 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | """Run model forward""" 95 | return self.model(x) + x 96 | -------------------------------------------------------------------------------- /configs.example/model/late_fusion.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.training.lightning_module.PVNetLightningModule 2 | 3 | model: 4 | _target_: pvnet.models.LateFusionModel 5 | output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] 6 | 7 | #-------------------------------------------- 8 | # NWP encoder 9 | #-------------------------------------------- 10 | 11 | nwp_encoders_dict: 12 | ukv: 13 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 14 | _partial_: True 15 | in_channels: 2 16 | out_features: 256 17 | number_of_conv3d_layers: 6 18 | conv3d_channels: 32 19 | image_size_pixels: 24 20 | ecmwf: 21 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 22 | _partial_: True 23 | in_channels: 12 24 | out_features: 256 25 | number_of_conv3d_layers: 4 26 | conv3d_channels: 32 27 | image_size_pixels: 12 28 | 29 | #-------------------------------------------- 30 | # Sat encoder settings 31 | #-------------------------------------------- 32 | 33 | sat_encoder: 34 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 35 | _partial_: True 36 | in_channels: 11 37 | out_features: 256 38 | number_of_conv3d_layers: 6 39 | conv3d_channels: 32 40 | image_size_pixels: 24 41 | 42 | add_image_embedding_channel: False 43 | 44 | #-------------------------------------------- 45 | # PV encoder settings 46 | #-------------------------------------------- 47 | 48 | pv_encoder: 49 | _target_: pvnet.models.late_fusion.site_encoders.encoders.SingleAttentionNetwork 50 | _partial_: True 51 | num_sites: 349 52 | out_features: 40 53 | num_heads: 4 54 | kdim: 40 55 | id_embed_dim: 20 56 | 57 | #-------------------------------------------- 58 | # Tabular network settings 59 | #-------------------------------------------- 60 | 61 | output_network: 62 | _target_: pvnet.models.late_fusion.linear_networks.networks.ResFCNet 63 | _partial_: True 64 | fc_hidden_features: 128 65 | n_res_blocks: 6 66 | res_block_layers: 2 67 | dropout_frac: 0.0 68 | 69 | embedding_dim: 16 70 | include_sun: True 71 | include_generation_history: False 72 | 73 | # The mapping between the location IDs and their embedding indices 74 | location_id_mapping: 75 | 1: 1 76 | 5: 2 77 | 110: 3 78 | # ... 79 | 80 | #-------------------------------------------- 81 | # Times 82 | #-------------------------------------------- 83 | 84 | # Foreast and time settings 85 | forecast_minutes: 480 86 | history_minutes: 120 87 | 88 | min_sat_delay_minutes: 60 89 | 90 | # These must also be set even if identical to forecast_minutes and history_minutes 91 | sat_history_minutes: 90 92 | pv_history_minutes: 180 93 | 94 | # These must be set for each NWP encoder 95 | nwp_history_minutes: 96 | ukv: 120 97 | ecmwf: 120 98 | nwp_forecast_minutes: 99 | ukv: 480 100 | ecmwf: 480 101 | # Optional; defaults to 60, so must be set for data with different time resolution 102 | nwp_interval_minutes: 103 | ukv: 60 104 | ecmwf: 60 105 | 106 | # ---------------------------------------------- 107 | # Optimizer 108 | # ---------------------------------------------- 109 | optimizer: 110 | _target_: pvnet.optimizers.EmbAdamWReduceLROnPlateau 111 | lr: 0.0001 112 | weight_decay: 0.01 113 | amsgrad: True 114 | patience: 5 115 | factor: 0.1 116 | threshold: 0.002 117 | -------------------------------------------------------------------------------- /tests/training/test_train.py: -------------------------------------------------------------------------------- 1 | """Tests and fixtures for CPU-only Trainer and offline W&B logging.""" 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | from omegaconf import DictConfig 7 | 8 | from pvnet.training.train import train as pvnet_train 9 | 10 | 11 | @pytest.fixture() 12 | def wandb_save_dir(session_tmp_path) -> str: 13 | """Return W&B save dir under session temp path.""" 14 | save_dir = str(session_tmp_path / "wandb") 15 | return save_dir 16 | 17 | 18 | @pytest.fixture() 19 | def trainer_cfg_cpu() -> dict: 20 | """Tiny CPU-only Trainer config.""" 21 | return { 22 | "_target_": "lightning.pytorch.Trainer", 23 | "max_epochs": 1, 24 | "limit_train_batches": 1, 25 | "limit_val_batches": 1, 26 | "accelerator": "cpu", 27 | "enable_checkpointing": True, 28 | "log_every_n_steps": 1, 29 | "enable_progress_bar": False, 30 | } 31 | 32 | 33 | @pytest.fixture() 34 | def logger_cfg(wandb_save_dir: str) -> dict: 35 | """W&B logger config.""" 36 | return { 37 | "wandb": { 38 | "_target_": "lightning.pytorch.loggers.wandb.WandbLogger", 39 | "project": "pvnet-tests", 40 | "save_dir": wandb_save_dir, 41 | "offline": True, 42 | "name": "train-offline-integration", 43 | "log_model": False, 44 | } 45 | } 46 | 47 | 48 | @pytest.fixture() 49 | def ckpt_cfg(wandb_save_dir: str) -> dict: 50 | """ModelCheckpoint config.""" 51 | return { 52 | "ckpt": { 53 | "_target_": "lightning.pytorch.callbacks.ModelCheckpoint", 54 | "dirpath": str(Path(wandb_save_dir).parent / "ckpts"), 55 | "save_last": True, 56 | "save_top_k": 1, 57 | "monitor": "MAE/val", 58 | "mode": "min", 59 | } 60 | } 61 | 62 | 63 | def build_lit_late_fusion_cfg( 64 | interval_minutes: int, 65 | include_time: bool, 66 | forecast_minutes: int = 480, 67 | history_minutes: int = 60, 68 | ) -> dict: 69 | """Build config for PVNetLightningModule + minimal LateFusionModel.""" 70 | return { 71 | "_target_": "pvnet.training.lightning_module.PVNetLightningModule", 72 | "model": { 73 | "_target_": "pvnet.models.LateFusionModel", 74 | "sat_encoder": None, 75 | "nwp_encoders_dict": None, 76 | "add_image_embedding_channel": False, 77 | "pv_encoder": None, 78 | "output_network": { 79 | "_target_": "pvnet.models.late_fusion.linear_networks.networks.ResFCNet", 80 | "_partial_": True, 81 | "fc_hidden_features": 128, 82 | "n_res_blocks": 6, 83 | "res_block_layers": 2, 84 | "dropout_frac": 0.0, 85 | }, 86 | "location_id_mapping": None, 87 | "embedding_dim": None, 88 | "include_sun": False, 89 | "include_time": include_time, 90 | "include_generation_history": True, 91 | "forecast_minutes": forecast_minutes, 92 | "history_minutes": history_minutes, 93 | "interval_minutes": interval_minutes, 94 | }, 95 | "optimizer": { 96 | "_target_": "pvnet.optimizers.Adam", 97 | "lr": 1e-3, 98 | }, 99 | "save_all_validation_results": False, 100 | } 101 | 102 | def test_train_pvnet( 103 | data_config_path, 104 | trainer_cfg_cpu, 105 | logger_cfg, 106 | ckpt_cfg, 107 | ): 108 | """Train pvnet model with W&B offline.""" 109 | cfg = DictConfig({ 110 | "seed": 42, 111 | "datamodule": { 112 | "_target_": "pvnet.datamodule.PVNetDataModule", 113 | "configuration": str(data_config_path), 114 | "batch_size": 2, 115 | "num_workers": 0, 116 | "prefetch_factor": None, 117 | }, 118 | "model": build_lit_late_fusion_cfg( 119 | interval_minutes=30, 120 | include_time=False, 121 | ), 122 | "logger": logger_cfg, 123 | "callbacks": ckpt_cfg, 124 | "trainer": trainer_cfg_cpu, 125 | }) 126 | 127 | pvnet_train(cfg) 128 | -------------------------------------------------------------------------------- /pvnet/load_model.py: -------------------------------------------------------------------------------- 1 | """Load a model from its checkpoint directory""" 2 | 3 | import glob 4 | import os 5 | 6 | import hydra 7 | import torch 8 | import yaml 9 | 10 | from pvnet.models.ensemble import Ensemble 11 | from pvnet.utils import ( 12 | DATA_CONFIG_NAME, 13 | DATAMODULE_CONFIG_NAME, 14 | FULL_CONFIG_NAME, 15 | MODEL_CONFIG_NAME, 16 | ) 17 | 18 | 19 | def get_model_from_checkpoints( 20 | checkpoint_dir_paths: list[str], 21 | val_best: bool = True, 22 | ) -> tuple[torch.nn.Module, dict, str, str | None, str | None]: 23 | """Load a model from its checkpoint directory 24 | 25 | Returns: 26 | tuple: 27 | model: nn.Module of pretrained model. 28 | model_config: path to model config used to train the model. 29 | data_config: path to data config used to create samples for the model. 30 | datamodule_config: path to datamodule used to create samples e.g train/test split info. 31 | experiment_configs: path to the full experimental config. 32 | 33 | """ 34 | is_ensemble = len(checkpoint_dir_paths) > 1 35 | 36 | model_configs = [] 37 | models = [] 38 | data_configs = [] 39 | datamodule_configs = [] 40 | experiment_configs = [] 41 | 42 | for path in checkpoint_dir_paths: 43 | 44 | # Load lightning training module 45 | with open(f"{path}/{MODEL_CONFIG_NAME}") as cfg: 46 | model_config = yaml.load(cfg, Loader=yaml.FullLoader) 47 | 48 | lightning_module = hydra.utils.instantiate(model_config) 49 | 50 | if val_best: 51 | # Only one epoch (best) saved per model 52 | files = glob.glob(f"{path}/epoch*.ckpt") 53 | if len(files) != 1: 54 | raise ValueError( 55 | f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one." 56 | ) 57 | # TODO: Loading with weights_only=False is not recommended 58 | checkpoint = torch.load(files[0], map_location="cpu", weights_only=True) 59 | else: 60 | checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=True) 61 | 62 | lightning_module.load_state_dict(state_dict=checkpoint["state_dict"]) 63 | 64 | # Extract the model from the lightning module 65 | models.append(lightning_module.model) 66 | model_configs.append(model_config["model"]) 67 | 68 | # Store the data config used for the model 69 | data_config = f"{path}/{DATA_CONFIG_NAME}" 70 | 71 | if os.path.isfile(data_config): 72 | data_configs.append(data_config) 73 | else: 74 | raise FileNotFoundError(f"File {data_config} does not exist") 75 | 76 | # TODO: This should be removed in a future release since no new models will be trained on 77 | # presaved samples 78 | # Check for datamodule config 79 | # This only exists if the model was trained with presaved samples 80 | datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}" 81 | if os.path.isfile(datamodule_config): 82 | datamodule_configs.append(datamodule_config) 83 | else: 84 | datamodule_configs.append(None) 85 | 86 | # Check for experiment config 87 | # For backwards compatibility - this might always exist 88 | experiment_config = f"{path}/{FULL_CONFIG_NAME}" 89 | if os.path.isfile(datamodule_config): 90 | experiment_configs.append(experiment_config) 91 | else: 92 | experiment_configs.append(None) 93 | 94 | if is_ensemble: 95 | model_config = { 96 | "_target_": "pvnet.models.ensemble.Ensemble", 97 | "model_list": model_configs, 98 | } 99 | model = Ensemble(model_list=models) 100 | 101 | else: 102 | model_config = model_configs[0] 103 | model = models[0] 104 | 105 | # Assume if using an ensemble that the members were trained on the same input data 106 | data_config = data_configs[0] 107 | datamodule_config = datamodule_configs[0] 108 | 109 | # TODO: How should we save the experimental configs if we had an ensemble? 110 | experiment_config = experiment_configs[0] 111 | 112 | return model, model_config, data_config, datamodule_config, experiment_config 113 | -------------------------------------------------------------------------------- /scripts/checkpoint_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """Command line tool to push locally save model checkpoints to huggingface 2 | 3 | To use this script, you will need to write a custom model card. You can copy and fill out 4 | `pvnet/model_cards/empty_model_card_template.md` to get you started. 5 | 6 | These model cards should not be added to and version controlled in the repo since they are specific 7 | to each user. 8 | 9 | Then run using: 10 | 11 | ``` 12 | python checkpoint_to_huggingface.py "path/to/model/checkpoints" \ 13 | --huggingface-repo="openclimatefix/pvnet_uk_region" \ 14 | --wandb-repo="openclimatefix/pvnet2.1" \ 15 | --card-template-path="pvnet/models/model_cards/my_custom_model_card.md" \ 16 | --local-path="~/tmp/this_model" \ 17 | --no-push-to-hub 18 | ``` 19 | """ 20 | 21 | import tempfile 22 | 23 | import typer 24 | import wandb 25 | 26 | from pvnet.load_model import get_model_from_checkpoints 27 | 28 | app = typer.Typer(pretty_exceptions_show_locals=False) 29 | 30 | 31 | @app.command() 32 | def push_to_huggingface( 33 | checkpoint_dir_paths: list[str] = typer.Argument(...,), 34 | huggingface_repo: str = typer.Option(..., "--huggingface-repo"), 35 | wandb_repo: str = typer.Option(..., "--wandb-repo"), 36 | card_template_path: str = typer.Option(..., "--card-template-path"), 37 | wandb_ids: list[str] = typer.Option([], "--wandb-id"), 38 | val_best: bool = typer.Option(True), 39 | local_path: str = typer.Option(None, "--local-path"), 40 | push_to_hub: bool = typer.Option(True), 41 | ): 42 | """Push a local model to a huggingface model repo 43 | 44 | Args: 45 | checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) 46 | huggingface_repo: Name of the HuggingFace repo to push the model to 47 | wandb_repo: Name of the wandb repo which has training logs 48 | card_template_path: Path to the model card template. 49 | wandb_ids: The wandb ID code(s) - if not filled out these are taken 50 | val_best: Use best model according to val loss, else last saved model 51 | local_path: Where to save the local copy of the model 52 | push_to_hub: Whether to push the model to the hub or just create local version. 53 | """ 54 | 55 | assert push_to_hub or local_path is not None 56 | 57 | is_ensemble = len(checkpoint_dir_paths) > 1 58 | 59 | # Check that the wandb-IDs are correct 60 | all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)] 61 | 62 | # If the IDs are not supplied try and pull them from the checkpoint dir name 63 | if wandb_ids == []: 64 | for path in checkpoint_dir_paths: 65 | dirname = path.split("/")[-1] 66 | if dirname in all_wandb_ids: 67 | wandb_ids.append(dirname) 68 | else: 69 | raise Exception(f"Could not find wand run for {path} within {wandb_repo}") 70 | 71 | # Else if they are provided check that they exist 72 | else: 73 | for wandb_id in wandb_ids: 74 | if wandb_id not in all_wandb_ids: 75 | raise Exception(f"Could not find wand run for {path} within {wandb_repo}") 76 | 77 | ( 78 | model, 79 | model_config, 80 | data_config_path, 81 | datamodule_config_path, 82 | experiment_config_path, 83 | ) = get_model_from_checkpoints(checkpoint_dir_paths, val_best) 84 | 85 | if not is_ensemble: 86 | wandb_ids = wandb_ids[0] 87 | 88 | # Push to hub 89 | if local_path is None: 90 | temp_dir = tempfile.TemporaryDirectory() 91 | model_output_dir = temp_dir.name 92 | else: 93 | model_output_dir = local_path 94 | 95 | model.save_pretrained( 96 | save_directory=model_output_dir, 97 | model_config=model_config, 98 | data_config_path=data_config_path, 99 | datamodule_config_path=datamodule_config_path, 100 | experiment_config_path=experiment_config_path, 101 | wandb_repo=wandb_repo, 102 | wandb_ids=wandb_ids, 103 | card_template_path=card_template_path, 104 | push_to_hub=push_to_hub, 105 | hf_repo_id=huggingface_repo if push_to_hub else None, 106 | ) 107 | 108 | if local_path is None: 109 | temp_dir.cleanup() 110 | 111 | 112 | if __name__ == "__main__": 113 | app() 114 | -------------------------------------------------------------------------------- /pvnet/training/train.py: -------------------------------------------------------------------------------- 1 | """Training""" 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import hydra 7 | from lightning.pytorch import ( 8 | Callback, 9 | LightningDataModule, 10 | LightningModule, 11 | Trainer, 12 | seed_everything, 13 | ) 14 | from lightning.pytorch.callbacks import ModelCheckpoint 15 | from lightning.pytorch.loggers import Logger, WandbLogger 16 | from omegaconf import DictConfig, OmegaConf 17 | 18 | from pvnet.utils import ( 19 | DATA_CONFIG_NAME, 20 | FULL_CONFIG_NAME, 21 | MODEL_CONFIG_NAME, 22 | ) 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def resolve_monitor_loss(output_quantiles: list | None) -> str: 28 | """Return the desired metric to monitor based on whether quantile regression is being used. 29 | 30 | The adds the option to use something like: 31 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 32 | 33 | in early stopping and model checkpoint callbacks so the callbacks config does not need to be 34 | modified depending on whether quantile regression is being used or not. 35 | """ 36 | if output_quantiles is None: 37 | return "MAE/val" 38 | else: 39 | return "quantile_loss/val" 40 | 41 | 42 | OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss) 43 | 44 | 45 | def train(config: DictConfig) -> None: 46 | """Contains training pipeline. 47 | 48 | Instantiates all PyTorch Lightning objects from config. 49 | 50 | Args: 51 | config (DictConfig): Configuration composed by Hydra. 52 | """ 53 | 54 | # Set seed for random number generators in pytorch, numpy and python.random 55 | if "seed" in config: 56 | seed_everything(config.seed, workers=True) 57 | 58 | # Init lightning datamodule 59 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 60 | 61 | # Init lightning model 62 | model: LightningModule = hydra.utils.instantiate(config.model) 63 | 64 | # Init lightning loggers 65 | loggers: list[Logger] = [] 66 | if "logger" in config: 67 | for _, lg_conf in config.logger.items(): 68 | loggers.append(hydra.utils.instantiate(lg_conf)) 69 | 70 | # Init lightning callbacks 71 | callbacks: list[Callback] = [] 72 | if "callbacks" in config: 73 | for _, cb_conf in config.callbacks.items(): 74 | callbacks.append(hydra.utils.instantiate(cb_conf)) 75 | 76 | # Align the wandb id with the checkpoint path 77 | # - only works if wandb logger and model checkpoint used 78 | # - this makes it easy to push the model to huggingface 79 | use_wandb_logger = False 80 | for logger in loggers: 81 | if isinstance(logger, WandbLogger): 82 | use_wandb_logger = True 83 | wandb_logger = logger 84 | break 85 | 86 | # Set the output directory based in the wandb-id of the run 87 | if use_wandb_logger: 88 | for callback in callbacks: 89 | if isinstance(callback, ModelCheckpoint): 90 | # Calling the .experiment property instantiates a wandb run 91 | wandb_id = wandb_logger.experiment.id 92 | 93 | # Save the run results to the expected parent folder but with the folder name 94 | # set by the wandb ID 95 | save_dir = "/".join(callback.dirpath.split("/")[:-1] + [wandb_id]) 96 | 97 | callback.dirpath = save_dir 98 | 99 | # Save the model config 100 | os.makedirs(save_dir, exist_ok=True) 101 | OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}") 102 | 103 | # Save the data config to the output directory and to wandb 104 | data_config = config.datamodule.configuration 105 | shutil.copyfile(data_config, f"{save_dir}/{DATA_CONFIG_NAME}") 106 | wandb_logger.experiment.save(f"{save_dir}/{DATA_CONFIG_NAME}", base_path=save_dir) 107 | 108 | # Save the full hydra config to the output directory and to wandb 109 | OmegaConf.save(config, f"{save_dir}/{FULL_CONFIG_NAME}") 110 | wandb_logger.experiment.save(f"{save_dir}/{FULL_CONFIG_NAME}", base_path=save_dir) 111 | 112 | break 113 | 114 | trainer: Trainer = hydra.utils.instantiate( 115 | config.trainer, 116 | logger=loggers, 117 | _convert_="partial", 118 | callbacks=callbacks, 119 | ) 120 | 121 | # Train the model completely 122 | trainer.fit(model=model, datamodule=datamodule) 123 | -------------------------------------------------------------------------------- /scripts/mae_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate analysis of MAE values for multiple model forecasts 3 | 4 | Does this for 48 hour horizon forecasts with 15 minute granularity 5 | 6 | """ 7 | 8 | import argparse 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | import wandb 15 | 16 | matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( 17 | color=[ 18 | "FFD053", # yellow 19 | "7BCDF3", # blue 20 | "63BCAF", # teal 21 | "086788", # dark blue 22 | "FF9736", # dark orange 23 | "E4E4E4", # grey 24 | "14120E", # black 25 | "FFAC5F", # orange 26 | "4C9A8E", # dark teal 27 | ] 28 | ) 29 | 30 | 31 | def main(project: str, runs: list[str], run_names: list[str]) -> None: 32 | """ 33 | Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity 34 | 35 | Args: 36 | project: name of W&B project 37 | runs: W&B ids of runs 38 | run_names: user specified names for runs 39 | 40 | """ 41 | api = wandb.Api() 42 | dfs = [] 43 | epoch_num = [] 44 | for run in runs: 45 | run = api.run(f"openclimatefix/{project}/{run}") 46 | 47 | df = run.history(samples=run.lastHistoryStep + 1) 48 | # Get the columns that are in the format 'val_step_MAE/step_' 49 | mae_cols = [col for col in df.columns if "val_step_MAE/step_" in col] 50 | # Sort them 51 | mae_cols.sort() 52 | df = df[mae_cols] 53 | # Get last non-NaN value 54 | # Drop all rows with all NaNs 55 | df = df.dropna(how="all") 56 | # Select the last row 57 | # Get average across entire row, and get the IDX for the one with the smallest values 58 | min_row_mean = np.inf 59 | for idx, (row_idx, row) in enumerate(df.iterrows()): 60 | if row.mean() < min_row_mean: 61 | min_row_mean = row.mean() 62 | min_row_idx = idx 63 | df = df.iloc[min_row_idx] 64 | # Calculate the timedelta for each group 65 | # Get the step from the column name 66 | column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] 67 | dfs.append(df) 68 | epoch_num.append(min_row_idx) 69 | # Get the timedelta for each group 70 | groupings = [ 71 | [0, 0], 72 | [15, 15], 73 | [30, 45], 74 | [45, 60], 75 | [60, 120], 76 | [120, 240], 77 | [240, 360], 78 | [360, 480], 79 | [480, 720], 80 | [720, 1440], 81 | [1440, 2880], 82 | ] 83 | 84 | groups_df = [] 85 | grouping_starts = [grouping[0] for grouping in groupings] 86 | header = "| Timestep |" 87 | separator = "| --- |" 88 | for run_name in run_names: 89 | header += f" {run_name} MAE % |" 90 | separator += " --- |" 91 | print(header) 92 | print(separator) 93 | for grouping in groupings: 94 | group_string = f"| {grouping[0]}-{grouping[1]} minutes |" 95 | # Select indicies from column_timesteps that are within the grouping, inclusive 96 | group_idx = [ 97 | idx 98 | for idx, timestep in enumerate(column_timesteps) 99 | if timestep >= grouping[0] and timestep <= grouping[1] 100 | ] 101 | data_one_group = [] 102 | for df in dfs: 103 | mean_row = df.iloc[group_idx].mean() 104 | group_string += f" {mean_row:0.3f} |" 105 | data_one_group.append(mean_row) 106 | print(group_string) 107 | 108 | groups_df.append(data_one_group) 109 | 110 | groups_df = pd.DataFrame(groups_df, columns=run_names, index=grouping_starts) 111 | 112 | for idx, df in enumerate(dfs): 113 | print(f"{run_names[idx]}: {df.mean()*100:0.3f}") 114 | 115 | # Plot the error per timestep 116 | plt.figure() 117 | for idx, df in enumerate(dfs): 118 | plt.plot( 119 | column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" 120 | ) 121 | plt.legend() 122 | plt.xlabel("Timestep (minutes)") 123 | plt.ylabel("MAE %") 124 | plt.title("MAE % for each timestep") 125 | plt.savefig("mae_per_timestep.png") 126 | plt.show() 127 | 128 | # Plot the error per grouped timestep 129 | plt.figure() 130 | for idx, run_name in enumerate(run_names): 131 | plt.plot( 132 | groups_df[run_name], 133 | label=f"{run_name}, epoch: {epoch_num[idx]}", 134 | marker="o", 135 | linestyle="-", 136 | ) 137 | plt.legend() 138 | plt.xlabel("Timestep (minutes)") 139 | plt.ylabel("MAE %") 140 | plt.title("MAE % for each grouped timestep") 141 | plt.savefig("mae_per_grouped_timestep.png") 142 | plt.show() 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument("--project", type=str, default="") 148 | # Add arguments that is a list of strings 149 | parser.add_argument("--list_of_runs", nargs="+") 150 | parser.add_argument("--run_names", nargs="+") 151 | args = parser.parse_args() 152 | main(args.project, args.list_of_runs, args.run_names) 153 | -------------------------------------------------------------------------------- /pvnet/datamodule.py: -------------------------------------------------------------------------------- 1 | """Data module for pytorch lightning""" 2 | 3 | import os 4 | 5 | import numpy as np 6 | from lightning.pytorch import LightningDataModule 7 | from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch 8 | from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch 9 | from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetDataset 10 | from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import batch_to_tensor 11 | from torch.utils.data import DataLoader, Subset 12 | 13 | 14 | def collate_fn(samples: list[NumpySample]) -> TensorBatch: 15 | """Convert a list of NumpySample samples to a tensor batch""" 16 | return batch_to_tensor(stack_np_samples_into_batch(samples)) 17 | 18 | 19 | class PVNetDataModule(LightningDataModule): 20 | """Base Datamodule which streams samples using a sampler from ocf-data-sampler.""" 21 | 22 | def __init__( 23 | self, 24 | configuration: str, 25 | batch_size: int = 16, 26 | num_workers: int = 0, 27 | prefetch_factor: int | None = None, 28 | persistent_workers: bool = False, 29 | pin_memory: bool = False, 30 | train_period: list[str | None] = [None, None], 31 | val_period: list[str | None] = [None, None], 32 | seed: int | None = None, 33 | dataset_pickle_dir: str | None = None, 34 | ): 35 | """Base Datamodule for streaming samples. 36 | 37 | Args: 38 | configuration: Path to ocf-data-sampler configuration file. 39 | batch_size: Batch size. 40 | num_workers: Number of workers to use in multiprocess batch loading. 41 | prefetch_factor: Number of batches loaded in advance by each worker. 42 | persistent_workers: If True, the data loader will not shut down the worker processes 43 | after a dataset has been consumed once. This allows to maintain the workers Dataset 44 | instances alive. 45 | pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory 46 | before returning them. 47 | train_period: Date range filter for train dataloader. 48 | val_period: Date range filter for val dataloader. 49 | seed: Random seed used in shuffling datasets. 50 | dataset_pickle_dir: Directory in which the val and train set will be presaved as 51 | pickle objects. Setting this speeds up instantiation of multiple workers a lot. 52 | """ 53 | super().__init__() 54 | 55 | self.configuration = configuration 56 | self.train_period = train_period 57 | self.val_period = val_period 58 | self.seed = seed 59 | self.dataset_pickle_dir = dataset_pickle_dir 60 | 61 | self._common_dataloader_kwargs = dict( 62 | batch_size=batch_size, 63 | batch_sampler=None, 64 | num_workers=num_workers, 65 | collate_fn=collate_fn, 66 | pin_memory=pin_memory, 67 | drop_last=False, 68 | timeout=0, 69 | worker_init_fn=None, 70 | prefetch_factor=prefetch_factor, 71 | persistent_workers=persistent_workers, 72 | multiprocessing_context="spawn" if num_workers > 0 else None, 73 | ) 74 | 75 | def setup(self, stage: str | None = None): 76 | """Called once to prepare the datasets.""" 77 | 78 | # This logic runs only once at the start of training, therefore the val dataset is only 79 | # shuffled once 80 | if stage == "fit": 81 | # Prepare the train dataset 82 | self.train_dataset = self._get_dataset(*self.train_period) 83 | 84 | # Prepare and pre-shuffle the val dataset and set seed for reproducibility 85 | val_dataset = self._get_dataset(*self.val_period) 86 | 87 | shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset)) 88 | self.val_dataset = Subset(val_dataset, shuffled_indices) 89 | 90 | if self.dataset_pickle_dir is not None: 91 | os.makedirs(self.dataset_pickle_dir, exist_ok=True) 92 | train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl" 93 | val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl" 94 | 95 | # For safety, these pickled datasets cannot be overwritten. 96 | # See: https://github.com/openclimatefix/pvnet/pull/445 97 | for path in [train_dataset_path, val_dataset_path]: 98 | if os.path.exists(path): 99 | raise FileExistsError( 100 | f"The pickled dataset path '{path}' already exists. Make sure that " 101 | "this can be safely deleted (i.e. not currently being used by any " 102 | "training run) and delete it manually. Else change the " 103 | "`dataset_pickle_dir` to a different directory." 104 | ) 105 | 106 | self.train_dataset.presave_pickle(train_dataset_path) 107 | self.train_dataset.presave_pickle(val_dataset_path) 108 | 109 | def teardown(self, stage: str | None = None) -> None: 110 | """Clean up the pickled datasets""" 111 | if self.dataset_pickle_dir is not None: 112 | for filename in ["val_dataset.pkl", "train_dataset.pkl"]: 113 | filepath = f"{self.dataset_pickle_dir}/{filename}" 114 | if os.path.exists(filepath): 115 | os.remove(filepath) 116 | 117 | def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetDataset: 118 | return PVNetDataset(self.configuration, start_time=start_time, end_time=end_time) 119 | 120 | def train_dataloader(self) -> DataLoader: 121 | """Construct train dataloader""" 122 | return DataLoader(self.train_dataset, shuffle=True, **self._common_dataloader_kwargs) 123 | 124 | def val_dataloader(self) -> DataLoader: 125 | """Construct val dataloader""" 126 | return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs) 127 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitType": "docs", 8 | "commitConvention": "angular", 9 | "contributors": [ 10 | { 11 | "login": "felix-e-h-p", 12 | "name": "Felix", 13 | "avatar_url": "https://avatars.githubusercontent.com/u/137530077?v=4", 14 | "profile": "https://github.com/felix-e-h-p", 15 | "contributions": [ 16 | "code" 17 | ] 18 | }, 19 | { 20 | "login": "Sukh-P", 21 | "name": "Sukhil Patel", 22 | "avatar_url": "https://avatars.githubusercontent.com/u/42407101?v=4", 23 | "profile": "https://github.com/Sukh-P", 24 | "contributions": [ 25 | "code" 26 | ] 27 | }, 28 | { 29 | "login": "dfulu", 30 | "name": "James Fulton", 31 | "avatar_url": "https://avatars.githubusercontent.com/u/41546094?v=4", 32 | "profile": "https://github.com/dfulu", 33 | "contributions": [ 34 | "code" 35 | ] 36 | }, 37 | { 38 | "login": "AUdaltsova", 39 | "name": "Alexandra Udaltsova", 40 | "avatar_url": "https://avatars.githubusercontent.com/u/43303448?v=4", 41 | "profile": "https://github.com/AUdaltsova", 42 | "contributions": [ 43 | "code", 44 | "review" 45 | ] 46 | }, 47 | { 48 | "login": "zakwatts", 49 | "name": "Megawattz", 50 | "avatar_url": "https://avatars.githubusercontent.com/u/47150349?v=4", 51 | "profile": "https://github.com/zakwatts", 52 | "contributions": [ 53 | "code" 54 | ] 55 | }, 56 | { 57 | "login": "peterdudfield", 58 | "name": "Peter Dudfield", 59 | "avatar_url": "https://avatars.githubusercontent.com/u/34686298?v=4", 60 | "profile": "https://github.com/peterdudfield", 61 | "contributions": [ 62 | "code" 63 | ] 64 | }, 65 | { 66 | "login": "mahdilamb", 67 | "name": "Mahdi Lamb", 68 | "avatar_url": "https://avatars.githubusercontent.com/u/4696915?v=4", 69 | "profile": "https://github.com/mahdilamb", 70 | "contributions": [ 71 | "infra" 72 | ] 73 | }, 74 | { 75 | "login": "jacobbieker", 76 | "name": "Jacob Prince-Bieker", 77 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 78 | "profile": "https://www.jacobbieker.com", 79 | "contributions": [ 80 | "code" 81 | ] 82 | }, 83 | { 84 | "login": "codderrrrr", 85 | "name": "codderrrrr", 86 | "avatar_url": "https://avatars.githubusercontent.com/u/149995852?v=4", 87 | "profile": "https://github.com/codderrrrr", 88 | "contributions": [ 89 | "code" 90 | ] 91 | }, 92 | { 93 | "login": "confusedmatrix", 94 | "name": "Chris Briggs", 95 | "avatar_url": "https://avatars.githubusercontent.com/u/617309?v=4", 96 | "profile": "https://chrisxbriggs.com", 97 | "contributions": [ 98 | "code" 99 | ] 100 | }, 101 | { 102 | "login": "tmi", 103 | "name": "tmi", 104 | "avatar_url": "https://avatars.githubusercontent.com/u/147159?v=4", 105 | "profile": "https://github.com/tmi", 106 | "contributions": [ 107 | "code" 108 | ] 109 | }, 110 | { 111 | "login": "carderne", 112 | "name": "Chris Arderne", 113 | "avatar_url": "https://avatars.githubusercontent.com/u/19817302?v=4", 114 | "profile": "https://rdrn.me/", 115 | "contributions": [ 116 | "code" 117 | ] 118 | }, 119 | { 120 | "login": "Dakshbir", 121 | "name": "Dakshbir", 122 | "avatar_url": "https://avatars.githubusercontent.com/u/144359831?v=4", 123 | "profile": "https://github.com/Dakshbir", 124 | "contributions": [ 125 | "code" 126 | ] 127 | }, 128 | { 129 | "login": "MAYANK12SHARMA", 130 | "name": "MAYANK SHARMA", 131 | "avatar_url": "https://avatars.githubusercontent.com/u/145884197?v=4", 132 | "profile": "https://github.com/MAYANK12SHARMA", 133 | "contributions": [ 134 | "code" 135 | ] 136 | }, 137 | { 138 | "login": "lambaaryan011", 139 | "name": "aryan lamba ", 140 | "avatar_url": "https://avatars.githubusercontent.com/u/153702847?v=4", 141 | "profile": "https://github.com/lambaaryan011", 142 | "contributions": [ 143 | "code" 144 | ] 145 | }, 146 | { 147 | "login": "michael-gendy", 148 | "name": "michael-gendy", 149 | "avatar_url": "https://avatars.githubusercontent.com/u/64384201?v=4", 150 | "profile": "https://github.com/michael-gendy", 151 | "contributions": [ 152 | "code" 153 | ] 154 | }, 155 | { 156 | "login": "adityasuthar", 157 | "name": "Aditya Suthar", 158 | "avatar_url": "https://avatars.githubusercontent.com/u/95685363?v=4", 159 | "profile": "https://adityasuthar.github.io/", 160 | "contributions": [ 161 | "code" 162 | ] 163 | }, 164 | { 165 | "login": "markus-kreft", 166 | "name": "Markus Kreft", 167 | "avatar_url": "https://avatars.githubusercontent.com/u/129367085?v=4", 168 | "profile": "https://github.com/markus-kreft", 169 | "contributions": [ 170 | "code" 171 | ] 172 | }, 173 | { 174 | "login": "JackKelly", 175 | "name": "Jack Kelly", 176 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4", 177 | "profile": "http://jack-kelly.com", 178 | "contributions": [ 179 | "ideas" 180 | ] 181 | }, 182 | { 183 | "login": "zaryab-ali", 184 | "name": "zaryab-ali", 185 | "avatar_url": "https://avatars.githubusercontent.com/u/85732412?v=4", 186 | "profile": "https://github.com/zaryab-ali", 187 | "contributions": [ 188 | "code" 189 | ] 190 | }, 191 | { 192 | "login": "Lex-Ashu", 193 | "name": "Lex-Ashu", 194 | "avatar_url": "https://avatars.githubusercontent.com/u/181084934?v=4", 195 | "profile": "https://github.com/Lex-Ashu", 196 | "contributions": [ 197 | "code" 198 | ] 199 | } 200 | ], 201 | "contributorsPerLine": 7, 202 | "skipCi": true, 203 | "repoType": "github", 204 | "repoHost": "https://github.com", 205 | "projectName": "pvnet", 206 | "projectOwner": "openclimatefix" 207 | } 208 | -------------------------------------------------------------------------------- /pvnet/utils.py: -------------------------------------------------------------------------------- 1 | """Utils""" 2 | 3 | import logging 4 | from typing import TYPE_CHECKING 5 | 6 | import rich.syntax 7 | import rich.tree 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | if TYPE_CHECKING: 12 | from pvnet.models.base_model import BaseModel 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | PYTORCH_WEIGHTS_NAME = "model_weights.safetensors" 18 | MODEL_CONFIG_NAME = "model_config.yaml" 19 | DATA_CONFIG_NAME = "data_config.yaml" 20 | DATAMODULE_CONFIG_NAME = "datamodule_config.yaml" 21 | FULL_CONFIG_NAME = "full_experiment_config.yaml" 22 | MODEL_CARD_NAME = "README.md" 23 | 24 | 25 | def run_config_utilities(config: DictConfig) -> None: 26 | """A couple of optional utilities. 27 | 28 | Controlled by main config file: 29 | - forcing debug friendly configuration 30 | 31 | Modifies DictConfig in place. 32 | 33 | Args: 34 | config (DictConfig): Configuration composed by Hydra. 35 | """ 36 | 37 | # Enable adding new keys to config 38 | OmegaConf.set_struct(config, False) 39 | 40 | # Force debugger friendly configuration if 41 | if config.trainer.get("fast_dev_run"): 42 | logger.info("Forcing debugger friendly configuration! ") 43 | # Debuggers don't like GPUs or multiprocessing 44 | if config.trainer.get("gpus"): 45 | config.trainer.gpus = 0 46 | if config.datamodule.get("pin_memory"): 47 | config.datamodule.pin_memory = False 48 | if config.datamodule.get("num_workers"): 49 | config.datamodule.num_workers = 0 50 | if config.datamodule.get("prefetch_factor"): 51 | config.datamodule.prefetch_factor = None 52 | 53 | # Disable adding new keys to config 54 | OmegaConf.set_struct(config, True) 55 | 56 | 57 | @rank_zero_only 58 | def print_config( 59 | config: DictConfig, 60 | fields: tuple[str] = ( 61 | "trainer", 62 | "model", 63 | "datamodule", 64 | "callbacks", 65 | "logger", 66 | "seed", 67 | ), 68 | resolve: bool = True, 69 | ) -> None: 70 | """Prints content of DictConfig using Rich library and its tree structure. 71 | 72 | Args: 73 | config (DictConfig): Configuration composed by Hydra. 74 | fields (Sequence[str], optional): Determines which main fields from config will 75 | be printed and in what order. 76 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 77 | """ 78 | 79 | style = "dim" 80 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 81 | 82 | for field in fields: 83 | branch = tree.add(field, style=style, guide_style=style) 84 | 85 | config_section = config.get(field) 86 | 87 | branch_content = str(config_section) 88 | if isinstance(config_section, DictConfig): 89 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 90 | 91 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 92 | 93 | rich.print(tree) 94 | 95 | 96 | def validate_batch_against_config( 97 | batch: dict, 98 | model: "BaseModel", 99 | ) -> None: 100 | """Validates tensor shapes in batch against model configuration.""" 101 | logger.info("Performing batch shape validation against model config.") 102 | 103 | # NWP validation 104 | if hasattr(model, "nwp_encoders_dict"): 105 | if "nwp" not in batch: 106 | raise ValueError( 107 | "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch." 108 | ) 109 | 110 | for source, nwp_data in batch["nwp"].items(): 111 | if source in model.nwp_encoders_dict: 112 | enc = model.nwp_encoders_dict[source] 113 | expected_channels = enc.in_channels 114 | if model.add_image_embedding_channel: 115 | expected_channels -= 1 116 | 117 | expected = ( 118 | nwp_data["nwp"].shape[0], 119 | enc.sequence_length, 120 | expected_channels, 121 | enc.image_size_pixels, 122 | enc.image_size_pixels, 123 | ) 124 | if tuple(nwp_data["nwp"].shape) != expected: 125 | actual_shape = tuple(nwp_data["nwp"].shape) 126 | raise ValueError( 127 | f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}" 128 | ) 129 | 130 | # Satellite validation 131 | if hasattr(model, "sat_encoder"): 132 | if "satellite_actual" not in batch: 133 | raise ValueError( 134 | "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch." 135 | ) 136 | 137 | enc = model.sat_encoder 138 | expected_channels = enc.in_channels 139 | if model.add_image_embedding_channel: 140 | expected_channels -= 1 141 | 142 | expected = ( 143 | batch["satellite_actual"].shape[0], 144 | enc.sequence_length, 145 | expected_channels, 146 | enc.image_size_pixels, 147 | enc.image_size_pixels, 148 | ) 149 | if tuple(batch["satellite_actual"].shape) != expected: 150 | actual_shape = tuple(batch["satellite_actual"].shape) 151 | raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}") 152 | 153 | # generation validation 154 | key = "generation" 155 | if key in batch: 156 | total_minutes = model.history_minutes + model.forecast_minutes 157 | interval = model.interval_minutes 158 | expected_len = total_minutes // interval + 1 159 | expected = (batch[key].shape[0], expected_len) 160 | if tuple(batch[key].shape) != expected: 161 | actual_shape = tuple(batch[key].shape) 162 | raise ValueError( 163 | f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}" 164 | ) 165 | 166 | logger.info("Batch shape validation successful!") 167 | 168 | 169 | def validate_gpu_config(config: DictConfig) -> None: 170 | """Abort if multiple GPUs requested by mistake i.e. `devices: 2` instead of `[2]`.""" 171 | tr = config.get("trainer", {}) 172 | dev = tr.get("devices") 173 | 174 | if isinstance(dev, int) and dev > 1: 175 | raise ValueError( 176 | f"Detected `devices: {dev}` — this requests {dev} GPUs. " 177 | "If you meant a specific GPU (e.g. GPU 2), use `devices: [2]`. " 178 | "Parallel training not supported." 179 | ) 180 | -------------------------------------------------------------------------------- /pvnet/optimizers.py: -------------------------------------------------------------------------------- 1 | """Optimizer factory-function classes. 2 | """ 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | from torch.nn import Module 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | def find_submodule_parameters(model: Module, search_modules: list[Module]) -> list[Parameter]: 12 | """Finds all parameters within given submodule types 13 | 14 | Args: 15 | model: torch Module to search through 16 | search_modules: List of submodule types to search for 17 | """ 18 | if isinstance(model, search_modules): 19 | return model.parameters() 20 | 21 | children = list(model.children()) 22 | if len(children) == 0: 23 | return [] 24 | else: 25 | params = [] 26 | for c in children: 27 | params += find_submodule_parameters(c, search_modules) 28 | return params 29 | 30 | 31 | def find_other_than_submodule_parameters( 32 | model: Module, 33 | ignore_modules: list[Module], 34 | ) -> list[Parameter]: 35 | """Finds all parameters not with given submodule types 36 | 37 | Args: 38 | model: torch Module to search through 39 | ignore_modules: List of submodule types to ignore 40 | """ 41 | if isinstance(model, ignore_modules): 42 | return [] 43 | 44 | children = list(model.children()) 45 | if len(children) == 0: 46 | return model.parameters() 47 | else: 48 | params = [] 49 | for c in children: 50 | params += find_other_than_submodule_parameters(c, ignore_modules) 51 | return params 52 | 53 | 54 | class AbstractOptimizer(ABC): 55 | """Abstract class for optimizer 56 | 57 | Optimizer classes will be used by model like: 58 | > OptimizerGenerator = AbstractOptimizer() 59 | > optimizer = OptimizerGenerator(model) 60 | The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s 61 | `configure_optimizers()` method. 62 | See : 63 | https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers 64 | 65 | """ 66 | 67 | @abstractmethod 68 | def __call__(self): 69 | """Abstract call""" 70 | pass 71 | 72 | 73 | class Adam(AbstractOptimizer): 74 | """Adam optimizer""" 75 | 76 | def __init__(self, lr: float = 0.0005, **kwargs): 77 | """Adam optimizer""" 78 | self.lr = lr 79 | self.kwargs = kwargs 80 | 81 | def __call__(self, model: Module): 82 | """Return optimizer""" 83 | return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs) 84 | 85 | 86 | class AdamW(AbstractOptimizer): 87 | """AdamW optimizer""" 88 | 89 | def __init__(self, lr: float = 0.0005, **kwargs): 90 | """AdamW optimizer""" 91 | self.lr = lr 92 | self.kwargs = kwargs 93 | 94 | def __call__(self, model: Module): 95 | """Return optimizer""" 96 | return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) 97 | 98 | 99 | 100 | class EmbAdamWReduceLROnPlateau(AbstractOptimizer): 101 | """AdamW optimizer and reduce on plateau scheduler""" 102 | 103 | def __init__( 104 | self, 105 | lr: float = 0.0005, 106 | weight_decay: float = 0.01, 107 | patience: int = 3, 108 | factor: float = 0.5, 109 | threshold: float = 2e-4, 110 | **opt_kwargs, 111 | ): 112 | """AdamW optimizer and reduce on plateau scheduler""" 113 | self.lr = lr 114 | self.weight_decay = weight_decay 115 | self.patience = patience 116 | self.factor = factor 117 | self.threshold = threshold 118 | self.opt_kwargs = opt_kwargs 119 | 120 | def __call__(self, model): 121 | """Return optimizer""" 122 | 123 | search_modules = (torch.nn.Embedding,) 124 | 125 | no_decay = find_submodule_parameters(model, search_modules) 126 | decay = find_other_than_submodule_parameters(model, search_modules) 127 | 128 | optim_groups = [ 129 | {"params": decay, "weight_decay": self.weight_decay}, 130 | {"params": no_decay, "weight_decay": 0.0}, 131 | ] 132 | opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs) 133 | 134 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 135 | opt, 136 | factor=self.factor, 137 | patience=self.patience, 138 | threshold=self.threshold, 139 | ) 140 | sch = { 141 | "scheduler": sch, 142 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 143 | } 144 | return [opt], [sch] 145 | 146 | 147 | class AdamWReduceLROnPlateau(AbstractOptimizer): 148 | """AdamW optimizer and reduce on plateau scheduler""" 149 | 150 | def __init__( 151 | self, 152 | lr: float = 0.0005, 153 | patience: int = 3, 154 | factor: float = 0.5, 155 | threshold: float = 2e-4, 156 | step_freq=None, 157 | **opt_kwargs, 158 | ): 159 | """AdamW optimizer and reduce on plateau scheduler""" 160 | self._lr = lr 161 | self.patience = patience 162 | self.factor = factor 163 | self.threshold = threshold 164 | self.step_freq = step_freq 165 | self.opt_kwargs = opt_kwargs 166 | 167 | def _call_multi(self, model): 168 | remaining_params = {k: p for k, p in model.named_parameters()} 169 | 170 | group_args = [] 171 | 172 | for key in self._lr.keys(): 173 | if key == "default": 174 | continue 175 | 176 | submodule_params = [] 177 | for param_name in list(remaining_params.keys()): 178 | if param_name.startswith(key): 179 | submodule_params += [remaining_params.pop(param_name)] 180 | 181 | group_args += [{"params": submodule_params, "lr": self._lr[key]}] 182 | 183 | remaining_params = [p for k, p in remaining_params.items()] 184 | group_args += [{"params": remaining_params}] 185 | 186 | opt = torch.optim.AdamW( 187 | group_args, 188 | lr=self._lr["default"] if model.lr is None else model.lr, 189 | **self.opt_kwargs, 190 | ) 191 | sch = { 192 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 193 | opt, 194 | factor=self.factor, 195 | patience=self.patience, 196 | threshold=self.threshold, 197 | ), 198 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 199 | } 200 | 201 | return [opt], [sch] 202 | 203 | def __call__(self, model): 204 | """Return optimizer""" 205 | if not isinstance(self._lr, float): 206 | return self._call_multi(model) 207 | else: 208 | default_lr = self._lr if model.lr is None else model.lr 209 | opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs) 210 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 211 | opt, 212 | factor=self.factor, 213 | patience=self.patience, 214 | threshold=self.threshold, 215 | ) 216 | sch = { 217 | "scheduler": sch, 218 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 219 | } 220 | return [opt], [sch] 221 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/encoders3d.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for the satellite/NWP data based on 3D concolutions. 2 | """ 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.late_fusion.encoders.basic_blocks import ( 8 | AbstractNWPSatelliteEncoder, 9 | ResidualConv3dBlock, 10 | ) 11 | 12 | 13 | class DefaultPVNet(AbstractNWPSatelliteEncoder): 14 | """This is the original encoding module used in PVNet, with a few minor tweaks.""" 15 | 16 | def __init__( 17 | self, 18 | sequence_length: int, 19 | image_size_pixels: int, 20 | in_channels: int, 21 | out_features: int, 22 | number_of_conv3d_layers: int = 4, 23 | conv3d_channels: int = 32, 24 | fc_features: int = 128, 25 | spatial_kernel_size: int = 3, 26 | temporal_kernel_size: int = 3, 27 | padding: int | tuple[int, ...] = (1, 0, 0), 28 | stride: int | tuple[int, ...] = 1, 29 | ): 30 | """This is the original encoding module used in PVNet, with a few minor tweaks. 31 | 32 | Args: 33 | sequence_length: The time sequence length of the data. 34 | image_size_pixels: The spatial size of the image. Assumed square. 35 | in_channels: Number of input channels. 36 | out_features: Number of output features. 37 | number_of_conv3d_layers: Number of convolution 3d layers that are used. 38 | conv3d_channels: Number of channels used in each conv3d layer. 39 | fc_features: number of output nodes out of the hidden fully connected layer. 40 | spatial_kernel_size: The spatial size of the kernel used in the conv3d layers. 41 | temporal_kernel_size: The temporal size of the kernel used in the conv3d layers. 42 | padding: The padding used in the conv3d layers. If an int, the same padding 43 | is used in all dimensions. The dimensions are (time, space, space) 44 | stride: The stride used in conv3d layers. If an int, the same stride is used 45 | in all dimensions 46 | """ 47 | 48 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 49 | 50 | if isinstance(padding, int): 51 | padding = (padding, padding, padding) 52 | 53 | if isinstance(stride, int): 54 | stride = (stride, stride, stride) 55 | 56 | # Check that the output shape of the convolutional layers will be at least 1x1 57 | cnn_spatial_output_size = image_size_pixels 58 | 59 | for _ in range(number_of_conv3d_layers): 60 | cnn_spatial_output_size = ( 61 | cnn_spatial_output_size - spatial_kernel_size + 2 * padding[1] 62 | ) // stride[1] + 1 63 | 64 | if not (cnn_spatial_output_size >= 1): 65 | raise ValueError( 66 | f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input " 67 | f"spatial size ({image_size_pixels})" 68 | ) 69 | 70 | cnn_sequence_length = ( 71 | sequence_length 72 | - ((temporal_kernel_size - 2 * padding[0]) - 1) * number_of_conv3d_layers 73 | ) 74 | 75 | conv_layers = [] 76 | 77 | conv_layers += [ 78 | nn.Conv3d( 79 | in_channels=in_channels, 80 | out_channels=conv3d_channels, 81 | kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size), 82 | padding=padding, 83 | stride=stride, 84 | ), 85 | nn.ELU(), 86 | ] 87 | for _ in range(0, number_of_conv3d_layers - 1): 88 | conv_layers += [ 89 | nn.Conv3d( 90 | in_channels=conv3d_channels, 91 | out_channels=conv3d_channels, 92 | kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size), 93 | padding=padding, 94 | stride=stride, 95 | ), 96 | nn.ELU(), 97 | ] 98 | 99 | self.conv_layers = nn.Sequential(*conv_layers) 100 | 101 | # Calculate the size of the output of the 3D convolutional layers 102 | cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * cnn_sequence_length 103 | 104 | self.final_block = nn.Sequential( 105 | nn.Linear(in_features=cnn_output_size, out_features=fc_features), 106 | nn.ELU(), 107 | nn.Linear(in_features=fc_features, out_features=out_features), 108 | nn.ELU(), 109 | ) 110 | 111 | def forward(self, x: torch.Tensor) -> torch.Tensor: 112 | """Run model forward""" 113 | out = self.conv_layers(x) 114 | out = out.reshape(x.shape[0], -1) 115 | return self.final_block(out) 116 | 117 | 118 | class ResConv3DNet(AbstractNWPSatelliteEncoder): 119 | """3D convolutional network based on ResNet architecture. 120 | 121 | The residual blocks are implemented based on the best performing block in [1]. 122 | 123 | Sources: 124 | [1] https://arxiv.org/pdf/1603.05027.pdf 125 | """ 126 | 127 | def __init__( 128 | self, 129 | sequence_length: int, 130 | image_size_pixels: int, 131 | in_channels: int, 132 | out_features: int, 133 | hidden_channels: int = 32, 134 | n_res_blocks: int = 4, 135 | res_block_layers: int = 2, 136 | batch_norm: bool = True, 137 | dropout_frac: float = 0.0, 138 | ): 139 | """Fully connected deep network based on ResNet architecture. 140 | 141 | Args: 142 | sequence_length: The time sequence length of the data. 143 | image_size_pixels: The spatial size of the image. Assumed square. 144 | in_channels: Number of input channels. 145 | out_features: Number of output features. 146 | hidden_channels: Number of channels in middle hidden layers. 147 | n_res_blocks: Number of residual blocks to use. 148 | res_block_layers: Number of Conv3D layers used in each residual block. 149 | batch_norm: Whether to include batch normalisation. 150 | dropout_frac: Probability of an element to be zeroed in the residual pathways. 151 | """ 152 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 153 | 154 | model = [ 155 | nn.Conv3d( 156 | in_channels=in_channels, 157 | out_channels=hidden_channels, 158 | kernel_size=(3, 3, 3), 159 | padding=(1, 1, 1), 160 | ), 161 | ] 162 | 163 | for i in range(n_res_blocks): 164 | model.extend( 165 | [ 166 | ResidualConv3dBlock( 167 | in_channels=hidden_channels, 168 | n_layers=res_block_layers, 169 | dropout_frac=dropout_frac, 170 | batch_norm=batch_norm, 171 | ), 172 | nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)), 173 | ] 174 | ) 175 | 176 | # Calculate the size of the output of the 3D convolutional layers 177 | final_im_size = image_size_pixels // (2**n_res_blocks) 178 | cnn_output_size = hidden_channels * sequence_length * final_im_size * final_im_size 179 | 180 | model.extend( 181 | [ 182 | nn.ELU(), 183 | nn.Flatten(start_dim=1, end_dim=-1), 184 | nn.Linear(in_features=cnn_output_size, out_features=out_features), 185 | nn.ELU(), 186 | ] 187 | ) 188 | 189 | self.model = nn.Sequential(*model) 190 | 191 | def forward(self, x: torch.Tensor) -> torch.Tensor: 192 | """Run model forward""" 193 | return self.model(x) 194 | -------------------------------------------------------------------------------- /scripts/migrate_old_model.py: -------------------------------------------------------------------------------- 1 | """Script to migrate old PVNet models which are hosted on huggingface to current version 2 | 3 | This script can be used to update models from version >= v4.1 4 | """ 5 | 6 | import datetime 7 | import os 8 | import tempfile 9 | from importlib.metadata import version 10 | 11 | import torch 12 | import yaml 13 | from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, file_exists 14 | from safetensors.torch import save_file 15 | 16 | from pvnet.models.base_model import BaseModel 17 | from pvnet.utils import DATA_CONFIG_NAME, MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME 18 | 19 | # ------------------------------------------ 20 | # USER SETTINGS 21 | 22 | # The huggingface commit of the model you want to update 23 | repo_id: str = "openclimatefix-models/pvnet_uk_region" 24 | revision: str = "30c406ea8455d9d43aa1284cba23c3a59102f549" 25 | 26 | # The local directory which will be downloaded to 27 | # If set to None a temporary directory will be used 28 | local_dir: str | None = None 29 | 30 | # Whether to upload the migrated model back to the huggingface - else just saved locally 31 | upload: bool = False 32 | 33 | # ------------------------------------------ 34 | # SETUP 35 | 36 | if local_dir is None: 37 | temp_dir = tempfile.TemporaryDirectory() 38 | save_dir = temp_dir.name 39 | 40 | else: 41 | os.makedirs(local_dir, exist_ok=False) 42 | save_dir = local_dir 43 | 44 | # Set up huggingface API 45 | api = HfApi() 46 | 47 | # Download the model repo 48 | _ = api.snapshot_download( 49 | repo_id=repo_id, 50 | revision=revision, 51 | local_dir=save_dir, 52 | force_download=True, 53 | ) 54 | 55 | # ------------------------------------------ 56 | # MIGRATION STEPS 57 | 58 | # Modify the model config 59 | with open(f"{save_dir}/{MODEL_CONFIG_NAME}") as cfg: 60 | model_config = yaml.load(cfg, Loader=yaml.FullLoader) 61 | 62 | # Get rid of the optimiser - we don't store this anymore 63 | if "optimizer" in model_config: 64 | del model_config["optimizer"] 65 | 66 | # This parameter has been moved out of the model to the pytorch lightning module 67 | if "save_validation_results_csv" in model_config: 68 | del model_config["save_validation_results_csv"] 69 | 70 | # This parameter has been removed 71 | if "adapt_batches" in model_config: 72 | del model_config["adapt_batches"] 73 | 74 | # This parameter has been removed 75 | if "target_key" in model_config: 76 | if model_config["target_key"] == "site": 77 | if "include_site_yield_history" in model_config: 78 | model_config["include_generation_history"] = model_config.pop( 79 | "include_site_yield_history" 80 | ) 81 | 82 | if model_config["target_key"] == "gsp" or not model_config["target_key"]: 83 | if "include_site_yield_history" in model_config: 84 | del model_config["include_site_yield_history"] 85 | 86 | del model_config["target_key"] 87 | 88 | if "include_gsp_yield_history" in model_config: 89 | # Set to false on all current gsp models and now defaults to false 90 | # Also false for all site models so can be removed 91 | del model_config["include_gsp_yield_history"] 92 | 93 | # Rename the top level model 94 | if model_config["_target_"] == "pvnet.models.multimodal.multimodal.Model": 95 | model_config["_target_"] = "pvnet.models.LateFusionModel" 96 | elif model_config["_target_"] == "pvnet.models.LateFusionModel": 97 | pass 98 | else: 99 | raise Exception("Unknown model: " + model_config["_target_"]) 100 | 101 | # Re-find the model components in the new package structure 102 | if model_config.get("nwp_encoders_dict", None) is not None: 103 | for k, v in model_config["nwp_encoders_dict"].items(): 104 | v["_target_"] = ( 105 | v["_target_"] 106 | .replace("multimodal", "late_fusion") 107 | .replace("ResConv3DNet2", "ResConv3DNet") 108 | ) 109 | 110 | 111 | for component in ["sat_encoder", "pv_encoder", "output_network"]: 112 | if model_config.get(component, None) is not None: 113 | model_config[component]["_target_"] = ( 114 | model_config[component]["_target_"] 115 | .replace("multimodal", "late_fusion") 116 | .replace("ResConv3DNet2", "ResConv3DNet") 117 | .replace("ResFCNet2", "ResFCNet") 118 | ) 119 | 120 | with open(f"{save_dir}/{MODEL_CONFIG_NAME}", "w") as f: 121 | yaml.dump(model_config, f, sort_keys=False, default_flow_style=False) 122 | 123 | # Modify the data config 124 | with open(f"{save_dir}/{DATA_CONFIG_NAME}") as cfg: 125 | data_config = yaml.load(cfg, Loader=yaml.FullLoader) 126 | 127 | # Reformat gsp/site generation input data config to new format 128 | if "gsp" in data_config["input_data"]: 129 | data_config["input_data"]["generation"] = data_config["input_data"].pop("gsp") 130 | 131 | if "boundaries_version" in data_config["input_data"]["generation"]: 132 | del data_config["input_data"]["generation"]["boundaries_version"] 133 | 134 | if "site" in data_config["input_data"]: 135 | data_config["input_data"]["generation"] = data_config["input_data"].pop("site") 136 | 137 | data_config["input_data"]["generation"]["zarr_path"] = data_config["input_data"][ 138 | "generation" 139 | ].pop("file_path") 140 | 141 | if "metadata_file_path" in data_config["input_data"]["generation"]: 142 | del data_config["input_data"]["generation"]["metadata_file_path"] 143 | 144 | with open(f"{save_dir}/{DATA_CONFIG_NAME}", "w") as f: 145 | yaml.dump(data_config, f, sort_keys=False, default_flow_style=False) 146 | 147 | # Resave the model weights as safetensors if in old format 148 | if os.path.exists(f"{save_dir}/pytorch_model.bin"): 149 | state_dict = torch.load(f"{save_dir}/pytorch_model.bin", map_location="cpu", weights_only=True) 150 | save_file(state_dict, f"{save_dir}/{PYTORCH_WEIGHTS_NAME}") 151 | os.remove(f"{save_dir}/pytorch_model.bin") 152 | else: 153 | assert os.path.exists(f"{save_dir}/{PYTORCH_WEIGHTS_NAME}") 154 | 155 | # Add a note to the model card to say the model has been migrated 156 | with open(f"{save_dir}/{MODEL_CARD_NAME}", "a") as f: 157 | current_date = datetime.date.today().strftime("%Y-%m-%d") 158 | pvnet_version = version("pvnet") 159 | f.write( 160 | f"\n\n---\n**Migration Note**: This model was migrated on {current_date} " 161 | f"to pvnet version {pvnet_version}\n" 162 | ) 163 | 164 | # ------------------------------------------ 165 | # CHECKS 166 | 167 | # Check the model can be loaded 168 | model = BaseModel.from_pretrained(model_id=save_dir, revision=None) 169 | 170 | print("Model checkpoint successfully migrated") 171 | 172 | # ------------------------------------------ 173 | # UPLOAD TO HUGGINGFACE 174 | 175 | if upload: 176 | print("Uploading migrated model to huggingface") 177 | 178 | operations = [] 179 | for file in [MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATA_CONFIG_NAME]: 180 | # Stage modified files for upload 181 | operations.append( 182 | CommitOperationAdd( 183 | path_in_repo=file, # Name of the file in the repo 184 | path_or_fileobj=f"{save_dir}/{file}", # Local path to the file 185 | ), 186 | ) 187 | 188 | # Remove old pytorch weights file if it exists in the most recent commit 189 | if file_exists(repo_id, "pytorch_model.bin"): 190 | operations.append(CommitOperationDelete(path_in_repo="pytorch_model.bin")) 191 | 192 | commit_info = api.create_commit( 193 | repo_id=repo_id, 194 | operations=operations, 195 | commit_message=f"Migrate model (HF commit {revision[:7]}) to pvnet version {pvnet_version}", 196 | ) 197 | 198 | # Print the most recent commit hash 199 | c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0] 200 | 201 | print( 202 | f"\nThe latest commit is now: \n" 203 | f" date: {c.created_at} \n" 204 | f" commit hash: {c.commit_id}\n" 205 | f" by: {c.authors}\n" 206 | f" title: {c.title}\n" 207 | ) 208 | 209 | if local_dir is None: 210 | temp_dir.cleanup() 211 | -------------------------------------------------------------------------------- /configs.example/datamodule/configuration/example_configuration.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Example config for producing PVNet samples 3 | name: example_config 4 | 5 | input_data: 6 | 7 | generation: 8 | # Path to GSP data in zarr format 9 | # e.g. gs://solar-pv-nowcasting-data/PV/GSP/v7/pv_gsp.zarr 10 | zarr_path: PLACEHOLDER.zarr 11 | interval_start_minutes: -60 12 | # Specified for intraday currently 13 | interval_end_minutes: 480 14 | time_resolution_minutes: 30 15 | # Random value from the list below will be chosen as the delay when dropout is used 16 | # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. 17 | # Values after t0 are assumed as targets and cannot be dropped. 18 | dropout_timedeltas_minutes: [] 19 | dropout_fraction: 0 # Fraction of samples with dropout 20 | 21 | nwp: 22 | 23 | ecmwf: 24 | provider: ecmwf 25 | # Path to ECMWF NWP data in zarr format 26 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 27 | zarr_path: PLACEHOLDER.zarr 28 | interval_start_minutes: -60 29 | # Specified for intraday currently 30 | interval_end_minutes: 480 31 | time_resolution_minutes: 60 32 | channels: 33 | - t2m # 2-metre temperature 34 | - dswrf # downwards short-wave radiation flux 35 | - dlwrf # downwards long-wave radiation flux 36 | - hcc # high cloud cover 37 | - mcc # medium cloud cover 38 | - lcc # low cloud cover 39 | - tcc # total cloud cover 40 | - sde # snow depth water equivalent 41 | - sr # direct solar radiation 42 | - duvrs # downwards UV radiation at surface 43 | - prate # precipitation rate 44 | - u10 # 10-metre U component of wind speed 45 | - u100 # 100-metre U component of wind speed 46 | - u200 # 200-metre U component of wind speed 47 | - v10 # 10-metre V component of wind speed 48 | - v100 # 100-metre V component of wind speed 49 | - v200 # 200-metre V component of wind speed 50 | # The following channels are accumulated and need to be diffed 51 | accum_channels: 52 | - dswrf # downwards short-wave radiation flux 53 | - dlwrf # downwards long-wave radiation flux 54 | - sr # direct solar radiation 55 | - duvrs # downwards UV radiation at surface 56 | image_size_pixels_height: 24 57 | image_size_pixels_width: 24 58 | dropout_timedeltas_minutes: [-360] 59 | dropout_fraction: 1.0 # Fraction of samples with dropout 60 | max_staleness_minutes: null 61 | normalisation_constants: 62 | t2m: 63 | mean: 283.48333740234375 64 | std: 3.692270040512085 65 | dswrf: 66 | mean: 11458988.0 67 | std: 13025427.0 68 | dlwrf: 69 | mean: 27187026.0 70 | std: 15855867.0 71 | hcc: 72 | mean: 0.3961029052734375 73 | std: 0.42244860529899597 74 | mcc: 75 | mean: 0.3288780450820923 76 | std: 0.38039860129356384 77 | lcc: 78 | mean: 0.44901806116104126 79 | std: 0.3791404366493225 80 | tcc: 81 | mean: 0.7049227356910706 82 | std: 0.37487083673477173 83 | sde: 84 | mean: 8.107526082312688e-05 85 | std: 0.000913831521756947 # Mapped from "sd" in the Python file 86 | sr: 87 | mean: 12905302.0 88 | std: 16294988.0 89 | duvrs: 90 | mean: 1305651.25 91 | std: 1445635.25 92 | prate: 93 | mean: 3.108070450252853e-05 94 | std: 9.81039775069803e-05 95 | u10: 96 | mean: 1.7677178382873535 97 | std: 5.531515598297119 98 | u100: 99 | mean: 2.393547296524048 100 | std: 7.2320556640625 101 | u200: 102 | mean: 2.7963004112243652 103 | std: 8.049470901489258 104 | v10: 105 | mean: 0.985887885093689 106 | std: 5.411230564117432 107 | v100: 108 | mean: 1.4244288206100464 109 | std: 6.944501876831055 110 | v200: 111 | mean: 1.6010299921035767 112 | std: 7.561611652374268 113 | # Added diff_ keys for the channels under accum_channels: 114 | diff_dlwrf: 115 | mean: 1136464.0 116 | std: 131942.03125 117 | diff_dswrf: 118 | mean: 420584.6875 119 | std: 715366.3125 120 | diff_duvrs: 121 | mean: 48265.4765625 122 | std: 81605.25 123 | diff_sr: 124 | mean: 469169.5 125 | std: 818950.6875 126 | 127 | ukv: 128 | provider: ukv 129 | # Path to UKV NWP data in zarr format 130 | # e.g. gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_7.zarr 131 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 132 | zarr_path: PLACEHOLDER.zarr 133 | interval_start_minutes: -60 134 | # Specified for intraday currently 135 | interval_end_minutes: 480 136 | time_resolution_minutes: 60 137 | channels: 138 | - t # 2-metre temperature 139 | - dswrf # downwards short-wave radiation flux 140 | - dlwrf # downwards long-wave radiation flux 141 | - hcc # high cloud cover 142 | - mcc # medium cloud cover 143 | - lcc # low cloud cover 144 | - sde # snow depth water equivalent 145 | - r # relative humidty 146 | - vis # visibility 147 | - si10 # 10-metre wind speed 148 | - wdir10 # 10-metre wind direction 149 | - prate # precipitation rate 150 | # These variables exist in CEDA training data but not in the live MetOffice live service 151 | - hcct # height of convective cloud top, meters above surface. NaN if no clouds 152 | - cdcb # height of lowest cloud base > 3 oktas 153 | - dpt # dew point temperature 154 | - prmsl # mean sea level pressure 155 | - h # geometrical? (maybe geopotential?) height 156 | image_size_pixels_height: 24 157 | image_size_pixels_width: 24 158 | dropout_timedeltas_minutes: [-360] 159 | dropout_fraction: 1.0 # Fraction of samples with dropout 160 | max_staleness_minutes: null 161 | normalisation_constants: 162 | t: 163 | mean: 283.64913206 164 | std: 4.38818501 165 | dswrf: 166 | mean: 111.28265039 167 | std: 190.47216887 168 | dlwrf: 169 | mean: 325.03130139 170 | std: 39.45988077 171 | hcc: 172 | mean: 29.11949682 173 | std: 38.07184418 174 | mcc: 175 | mean: 40.88984494 176 | std: 41.91144559 177 | lcc: 178 | mean: 50.08362643 179 | std: 39.33210726 180 | sde: 181 | mean: 0.00289545 182 | std: 0.1029753 183 | r: 184 | mean: 81.79229501 185 | std: 11.45012499 186 | vis: 187 | mean: 32262.03285118 188 | std: 21578.97975625 189 | si10: 190 | mean: 6.88348448 191 | std: 3.94718813 192 | wdir10: 193 | mean: 199.41891636 194 | std: 94.08407495 195 | prate: 196 | mean: 3.45793433e-05 197 | std: 0.00021497 198 | hcct: 199 | mean: -18345.97478167 200 | std: 18382.63958991 201 | cdcb: 202 | mean: 1412.26599062 203 | std: 2126.99350113 204 | dpt: 205 | mean: 280.54379901 206 | std: 4.57250482 207 | prmsl: 208 | mean: 101321.61574029 209 | std: 1252.71790539 210 | h: 211 | mean: 2096.51991356 212 | std: 1075.77812282 213 | 214 | satellite: 215 | # Path to Satellite data (non-HRV) in zarr format 216 | # e.g. gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr 217 | zarr_path: PLACEHOLDER.zarr 218 | interval_start_minutes: -30 219 | interval_end_minutes: 0 220 | time_resolution_minutes: 5 221 | channels: 222 | - IR_016 # Surface, cloud phase 223 | - IR_039 # Surface, clouds, wind fields 224 | - IR_087 # Surface, clouds, atmospheric instability 225 | - IR_097 # Ozone 226 | - IR_108 # Surface, clouds, wind fields, atmospheric instability 227 | - IR_120 # Surface, clouds, atmospheric instability 228 | - IR_134 # Cirrus cloud height, atmospheric instability 229 | - VIS006 # Surface, clouds, wind fields 230 | - VIS008 # Surface, clouds, wind fields 231 | - WV_062 # Water vapor, high level clouds, upper air analysis 232 | - WV_073 # Water vapor, atmospheric instability, upper-level dynamics 233 | image_size_pixels_height: 24 234 | image_size_pixels_width: 24 235 | dropout_timedeltas_minutes: [] 236 | dropout_fraction: 0 # Fraction of samples with dropout 237 | normalisation_constants: 238 | IR_016: 239 | mean: 0.17594202 240 | std: 0.21462157 241 | IR_039: 242 | mean: 0.86167645 243 | std: 0.04618041 244 | IR_087: 245 | mean: 0.7719318 246 | std: 0.06687243 247 | IR_097: 248 | mean: 0.8014212 249 | std: 0.0468558 250 | IR_108: 251 | mean: 0.71254843 252 | std: 0.17482725 253 | IR_120: 254 | mean: 0.89058584 255 | std: 0.06115861 256 | IR_134: 257 | mean: 0.944365 258 | std: 0.04492306 259 | VIS006: 260 | mean: 0.09633306 261 | std: 0.12184761 262 | VIS008: 263 | mean: 0.11426069 264 | std: 0.13090034 265 | WV_062: 266 | mean: 0.7359355 267 | std: 0.16111417 268 | WV_073: 269 | mean: 0.62479186 270 | std: 0.12924142 271 | 272 | solar_position: 273 | interval_start_minutes: -60 274 | interval_end_minutes: 480 275 | time_resolution_minutes: 30 276 | -------------------------------------------------------------------------------- /scripts/backtest_sites.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to run backtest for PVNet for specific sites 3 | 4 | Use: 5 | 6 | - This script uses hydra to construct the config, just like in `run.py`. So you need to make sure 7 | that the data config is set up appropriate for the model being run in this script 8 | - The following variables are hard coded near the top of the script and should be changed prior to 9 | use: 10 | - number of workers to use; 11 | - the PVNet model checkpoint (either local or HuggingFace repo details); 12 | - the time range over which predictions are made; 13 | - the output directory where the results are stored; 14 | 15 | - Outputs netCDF files with the predictions for each t0 in seperate files, 16 | each file has forecasts for all sites. 17 | Time resolution of the forecast t0s is the same as the time resolution of the generation data. 18 | 19 | - WARNING: this script currently assumes that if you are running the backtest for multiple sites 20 | (generation data being used has multiple sites). 21 | that they will all have the same t0s available in generation data, 22 | if they have non overlapping periods may be best to run this multiple times with 23 | different generation files for each site, otherwise silent errors could occur. 24 | 25 | ``` 26 | python scripts/backtest_sites.py 27 | ``` 28 | 29 | """ 30 | import os 31 | 32 | import hydra 33 | import numpy as np 34 | import pandas as pd 35 | import torch 36 | import xarray as xr 37 | from ocf_data_sampler.config import load_yaml_configuration 38 | from ocf_data_sampler.load.load_dataset import get_dataset_dict 39 | from ocf_data_sampler.numpy_sample.common_types import NumpyBatch 40 | from ocf_data_sampler.torch_datasets.pvnet_dataset import PVNetConcurrentDataset 41 | from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import ( 42 | batch_to_tensor, 43 | copy_batch_to_device, 44 | ) 45 | from omegaconf import DictConfig 46 | from torch.utils.data import DataLoader 47 | from tqdm import tqdm 48 | 49 | from pvnet.load_model import get_model_from_checkpoints 50 | from pvnet.models.base_model import BaseModel as PVNetBaseModel 51 | 52 | # ------------------------------------------------------------------ 53 | # USER CONFIGURED VARIABLES TO RUN THE SCRIPT 54 | 55 | num_workers = 2 56 | 57 | # Directory path to save results 58 | output_dir: str = "example_repo" 59 | 60 | # Local directory to load the PVNet checkpoint from. By default this should pull the best performing 61 | # checkpoint on the val set, set to None if using HF 62 | model_checkpoint_dir: str | None = None 63 | 64 | 65 | # Location to download exported PVNet model on HF, set to None if using local 66 | hf_model_id: str | None = "openclimatefix/example_repo" 67 | hf_revision: str | None = "95b1658c2b771e567fb3a0379e9bd600e0b1d209" 68 | 69 | # Forecasts will be made for all available init times between these 70 | start_datetime = "2024-06-05 00:00" 71 | end_datetime = "2024-06-05 03:00" 72 | 73 | # ------------------------------------------------------------------ 74 | # DERIVED VARIABLES 75 | 76 | # This will run on GPU if it exists 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | 79 | # ------------------------------------------------------------------ 80 | # GLOBAL VARIABLES 81 | 82 | # When sun as elevation below this, the forecast is set to zero 83 | MIN_DAY_ELEVATION = 0 84 | 85 | # ------------------------------------------------------------------ 86 | # FUNCTIONS 87 | 88 | def preds_to_dataarray(preds, model, valid_times, site_ids): 89 | """Put numpy array of predictions into a dataarray""" 90 | 91 | if model.use_quantile_regression: 92 | output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] 93 | output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" 94 | else: 95 | output_labels = ["forecast_mw"] 96 | preds = preds[..., np.newaxis] 97 | da = xr.DataArray( 98 | data=preds, 99 | dims=["site_id", "target_datetime_utc", "output_label"], 100 | coords=dict( 101 | site_id = site_ids, 102 | target_datetime_utc=valid_times, 103 | output_label=output_labels, 104 | ), 105 | ) 106 | return da 107 | 108 | def get_sites_ds(config_path: str) -> xr.Dataset: 109 | """Load site data from the path in the data config. 110 | 111 | Args: 112 | config_path: Path to the data configuration file 113 | 114 | Returns: 115 | xarray.Dataset of PV sites data 116 | """ 117 | config = load_yaml_configuration(config_path) 118 | datasets_dict = get_dataset_dict(config.input_data) 119 | return datasets_dict["site"].to_dataset(name="site") 120 | 121 | 122 | class ModelPipe: 123 | """A class to conveniently make and process predictions from batches""" 124 | 125 | def __init__(self, model, ds_site: xr.Dataset, interval_start, interval_end, time_resolution): 126 | """A class to conveniently make and process predictions from batches 127 | 128 | Args: 129 | model: PVNet site level model 130 | ds_site: xarray dataset of pv site true values and capacities 131 | interval_start: The start timestamp (inclusive) for the prediction interval. 132 | interval_end: The end timestamp (exclusive) for the prediction interval. 133 | time_resolution: The time resolution (e.g., in minutes) for the prediction intervals. 134 | 135 | """ 136 | self.model = model 137 | self.ds_site = ds_site 138 | self.interval_start = interval_start 139 | self.interval_end = interval_end 140 | self.time_resolution = time_resolution 141 | 142 | def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: 143 | """Run the batch through the model and compile the predictions into an xarray DataArray 144 | 145 | Args: 146 | batch: A batch containing inputs for a site 147 | 148 | Returns: 149 | xarray.Dataset of site forecasts for the sample 150 | """ 151 | 152 | tensor_batch = batch_to_tensor(batch) 153 | # First available timestamp in the sample (this is t0 + interval_start) 154 | first_time = pd.Timestamp(tensor_batch["site_time_utc"][0][0].item()) 155 | # Compute t0 (true start of forecast) 156 | t0 = first_time - pd.Timedelta(self.interval_start) 157 | 158 | # Generate valid times for inference (only t0 to t0 + interval_end) 159 | valid_times = pd.date_range( 160 | start=t0 + pd.Timedelta(self.time_resolution.astype(int), "min"), 161 | end=t0 + pd.Timedelta(self.interval_end), 162 | freq=f"{self.time_resolution.astype(int)}min", 163 | ) 164 | # Get capacity for this site 165 | site_capacities = [float(i) for i in self.ds_site["capacity_kwp"].values] 166 | # Get solar elevation and create sundown mask 167 | elevation = (tensor_batch['solar_elevation'] - 0.5) * 180 168 | # We only need elevation mask for forecasted values, not history 169 | elevation = elevation[:, -valid_times.shape[0]:] 170 | site_ids = self.ds_site["site_id"].values 171 | 172 | da_sundown_mask = xr.DataArray( 173 | data=elevation < MIN_DAY_ELEVATION, 174 | dims=["site_id", "target_datetime_utc"], 175 | coords=dict(site_id=site_ids, 176 | target_datetime_utc=valid_times, 177 | ), 178 | ) 179 | with torch.no_grad(): 180 | # Run through model to get 0-1 predictions 181 | tensor_batch = copy_batch_to_device(tensor_batch, device) 182 | y_normed = self.model(tensor_batch).detach().cpu().numpy() 183 | 184 | da_normed = preds_to_dataarray(y_normed, self.model, valid_times, site_ids) 185 | 186 | # Multiply normalised forecasts by capacity and clip negatives 187 | # Define multipliers for each id 188 | capacity_multipliers = xr.DataArray( 189 | data=site_capacities, 190 | dims=["site_id"], 191 | coords={"site_id": site_ids} 192 | ) 193 | da_abs = da_normed.clip(0, None) * capacity_multipliers 194 | 195 | # Apply sundown mask 196 | da_abs = da_abs.where(~da_sundown_mask).fillna(0.0) 197 | da_abs = da_abs.expand_dims(dim="init_time_utc", axis=0).assign_coords( 198 | init_time_utc=np.array([t0], dtype="datetime64[ns]") 199 | ) 200 | 201 | return da_abs 202 | 203 | 204 | @hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2") 205 | def main(config: DictConfig): 206 | """Runs the backtest""" 207 | 208 | dataloader_kwargs = dict( 209 | shuffle=False, 210 | batch_size=None, 211 | num_workers=num_workers, 212 | prefetch_factor=2 if num_workers>0 else None, 213 | multiprocessing_context="spawn" if num_workers>0 else None, 214 | pin_memory=False, 215 | drop_last=False, 216 | persistent_workers=False, 217 | sampler=None, 218 | batch_sampler=None, 219 | collate_fn=None, 220 | timeout=0, 221 | worker_init_fn=None, 222 | ) 223 | 224 | # Set up output dir 225 | os.makedirs(output_dir) 226 | 227 | # load yaml file 228 | unpacked_configuration = load_yaml_configuration(config.datamodule.configuration) 229 | 230 | interval_start = np.timedelta64( 231 | unpacked_configuration.input_data.site.interval_start_minutes, "m" 232 | ) 233 | interval_end = np.timedelta64(unpacked_configuration.input_data.site.interval_end_minutes, "m") 234 | time_resolution = np.timedelta64( 235 | unpacked_configuration.input_data.site.time_resolution_minutes, "m" 236 | ) 237 | 238 | # Create dataset 239 | dataset = PVNetConcurrentDataset( 240 | config.datamodule.configuration, start_time=start_datetime, end_time=end_datetime 241 | ) 242 | 243 | # Load the site data 244 | ds_sites = get_sites_ds(config.datamodule.configuration) 245 | 246 | # Create a dataloader 247 | dataloader = DataLoader(dataset, **dataloader_kwargs) 248 | 249 | # Load the PVNet model 250 | if model_checkpoint_dir: 251 | model, *_ = get_model_from_checkpoints([model_checkpoint_dir], val_best=True) 252 | model.eval() 253 | model.to(device) 254 | elif hf_model_id: 255 | model = PVNetBaseModel.from_pretrained( 256 | model_id=hf_model_id, 257 | revision=hf_revision).to(device).eval() 258 | else: 259 | raise ValueError("Provide a model checkpoint or a HuggingFace model") 260 | 261 | # Create object to make predictions 262 | model_pipe = ModelPipe(model, ds_sites, interval_start, interval_end, time_resolution) 263 | 264 | # Loop through the batches 265 | pbar = tqdm(total=len(dataset)) 266 | for i, batch in enumerate(dataloader): 267 | try: 268 | # Make predictions 269 | ds_abs_all = model_pipe.predict_batch(batch) 270 | t0 = ds_abs_all.init_time_utc.values[0] 271 | # Save the predictions 272 | filename = f"{output_dir}/{t0}.nc" 273 | ds_abs_all.to_netcdf(filename) 274 | 275 | pbar.update() 276 | except Exception as e: 277 | print(f"Exception {e} at batch {i}") 278 | pass 279 | 280 | pbar.close() 281 | del dataloader 282 | 283 | 284 | if __name__ == "__main__": 285 | main() 286 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dask.array 4 | import hydra 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | import torch 9 | import xarray as xr 10 | from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration 11 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 12 | from omegaconf import OmegaConf 13 | 14 | from pvnet.datamodule import PVNetDataModule 15 | from pvnet.models import LateFusionModel 16 | 17 | _top_test_directory = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | 20 | uk_sat_area_string = """msg_seviri_rss_3km: 21 | description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution 22 | projection: 23 | proj: geos 24 | lon_0: 9.5 25 | h: 35785831 26 | x_0: 0 27 | y_0: 0 28 | a: 6378169 29 | rf: 295.488065897014 30 | no_defs: null 31 | type: crs 32 | shape: 33 | height: 298 34 | width: 615 35 | area_extent: 36 | lower_left_xy: [28503.830075263977, 5090183.970808983] 37 | upper_right_xy: [-1816744.1169023514, 4196063.827395439] 38 | units: m 39 | """ 40 | 41 | 42 | @pytest.fixture(scope="session") 43 | def session_tmp_path(tmp_path_factory): 44 | return tmp_path_factory.mktemp("data") 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def sat_zarr_path(session_tmp_path) -> str: 49 | variables = [ 50 | "IR_016", "IR_039", "IR_087", "IR_097", "IR_108", "IR_120", 51 | "IR_134", "VIS006", "VIS008", "WV_062", "WV_073", 52 | ] 53 | times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min") 54 | y = np.linspace(start=4191563, stop=5304712, num=100) 55 | x = np.linspace(start=15002, stop=-1824245, num=100) 56 | 57 | coords = ( 58 | ("variable", variables), 59 | ("time", times), 60 | ("y_geostationary", y), 61 | ("x_geostationary", x), 62 | ) 63 | 64 | data = dask.array.zeros( 65 | shape=tuple(len(coord_values) for _, coord_values in coords), 66 | chunks=(-1, 10, -1, -1), 67 | dtype=np.float32, 68 | ) 69 | 70 | attrs = {"area": uk_sat_area_string} 71 | 72 | ds = xr.DataArray(data=data, coords=coords, attrs=attrs).to_dataset(name="data") 73 | 74 | zarr_path = session_tmp_path / "test_sat.zarr" 75 | ds.to_zarr(zarr_path) 76 | 77 | return zarr_path 78 | 79 | 80 | @pytest.fixture(scope="session") 81 | def ukv_zarr_path(session_tmp_path) -> str: 82 | init_times = pd.date_range(start="2023-01-01 00:00", freq="180min", periods=24 * 7) 83 | variables = ["si10", "dswrf", "t", "prate"] 84 | steps = pd.timedelta_range("0h", "24h", freq="1h") 85 | x = np.linspace(-239_000, 857_000, 200) 86 | y = np.linspace(-183_000, 1425_000, 200) 87 | 88 | coords = ( 89 | ("init_time", init_times), 90 | ("variable", variables), 91 | ("step", steps), 92 | ("x", x), 93 | ("y", y), 94 | ) 95 | 96 | data = dask.array.random.uniform( 97 | low=0, 98 | high=200, 99 | size=tuple(len(coord_values) for _, coord_values in coords), 100 | chunks=(1, -1, -1, 50, 50), 101 | ).astype(np.float32) 102 | 103 | ds = xr.DataArray(data=data, coords=coords).to_dataset(name="UKV") 104 | 105 | zarr_path = session_tmp_path / "ukv_nwp.zarr" 106 | ds.to_zarr(zarr_path) 107 | return zarr_path 108 | 109 | 110 | @pytest.fixture(scope="session") 111 | def ecmwf_zarr_path(session_tmp_path) -> str: 112 | init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7) 113 | variables = ["t2m", "dswrf", "mcc"] 114 | steps = pd.timedelta_range("0h", "14h", freq="1h") 115 | lons = np.arange(-12.0, 3.0, 0.1) 116 | lats = np.arange(48.0, 65.0, 0.1) 117 | 118 | coords = ( 119 | ("init_time", init_times), 120 | ("variable", variables), 121 | ("step", steps), 122 | ("longitude", lons), 123 | ("latitude", lats), 124 | ) 125 | 126 | data = dask.array.random.uniform( 127 | low=0, 128 | high=200, 129 | size=tuple(len(coord_values) for _, coord_values in coords), 130 | chunks=(1, -1, -1, 50, 50), 131 | ).astype(np.float32) 132 | 133 | ds = xr.DataArray(data=data, coords=coords).to_dataset(name="ECMWF_UK") 134 | 135 | zarr_path = session_tmp_path / "ukv_ecmwf.zarr" 136 | ds.to_zarr(zarr_path) 137 | yield zarr_path 138 | 139 | 140 | @pytest.fixture(scope="session") 141 | def generation_zarr_path(session_tmp_path) -> str: 142 | 143 | times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min") 144 | location_ids = np.arange(318) 145 | # Rough UK bounding box 146 | lat_min, lat_max = 49.9, 58.7 147 | lon_min, lon_max = -8.6, 1.8 148 | 149 | # Generate random uniform points 150 | latitudes = np.random.uniform(lat_min, lat_max, len(location_ids)).astype("float64") 151 | longitudes = np.random.uniform(lon_min, lon_max, len(location_ids)).astype("float64") 152 | 153 | capacity = np.ones((len(times), len(location_ids))) 154 | 155 | generation = np.random.uniform(0, 200, (len(times), len(location_ids))).astype(np.float32) 156 | 157 | # Build Dataset 158 | ds_uk = xr.Dataset( 159 | data_vars={ 160 | "capacity_mwp": (("time_utc", "location_id"), capacity), 161 | "generation_mw": (("time_utc", "location_id"), generation), 162 | }, 163 | coords={ 164 | "time_utc": times, 165 | "location_id": location_ids, 166 | "latitude": ("location_id", latitudes), 167 | "longitude": ("location_id", longitudes), 168 | }, 169 | ) 170 | 171 | zarr_path = session_tmp_path / "uk_generation.zarr" 172 | ds_uk.to_zarr(zarr_path) 173 | return zarr_path 174 | 175 | 176 | @pytest.fixture(scope="session") 177 | def data_config_path( 178 | session_tmp_path, 179 | sat_zarr_path, 180 | ukv_zarr_path, 181 | ecmwf_zarr_path, 182 | generation_zarr_path 183 | ) -> str: 184 | 185 | # Populate the config with the generated zarr paths 186 | config = load_yaml_configuration(f"{_top_test_directory}/test_data/data_config.yaml") 187 | config.input_data.nwp["ukv"].zarr_path = str(ukv_zarr_path) 188 | config.input_data.nwp["ecmwf"].zarr_path = str(ecmwf_zarr_path) 189 | config.input_data.satellite.zarr_path = str(sat_zarr_path) 190 | config.input_data.generation.zarr_path = str(generation_zarr_path) 191 | 192 | filename = f"{session_tmp_path}/data_config.yaml" 193 | save_yaml_configuration(config, filename) 194 | return filename 195 | 196 | 197 | @pytest.fixture(scope="session") 198 | def streamed_datamodule(data_config_path) -> PVNetDataModule: 199 | dm = PVNetDataModule( 200 | configuration=data_config_path, 201 | batch_size=2, 202 | num_workers=0, 203 | prefetch_factor=None, 204 | ) 205 | dm.setup(stage="fit") 206 | return dm 207 | 208 | @pytest.fixture(scope="session") 209 | def batch(streamed_datamodule) -> TensorBatch: 210 | return next(iter(streamed_datamodule.train_dataloader())) 211 | 212 | 213 | @pytest.fixture(scope="session") 214 | def satellite_batch_component(batch) -> torch.Tensor: 215 | return torch.swapaxes(batch["satellite_actual"], 1, 2).float() 216 | 217 | 218 | @pytest.fixture() 219 | def model_minutes_kwargs() -> dict: 220 | return dict(forecast_minutes=480, history_minutes=60) 221 | 222 | 223 | @pytest.fixture() 224 | def encoder_model_kwargs() -> dict: 225 | # Used to test encoder model on satellite data 226 | return dict( 227 | sequence_length=7, # 30 minutes of 5 minutely satellite data = 7 time steps 228 | image_size_pixels=24, 229 | in_channels=11, 230 | out_features=128, 231 | ) 232 | 233 | 234 | @pytest.fixture() 235 | def site_encoder_model_kwargs() -> dict: 236 | """Used to test site encoder model on PV data with data sampler""" 237 | return dict( 238 | sequence_length=60 // 15 + 1, 239 | num_sites=1, 240 | out_features=128, 241 | key_to_use="generation", 242 | ) 243 | 244 | @pytest.fixture() 245 | def raw_late_fusion_model_kwargs(model_minutes_kwargs) -> dict: 246 | return dict( 247 | sat_encoder=dict( 248 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 249 | _partial_=True, 250 | in_channels=11, 251 | out_features=128, 252 | number_of_conv3d_layers=6, 253 | conv3d_channels=32, 254 | image_size_pixels=24, 255 | ), 256 | nwp_encoders_dict={ 257 | "ukv": dict( 258 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 259 | _partial_=True, 260 | in_channels=4, 261 | out_features=128, 262 | number_of_conv3d_layers=6, 263 | conv3d_channels=32, 264 | image_size_pixels=24, 265 | ), 266 | "ecmwf": dict( 267 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 268 | _partial_=True, 269 | in_channels=3, 270 | out_features=128, 271 | number_of_conv3d_layers=2, 272 | stride=[1,2,2], 273 | conv3d_channels=32, 274 | image_size_pixels=12, 275 | ), 276 | }, 277 | 278 | add_image_embedding_channel=True, 279 | output_network=dict( 280 | _target_="pvnet.models.late_fusion.linear_networks.networks.ResFCNet", 281 | _partial_=True, 282 | fc_hidden_features=128, 283 | n_res_blocks=6, 284 | res_block_layers=2, 285 | dropout_frac=0.0, 286 | ), 287 | location_id_mapping={i:i for i in range(1, 318)}, 288 | embedding_dim=16, 289 | include_sun=True, 290 | include_generation_history=True, 291 | sat_history_minutes=30, 292 | nwp_history_minutes={"ukv": 120, "ecmwf": 120}, 293 | nwp_forecast_minutes={"ukv": 480, "ecmwf": 480}, 294 | nwp_interval_minutes={"ukv": 60, "ecmwf": 60}, 295 | min_sat_delay_minutes=0, 296 | **model_minutes_kwargs, 297 | ) 298 | 299 | 300 | @pytest.fixture() 301 | def late_fusion_model_kwargs(raw_late_fusion_model_kwargs) -> dict: 302 | return hydra.utils.instantiate(raw_late_fusion_model_kwargs) 303 | 304 | 305 | @pytest.fixture() 306 | def late_fusion_model(late_fusion_model_kwargs) -> LateFusionModel: 307 | return LateFusionModel(**late_fusion_model_kwargs) 308 | 309 | 310 | @pytest.fixture() 311 | def raw_late_fusion_model_kwargs_generation_history(model_minutes_kwargs) -> dict: 312 | return dict( 313 | # Set inputs to None/False apart from generation history 314 | sat_encoder=None, 315 | nwp_encoders_dict=None, 316 | add_image_embedding_channel=False, 317 | pv_encoder=None, 318 | output_network=dict( 319 | _target_="pvnet.models.late_fusion.linear_networks.networks.ResFCNet", 320 | _partial_=True, 321 | fc_hidden_features=128, 322 | n_res_blocks=6, 323 | res_block_layers=2, 324 | dropout_frac=0.0, 325 | ), 326 | location_id_mapping=None, 327 | embedding_dim=None, 328 | include_sun=False, 329 | include_time=True, 330 | include_generation_history=True, 331 | forecast_minutes=480, 332 | history_minutes=60, 333 | interval_minutes=30, 334 | ) 335 | 336 | 337 | @pytest.fixture() 338 | def late_fusion_model_kwargs_generation_history(raw_late_fusion_model_kwargs_generation_history) -> dict: 339 | return hydra.utils.instantiate(raw_late_fusion_model_kwargs_generation_history) 340 | 341 | 342 | @pytest.fixture() 343 | def late_fusion_model_generation_history(late_fusion_model_kwargs_generation_history) -> LateFusionModel: 344 | return LateFusionModel(**late_fusion_model_kwargs_generation_history) 345 | 346 | 347 | @pytest.fixture() 348 | def late_fusion_quantile_model(late_fusion_model_kwargs) -> LateFusionModel: 349 | return LateFusionModel(output_quantiles=[0.1, 0.5, 0.9], **late_fusion_model_kwargs) 350 | 351 | 352 | @pytest.fixture 353 | def trainer_cfg(): 354 | def _make(trainer_dict): 355 | return OmegaConf.create({"trainer": trainer_dict}) 356 | return _make 357 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/encoders.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for the site-level PV data.""" 2 | 3 | import einops 4 | import torch 5 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 6 | from torch import nn 7 | 8 | from pvnet.models.late_fusion.linear_networks.networks import ResFCNet 9 | from pvnet.models.late_fusion.site_encoders.basic_blocks import AbstractSitesEncoder 10 | 11 | 12 | # TODO update this to work with the new sample data format 13 | class SimpleLearnedAggregator(AbstractSitesEncoder): 14 | """A simple model which learns a different weighted-average across all PV sites for each GSP. 15 | 16 | Each sequence from each site is independently encodeded through some dense layers wih skip- 17 | connections, then the encoded form of each sequence is aggregated through a learned weighted-sum 18 | and finally put through more dense layers. 19 | 20 | This model was written to be a simplified version of a single-headed attention layer. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | sequence_length: int, 26 | num_sites: int, 27 | out_features: int, 28 | value_dim: int = 10, 29 | value_enc_resblocks: int = 2, 30 | final_resblocks: int = 2, 31 | ): 32 | """A simple sequence encoder and weighted-average model. 33 | 34 | Args: 35 | sequence_length: The time sequence length of the data. 36 | num_sites: Number of PV sites in the input data. 37 | out_features: Number of output features. 38 | value_dim: The number of features in each encoded sequence. Similar to the value 39 | dimension in single- or multi-head attention. 40 | value_dim: The number of features in each encoded sequence. Similar to the value 41 | dimension in single- or multi-head attention. 42 | value_enc_resblocks: Number of residual blocks in the value-encoder sub-network. 43 | final_resblocks: Number of residual blocks in the final sub-network. 44 | """ 45 | 46 | super().__init__(sequence_length, num_sites, out_features) 47 | 48 | # Network used to encode each PV site sequence 49 | self._value_encoder = nn.Sequential( 50 | ResFCNet( 51 | in_features=sequence_length, 52 | out_features=value_dim, 53 | fc_hidden_features=value_dim, 54 | n_res_blocks=value_enc_resblocks, 55 | res_block_layers=2, 56 | dropout_frac=0, 57 | ), 58 | ) 59 | 60 | # The learned weighted average is stored in an embedding layer for ease of use 61 | self._attention_network = nn.Sequential( 62 | nn.Embedding(318, num_sites), 63 | nn.Softmax(dim=1), 64 | ) 65 | 66 | # Network used to process weighted average 67 | self.output_network = ResFCNet( 68 | in_features=value_dim, 69 | out_features=out_features, 70 | fc_hidden_features=value_dim, 71 | n_res_blocks=final_resblocks, 72 | res_block_layers=2, 73 | dropout_frac=0, 74 | ) 75 | 76 | def _calculate_attention(self, x: TensorBatch) -> torch.Tensor: 77 | gsp_ids = x["gsp_id"].squeeze().int() 78 | attention = self._attention_network(gsp_ids) 79 | return attention 80 | 81 | def _encode_value(self, x: TensorBatch) -> torch.Tensor: 82 | # Shape: [batch size, sequence length, PV site] 83 | pv_site_seqs = x["pv"].float() 84 | batch_size = pv_site_seqs.shape[0] 85 | 86 | pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1) 87 | 88 | x_seq_enc = self._value_encoder(pv_site_seqs) 89 | x_seq_out = x_seq_enc.unflatten(0, (batch_size, self.num_sites)) 90 | return x_seq_out 91 | 92 | def forward(self, x: TensorBatch) -> torch.Tensor: 93 | """Run model forward""" 94 | # Output has shape: [batch size, num_sites, value_dim] 95 | encodeded_seqs = self._encode_value(x) 96 | 97 | # Calculate learned averaging weights 98 | attn_avg_weights = self._calculate_attention(x) 99 | 100 | # Take weighted average across num_sites 101 | value_weighted_avg = (encodeded_seqs * attn_avg_weights.unsqueeze(-1)).sum(dim=1) 102 | 103 | # Put through final processing layers 104 | x_out = self.output_network(value_weighted_avg) 105 | 106 | return x_out 107 | 108 | 109 | class SingleAttentionNetwork(AbstractSitesEncoder): 110 | """A simple attention-based model with a single multihead attention layer 111 | 112 | For the attention layer the query is based on the target alone, the key is based on the 113 | input ID and the recent input data, the value is based on the recent input data. 114 | 115 | """ 116 | 117 | def __init__( 118 | self, 119 | sequence_length: int, 120 | num_sites: int, 121 | out_features: int, 122 | kdim: int = 10, 123 | id_embed_dim: int = 10, 124 | num_heads: int = 2, 125 | n_kv_res_blocks: int = 2, 126 | kv_res_block_layers: int = 2, 127 | use_id_in_value: bool = False, 128 | target_id_dim: int = 318, 129 | key_to_use: str = "generation", 130 | num_channels: int = 1, 131 | num_sites_in_inference: int = 1, 132 | ): 133 | """A simple attention-based model with a single multihead attention layer 134 | 135 | Args: 136 | sequence_length: The time sequence length of the data. 137 | num_sites: Number of sites in the input data. 138 | out_features: Number of output features. In this network this is also the embed and 139 | value dimension in the multi-head attention layer. 140 | kdim: The dimensions used the keys. 141 | id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s). 142 | num_heads: Number of parallel attention heads. Note that `out_features` will be split 143 | across `num_heads` so `out_features` must be a multiple of `num_heads`. 144 | n_kv_res_blocks: Number of residual blocks to use in the key and value encoders. 145 | kv_res_block_layers: Number of fully-connected layers used in each residual block within 146 | the key and value encoders. 147 | use_id_in_value: Whether to use a site ID embedding in network used to produce the 148 | value for the attention layer. 149 | target_id_dim: The number of unique IDs. 150 | key_to_use: The key to use in the attention layer. 151 | num_channels: Number of channels in the input data 152 | num_sites_in_inference: Number of sites to use in inference. 153 | This is used to determine the number of sites to use in the 154 | attention layer, for a single site, 1 works, while for multiple sites 155 | this would be higher than that 156 | 157 | """ 158 | super().__init__(sequence_length, num_sites, out_features) 159 | self.sequence_length = sequence_length 160 | self.target_id_embedding = nn.Embedding(target_id_dim, out_features) 161 | self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim) 162 | self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False) 163 | self.use_id_in_value = use_id_in_value 164 | self.key_to_use = key_to_use 165 | self.num_channels = num_channels 166 | self.num_sites_in_inference = num_sites_in_inference 167 | 168 | if use_id_in_value: 169 | self.value_id_embedding = nn.Embedding(num_sites, id_embed_dim) 170 | 171 | self._value_encoder = nn.Sequential( 172 | ResFCNet( 173 | in_features=sequence_length * self.num_channels 174 | + int(use_id_in_value) * id_embed_dim, 175 | out_features=out_features, 176 | fc_hidden_features=sequence_length * self.num_channels, 177 | n_res_blocks=n_kv_res_blocks, 178 | res_block_layers=kv_res_block_layers, 179 | dropout_frac=0, 180 | ), 181 | ) 182 | 183 | self._key_encoder = nn.Sequential( 184 | ResFCNet( 185 | in_features=id_embed_dim + sequence_length * self.num_channels, 186 | out_features=kdim, 187 | fc_hidden_features=id_embed_dim + sequence_length * self.num_channels, 188 | n_res_blocks=n_kv_res_blocks, 189 | res_block_layers=kv_res_block_layers, 190 | dropout_frac=0, 191 | ), 192 | ) 193 | 194 | self.multihead_attn = nn.MultiheadAttention( 195 | embed_dim=out_features, 196 | kdim=kdim, 197 | vdim=out_features, 198 | num_heads=num_heads, 199 | batch_first=True, 200 | ) 201 | 202 | def _encode_inputs(self, x: TensorBatch) -> tuple[torch.Tensor, int]: 203 | # Shape: [batch size, sequence length, number of sites] 204 | # Shape: [batch size, station_id, sequence length, channels] 205 | input_data = x[f"{self.key_to_use}"] 206 | if len(input_data.shape) == 2: # one site per sample 207 | input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D 208 | if len(input_data.shape) == 4: # Has multiple channels 209 | input_data = input_data[:, :, : self.sequence_length] 210 | input_data = einops.rearrange(input_data, "b id s c -> b (s c) id") 211 | else: 212 | input_data = input_data[:, : self.sequence_length] 213 | site_seqs = input_data.float() 214 | batch_size = site_seqs.shape[0] 215 | site_seqs = site_seqs.swapaxes(1, 2) # [batch size, location ID, sequence length] 216 | return site_seqs, batch_size 217 | 218 | def _encode_query(self, x: TensorBatch) -> torch.Tensor: 219 | ids = x["location_id"].int() 220 | query = self.target_id_embedding(ids).unsqueeze(1) 221 | return query 222 | 223 | def _encode_key(self, x: TensorBatch) -> torch.Tensor: 224 | site_seqs, batch_size = self._encode_inputs(x) 225 | 226 | # site ID embeddings are the same for each sample 227 | id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1)) 228 | # Each concated (site sequence, site ID embedding) is processed with encoder 229 | x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1) 230 | key = self._key_encoder(x_seq_in) 231 | 232 | # Reshape to [batch size, site, kdim] 233 | key = key.unflatten(0, (batch_size, self.num_sites)) 234 | return key 235 | 236 | def _encode_value(self, x: TensorBatch) -> torch.Tensor: 237 | site_seqs, batch_size = self._encode_inputs(x) 238 | 239 | if self.use_id_in_value: 240 | # site ID embeddings are the same for each sample 241 | id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1)) 242 | # Each concated (site sequence, site ID embedding) is processed with encoder 243 | x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1) 244 | else: 245 | # Encode each site sequence independently 246 | x_seq_in = site_seqs.flatten(0, 1) 247 | value = self._value_encoder(x_seq_in) 248 | 249 | # Reshape to [batch size, site, vdim] 250 | value = value.unflatten(0, (batch_size, self.num_sites)) 251 | return value 252 | 253 | def _attention_forward( 254 | self, x: dict, average_attn_weights: bool = True 255 | ) -> tuple[torch.Tensor, torch.Tensor :]: 256 | query = self._encode_query(x) 257 | key = self._encode_key(x) 258 | value = self._encode_value(x) 259 | attn_output, attn_weights = self.multihead_attn( 260 | query, key, value, average_attn_weights=average_attn_weights 261 | ) 262 | 263 | return attn_output, attn_weights 264 | 265 | def forward(self, x: TensorBatch) -> torch.Tensor: 266 | """Run model forward""" 267 | 268 | attn_output, _ = self._attention_forward(x) 269 | 270 | # Reshape from [batch_size, 1, vdim] to [batch_size, vdim] 271 | x_out = attn_output.squeeze() 272 | if len(x_out.shape) == 1: 273 | x_out = x_out.unsqueeze(0) 274 | 275 | return x_out 276 | -------------------------------------------------------------------------------- /pvnet/training/lightning_module.py: -------------------------------------------------------------------------------- 1 | """Pytorch lightning module for training PVNet models""" 2 | 3 | import lightning.pytorch as pl 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn.functional as F 9 | import wandb 10 | import xarray as xr 11 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 12 | from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import copy_batch_to_device 13 | 14 | from pvnet.datamodule import collate_fn 15 | from pvnet.models.base_model import BaseModel 16 | from pvnet.optimizers import AbstractOptimizer 17 | from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot 18 | from pvnet.utils import validate_batch_against_config 19 | 20 | 21 | class PVNetLightningModule(pl.LightningModule): 22 | """Lightning module for training PVNet models""" 23 | 24 | def __init__( 25 | self, 26 | model: BaseModel, 27 | optimizer: AbstractOptimizer, 28 | save_all_validation_results: bool = False, 29 | ): 30 | """Lightning module for training PVNet models 31 | 32 | Args: 33 | model: The PVNet model 34 | optimizer: Optimizer 35 | save_all_validation_results: Whether to save all the validation predictions to wandb 36 | """ 37 | super().__init__() 38 | 39 | self.model = model 40 | self._optimizer = optimizer 41 | self.save_all_validation_results = save_all_validation_results 42 | 43 | # Model must have lr to allow tuning 44 | # This setting is only used when lr is tuned with callback 45 | self.lr = None 46 | 47 | def transfer_batch_to_device( 48 | self, 49 | batch: TensorBatch, 50 | device: torch.device, 51 | dataloader_idx: int, 52 | ) -> dict: 53 | """Method to move custom batches to a given device""" 54 | return copy_batch_to_device(batch, device) 55 | 56 | def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 57 | """Calculate quantile loss. 58 | 59 | Note: 60 | Implementation copied from: 61 | https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting 62 | /metrics/quantile.html#QuantileLoss.loss 63 | 64 | Args: 65 | y_quantiles: Quantile prediction of network 66 | y: Target values 67 | 68 | Returns: 69 | Quantile loss 70 | """ 71 | losses = [] 72 | for i, q in enumerate(self.model.output_quantiles): 73 | errors = y - y_quantiles[..., i] 74 | losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) 75 | losses = 2 * torch.cat(losses, dim=2) 76 | 77 | return losses.mean() 78 | 79 | def configure_optimizers(self): 80 | """Configure the optimizers using learning rate found with LR finder if used""" 81 | if self.lr is not None: 82 | # Use learning rate found by learning rate finder callback 83 | self._optimizer.lr = self.lr 84 | return self._optimizer(self.model) 85 | 86 | def _calculate_common_losses( 87 | self, 88 | y: torch.Tensor, 89 | y_hat: torch.Tensor, 90 | ) -> dict[str, torch.Tensor]: 91 | """Calculate losses common to train, and val""" 92 | 93 | losses = {} 94 | 95 | if self.model.use_quantile_regression: 96 | losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y) 97 | y_hat = self.model._quantiles_to_prediction(y_hat) 98 | 99 | losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)}) 100 | 101 | return losses 102 | 103 | def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor: 104 | """Run training step""" 105 | y_hat = self.model(batch) 106 | 107 | y = batch["generation"][:, -self.model.forecast_len :] 108 | 109 | losses = self._calculate_common_losses(y, y_hat) 110 | losses = {f"{k}/train": v for k, v in losses.items()} 111 | 112 | self.log_dict(losses, on_step=True, on_epoch=True) 113 | 114 | if self.model.use_quantile_regression: 115 | opt_target = losses["quantile_loss/train"] 116 | else: 117 | opt_target = losses["MAE/train"] 118 | return opt_target 119 | 120 | def _calculate_val_losses( 121 | self, 122 | y: torch.Tensor, 123 | y_hat: torch.Tensor, 124 | ) -> dict[str, torch.Tensor]: 125 | """Calculate additional losses only run in validation""" 126 | 127 | losses = {} 128 | 129 | if self.model.use_quantile_regression: 130 | metric_name = "val_fraction_below/fraction_below_{:.2f}_quantile" 131 | # Add fraction below each quantile for calibration 132 | for i, quantile in enumerate(self.model.output_quantiles): 133 | below_quant = y <= y_hat[..., i] 134 | # Mask values small values, which are dominated by night 135 | mask = y >= 0.01 136 | losses[metric_name.format(quantile)] = below_quant[mask].float().mean() 137 | 138 | return losses 139 | 140 | def _calculate_step_metrics( 141 | self, 142 | y: torch.Tensor, 143 | y_hat: torch.Tensor, 144 | ) -> tuple[np.array, np.array]: 145 | """Calculate the MAE and MSE at each forecast step""" 146 | 147 | mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy() 148 | mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy() 149 | 150 | return mae_each_step, mse_each_step 151 | 152 | def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None: 153 | """Internally store the validation predictions""" 154 | 155 | y = batch["generation"][:, -self.model.forecast_len :].cpu().numpy() 156 | y_hat = y_hat.cpu().numpy() 157 | ids = batch["location_id"].cpu().numpy() 158 | init_times_utc = pd.to_datetime( 159 | batch["time_utc"][:, self.model.history_len + 1].cpu().numpy().astype("datetime64[ns]") 160 | ) 161 | 162 | if self.model.use_quantile_regression: 163 | p_levels = self.model.output_quantiles 164 | else: 165 | p_levels = [0.5] 166 | y_hat = y_hat[..., None] 167 | 168 | ds_preds_batch = xr.Dataset( 169 | data_vars=dict( 170 | y_hat=(["sample_num", "forecast_step", "p_level"], y_hat), 171 | y=(["sample_num", "forecast_step"], y), 172 | ), 173 | coords=dict( 174 | ids=("sample_num", ids), 175 | init_times_utc=("sample_num", init_times_utc), 176 | p_level=p_levels, 177 | ), 178 | ) 179 | self.all_val_results.append(ds_preds_batch) 180 | 181 | def on_validation_epoch_start(self): 182 | """Run at start of val period""" 183 | # Set up stores which we will fill during validation 184 | self.all_val_results: list[xr.Dataset] = [] 185 | self._val_horizon_maes: list[np.array] = [] 186 | if self.current_epoch == 0: 187 | self._val_persistence_horizon_maes: list[np.array] = [] 188 | 189 | # Plot some sample forecasts 190 | val_dataset = self.trainer.val_dataloaders.dataset 191 | 192 | plots_per_figure = 16 193 | num_figures = 2 194 | 195 | for plot_num in range(num_figures): 196 | idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure 197 | idxs = idxs[idxs < len(val_dataset)] 198 | 199 | if len(idxs) == 0: 200 | continue 201 | 202 | batch = collate_fn([val_dataset[i] for i in idxs]) 203 | batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0) 204 | 205 | # Batch validation check only during sanity check phase - use first batch 206 | if self.trainer.sanity_checking and plot_num == 0: 207 | validate_batch_against_config(batch=batch, model=self.model) 208 | 209 | with torch.no_grad(): 210 | y_hat = self.model(batch) 211 | 212 | fig = plot_sample_forecasts( 213 | batch, 214 | y_hat, 215 | quantiles=self.model.output_quantiles, 216 | key_to_plot="generation", 217 | ) 218 | 219 | plot_name = f"val_forecast_samples/sample_set_{plot_num}" 220 | 221 | # Disabled for testing or using no logger 222 | if self.logger: 223 | self.logger.experiment.log({plot_name: wandb.Image(fig)}) 224 | 225 | plt.close(fig) 226 | 227 | def validation_step(self, batch: TensorBatch, batch_idx: int) -> None: 228 | """Run validation step""" 229 | 230 | y_hat = self.model(batch) 231 | 232 | # Internally store the val predictions 233 | self._store_val_predictions(batch, y_hat) 234 | 235 | y = batch["generation"][:, -self.model.forecast_len :] 236 | 237 | losses = self._calculate_common_losses(y, y_hat) 238 | losses = {f"{k}/val": v for k, v in losses.items()} 239 | 240 | losses.update(self._calculate_val_losses(y, y_hat)) 241 | 242 | # Calculate the horizon MAE/MSE metrics 243 | if self.model.use_quantile_regression: 244 | y_hat_mid = self.model._quantiles_to_prediction(y_hat) 245 | else: 246 | y_hat_mid = y_hat 247 | 248 | mae_step, mse_step = self._calculate_step_metrics(y, y_hat_mid) 249 | 250 | # Store to make horizon-MAE plot 251 | self._val_horizon_maes.append(mae_step) 252 | 253 | # Also add each step to logged metrics 254 | losses.update({f"val_step_MAE/step_{i:03}": m for i, m in enumerate(mae_step)}) 255 | losses.update({f"val_step_MSE/step_{i:03}": m for i, m in enumerate(mse_step)}) 256 | 257 | # Calculate the persistance losses - we only need to do this once per training run 258 | # not every epoch 259 | if self.current_epoch == 0: 260 | y_persist = ( 261 | batch["generation"][:, -(self.model.forecast_len + 1)] 262 | .unsqueeze(1) 263 | .expand(-1, self.model.forecast_len) 264 | ) 265 | mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist) 266 | self._val_persistence_horizon_maes.append(mae_step_persist) 267 | losses.update( 268 | { 269 | "MAE/val_persistence": mae_step_persist.mean(), 270 | "MSE/val_persistence": mse_step_persist.mean(), 271 | } 272 | ) 273 | 274 | # Log the metrics 275 | self.log_dict(losses, on_step=False, on_epoch=True) 276 | 277 | def on_validation_epoch_end(self) -> None: 278 | """Run on epoch end""" 279 | 280 | ds_val_results = xr.concat(self.all_val_results, dim="sample_num") 281 | self.all_val_results = [] 282 | 283 | val_horizon_maes = np.mean(self._val_horizon_maes, axis=0) 284 | self._val_horizon_maes = [] 285 | 286 | # We only run this on the first epoch 287 | if self.current_epoch == 0: 288 | val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0) 289 | self._val_persistence_horizon_maes = [] 290 | 291 | if isinstance(self.logger, pl.loggers.WandbLogger): 292 | # Calculate and log extreme error metrics 293 | val_error = ds_val_results["y"] - ds_val_results["y_hat"].sel(p_level=0.5) 294 | 295 | # Factor out this part of the string for brevity below 296 | s = "error_extremes/{}_percentile_median_forecast_error" 297 | s_abs = "error_extremes/{}_percentile_median_forecast_absolute_error" 298 | 299 | extreme_error_metrics = { 300 | s.format("2nd"): val_error.quantile(0.02).item(), 301 | s.format("5th"): val_error.quantile(0.05).item(), 302 | s.format("95th"): val_error.quantile(0.95).item(), 303 | s.format("98th"): val_error.quantile(0.98).item(), 304 | s_abs.format("95th"): np.abs(val_error).quantile(0.95).item(), 305 | s_abs.format("98th"): np.abs(val_error).quantile(0.98).item(), 306 | } 307 | 308 | self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True) 309 | 310 | # Optionally save all validation results - these are overridden each epoch 311 | if self.save_all_validation_results: 312 | # Add attributes 313 | ds_val_results.attrs["epoch"] = self.current_epoch 314 | 315 | # Save locally to the wandb output dir 316 | wandb_log_dir = self.logger.experiment.dir 317 | filepath = f"{wandb_log_dir}/validation_results.netcdf" 318 | ds_val_results.to_netcdf(filepath) 319 | 320 | # Uplodad to wandb 321 | self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now") 322 | 323 | # Create the horizon accuracy curve 324 | horizon_mae_plot = wandb_line_plot( 325 | x=np.arange(self.model.forecast_len), 326 | y=val_horizon_maes, 327 | xlabel="Horizon step", 328 | ylabel="MAE", 329 | title="Val horizon loss curve", 330 | ) 331 | 332 | wandb.log({"val_horizon_mae_plot": horizon_mae_plot}) 333 | 334 | # Create persistence horizon accuracy curve but only on first epoch 335 | if self.current_epoch == 0: 336 | persist_horizon_mae_plot = wandb_line_plot( 337 | x=np.arange(self.model.forecast_len), 338 | y=val_persistence_horizon_maes, 339 | xlabel="Horizon step", 340 | ylabel="MAE", 341 | title="Val persistence horizon loss curve", 342 | ) 343 | wandb.log({"persistence_val_horizon_mae_plot": persist_horizon_mae_plot}) 344 | -------------------------------------------------------------------------------- /scripts/backtest_uk_gsp.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to run backtest for PVNet and the summation model for UK regional and national 3 | 4 | Use: 5 | 6 | - This script uses exported PVNet and PVNet summation models stored either locally or on huggingface 7 | - The save directory, model paths, the backtest time range, the input data paths, and number of 8 | workers used are near the top of the script as hard-coded user variables. These should be changed. 9 | 10 | 11 | ``` 12 | python backtest_uk_gsp.py 13 | ``` 14 | 15 | """ 16 | 17 | import logging 18 | import os 19 | import shutil 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import torch 24 | import xarray as xr 25 | import yaml 26 | from ocf_data_sampler.torch_datasets.utils.torch_batch_utils import ( 27 | batch_to_tensor, 28 | copy_batch_to_device, 29 | ) 30 | from pvnet_summation.data.datamodule import StreamedDataset 31 | from pvnet_summation.models.base_model import BaseModel as SummationBaseModel 32 | from torch.utils.data import DataLoader 33 | from tqdm import tqdm 34 | 35 | from pvnet.models.base_model import BaseModel as PVNetBaseModel 36 | 37 | # ------------------------------------------------------------------ 38 | # USER CONFIGURED VARIABLES 39 | 40 | # This dir is used as a working space during the backtest and also the final results will be saved 41 | # under the path {output_dir}.zarr 42 | output_dir = "/home/james/tmp/test_backtest/pvnet_v2" 43 | 44 | # Number of workers to use in the dataloader 45 | num_workers = 16 46 | 47 | # Location of the exported PVNet and summation model pair 48 | pvnet_model_name: str = "openclimatefix-models/pvnet_uk_region" 49 | pvnet_model_version: str | None = "ff09e4aee871fe094d3a2dabe9d9cea50e4b5485" 50 | 51 | # If set to None, no national forecast is made 52 | summation_model_name: str | None = "openclimatefix-models/pvnet_v2_summation" 53 | summation_model_version: str | None = "d746683893330fe3380e57e65d40812daa343c8e" 54 | 55 | # Forecasts will be made for all available init-times between these. If set to None, predictions 56 | # will be made for all init-times avaiable accordng to the input data 57 | start_datetime: str | None = "2022-01-01 00:00" 58 | end_datetime: str | None = "2022-12-31 23:30" 59 | 60 | # The paths to the input data for the backtest 61 | backtest_paths = { 62 | "gsp": "/mnt/storage_u2_30tb_a/ml_training_zarrs/pv/pvlive_gsp_new_boundaries_2019-2025.zarr", 63 | "nwp": { 64 | "ukv": [ 65 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ukv_v7/UKV_intermediate_version_7.1.zarr", 66 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ukv_v7/UKV_2021_missing.zarr", 67 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ukv_v7/UKV_2022.zarr", 68 | ], 69 | "ecmwf": [ 70 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ecmwf_v3/ECMWF_2019.zarr", 71 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ecmwf_v3/ECMWF_2020.zarr", 72 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ecmwf_v3/ECMWF_2021.zarr", 73 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/nwp/ecmwf_v3/ECMWF_2022.zarr", 74 | ], 75 | "cloudcasting": "/mnt/raphael/fast/cloudcasting/simvp.zarr", 76 | }, 77 | "satellite": [ 78 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/sat/uk_sat_crops/v1/2019_nonhrv.zarr", 79 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/sat/uk_sat_crops/v1/2020_nonhrv.zarr", 80 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/sat/uk_sat_crops/v1/2021_nonhrv.zarr", 81 | "/mnt/storage_u2_30tb_a/ml_training_zarrs/sat/uk_sat_crops/v1/2022_nonhrv.zarr", 82 | ], 83 | } 84 | 85 | # When sun as elevation below this, the forecast is set to zero 86 | MIN_DAY_ELEVATION = 0 87 | 88 | # ------------------------------------------------------------------ 89 | 90 | logger = logging.getLogger(__name__) 91 | 92 | # This will run on GPU if it exists 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | 95 | # ------------------------------------------------------------------ 96 | # FUNCTIONS 97 | 98 | _model_mismatch_msg = ( 99 | "The PVNet version running in this app is {}/{}. The summation model running in this app was " 100 | "trained on outputs from PVNet version {}/{}. Combining these models may lead to an error if " 101 | "the shape of PVNet output doesn't match the expected shape of the summation model. Combining " 102 | "may lead to unreliable results even if the shapes match." 103 | ) 104 | 105 | def populate_config_with_data_data_filepaths(config: dict) -> dict: 106 | """Populate the data source filepaths in the config 107 | 108 | Args: 109 | config: The data config 110 | """ 111 | 112 | # Replace the GSP data path 113 | config["input_data"]["gsp"]["zarr_path"] = backtest_paths["gsp"] 114 | 115 | # Replace satellite data path if using it 116 | if "satellite" in config["input_data"]: 117 | if config["input_data"]["satellite"]["zarr_path"] != "": 118 | config["input_data"]["satellite"]["zarr_path"] = backtest_paths["satellite"] 119 | 120 | # NWP is nested so much be treated separately 121 | if "nwp" in config["input_data"]: 122 | nwp_config = config["input_data"]["nwp"] 123 | for nwp_source in nwp_config.keys(): 124 | provider = nwp_config[nwp_source]["provider"] 125 | assert provider in backtest_paths["nwp"], f"Missing NWP path: {provider}" 126 | nwp_config[nwp_source]["zarr_path"] = backtest_paths["nwp"][provider] 127 | 128 | return config 129 | 130 | 131 | def overwrite_config_dropouts(config: dict) -> dict: 132 | """Overwrite the config drouput parameters for the backtest 133 | 134 | Args: 135 | config: The data config 136 | """ 137 | if "satellite" in config["input_data"]: 138 | 139 | satellite_config = config["input_data"]["satellite"] 140 | 141 | if satellite_config["zarr_path"] != "": 142 | satellite_config["dropout_timedeltas_minutes"] = [] 143 | satellite_config["dropout_fraction"] = 0 144 | 145 | # Don't modify NWP dropout since this accounts for the expected NWP delay 146 | 147 | return config 148 | 149 | 150 | class BacktestStreamedDataset(StreamedDataset): 151 | """A torch dataset object used only for backtesting""" 152 | 153 | def _get_sample(self, t0: pd.Timestamp) -> ...: 154 | """Generate a concurrent PVNet sample for given init-time + augment for backtesting. 155 | 156 | Args: 157 | t0: init-time for sample 158 | """ 159 | 160 | sample = super()._get_sample(t0) 161 | 162 | total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item() 163 | 164 | sample.update( 165 | { 166 | "backtest_t0": t0, 167 | "backtest_national_capacity": total_capacity, 168 | } 169 | ) 170 | 171 | return sample 172 | 173 | 174 | class Forecaster: 175 | """Class for making and solar forecasts for all GB GSPs and national total""" 176 | 177 | def __init__(self): 178 | """Class for making and solar forecasts for all GB GSPs and national total 179 | """ 180 | 181 | # Load the GSP-level model 182 | self.model = PVNetBaseModel.from_pretrained( 183 | model_id=pvnet_model_name, 184 | revision=pvnet_model_version, 185 | ).to(device).eval() 186 | 187 | # Load the summation model 188 | if summation_model_name is not None: 189 | self.sum_model = SummationBaseModel.from_pretrained( 190 | model_id=summation_model_name, 191 | revision=summation_model_version, 192 | ).to(device).eval() 193 | 194 | # Compare the current GSP model with the one the summation model was trained on 195 | datamodule_path = SummationBaseModel.get_datamodule_config( 196 | model_id=summation_model_name, 197 | revision=summation_model_version, 198 | ) 199 | with open(datamodule_path) as cfg: 200 | sum_pvnet_cfg = yaml.load(cfg, Loader=yaml.FullLoader)["pvnet_model"] 201 | 202 | sum_expected_gsp_model = (sum_pvnet_cfg["model_id"], sum_pvnet_cfg["revision"]) 203 | this_gsp_model = (pvnet_model_name, pvnet_model_version) 204 | 205 | if sum_expected_gsp_model != this_gsp_model: 206 | logger.warning(_model_mismatch_msg.format(*this_gsp_model, *sum_expected_gsp_model)) 207 | 208 | # These are the steps this forecast will predict for 209 | self.steps = pd.timedelta_range( 210 | start="30min", 211 | freq="30min", 212 | periods=self.model.forecast_len, 213 | ) 214 | 215 | @torch.inference_mode() 216 | def predict(self, sample: dict) -> xr.Dataset: 217 | """Make predictions for the batch and store results internally""" 218 | 219 | x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device) 220 | 221 | # Run batch through model 222 | normed_preds = self.model(x).detach().cpu().numpy() 223 | 224 | # Calculate sun mask 225 | # The dataloader normalises solar elevation data to the range [0, 1] 226 | elevation_degrees = (sample["pvnet_inputs"]["solar_elevation"] - 0.5) * 180 227 | # We only need elevation mask for forecasted values, not history 228 | elevation_degrees = elevation_degrees[:, -normed_preds.shape[1]:] 229 | sun_down_masks = elevation_degrees < MIN_DAY_ELEVATION 230 | 231 | # Convert GSP results to xarray DataArray 232 | t0 = sample["backtest_t0"] 233 | gsp_ids = sample["pvnet_inputs"]["gsp_id"] 234 | 235 | da_normed = self.to_dataarray( 236 | normed_preds, 237 | t0, 238 | gsp_ids, 239 | self.model.output_quantiles, 240 | ) 241 | 242 | da_sundown_mask = self.to_dataarray(sun_down_masks, t0, gsp_ids, None) 243 | 244 | # Multiply normalised forecasts by capacities and clip negatives 245 | da_abs = ( 246 | da_normed.clip(0, None) 247 | * sample["pvnet_inputs"]["gsp_effective_capacity_mwp"][None, :, None, None].numpy() 248 | ) 249 | 250 | # Apply sundown mask 251 | da_abs = da_abs.where(~da_sundown_mask).fillna(0.0) 252 | 253 | if summation_model_name is not None: 254 | # Make national predictions using summation model 255 | # - Need to add batch dimension and convert to torch tensors on device 256 | sample["pvnet_outputs"] = torch.tensor(normed_preds[None]).to(device) 257 | for k in ["relative_capacity", "azimuth", "elevation"]: 258 | sample[k] = sample[k][None].to(device) 259 | normed_national = self.sum_model(sample).detach().squeeze().cpu().numpy() 260 | 261 | # Convert national predictions to DataArray 262 | da_normed_national = self.to_dataarray( 263 | normed_national[np.newaxis], 264 | t0, 265 | gsp_ids=[0], 266 | output_quantiles=self.sum_model.output_quantiles, 267 | ) 268 | 269 | # Multiply normalised forecasts by capacity and clip negatives 270 | national_capacity = sample["backtest_national_capacity"] 271 | da_abs_national = da_normed_national.clip(0, None) * national_capacity 272 | 273 | # Apply sundown mask - All GSPs must be masked to mask national 274 | da_abs_national = da_abs_national.where(~da_sundown_mask.all(dim="gsp_id")).fillna(0.0) 275 | 276 | # Convert to Dataset and add attrs about the models used 277 | da_abs = xr.concat([da_abs_national, da_abs], dim="gsp_id") 278 | 279 | da_abs = da_abs.to_dataset(name="hindcast") 280 | da_abs.attrs.update( 281 | { 282 | "pvnet_model_name": pvnet_model_name, 283 | "pvnet_model_version": pvnet_model_version or "none", 284 | "summation_model_name": summation_model_name or "none", 285 | "summation_model_version": summation_model_version or "none", 286 | } 287 | ) 288 | 289 | return da_abs 290 | 291 | def to_dataarray( 292 | self, 293 | preds: np.ndarray, 294 | t0: pd.Timestamp, 295 | gsp_ids: list[int], 296 | output_quantiles: list[float] | None, 297 | ) -> xr.DataArray: 298 | """Put numpy array of predictions into a dataarray""" 299 | 300 | dims = ["init_time_utc", "gsp_id", "step"] 301 | coords = dict( 302 | init_time_utc=[t0], 303 | gsp_id=gsp_ids, 304 | step=self.steps, 305 | ) 306 | 307 | if output_quantiles is not None: 308 | dims.append("quantile") 309 | coords["quantile"] = output_quantiles 310 | 311 | return xr.DataArray(data=preds[np.newaxis, ...], dims=dims, coords=coords) 312 | 313 | # ------------------------------------------------------------------ 314 | # RUN 315 | 316 | if __name__=="__main__": 317 | 318 | # Set up output dir 319 | os.makedirs(output_dir, exist_ok=False) 320 | 321 | data_config_path = PVNetBaseModel.get_data_config( 322 | model_id=pvnet_model_name, 323 | revision=pvnet_model_version, 324 | ) 325 | 326 | with open(data_config_path) as file: 327 | data_config = yaml.load(file, Loader=yaml.FullLoader) 328 | 329 | data_config = populate_config_with_data_data_filepaths(data_config) 330 | data_config = overwrite_config_dropouts(data_config) 331 | 332 | modified_data_config_filepath = f"{output_dir}/data_config.yaml" 333 | 334 | with open(modified_data_config_filepath, "w") as file: 335 | yaml.dump(data_config, file, default_flow_style=False) 336 | 337 | 338 | dataset = BacktestStreamedDataset( 339 | config_filename=modified_data_config_filepath, 340 | start_time=start_datetime, 341 | end_time=end_datetime, 342 | ) 343 | 344 | dataloader_kwargs = dict( 345 | num_workers=num_workers, 346 | prefetch_factor=2 if num_workers>0 else None, 347 | multiprocessing_context="spawn" if num_workers>0 else None, 348 | shuffle=False, 349 | batch_size=None, 350 | sampler=None, 351 | batch_sampler=None, 352 | collate_fn=None, 353 | drop_last=False, 354 | timeout=0, 355 | worker_init_fn=None, 356 | persistent_workers=False, 357 | ) 358 | 359 | if num_workers>0: 360 | dataset.presave_pickle(f"{output_dir}/dataset.pkl") 361 | 362 | dataloader = DataLoader(dataset, **dataloader_kwargs) 363 | forecaster = Forecaster() 364 | 365 | # Loop through the batches 366 | pbar = tqdm(total=len(dataloader)) 367 | for sample in dataloader: 368 | # Make predictions for the init-time 369 | ds_abs_all = forecaster.predict(sample) 370 | 371 | # Save the predictions 372 | t0 = pd.Timestamp(ds_abs_all.init_time_utc.item()) 373 | filename = f"{output_dir}/{t0}.nc" 374 | ds_abs_all.to_netcdf(filename) 375 | 376 | pbar.update() 377 | 378 | # Close down 379 | pbar.close() 380 | 381 | # Clean up 382 | if num_workers>0: 383 | os.remove(f"{output_dir}/dataset.pkl") 384 | 385 | # Reload all the forecasts and resave as single zarr 386 | ds_all_forecast = xr.open_mfdataset(f"{output_dir}/*.nc", parallel=True).compute() 387 | ds_all_forecast = ds_all_forecast.chunk({"init_time_utc": 32}) 388 | ds_all_forecast.to_zarr(f"{output_dir}.zarr") 389 | 390 | # Remove the intermediate results 391 | shutil.rmtree(output_dir) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PVNet 2 | 3 | [![All Contributors](https://img.shields.io/badge/all_contributors-21-orange.svg?style=flat-square)](#contributors-) 4 | 5 | 6 | [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/PVNet?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/PVNet/tags) 7 | [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories) 8 | 9 | 10 | This project is used for training PVNet and running PVNet on live data. 11 | 12 | PVNet is a multi-modal late-fusion model for predicting renewable energy generation from weather 13 | data. The NWP (Numerical Weather Prediction) and satellite data are sent through a neural network 14 | which encodes them down to 1D intermediate representations. These are concatenated together with 15 | recent generation, the calculated solar coordinates (azimuth and elevation) and the location ID 16 | which has been put through an embedding layer. This 1D concatenated feature vector is put through 17 | an output network which outputs predictions of the future energy yield. 18 | 19 | 20 | ## Experiments 21 | 22 | Our paper based on this repo was accepted into the Tackling Climate Change with Machine Learning 23 | workshop at ICLR 2024 and can be viewed [here](https://www.climatechange.ai/papers/iclr2024/46). 24 | 25 | Some more structured notes on experiments we have performed with PVNet are 26 | [here](https://docs.google.com/document/d/1VumDwWd8YAfvXbOtJEv3ZJm_FHQDzrKXR0jU9vnvGQg). 27 | 28 | 29 | ## Setup / Installation 30 | 31 | ```bash 32 | git clone git@github.com:openclimatefix/PVNet.git 33 | cd PVNet 34 | pip install . 35 | ``` 36 | 37 | The commit history is extensive. To save download time, use a depth of 1: 38 | ```bash 39 | git clone --depth 1 git@github.com:openclimatefix/PVNet.git 40 | ``` 41 | This means only the latest commit and its associated files will be downloaded. 42 | 43 | Next, in the PVNet repo, install PVNet as an editable package: 44 | 45 | ```bash 46 | pip install -e . 47 | ``` 48 | 49 | ### Additional development dependencies 50 | 51 | ```bash 52 | pip install ".[dev]" 53 | ``` 54 | 55 | 56 | 57 | ## Getting started with running PVNet 58 | 59 | Before running any code in PVNet, copy the example configuration to a 60 | configs directory: 61 | 62 | ``` 63 | cp -r configs.example configs 64 | ``` 65 | 66 | You will be making local amendments to these configs. See the README in 67 | `configs.example` for more info. 68 | 69 | ### Datasets 70 | 71 | As a minimum, in order to create samples of data/run PVNet, you will need to 72 | supply paths to NWP and GSP data. PV data can also be used. We list some 73 | suggested locations for downloading such datasets below: 74 | 75 | **GSP (Grid Supply Point)** - Regional PV generation data\ 76 | The University of Sheffield provides API access to download this data: 77 | https://www.solar.sheffield.ac.uk/api/ 78 | 79 | Documentation for querying generation data aggregated by GSP region can be found 80 | here: 81 | https://docs.google.com/document/d/e/2PACX-1vSDFb-6dJ2kIFZnsl-pBQvcH4inNQCA4lYL9cwo80bEHQeTK8fONLOgDf6Wm4ze_fxonqK3EVBVoAIz/pub#h.9d97iox3wzmd 82 | 83 | **NWP (Numerical weather predictions)**\ 84 | OCF maintains a Zarr formatted version of the German Weather Service's (DWD) 85 | ICON-EU NWP model here: 86 | https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK 87 | 88 | **PV**\ 89 | OCF maintains a dataset of PV generation from 1311 private PV installations 90 | here: https://huggingface.co/datasets/openclimatefix/uk_pv 91 | 92 | 93 | ### Connecting with ocf-data-sampler for sample creation 94 | 95 | Outside the PVNet repo, clone the ocf-data-sampler repo and exit the conda env created for PVNet: https://github.com/openclimatefix/ocf-data-sampler 96 | ```bash 97 | git clone git@github.com/openclimatefix/ocf-data-sampler.git 98 | conda create -n ocf-data-sampler python=3.11 99 | ``` 100 | 101 | Then go inside the ocf-data-sampler repo to add packages 102 | 103 | ```bash 104 | pip install . 105 | ``` 106 | 107 | Then exit this environment, and enter back into the pvnet conda environment and install ocf-data-sampler in editable mode (-e). This means the package is directly linked to the source code in the ocf-data-sampler repo. 108 | 109 | ```bash 110 | pip install -e 111 | ``` 112 | 113 | If you install the local version of `ocf-data-sampler` that is more recent than the version 114 | specified in `PVNet` it is not guarenteed to function properly with this library. 115 | 116 | 117 | ### Set up and config example for streaming 118 | 119 | We will use the following example config file to describe your data sources: `/PVNet/configs/datamodule/configuration/example_configuration.yaml`. Ensure that the file paths are set to the correct locations in `example_configuration.yaml`: search for `PLACEHOLDER` to find where to input the location of the files. Delete or comment the parts for data you are not using. 120 | 121 | At run time, the datamodule config `PVNet/configs/datamodule/streamed_samples.yaml` points to your chosen configuration file: 122 | 123 | configuration: "/FULL-PATH-TO-REPO/PVNet/configs/datamodule/configuration/example_configuration.yaml" 124 | 125 | You can also update train/val/test time ranges here to match the period you have access to. 126 | 127 | If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): 128 | 129 | gcloud auth login 130 | 131 | You can provide multiple storage locations as a list. For example: 132 | 133 | satellite: 134 | zarr_path: 135 | - "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr" 136 | - "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/2021_nonhrv.zarr" 137 | 138 | `ocf-data-sampler` is currently set up to use 11 channels from the satellite data (the 12th, HRV, is not used). 139 | 140 | ⚠️ NB: Our publicly accessible satellite data is currently saved with a blosc2 compressor, which is not supported by the tensorstore backend PVNet relies on now. We are in the process of updating this; for now, the paths above cannot be used with this codebase. 141 | 142 | ### Training PVNet 143 | 144 | How PVNet is run is determined by the configuration files. The example configs in `PVNet/configs.example` work with **streamed_samples** using `datamodule/streamed_samples.yaml`. 145 | 146 | Update the following before training: 147 | 148 | 1. In `configs/model/late_fusion.yaml`: 149 | - Update the list of encoders to match the data sources you are using. For different NWP sources, keep the same structure but ensure: 150 | - `in_channels`: the number of variables your NWP source supplies 151 | - `image_size_pixels`: spatial crop matching your NWP resolution and the settings in your datamodule configuration (unless you coarsened, e.g. for ECMWF) 152 | 2. In `configs/trainer/default.yaml`: 153 | - Set `accelerator: 0` if running on a system without a supported GPU 154 | 3. In `configs/datamodule/streamed_samples.yaml`: 155 | - Point `configuration:` to your local `example_configuration.yaml` (or your custom one) 156 | - Adjust the train/val/test time ranges to your available data 157 | 158 | If you create custom config files, update the main `./configs/config.yaml` defaults: 159 | 160 | defaults: 161 | - trainer: default.yaml 162 | - model: late_fusion.yaml 163 | - datamodule: streamed_samples.yaml 164 | - callbacks: null 165 | - experiment: null 166 | - hparams_search: null 167 | - hydra: default.yaml 168 | 169 | Now train PVNet: 170 | 171 | python run.py 172 | 173 | You can override any setting with Hydra, e.g.: 174 | 175 | python run.py datamodule=streamed_samples datamodule.configuration="/FULL-PATH/PVNet/configs/datamodule/configuration/example_configuration.yaml" 176 | 177 | ## Backtest 178 | 179 | If you have successfully trained a PVNet model and have a saved model checkpoint you can create a backtest using this, e.g. forecasts on historical data to evaluate forecast accuracy/skill. This can be done by running one of the scripts in this repo such as [the UK GSP backtest script](scripts/backtest_uk_gsp.py) or the [the pv site backtest script](scripts/backtest_sites.py), further info on how to run these are in each backtest file. 180 | 181 | ## Testing 182 | 183 | You can use `python -m pytest tests` to run tests 184 | 185 | ## Contributors ✨ 186 | 187 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 |
Felix
Felix

💻
Sukhil Patel
Sukhil Patel

💻
James Fulton
James Fulton

💻
Alexandra Udaltsova
Alexandra Udaltsova

💻 👀
Megawattz
Megawattz

💻
Peter Dudfield
Peter Dudfield

💻
Mahdi Lamb
Mahdi Lamb

🚇
Jacob Prince-Bieker
Jacob Prince-Bieker

💻
codderrrrr
codderrrrr

💻
Chris Briggs
Chris Briggs

💻
tmi
tmi

💻
Chris Arderne
Chris Arderne

💻
Dakshbir
Dakshbir

💻
MAYANK SHARMA
MAYANK SHARMA

💻
aryan lamba
aryan lamba

💻
michael-gendy
michael-gendy

💻
Aditya Suthar
Aditya Suthar

💻
Markus Kreft
Markus Kreft

💻
Jack Kelly
Jack Kelly

🤔
zaryab-ali
zaryab-ali

💻
Lex-Ashu
Lex-Ashu

💻
223 | 224 | 225 | 226 | 227 | 228 | 229 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! 230 | --------------------------------------------------------------------------------