├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── conv_block.py │ ├── .DS_Store │ ├── conv1d_module.py │ └── rcnn_module.py ├── datamodules │ ├── __init__.py │ ├── components │ │ └── __init__.py │ ├── .DS_Store │ └── weather_datamodule.py ├── .DS_Store ├── vendor │ └── __init__.py ├── utils │ ├── training_utils.py │ ├── data_utils.py │ ├── metrics.py │ ├── plotting.py │ └── __init__.py ├── testing_pipeline.py └── training_pipeline.py ├── configs ├── local │ └── .gitkeep ├── callbacks │ ├── none.yaml │ └── default.yaml ├── .DS_Store ├── trainer │ ├── ddp.yaml │ └── default.yaml ├── debug │ ├── test_only.yaml │ ├── overfit.yaml │ ├── step.yaml │ ├── profiler.yaml │ ├── limit_batches.yaml │ └── default.yaml ├── logger │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── comet.yaml │ ├── tensorboard.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ └── wandb.yaml ├── log_dir │ ├── debug.yaml │ ├── default.yaml │ └── evaluation.yaml ├── model │ ├── conv1d.yaml │ ├── vit.yaml │ └── rcnn.yaml ├── datamodule │ └── weather.yaml ├── experiment │ └── example.yaml ├── hparams_search │ └── optuna.yaml ├── train.yaml └── test.yaml ├── .DS_Store ├── tests └── .DS_Store ├── .gitignore ├── scripts └── schedule.sh ├── train.py ├── Dockerfile ├── requirements.txt ├── test.py ├── README.md ├── test_ensemble.py └── preprocess.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/tests/.DS_Store -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /src/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/models/.DS_Store -------------------------------------------------------------------------------- /src/datamodules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/datamodules/.DS_Store -------------------------------------------------------------------------------- /src/vendor/__init__.py: -------------------------------------------------------------------------------- 1 | # use this folder for storing third party code that cannot be installed using pip/conda 2 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 4 5 | strategy: ddp 6 | sync_batchnorm: True 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.csv 2 | **.sh 3 | **.ipynb 4 | **.npy 5 | **.png 6 | **.DS_Store 7 | *.pyc 8 | .ipynb_checkpoints/ 9 | data/ 10 | logs/ 11 | -------------------------------------------------------------------------------- /configs/debug/test_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs only test epoch 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | train: False 9 | test: True 10 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | -------------------------------------------------------------------------------- /configs/debug/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Shedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python train.py trainer.max_epochs=5 6 | 7 | python train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /configs/log_dir/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/debugs/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/debugs/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/log_dir/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/experiments/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/experiments/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/log_dir/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | project_name: "binary_classifier_weather_convlstm" 7 | experiment_name: ${name} 8 | -------------------------------------------------------------------------------- /configs/model/conv1d.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.conv1d_module.Conv1dModule 2 | 3 | n_cells_hor: ${n_cells_hor} 4 | n_cells_ver: ${n_cells_ver} 5 | history_length: ${history_length} 6 | periods_forward: ${periods_forward} 7 | batch_size: ${batch_size} 8 | lr: 0.003 9 | weight_decay: 0 10 | 11 | -------------------------------------------------------------------------------- /configs/debug/limit_batches.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 1 4 | 5 | min_epochs: 1 6 | max_epochs: ${num_epochs} 7 | 8 | # number of validation steps to execute at the beginning of the training 9 | # num_sanity_val_steps: 0 10 | 11 | # ckpt path 12 | resume_from_checkpoint: null 13 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: null 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /configs/model/vit.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.vit_module.ViTModule 2 | 3 | embed_dim: 16 4 | hidden_dim: 32 5 | num_channels: 1 6 | num_heads: 1 7 | num_layers: 1 8 | patch_size: 10 9 | num_patches: 5 10 | dropout: 0.3 11 | n_cells_hor: 200 12 | n_cells_ver: 250 13 | batch_size: 1 14 | lr: 0.003 15 | weight_decay: 0 16 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: ${name} 6 | tracking_uri: ${original_work_dir}/logs/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 7 | tags: null 8 | prefix: "" 9 | artifact_location: null 10 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: ${name} 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "template-tests" 6 | # name: ${name} 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /configs/model/rcnn.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.rcnn_module.RCNNModule 2 | 3 | mode: ${mode} # "classification" or "regression" 4 | embedding_size: 16 5 | hidden_state_size: 32 6 | kernel_size: 3 7 | groups: 1 8 | dilation: 1 9 | n_cells_hor: 100 10 | n_cells_ver: 100 11 | history_length: ${history_length} 12 | periods_forward: ${periods_forward} 13 | batch_size: ${batch_size} 14 | num_of_additional_features: ${num_of_additional_features} 15 | num_classes: ${num_classes} 16 | boundaries: ${boundaries} 17 | values_range: 10 18 | dropout: 0.0 19 | lr: 0.05 20 | weight_decay: 0 21 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | dotenv.load_dotenv(override=True) 8 | 9 | 10 | @hydra.main(config_path="configs/", config_name="train.yaml") 11 | def main(config: DictConfig): 12 | 13 | # Imports can be nested inside @hydra.main to optimize tab completion 14 | # https://github.com/facebookresearch/hydra/issues/934 15 | from src import utils 16 | from src.training_pipeline import train 17 | 18 | # Applies optional utilities 19 | utils.extras(config) 20 | 21 | # Train model 22 | return train(config) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /configs/datamodule/weather.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.weather_datamodule.WeatherDataModule 2 | 3 | mode: ${mode} 4 | data_dir: ${data_dir} # data_dir is specified in config.yaml 5 | dataset_name: ${dataset_name} 6 | left_border: 0 7 | down_border: 0 8 | right_border: 350 9 | up_border: 350 10 | time_col: "date" 11 | event_col: "val" 12 | x_col: "x" 13 | y_col: "y" 14 | train_val_test_split: [0.7, 0.3, 0.3] 15 | periods_forward: ${periods_forward} 16 | history_length: ${history_length} 17 | data_start: 0 18 | data_len: 1000 19 | feature_to_predict: ${feature_to_predict} 20 | num_of_additional_features: ${num_of_additional_features} 21 | additional_features: ${additional_features} 22 | boundaries: ${boundaries} 23 | patch_size: 8 24 | normalize: True 25 | batch_size: ${batch_size} 26 | num_workers: 0 27 | pin_memory: False 28 | -------------------------------------------------------------------------------- /src/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_mask(lengths, size): 5 | mask = torch.arange(size).unsqueeze(0).repeat(len(lengths), 1) 6 | mask = mask < lengths.unsqueeze(1) 7 | return mask 8 | 9 | 10 | def smape(pred, targ): 11 | return torch.mean(2 * torch.abs(pred - targ) / (torch.abs(pred) + torch.abs(targ))) 12 | 13 | 14 | def normalize(seq_x, seq_y): 15 | with torch.no_grad(): 16 | norm_consts = torch.max(seq_x[:, 0, :], dim=1).values.unsqueeze(1) 17 | seq_x[:, 0, :] /= norm_consts 18 | seq_y[:, :, 0] /= norm_consts 19 | return seq_x, seq_y, norm_consts 20 | 21 | 22 | def denormalize(seq_x, seq_y, out, norm_consts): 23 | with torch.no_grad(): 24 | seq_x[:, 0, :] *= norm_consts 25 | seq_y *= norm_consts 26 | out *= norm_consts 27 | return seq_x, seq_y, out 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Pull base image 2 | FROM ubuntu:20.04 3 | 4 | # Set environment variables 5 | ENV PYTHONDONTWRITEBYTECODE 1 6 | ENV PYTHONUNBUFFERED 1 7 | 8 | ENV TZ=Europe/Moscow 9 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 10 | 11 | # Set work directory 12 | WORKDIR /app 13 | 14 | # Install dependencies 15 | COPY requirements.txt /app 16 | RUN apt-get update && apt-get install -y python3 python3-pip libgdal-dev git vim 17 | RUN pip install -r requirements.txt 18 | RUN pip install --force-reinstall torch==1.10.1+cu113 --extra-index-url https://download.pytorch.org/whl/ 19 | # Install GDAL (for preprocessing) 20 | ARG CPLUS_INCLUDE_PATH=/usr/include/gdal 21 | ARG C_INCLUDE_PATH=/usr/include/gdal 22 | RUN pip3 install gdal==$(gdal-config --version) 23 | # Create folders 24 | RUN mkdir -p data/raw data/preprocessed data/celled 25 | # Copy project 26 | COPY . /app/ 27 | -------------------------------------------------------------------------------- /configs/experiment/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /datamodule: mnist.yaml 8 | - override /model: mnist.yaml 9 | - override /callbacks: default.yaml 10 | - override /logger: null 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | name: "simple_dense_net" 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 10 23 | max_epochs: 10 24 | gradient_clip_val: 0.5 25 | 26 | model: 27 | lr: 0.002 28 | net: 29 | lin1_size: 128 30 | lin2_size: 256 31 | lin3_size: 64 32 | 33 | datamodule: 34 | batch_size: 64 35 | 36 | logger: 37 | wandb: 38 | tags: ["mnist", "${name}"] 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch==1.10.1 3 | pytorch-lightning==1.5.1 4 | torchmetrics>=0.7.0 5 | 6 | # --------- hydra --------- # 7 | hydra>=2.5 8 | hydra-core>=1.3.0 9 | hydra-colorlog>=1.2.0 10 | hydra-optuna-sweeper>=1.1.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # tensorboard 15 | comet-ml>=3.33 16 | 17 | # --------- linters --------- # 18 | pre-commit # hooks for applying linters on commit 19 | black # code formatting 20 | isort # import sorting 21 | flake8 # code analysis 22 | nbstripout # remove output from jupyter notebooks 23 | 24 | # --------- others --------- # 25 | python-dotenv # loading env variables from .env file 26 | rich # beautiful text formatting in3 terminal 27 | pytest # tests 28 | sh # for running bash commands in some tests 29 | pudb # debugger 30 | seaborn>=0.10.1 # plotting utils 31 | tqdm 32 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | defaults: 7 | - override /log_dir: debug.yaml 8 | 9 | trainer: 10 | max_epochs: 1 11 | gpus: 0 # debuggers don't like gpus 12 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 13 | track_grad_norm: 2 # track gradient norm with loggers 14 | 15 | datamodule: 16 | num_workers: 0 # debuggers don't like multiprocessing 17 | pin_memory: False # disable gpu memory pin 18 | 19 | # sets level of all command line loggers to 'DEBUG' 20 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 21 | hydra: 22 | verbose: True 23 | 24 | # use this to set level of only chosen command line loggers to 'DEBUG': 25 | # verbose: [src.train, src.utils] 26 | 27 | # config is already printed by hydra when `hydra/verbose: True` 28 | print_config: False 29 | -------------------------------------------------------------------------------- /src/models/components/conv_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__( 6 | self, 7 | in_channels, 8 | out_channels, 9 | kernel_size, 10 | stride=1, 11 | padding=1, 12 | dilation=1, 13 | groups=1, 14 | ): 15 | super(ConvBlock, self).__init__() 16 | 17 | self.CONV = nn.Conv2d( 18 | in_channels, 19 | out_channels, 20 | kernel_size=kernel_size, 21 | stride=stride, 22 | padding=padding, 23 | dilation=dilation, 24 | groups=groups, 25 | bias=False, 26 | ) 27 | 28 | self.BNORM = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=False) 29 | 30 | self.MAXPOOL = nn.MaxPool2d(3, stride=1, padding=1, dilation=1) 31 | 32 | def forward(self, x): 33 | x = self.CONV(x) 34 | x = self.BNORM(x) 35 | x = self.MAXPOOL(x) 36 | 37 | return x 38 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/rocauc_median" # name of the logged metric which determines when model is improving 4 | mode: "max" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "val/rocauc_median" # name of the logged metric which determines when model is improving 15 | mode: "max" # "max" means higher metric value is better, can be also "min" 16 | patience: 50 # how many validation epochs of not improving until training stops 17 | min_delta: 0.02 # minimum change in the monitored metric needed to qualify as an improvement 18 | 19 | model_summary: 20 | _target_: pytorch_lightning.callbacks.RichModelSummary 21 | max_depth: -1 22 | 23 | rich_progress_bar: 24 | _target_: pytorch_lightning.callbacks.RichProgressBar 25 | -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pathlib 6 | import torch 7 | 8 | import tqdm 9 | 10 | 11 | def create_celled_data( 12 | data_path, 13 | dataset_name, 14 | time_col: str = "time", 15 | event_col: str = "val", 16 | x_col: str = "x", 17 | y_col: str = "y", 18 | ): 19 | data_path = pathlib.Path( 20 | data_path, 21 | dataset_name, 22 | ) 23 | 24 | df = pd.read_csv(data_path) 25 | df.sort_values(by=[time_col], inplace=True) 26 | df = df[[event_col, x_col, y_col, time_col]] 27 | 28 | indicies = range(df.shape[0]) 29 | start_date = int(df[time_col][indicies[0]]) 30 | finish_date = int(df[time_col][indicies[-1]]) 31 | n_cells_hor = df[x_col].max() - df[x_col].min() + 1 32 | n_cells_ver = df[y_col].max() - df[y_col].min() + 1 33 | celled_data = torch.zeros([finish_date - start_date + 1, n_cells_hor, n_cells_ver]) 34 | 35 | for i in tqdm.tqdm(indicies): 36 | x = int(df[x_col][i]) 37 | y = int(df[y_col][i]) 38 | celled_data[int(df[time_col][i]) - start_date, x, y] = df[event_col][i] 39 | 40 | return celled_data 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | import glob 4 | from pathlib import Path 5 | import torch 6 | import os 7 | from omegaconf import DictConfig 8 | 9 | # load environment variables from `.env` file if it exists 10 | # recursively searches for `.env` in all folders starting from work dir 11 | dotenv.load_dotenv(override=True) 12 | from src import utils 13 | 14 | log = utils.get_logger(__name__) 15 | 16 | @hydra.main(config_path="configs/", config_name="test.yaml") 17 | def main(config: DictConfig): 18 | 19 | # Imports can be nested inside @hydra.main to optimize tab completion 20 | # https://github.com/facebookresearch/hydra/issues/934 21 | from src import utils 22 | from src.utils import metrics 23 | from src.testing_pipeline import test 24 | 25 | # Applies optional utilities 26 | utils.extras(config) 27 | # Evaluate model 28 | chkpts = [] 29 | os.chdir("/Weather-Prediction-NN") 30 | path = config.ckpt_folder 31 | print(path) 32 | for ck in Path(path).rglob("*.ckpt"): 33 | if not "last" in str(ck): 34 | chkpts.append(ck) 35 | 36 | print(chkpts) 37 | config.ckpt_path = chkpts[-1] 38 | 39 | return test(config) 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | 45 | main() 46 | -------------------------------------------------------------------------------- /configs/hparams_search/optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/R2" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | sweeper: 18 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 19 | 20 | # storage URL to persist optimization results 21 | # for example, you can use SQLite if you set 'sqlite:///example.db' 22 | storage: null 23 | 24 | # name of the study to persist optimization results 25 | study_name: null 26 | 27 | # number of parallel workers 28 | n_jobs: 1 29 | 30 | # 'minimize' or 'maximize' the objective 31 | direction: maximize 32 | 33 | # total number of runs that will be executed 34 | n_trials: 25 35 | 36 | # choose Optuna hyperparameter sampler 37 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 38 | sampler: 39 | _target_: optuna.samplers.TPESampler 40 | seed: 23 41 | n_startup_trials: 10 # number of random sampling runs before optimization starts 42 | 43 | # define range of hyperparameters 44 | search_space: 45 | #datamodule.batch_size: 46 | # type: categorical 47 | # choices: [32, 64, 128] 48 | model.lr: 49 | type: float 50 | low: 0.0001 51 | high: 0.2 52 | #model.embedding_size: 53 | # type: categorical 54 | # choices: [16, 32, 64, 128, 256] 55 | #model.hidden_state_size: 56 | # type: categorical 57 | # choices: [16, 32, 64, 128, 256] 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvLSTM model for Weather Forecasting 2 | 3 | PyTorch Lightning implementation of drought forecasting (classification) model (Convolutional LSTM). Classification is based on [PDSI index](https://en.wikipedia.org/wiki/Palmer_drought_index), and its corresponding bins. 4 | 5 | 6 | 7 | We solve binary classification problem, where threshold for a drought could be adjusted in config file. 8 | 9 | ## Docker container launch 10 | 11 | First, build an image 12 | 13 | ``` 14 | docker build . -t= 15 | ``` 16 | Then run a container with required parameters 17 | 18 | ``` 19 | docker run --mount type=bind,source=/local_path/Droughts/,destination=/Droughts/ -p : --memory=64g --cpuset-cpus="0-7" --gpus '"device=0"' -it --rm --name= 20 | ``` 21 | 22 | ## Preprocessing ## 23 | 24 | Input is geospatial monthly data, downloaded as .tif from public sources (e.g. from Google Earth Engine) and put into "data/raw" folder. Naming convention is "region_feature.tif". Please run 25 | 26 | ``` 27 | python3 preprocess.py --region region_name --band feature_name --endyear last_year_of_data --endmonth last_month_of_data 28 | ``` 29 | 30 | Results (both as .csv and .npy files) could be found in "data/preprocessed" folder. 31 | 32 | ## Training ## 33 | 34 | To train model - first, change configs of datamodule and network (if necessary), edit necessary parameters (e.g. data path in train.yaml) - and then run 35 | ``` 36 | python3 train.py --config==train.yaml 37 | ``` 38 | 39 | Experiments results can be tracked via Comet ML (please add your token to logger config file or export it as enviromental variable) 40 | 41 | ## Inference ## 42 | 43 | To run model on test dataset, calculate metrics and save predictions - first, change configs of datamodule and network (if necessary), add path to model checkpoint - and then run 44 | ``` 45 | python3 test.py --config==test.yaml 46 | 47 | -------------------------------------------------------------------------------- /test_ensemble.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | import glob 4 | from pathlib import Path 5 | import torch 6 | import os 7 | from omegaconf import DictConfig 8 | 9 | # load environment variables from `.env` file if it exists 10 | # recursively searches for `.env` in all folders starting from work dir 11 | dotenv.load_dotenv(override=True) 12 | from src import utils 13 | 14 | log = utils.get_logger(__name__) 15 | 16 | @hydra.main(config_path="configs/", config_name="test.yaml") 17 | def main(config: DictConfig): 18 | 19 | # Imports can be nested inside @hydra.main to optimize tab completion 20 | # https://github.com/facebookresearch/hydra/issues/934 21 | from src import utils 22 | from src.utils import metrics 23 | from src.testing_pipeline import test 24 | 25 | # Applies optional utilities 26 | utils.extras(config) 27 | all_preds = [] 28 | os.chdir("/Weather-Prediction-NN") 29 | # Evaluate model 30 | chkpts = [] 31 | path = config.ckpt_folder 32 | for ck in Path(path).rglob("*.ckpt"): 33 | if not "last" in str(ck): 34 | chkpts.append(ck) 35 | for c in chkpts: 36 | config.ckpt_path = c 37 | preds, all_targets = test(config) 38 | all_preds.append(preds) 39 | 40 | all_preds = torch.stack((all_preds)) 41 | all_preds = torch.mean(all_preds, dim=0) 42 | rocauc_table, ap_table, f1_table = metrics.metrics_celled(all_targets, all_preds) 43 | res_rocauc = torch.median(rocauc_table) 44 | res_ap = torch.median(ap_table) 45 | res_f1 = torch.median(f1_table) 46 | log.info(f"test_ensemble_median_rocauc: {res_rocauc}") 47 | log.info(f"test_ensemble_median_ap: {res_ap}") 48 | log.info(f"test_ensemble_median_f1: {res_f1}") 49 | with open("ens.txt", "a") as f: 50 | f.write(config.ckpt_folder + "\n") 51 | f.write("median_rocauc: " + str(res_rocauc) + "\n") 52 | f.write("\n") 53 | f.write("median_ap: " + str(res_ap) + "\n") 54 | f.write("\n") 55 | f.write("median_f1: " + str(res_f1) + "\n") 56 | f.write("\n") 57 | return 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | 63 | main() 64 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchmetrics.classification import ( 4 | AUROC, 5 | AveragePrecision, 6 | ROC, 7 | ) 8 | from torcheval.metrics.functional import binary_f1_score, binary_accuracy 9 | 10 | 11 | def metrics_celled(all_targets, all_preds, mode: str = "train"): 12 | rocauc_table = torch.zeros(all_preds.shape[1], all_preds.shape[2]) 13 | rocauc = AUROC(task="binary", num_classes=1) 14 | rocauc_table = torch.tensor( 15 | [ 16 | [ 17 | rocauc(all_preds[:, x, y], all_targets[:, x, y]) 18 | for x in range(all_preds.shape[1]) 19 | ] 20 | for y in range(all_preds.shape[2]) 21 | ] 22 | ) 23 | rocauc_table = torch.nan_to_num(rocauc_table, nan=0.0) 24 | 25 | ap_table = torch.zeros(all_preds.shape[1], all_preds.shape[2]) 26 | acc_table = torch.zeros(all_preds.shape[1], all_preds.shape[2]) 27 | f1_table = torch.zeros(all_preds.shape[1], all_preds.shape[2]) 28 | thresholds = torch.zeros(all_preds.shape[1], all_preds.shape[2]) 29 | 30 | if mode == "test": 31 | ap = AveragePrecision(task="binary") 32 | roc = ROC(task="binary") 33 | for x in range(all_preds.shape[1]): 34 | for y in range(all_preds.shape[2]): 35 | ap_table[x][y] = ap(all_preds[:, x, y], all_targets[:, x, y]) 36 | fpr, tpr, thr = roc(all_preds[:, x, y], all_targets[:, x, y]) 37 | j_stat = tpr - fpr 38 | ind = torch.argmax(j_stat).item() 39 | thresholds[x][y] = thr[ind] 40 | acc_table[x][y] = binary_accuracy( 41 | all_preds[:, x, y], all_targets[:, x, y], threshold=thresholds[x][y] 42 | ) 43 | f1_table[x][y] = binary_f1_score( 44 | all_preds[:, x, y], all_targets[:, x, y], threshold=thresholds[x][y] 45 | ) 46 | 47 | ap_table = torch.nan_to_num(ap_table, nan=0.0) 48 | f1_table = torch.nan_to_num(f1_table, nan=0.0) 49 | acc_table = torch.nan_to_num(acc_table, nan=0.0) 50 | 51 | return rocauc_table, ap_table, f1_table, acc_table, thresholds 52 | -------------------------------------------------------------------------------- /src/testing_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import hydra 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything 7 | from pytorch_lightning.loggers import LightningLoggerBase 8 | 9 | from src import utils 10 | 11 | log = utils.get_logger(__name__) 12 | 13 | 14 | def test(config: DictConfig) -> None: 15 | """Contains minimal example of the testing pipeline. Evaluates given checkpoint on a testset. 16 | 17 | Args: 18 | config (DictConfig): Configuration composed by Hydra. 19 | 20 | Returns: 21 | None 22 | """ 23 | 24 | # Set seed for random number generators in pytorch, numpy and python.random 25 | if config.get("seed"): 26 | seed_everything(config.seed, workers=True) 27 | 28 | # Convert relative ckpt path to absolute path if necessary 29 | if not os.path.isabs(config.ckpt_path): 30 | config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path) 31 | 32 | # Init lightning datamodule 33 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 34 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 35 | datamodule.setup() 36 | config.model.n_cells_hor = datamodule.h 37 | config.model.n_cells_ver = datamodule.w 38 | 39 | # Init lightning model 40 | log.info(f"Instantiating model <{config.model._target_}>") 41 | model: LightningModule = hydra.utils.instantiate(config.model) 42 | model.global_avg = datamodule.data_test.global_avg 43 | 44 | # Init lightning loggers 45 | logger: List[LightningLoggerBase] = [] 46 | if "logger" in config: 47 | for _, lg_conf in config.logger.items(): 48 | if "_target_" in lg_conf: 49 | log.info(f"Instantiating logger <{lg_conf._target_}>") 50 | logger.append(hydra.utils.instantiate(lg_conf)) 51 | 52 | # Init lightning trainer 53 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 54 | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger) 55 | 56 | # Log hyperparameters 57 | if trainer.logger: 58 | trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) 59 | 60 | log.info("Starting testing!") 61 | trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path) 62 | 63 | return 64 | #return model.saved_predictions, model.saved_targets 65 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - datamodule: weather.yaml 7 | - model: rcnn.yaml 8 | - callbacks: default.yaml 9 | - logger: comet.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 10 | - trainer: default.yaml 11 | - log_dir: default.yaml 12 | 13 | # experiment configs allow for version control of specific configurations 14 | # e.g. best hyperparameters for each combination of model and datamodule 15 | - experiment: null 16 | 17 | # debugging config (enable through command line, e.g. `python train.py debug=default) 18 | - debug: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null # optuna.yaml 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # enable color logging 28 | - override hydra/hydra_logging: colorlog 29 | - override hydra/job_logging: colorlog 30 | 31 | # path to original working directory 32 | # hydra hijacks working directory by changing it to the new log directory 33 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 34 | original_work_dir: ${hydra:runtime.cwd} 35 | 36 | # path to folder with data 37 | data_dir: ${original_work_dir}/data/ 38 | dataset_name: "pdsi_Belarus_with_neighb.csv" 39 | batch_size: 8 40 | num_epochs: 10 41 | 42 | # pretty print config at the start of the run using Rich library 43 | print_config: True 44 | 45 | # disable python warnings if they annoy you 46 | ignore_warnings: True 47 | 48 | # set False to skip model training 49 | train: True 50 | 51 | # evaluate on test set, using best model weights achieved during training 52 | # lightning chooses best weights based on the metric specified in checkpoint callback 53 | test: True 54 | 55 | # seed for random number generators in pytorch, numpy and python.random 56 | seed: null 57 | 58 | # parameters of the model that are shared with dataloader 59 | history_length: 9 60 | periods_forward: 1 61 | mode: "classification" # "classification" or "regression" 62 | num_classes: 2 # for classifier 63 | boundaries: [-2] # for classifier 64 | feature_to_predict: "pdsi" 65 | num_of_additional_features: 0 66 | additional_features: [] #["pr", "pet", "tmmn", "tmmx"] 67 | 68 | # default name for the experiment, determines logging folder path 69 | # (you can overwrite this name in experiment configs) 70 | name: acc_newdimensions_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward} 71 | -------------------------------------------------------------------------------- /src/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import seaborn as sns 6 | import torch 7 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 8 | from matplotlib.figure import Figure 9 | from sklearn.metrics import confusion_matrix 10 | 11 | 12 | def make_heatmap(table, filename="rocauc_spatial.png", size=(8, 6)): 13 | fig = Figure(figsize=size, frameon=True) 14 | canvas = FigureCanvas(fig) 15 | ax = fig.add_subplot(111) 16 | ax = sns.heatmap(table, vmin=0.0, vmax=1.0) 17 | ax.set_xticklabels([]) 18 | ax.set_yticklabels([]) 19 | 20 | full_path = os.path.expanduser(filename) 21 | ax.figure.savefig(full_path) 22 | ax.cla() 23 | 24 | return full_path 25 | 26 | 27 | def make_cf_matrix( 28 | targets, preds, thresholds, filename: str = "cf_matrix.png", size=(8, 6) 29 | ): 30 | targets = torch.flatten(targets).cpu().numpy() 31 | for x in range(preds.shape[1]): 32 | for y in range(preds.shape[2]): 33 | preds[:,x,y] = torch.bucketize(preds[:,x,y], torch.Tensor([thresholds[x][y]]).cuda()) 34 | preds = torch.flatten(preds).cpu().numpy() 35 | 36 | cf_matrix = confusion_matrix(targets, preds) 37 | fig = Figure(figsize=size, frameon=True) 38 | ax = fig.add_subplot(111) 39 | ax = sns.heatmap( 40 | cf_matrix / np.sum(cf_matrix), annot=True, fmt=".2%", cmap="Blues", cbar=False 41 | ) 42 | 43 | full_path = os.path.expanduser(filename) 44 | ax.figure.savefig(full_path) 45 | ax.cla() 46 | 47 | return full_path 48 | 49 | 50 | def make_pred_vs_target_plot( 51 | preds, 52 | targets, 53 | title="Comparison", 54 | size=(8, 6), 55 | xlabel=None, 56 | xlabel_rotate=45, 57 | ylabel=None, 58 | ylabel_rotate=0, 59 | filename="forecasts.png", 60 | ): 61 | fig = Figure(figsize=size, frameon=False) 62 | canvas = FigureCanvas(fig) 63 | ax = fig.add_subplot(111) 64 | x_length = targets.shape[1] 65 | y_length = targets.shape[2] 66 | x_random = random.choice(list(range(x_length))) 67 | y_random = random.choice(list(range(y_length))) 68 | targets = targets.cpu() 69 | targets = torch.mean(targets, dim=[1, 2]) 70 | preds = preds.cpu() 71 | preds = torch.mean(preds, dim=[1, 2]) 72 | time_periods = np.arange(0, targets.shape[0]) 73 | ax.plot(time_periods, targets, "g-", label="actual") 74 | ax.plot(time_periods, preds, "b--", label="predictions") 75 | ax.legend() 76 | 77 | if xlabel: 78 | ax.set_xlabel(xlabel) 79 | if ylabel: 80 | ax.set_ylabel(ylabel) 81 | if title: 82 | ax.set_title(title) 83 | 84 | fig.tight_layout() 85 | filename = os.path.expanduser(filename) 86 | fig.savefig(filename) 87 | return fig 88 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from osgeo import gdal 2 | import numpy as np 3 | import pandas as pd 4 | import tqdm 5 | import argparse 6 | import os 7 | 8 | 9 | parser = argparse.ArgumentParser(description="preprocess tif files") 10 | parser.add_argument("--region", type=str, help="name of region") 11 | parser.add_argument( 12 | "--band", type=str, default="pdsi", help="name of variable to process" 13 | ) 14 | parser.add_argument("--endyear", type=int, default=2020, help="last year of data") 15 | parser.add_argument( 16 | "--endmonth", type=int, default=1, help="last month of data, from 1 to 12" 17 | ) 18 | args = parser.parse_args() 19 | region = args.region 20 | feature = args.band 21 | endyear = args.endyear 22 | endmonth = args.endmonth 23 | 24 | save_path = "data/preprocessed/" 25 | print(f"region {region}") 26 | print(f"band {feature}") 27 | 28 | 29 | ds = gdal.Open("data/raw/" + region + "_" + feature + ".tif") 30 | print(f"number of months {ds.RasterCount}") 31 | print(f"x dim {ds.RasterXSize}") 32 | print(f"y dim {ds.RasterYSize}") 33 | 34 | num_of_months = ds.RasterCount 35 | xsize = ds.RasterXSize 36 | ysize = ds.RasterYSize 37 | all_data = np.zeros((num_of_months, 1, ysize, xsize)) 38 | 39 | curr_month = endmonth 40 | curr_year = endyear 41 | total_df = pd.DataFrame(columns=["y", "x", "value", "date"]) 42 | 43 | 44 | for i in tqdm.tqdm(range(num_of_months - 1, 1, -1)): 45 | if curr_month == 0: 46 | curr_month = 12 47 | curr_year -= 1 48 | 49 | curr_date = str(curr_year) + "-" + str(curr_month) 50 | band = ds.GetRasterBand(i) 51 | data = band.ReadAsArray() 52 | # terraclim features need to be normalized 53 | if feature == "pdsi": 54 | data = data / 100 55 | elif feature == "pet" or feature == "tmmn" or feature == "tmmx": 56 | data = data / 10 57 | all_data[i][0] = data 58 | 59 | df_row = ( 60 | pd.DataFrame(data, columns=list(range(xsize))) 61 | .reset_index() 62 | .melt(id_vars="index") 63 | .rename(columns={"index": "y", "variable": "x"}) 64 | ) 65 | df_row["date"] = curr_date 66 | total_df = pd.concat([total_df, df_row]) 67 | 68 | curr_month -= 1 69 | 70 | np.save(save_path + region + "_" + feature + ".npy", all_data) 71 | total_df.to_csv(save_path + region + "_" + feature + ".csv") 72 | print(f"{region} global stats") 73 | print(f"mean: {np.mean(all_data)}") 74 | print(f"std: {np.std(all_data)}") 75 | num_of_channels = 1 76 | global_means = np.zeros((1, num_of_channels, 1, 1)) 77 | global_stds = np.zeros((1, num_of_channels, 1, 1)) 78 | global_means[0, 0, 0, 0] = np.mean(all_data) 79 | global_stds[0, 0, 0, 0] = np.std(all_data) 80 | np.save(save_path + region + "_" + feature + "_global_means.npy", global_means) 81 | np.save(save_path + region + "_" + feature + "_global_stds.npy", global_stds) 82 | -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - datamodule: weather.yaml 7 | - model: rcnn.yaml 8 | - callbacks: null 9 | - logger: comet.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 10 | - trainer: default.yaml 11 | - log_dir: evaluation.yaml 12 | 13 | # experiment configs allow for version control of specific configurations 14 | # e.g. best hyperparameters for each combination of model and datamodule 15 | - experiment: null 16 | 17 | # debugging config (enable through command line, e.g. `python train.py debug=default) 18 | - debug: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null # optuna.yaml 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # enable color logging 28 | - override hydra/hydra_logging: colorlog 29 | - override hydra/job_logging: colorlog 30 | 31 | # path to original working directory 32 | # hydra hijacks working directory by changing it to the new log directory 33 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 34 | original_work_dir: ${hydra:runtime.cwd} 35 | 36 | # path to folder with data 37 | data_dir: ${original_work_dir}/data/ 38 | dataset_name: "pdsi_Belarus_with_neighb.csv" 39 | batch_size: 1 40 | num_epochs: 100 41 | 42 | # pretty print config at the start of the run using Rich library 43 | print_config: True 44 | 45 | # disable python warnings if they annoy you 46 | ignore_warnings: True 47 | 48 | # set False to skip model training 49 | train: True 50 | 51 | # evaluate on test set, using best model weights achieved during training 52 | # lightning chooses best weights based on the metric specified in checkpoint callback 53 | test: True 54 | 55 | # seed for random number generators in pytorch, numpy and python.random 56 | seed: null 57 | 58 | # parameters of the model that are shared with dataloader 59 | n_cells_hor: 66 60 | n_cells_ver: 123 61 | mode: "classification" # "classification" or "regression" 62 | num_classes: 2 # for classifier 63 | boundaries: [-2] # for classifier 64 | feature_to_predict: "pdsi" 65 | num_of_additional_features: 0 66 | additional_features: [] #["pr", "pet", "tmmn", "tmmx"] 67 | 68 | # default name for the experiment, determines logging folder path 69 | # (you can overwrite this name in experiment configs) 70 | history_length: 6 71 | periods_forward: 1 72 | name: GlobalBaseline_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward} 73 | 74 | # checkpoints 75 | ckpt_folder: logs/experiments/runs/recalc_newdimensions_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward} 76 | ckpt_path: null 77 | -------------------------------------------------------------------------------- /src/training_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import hydra 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import ( 7 | Callback, 8 | LightningDataModule, 9 | LightningModule, 10 | Trainer, 11 | seed_everything, 12 | ) 13 | from pytorch_lightning.loggers import LightningLoggerBase 14 | 15 | from src import utils 16 | 17 | log = utils.get_logger(__name__) 18 | 19 | 20 | def train(config: DictConfig) -> Optional[float]: 21 | """Contains the training pipeline. Can additionally evaluate model on a testset, using best 22 | weights achieved during training. 23 | 24 | Args: 25 | config (DictConfig): Configuration composed by Hydra. 26 | 27 | Returns: 28 | Optional[float]: Metric score for hyperparameter optimization. 29 | """ 30 | 31 | # Set seed for random number generators in pytorch, numpy and python.random 32 | if config.get("seed"): 33 | seed_everything(config.seed, workers=True) 34 | 35 | # Convert relative ckpt path to absolute path if necessary 36 | ckpt_path = config.trainer.get("resume_from_checkpoint") 37 | if ckpt_path and not os.path.isabs(ckpt_path): 38 | config.trainer.resume_from_checkpoint = os.path.join( 39 | hydra.utils.get_original_cwd(), ckpt_path 40 | ) 41 | 42 | 43 | # Init lightning datamodule 44 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 45 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 46 | datamodule.setup() 47 | config.model.n_cells_hor = datamodule.h 48 | config.model.n_cells_ver = datamodule.w 49 | 50 | # Init lightning model 51 | log.info(f"Instantiating model <{config.model._target_}>") 52 | model: LightningModule = hydra.utils.instantiate(config.model) 53 | 54 | # Init lightning callbacks 55 | callbacks: List[Callback] = [] 56 | if "callbacks" in config: 57 | for _, cb_conf in config.callbacks.items(): 58 | if "_target_" in cb_conf: 59 | log.info(f"Instantiating callback <{cb_conf._target_}>") 60 | callbacks.append(hydra.utils.instantiate(cb_conf)) 61 | 62 | # Init lightning loggers 63 | logger: List[LightningLoggerBase] = [] 64 | if "logger" in config: 65 | for _, lg_conf in config.logger.items(): 66 | if "_target_" in lg_conf: 67 | log.info(f"Instantiating logger <{lg_conf._target_}>") 68 | logger.append(hydra.utils.instantiate(lg_conf)) 69 | 70 | # Init lightning trainer 71 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 72 | trainer: Trainer = hydra.utils.instantiate( 73 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 74 | ) 75 | 76 | # Send some parameters from config to all lightning loggers 77 | log.info("Logging hyperparameters!") 78 | utils.log_hyperparameters( 79 | config=config, 80 | model=model, 81 | datamodule=datamodule, 82 | trainer=trainer, 83 | callbacks=callbacks, 84 | logger=logger, 85 | ) 86 | 87 | # Train the model 88 | if config.get("train"): 89 | log.info("Starting training!") 90 | trainer.fit(model=model, datamodule=datamodule) 91 | 92 | # Get metric score for hyperparameter optimization 93 | optimized_metric = config.get("optimized_metric") 94 | if optimized_metric and optimized_metric not in trainer.callback_metrics: 95 | raise Exception( 96 | "Metric for hyperparameter optimization not found! " 97 | "Make sure the `optimized_metric` in `hparams_search` config is correct!" 98 | ) 99 | score = trainer.callback_metrics.get(optimized_metric) 100 | 101 | # Test the model 102 | if config.get("test"): 103 | ckpt_path = "best" 104 | if not config.get("train") or config.trainer.get("fast_dev_run"): 105 | ckpt_path = None 106 | log.info("Starting testing!") 107 | model.global_avg = datamodule.data_test.global_avg 108 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 109 | 110 | # Make sure everything closed properly 111 | log.info("Finalizing!") 112 | utils.finish( 113 | config=config, 114 | model=model, 115 | datamodule=datamodule, 116 | trainer=trainer, 117 | callbacks=callbacks, 118 | logger=logger, 119 | ) 120 | 121 | # Print path to best checkpoint 122 | if not config.trainer.get("fast_dev_run") and config.get("train"): 123 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") 124 | 125 | # Return metric score for hyperparameter optimization 126 | return score 127 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import List, Sequence 4 | 5 | import pytorch_lightning as pl 6 | import rich.syntax 7 | import rich.tree 8 | from omegaconf import DictConfig, OmegaConf 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | 12 | def get_logger(name=__name__) -> logging.Logger: 13 | """Initializes multi-GPU-friendly python command line logger.""" 14 | 15 | logger = logging.getLogger(name) 16 | 17 | # this ensures all logging levels get marked with the rank zero decorator 18 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 19 | for level in ( 20 | "debug", 21 | "info", 22 | "warning", 23 | "error", 24 | "exception", 25 | "fatal", 26 | "critical", 27 | ): 28 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 29 | 30 | return logger 31 | 32 | 33 | log = get_logger(__name__) 34 | 35 | 36 | def extras(config: DictConfig) -> None: 37 | """Applies optional utilities, controlled by config flags. 38 | 39 | Utilities: 40 | - Ignoring python warnings 41 | - Rich config printing 42 | """ 43 | 44 | # disable python warnings if 45 | if config.get("ignore_warnings"): 46 | log.info("Disabling python warnings! ") 47 | warnings.filterwarnings("ignore") 48 | 49 | # pretty print config tree using Rich library if 50 | if config.get("print_config"): 51 | log.info("Printing config tree with Rich! ") 52 | print_config(config, resolve=True) 53 | 54 | 55 | @rank_zero_only 56 | def print_config( 57 | config: DictConfig, 58 | print_order: Sequence[str] = ( 59 | "datamodule", 60 | "model", 61 | "callbacks", 62 | "logger", 63 | "trainer", 64 | ), 65 | resolve: bool = True, 66 | ) -> None: 67 | """Prints content of DictConfig using Rich library and its tree structure. 68 | 69 | Args: 70 | config (DictConfig): Configuration composed by Hydra. 71 | print_order (Sequence[str], optional): Determines in what order config components are printed. 72 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 73 | """ 74 | 75 | style = "dim" 76 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 77 | 78 | quee = [] 79 | 80 | for field in print_order: 81 | quee.append(field) if field in config else log.info( 82 | f"Field '{field}' not found in config" 83 | ) 84 | 85 | for field in config: 86 | if field not in quee: 87 | quee.append(field) 88 | 89 | for field in quee: 90 | branch = tree.add(field, style=style, guide_style=style) 91 | 92 | config_group = config[field] 93 | if isinstance(config_group, DictConfig): 94 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 95 | else: 96 | branch_content = str(config_group) 97 | 98 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 99 | 100 | rich.print(tree) 101 | 102 | with open("config_tree.log", "w") as file: 103 | rich.print(tree, file=file) 104 | 105 | 106 | @rank_zero_only 107 | def log_hyperparameters( 108 | config: DictConfig, 109 | model: pl.LightningModule, 110 | datamodule: pl.LightningDataModule, 111 | trainer: pl.Trainer, 112 | callbacks: List[pl.Callback], 113 | logger: List[pl.loggers.LightningLoggerBase], 114 | ) -> None: 115 | """Controls which config parts are saved by Lightning loggers. 116 | 117 | Additionaly saves: 118 | - number of model parameters 119 | """ 120 | 121 | if not trainer.logger: 122 | return 123 | 124 | hparams = {} 125 | 126 | # choose which parts of hydra config will be saved to loggers 127 | hparams["model"] = config["model"] 128 | 129 | # save number of model parameters 130 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 131 | hparams["model/params/trainable"] = sum( 132 | p.numel() for p in model.parameters() if p.requires_grad 133 | ) 134 | hparams["model/params/non_trainable"] = sum( 135 | p.numel() for p in model.parameters() if not p.requires_grad 136 | ) 137 | 138 | hparams["datamodule"] = config["datamodule"] 139 | hparams["trainer"] = config["trainer"] 140 | 141 | if "seed" in config: 142 | hparams["seed"] = config["seed"] 143 | if "callbacks" in config: 144 | hparams["callbacks"] = config["callbacks"] 145 | 146 | # send hparams to all loggers 147 | trainer.logger.log_hyperparams(hparams) 148 | 149 | 150 | def finish( 151 | config: DictConfig, 152 | model: pl.LightningModule, 153 | datamodule: pl.LightningDataModule, 154 | trainer: pl.Trainer, 155 | callbacks: List[pl.Callback], 156 | logger: List[pl.loggers.LightningLoggerBase], 157 | ) -> None: 158 | """Makes sure everything closed properly.""" 159 | 160 | # without this sweeps with wandb logger might crash! 161 | for lg in logger: 162 | if isinstance(lg, pl.loggers.wandb.WandbLogger): 163 | import wandb 164 | 165 | wandb.finish() 166 | -------------------------------------------------------------------------------- /src/models/conv1d_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import LightningModule 7 | from torch.autograd import Variable 8 | 9 | from src.models.components.conv_block import ConvBlock 10 | from src.utils.metrics import rmse, rsquared, smape 11 | from sklearn.metrics import r2_score 12 | from src.utils.plotting import make_heatmap, make_pred_vs_target_plot 13 | 14 | import seaborn as sns 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | 19 | class Conv1dModule(LightningModule): 20 | """Example of LightningModule for MNIST classification. 21 | 22 | A LightningModule organizes your PyTorch code into 5 sections: 23 | - Computations (init). 24 | - Train loop (training_step) 25 | - Validation loop (validation_step) 26 | - Test loop (test_step) 27 | - Optimizers (configure_optimizers) 28 | 29 | Read the docs: 30 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html 31 | """ 32 | 33 | def __init__( 34 | self, 35 | n_cells_hor: int = 200, 36 | n_cells_ver: int = 250, 37 | history_length: int = 1, 38 | periods_forward: int = 1, 39 | batch_size: int = 1, 40 | lr: float = 0.003, 41 | weight_decay: float = 0.0, 42 | ): 43 | super(Conv1dModule, self).__init__() 44 | 45 | # this line allows to access init params with 'self.hparams' attribute 46 | # it also ensures init params will be stored in ckpt 47 | self.save_hyperparameters(logger=False) 48 | 49 | self.n_cells_hor = n_cells_hor 50 | self.n_cells_ver = n_cells_ver 51 | self.history_length = history_length 52 | self.periods_forward = periods_forward 53 | self.batch_size = batch_size 54 | self.lr = lr 55 | self.weight_decay = weight_decay 56 | 57 | 58 | self.conv1x1 = nn.Conv2d( 59 | self.history_length, 60 | self.periods_forward, 61 | kernel_size=1, 62 | stride=1, 63 | padding=0, 64 | bias=False, 65 | ) 66 | 67 | # loss 68 | self.criterion = nn.MSELoss() 69 | 70 | def forward(self, x: torch.Tensor): 71 | 72 | prediction = self.conv1x1(x) 73 | 74 | return prediction 75 | 76 | def step(self, batch: Any): 77 | x, y = batch 78 | preds = self.forward(x) 79 | loss = self.criterion(preds, y) 80 | return loss, preds, y 81 | 82 | def training_step(self, batch: Any, batch_idx: int): 83 | loss, preds, targets = self.step(batch) 84 | 85 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 86 | 87 | # we can return here dict with any tensors 88 | # and then read it in some callback or in `training_epoch_end()`` below 89 | # remember to always return loss from `training_step()` or else backpropagation will fail! 90 | return {"loss": loss, "preds": preds, "targets": targets} 91 | 92 | def training_epoch_end(self, outputs: List[Any]): 93 | # `outputs` is a list of dicts returned from `training_step()` 94 | all_preds = outputs[0]["preds"] 95 | all_targets = outputs[0]["targets"] 96 | 97 | for i in range(1, len(outputs)): 98 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 99 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 100 | 101 | # log metrics 102 | r2table = rsquared(all_targets, all_preds, mode="mean") 103 | self.log("train/R2_std", np.std(r2table), on_epoch=True, prog_bar=True) 104 | self.log("train/R2", np.median(r2table), on_epoch=True, prog_bar=True) 105 | self.log("train/R2_min", np.min(r2table), on_epoch=True, prog_bar=True) 106 | self.log("train/R2_max", np.max(r2table), on_epoch=True, prog_bar=True) 107 | self.log("train/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True) 108 | 109 | def validation_step(self, batch: Any, batch_idx: int): 110 | loss, preds, targets = self.step(batch) 111 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 112 | 113 | return {"loss": loss, "preds": preds, "targets": targets} 114 | 115 | def validation_epoch_end(self, outputs: List[Any]): 116 | all_preds = outputs[0]["preds"] 117 | all_targets = outputs[0]["targets"] 118 | 119 | for i in range(1, len(outputs)): 120 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 121 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 122 | 123 | # log metrics 124 | r2table = rsquared(all_targets, all_preds, mode="mean") 125 | self.log("val/R2_std", np.std(r2table), on_epoch=True, prog_bar=True) 126 | self.log("val/R2", np.median(r2table), on_epoch=True, prog_bar=True) 127 | self.log("val/R2_min", np.min(r2table), on_epoch=True, prog_bar=True) 128 | self.log("val/R2_max", np.max(r2table), on_epoch=True, prog_bar=True) 129 | self.log("val/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True) 130 | 131 | 132 | def test_step(self, batch: Any, batch_idx: int): 133 | loss, preds, targets = self.step(batch) 134 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 135 | 136 | return {"loss": loss, "preds": preds, "targets": targets} 137 | 138 | def test_epoch_end(self, outputs: List[Any]): 139 | all_preds = outputs[0]["preds"] 140 | all_targets = outputs[0]["targets"] 141 | 142 | for i in range(1, len(outputs)): 143 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 144 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 145 | 146 | print(all_preds.shape) 147 | print(all_targets.shape) 148 | 149 | # log metrics 150 | test_r2table = rsquared(all_targets, all_preds, mode="full") 151 | self.log("test/R2_std", np.std(test_r2table), on_epoch=True, prog_bar=True) 152 | self.log( 153 | "test/R2_median", np.median(test_r2table), on_epoch=True, prog_bar=True 154 | ) 155 | self.log("test/R2_min", np.min(test_r2table), on_epoch=True, prog_bar=True) 156 | self.log("test/R2_max", np.max(test_r2table), on_epoch=True, prog_bar=True) 157 | self.log("test/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True) 158 | 159 | # log graphs 160 | mse_conv = [] 161 | r2_conv = [] 162 | 163 | for i in range(1, self.periods_forward+1): 164 | 165 | preds_i = all_preds[:, :i, :, :] 166 | targets_i = all_targets[:, :i, :, :] 167 | mse_conv.append(rmse(preds_i, targets_i).item()) 168 | mean_preds = torch.mean(preds_i, axis=(2, 3)) 169 | mean_targets = torch.mean(targets_i, axis=(2, 3)) 170 | r2_conv.append(r2_score(mean_targets.cpu().numpy(), mean_preds.cpu().numpy())) 171 | 172 | h = [i for i in range(1, self.periods_forward+1)] 173 | fig1 = plt.figure(figsize=(7, 7)) 174 | ax = fig1.add_subplot(1, 1, 1) 175 | sns.lineplot(x=h, y=mse_conv, ax = ax) 176 | ax.legend(['conv1d']) 177 | ax.set_xlabel("horizon (in months)") 178 | ax.set_title("MSE") 179 | fig1.savefig("conv_mse.png") 180 | 181 | fig2 = plt.figure(figsize=(7, 7)) 182 | ax = fig2.add_subplot(1, 1, 1) 183 | sns.lineplot(x=h, y=r2_conv, ax = ax) 184 | ax.legend(['conv1d']) 185 | ax.set_xlabel("horizon (in months)") 186 | ax.set_title("R2") 187 | fig2.savefig("r2_mse.png") 188 | 189 | 190 | def on_epoch_end(self): 191 | pass 192 | 193 | def configure_optimizers(self): 194 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 195 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 196 | 197 | See examples here: 198 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers 199 | """ 200 | return torch.optim.Adam( 201 | params=self.parameters(), 202 | lr=self.hparams.lr, 203 | weight_decay=self.hparams.weight_decay, 204 | ) 205 | -------------------------------------------------------------------------------- /src/datamodules/weather_datamodule.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Optional, Tuple, List 3 | 4 | import torch 5 | from pytorch_lightning import LightningDataModule 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from src.utils.data_utils import create_celled_data 9 | from src import utils 10 | 11 | log = utils.get_logger(__name__) 12 | 13 | 14 | class Dataset_RNN(Dataset): 15 | """ 16 | Simple Torch Dataset for many-to-many RNN 17 | celled_data: source of data, 18 | start_date: start date index, 19 | end_date: end date index, 20 | periods_forward: number of future periods for a target, 21 | history_length: number of past periods for an input, 22 | transforms: input data manipulations 23 | """ 24 | 25 | def __init__( 26 | self, 27 | celled_data: torch.Tensor, 28 | celled_features_list: List[torch.Tensor], 29 | start_date: int, 30 | end_date: int, 31 | periods_forward: int, 32 | history_length: int, 33 | boundaries: Optional[List[None]], 34 | mode, 35 | normalize: bool, 36 | moments: Optional[List[None]], 37 | global_avg: Optional[List[None]], 38 | ): 39 | self.data = celled_data[start_date:end_date, :, :] 40 | self.features = [ 41 | feature[start_date:end_date, :, :] 42 | for feature in celled_features_list 43 | ] 44 | self.periods_forward = periods_forward 45 | self.history_length = history_length 46 | self.mode = mode 47 | self.target = self.data 48 | # bins for pdsi 49 | self.boundaries = boundaries 50 | if self.mode == "classification": 51 | # 1 is for drought 52 | self.target = 1 - torch.bucketize(self.target, self.boundaries) 53 | if len(global_avg) > 0: 54 | self.global_avg = global_avg 55 | else: 56 | self.global_avg, _ = torch.mode(self.target, dim=0) 57 | # normalization 58 | if moments: 59 | self.moments = moments 60 | if normalize: 61 | self.data = (self.data - self.moments[0][0]) / ( 62 | self.moments[0][1] - self.moments[0][0] 63 | ) 64 | for i in range(1, len(self.moments)): 65 | self.features[i - 1] = ( 66 | self.features[i - 1] - self.moments[i][0] 67 | ) / (self.moments[i][1] - self.moments[i][0]) 68 | else: 69 | self.moments = [] 70 | if normalize: 71 | self.data = (self.data - torch.min(self.data)) / ( 72 | torch.max(self.data) - torch.min(self.data) 73 | ) 74 | self.moments.append((torch.min(self.data), torch.max(self.data))) 75 | for i in range(len(self.features)): 76 | self.features[i] = ( 77 | self.features[i] - torch.min(self.features[i]) 78 | ) / (torch.max(self.features[i]) - torch.min(self.features[i])) 79 | self.moments.append( 80 | (torch.min(self.features[i]), torch.max(self.features[i])) 81 | ) 82 | 83 | def __len__(self): 84 | return len(self.data) - self.periods_forward - self.history_length 85 | 86 | def __getitem__(self, idx): 87 | input_tensor = self.data[idx : idx + self.history_length] 88 | for feature in self.features: 89 | input_tensor = torch.cat( 90 | (input_tensor, feature[idx : idx + self.history_length]), dim=0 91 | ) 92 | 93 | target = self.target[ 94 | idx + self.history_length : idx + self.history_length + self.periods_forward 95 | ] 96 | 97 | return ( 98 | input_tensor, 99 | target, 100 | ) 101 | 102 | 103 | class WeatherDataModule(LightningDataModule): 104 | """LightningDataModule for Weather dataset. 105 | 106 | A DataModule implements 5 key methods: 107 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 108 | - setup (things to do on every accelerator in distributed mode) 109 | - train_dataloader (the training dataloader) 110 | - val_dataloader (the validation dataloader(s)) 111 | - test_dataloader (the test dataloader(s)) 112 | 113 | This allows you to share a full dataset without explaining how to download, 114 | split, transform and process the data. 115 | 116 | Read the docs: 117 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 118 | """ 119 | 120 | def __init__( 121 | self, 122 | mode: str = "regression", 123 | data_dir: str = "data", 124 | dataset_name: str = "dataset_name", 125 | left_border: int = 0, 126 | down_border: int = 0, 127 | right_border: int = 2000, 128 | up_border: int = 2500, 129 | time_col: str = "time", 130 | event_col: str = "value", 131 | x_col: str = "x", 132 | y_col: str = "y", 133 | train_val_test_split: Tuple[float] = (0.8, 0.1, 0.1), 134 | periods_forward: int = 1, 135 | history_length: int = 1, 136 | data_start: int = 0, 137 | data_len: int = 100, 138 | feature_to_predict: str = "pdsi", 139 | num_of_additional_features: int = 0, 140 | additional_features: Optional[List[str]] = None, 141 | boundaries: Optional[List[str]] = None, 142 | patch_size: int = 8, 143 | normalize: bool = False, 144 | batch_size: int = 64, 145 | num_workers: int = 0, 146 | pin_memory: bool = False, 147 | ): 148 | super().__init__() 149 | 150 | # this line allows to access init params with 'self.hparams' attribute 151 | self.save_hyperparameters(logger=False) 152 | self.mode = mode 153 | self.data_dir = data_dir 154 | self.dataset_name = dataset_name 155 | self.left_border = left_border 156 | self.right_border = right_border 157 | self.down_border = down_border 158 | self.up_border = up_border 159 | self.h = 0 160 | self.w = 0 161 | self.time_col = time_col 162 | self.event_col = event_col 163 | self.x_col = x_col 164 | self.y_col = y_col 165 | 166 | self.data_train: Optional[Dataset] = None 167 | self.data_val: Optional[Dataset] = None 168 | self.data_test: Optional[Dataset] = None 169 | self.train_val_test_split = train_val_test_split 170 | self.periods_forward = periods_forward 171 | self.history_length = history_length 172 | self.data_start = data_start 173 | self.data_len = data_len 174 | self.feature_to_predict = feature_to_predict 175 | self.num_of_features = num_of_additional_features + 1 176 | self.additional_features = additional_features 177 | self.boundaries = torch.Tensor(boundaries) 178 | self.patch_size = patch_size 179 | self.normalize = normalize 180 | 181 | self.batch_size = batch_size 182 | self.num_workers = num_workers 183 | self.pin_memory = pin_memory 184 | 185 | def prepare_data(self): 186 | """Download data if needed. 187 | 188 | This method is called only from a single GPU. 189 | Do not use it to assign state (self.x = y). 190 | """ 191 | celled_data_path = pathlib.Path(self.data_dir, "celled", self.dataset_name) 192 | if not celled_data_path.is_file(): 193 | celled_data = create_celled_data( 194 | self.data_dir, 195 | self.dataset_name, 196 | self.time_col, 197 | self.event_col, 198 | self.x_col, 199 | self.y_col, 200 | ) 201 | log.info(f"Original dataset shape: {celled_data.shape}") 202 | torch.save(celled_data, celled_data_path) 203 | 204 | data_dir_geo = self.dataset_name.split(self.feature_to_predict)[1] 205 | for feature in self.additional_features: 206 | celled_feature_path = pathlib.Path( 207 | self.data_dir, "celled", feature + data_dir_geo 208 | ) 209 | if not celled_feature_path.is_file(): 210 | celled_feature = create_celled_data( 211 | self.data_dir, 212 | feature + data_dir_geo, 213 | self.time_col, 214 | self.event_col, 215 | self.x_col, 216 | self.y_col, 217 | ) 218 | torch.save(celled_feature, celled_feature_path) 219 | 220 | def setup(self, stage: Optional[str] = None): 221 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 222 | 223 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`, 224 | so be careful not to execute the random split twice! The `stage` can be used to 225 | differentiate whether it's called before trainer.fit()` or `trainer.test()`. 226 | """ 227 | 228 | # load datasets only if they're not loaded already 229 | if not self.data_train and not self.data_val and not self.data_test: 230 | celled_data_path = pathlib.Path(self.data_dir, "celled", self.dataset_name) 231 | celled_data = torch.load(celled_data_path) 232 | # make borders divisible by patch size 233 | h = celled_data.shape[1] 234 | self.h = h - h % self.patch_size 235 | w = celled_data.shape[2] 236 | self.w = w - w % self.patch_size 237 | celled_data = celled_data[:, :self.h, :self.w] 238 | celled_data = celled_data[ 239 | self.data_start : self.data_start + self.data_len, 240 | self.down_border : self.up_border, 241 | self.left_border : self.right_border, 242 | ] 243 | # loading features 244 | celled_features_list = [] 245 | data_dir_geo = self.dataset_name.split(self.feature_to_predict)[1] 246 | for feature in self.additional_features: 247 | celled_feature_path = pathlib.Path( 248 | self.data_dir, "celled", feature + data_dir_geo 249 | ) 250 | celled_feature = torch.load(celled_feature_path) 251 | # make borders divisible by patch size 252 | h = celled_feature.shape[1] 253 | h = h - h % self.patch_size 254 | w = celled_feature.shape[2] 255 | w = w - w % self.patch_size 256 | celled_feature = celled_feature[:, :h, :w] 257 | celled_feature = celled_feature[ 258 | self.data_start : self.data_start + self.data_len, 259 | self.down_border : self.up_border, 260 | self.left_border : self.right_border, 261 | ] 262 | celled_features_list.append(celled_feature) 263 | 264 | train_start = 0 265 | train_end = int(self.train_val_test_split[0] * celled_data.shape[0]) 266 | self.data_train = Dataset_RNN( 267 | celled_data, 268 | celled_features_list, 269 | train_start, 270 | train_end, 271 | self.periods_forward, 272 | self.history_length, 273 | self.boundaries, 274 | self.mode, 275 | self.normalize, 276 | [], 277 | [], 278 | ) 279 | # valid_end = int( 280 | # (self.train_val_test_split[0] + self.train_val_test_split[1]) 281 | # * celled_data.shape[0] 282 | # ) 283 | valid_end = celled_data.shape[0] 284 | self.data_val = Dataset_RNN( 285 | celled_data, 286 | celled_features_list, 287 | train_end - self.history_length, 288 | valid_end, 289 | self.periods_forward, 290 | self.history_length, 291 | self.boundaries, 292 | self.mode, 293 | self.normalize, 294 | self.data_train.moments, 295 | self.data_train.global_avg, 296 | ) 297 | test_end = celled_data.shape[0] 298 | self.data_test = Dataset_RNN( 299 | celled_data, 300 | celled_features_list, 301 | train_end - self.history_length, 302 | test_end, 303 | self.periods_forward, 304 | self.history_length, 305 | self.boundaries, 306 | self.mode, 307 | self.normalize, 308 | self.data_train.moments, 309 | self.data_train.global_avg, 310 | ) 311 | log.info(f"train dataset shape {self.data_train.data.shape}") 312 | log.info(f"val dataset shape {self.data_val.data.shape}") 313 | log.info(f"test dataset shape {self.data_test.data.shape}") 314 | 315 | def train_dataloader(self): 316 | return DataLoader( 317 | self.data_train, 318 | batch_size=self.batch_size, 319 | num_workers=self.num_workers, 320 | pin_memory=self.pin_memory, 321 | shuffle=False, 322 | ) 323 | 324 | def val_dataloader(self): 325 | return DataLoader( 326 | self.data_val, 327 | batch_size=self.batch_size, 328 | num_workers=self.num_workers, 329 | pin_memory=self.pin_memory, 330 | shuffle=False, 331 | ) 332 | 333 | def test_dataloader(self): 334 | return DataLoader( 335 | self.data_test, 336 | batch_size=self.batch_size, 337 | num_workers=self.num_workers, 338 | pin_memory=self.pin_memory, 339 | shuffle=False, 340 | ) 341 | -------------------------------------------------------------------------------- /src/models/rcnn_module.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Any, List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from pytorch_lightning import LightningModule 8 | from sklearn.metrics import r2_score, roc_auc_score 9 | from torch.autograd import Variable 10 | 11 | from src.models.components.conv_block import ConvBlock 12 | from src.utils.metrics import metrics_celled 13 | from src.utils.plotting import make_heatmap, make_cf_matrix 14 | 15 | 16 | class ScaledTanh(nn.Module): 17 | def __init__(self, coef: int = 10): 18 | super().__init__() 19 | self.c = coef 20 | 21 | def forward(self, x): 22 | output = torch.mul(torch.tanh(x), self.c) 23 | return output 24 | 25 | 26 | class RCNNModule(LightningModule): 27 | """Example of LightningModule for MNIST classification. 28 | 29 | A LightningModule organizes your PyTorch code into 5 sections: 30 | - Computations (init). 31 | - Train loop (training_step) 32 | - Validation loop (validation_step) 33 | - Test loop (test_step) 34 | - Optimizers (configure_optimizers) 35 | 36 | Read the docs: 37 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html 38 | """ 39 | 40 | def __init__( 41 | self, 42 | mode: str = "regression", 43 | embedding_size: int = 16, 44 | hidden_state_size: int = 32, 45 | kernel_size: int = 3, 46 | groups: int = 1, 47 | dilation: int = 1, 48 | n_cells_hor: int = 200, 49 | n_cells_ver: int = 250, 50 | history_length: int = 1, 51 | periods_forward: int = 1, 52 | batch_size: int = 1, 53 | num_of_additional_features: int = 0, 54 | boundaries: List[int] = [-2], 55 | num_classes: int = 2, 56 | values_range: int = 10, 57 | dropout: float = 0.0, 58 | lr: float = 0.003, 59 | weight_decay: float = 0.0, 60 | ): 61 | super(self.__class__, self).__init__() 62 | 63 | # this line allows to access init params with 'self.hparams' attribute 64 | # it also ensures init params will be stored in ckpt 65 | self.save_hyperparameters(logger=False) 66 | 67 | self.n_cells_hor = n_cells_hor 68 | self.n_cells_ver = n_cells_ver 69 | self.history_length = history_length 70 | self.periods_forward = periods_forward 71 | self.batch_size = batch_size 72 | self.lr = lr 73 | self.weight_decay = weight_decay 74 | 75 | self.num_of_features = num_of_additional_features + 1 76 | self.tanh_coef = values_range 77 | # number of bins for pdsi 78 | self.dropout = torch.nn.Dropout2d(p=dropout) 79 | self.num_class = num_classes 80 | self.boundaries = torch.tensor(boundaries).cuda() 81 | 82 | self.emb_size = embedding_size 83 | self.hid_size = hidden_state_size 84 | self.kernel_size = kernel_size 85 | self.dilation = dilation 86 | self.groups = groups 87 | 88 | self.global_avg = None 89 | self.saved_predictions = None 90 | self.saved_targets = None 91 | 92 | self.embedding = nn.Sequential( 93 | ConvBlock( 94 | self.num_of_features * history_length, 95 | self.emb_size, 96 | self.kernel_size, 97 | stride=1, 98 | padding=self.kernel_size // 2, 99 | dilation=self.dilation, 100 | groups=self.groups, 101 | ), 102 | nn.ReLU(), 103 | ConvBlock( 104 | self.emb_size, 105 | self.emb_size, 106 | self.kernel_size, 107 | stride=1, 108 | padding=self.kernel_size // 2, 109 | ), 110 | ) 111 | 112 | self.f_t = nn.Sequential( 113 | ConvBlock( 114 | self.hid_size + self.emb_size, 115 | self.hid_size, 116 | self.kernel_size, 117 | stride=1, 118 | padding=self.kernel_size // 2, 119 | ), 120 | nn.Sigmoid(), 121 | ) 122 | self.i_t = nn.Sequential( 123 | ConvBlock( 124 | self.hid_size + self.emb_size, 125 | self.hid_size, 126 | self.kernel_size, 127 | stride=1, 128 | padding=self.kernel_size // 2, 129 | ), 130 | nn.Sigmoid(), 131 | ) 132 | self.c_t = nn.Sequential( 133 | ConvBlock( 134 | self.hid_size + self.emb_size, 135 | self.hid_size, 136 | self.kernel_size, 137 | stride=1, 138 | padding=self.kernel_size // 2, 139 | ), 140 | nn.Tanh(), 141 | ) 142 | self.o_t = nn.Sequential( 143 | ConvBlock( 144 | self.hid_size + self.emb_size, 145 | self.hid_size, 146 | self.kernel_size, 147 | stride=1, 148 | padding=self.kernel_size // 2, 149 | ), 150 | nn.Sigmoid(), 151 | ) 152 | 153 | self.final_conv = nn.Sequential( 154 | nn.Conv2d( 155 | self.hid_size, 156 | self.periods_forward, 157 | kernel_size=1, 158 | stride=1, 159 | padding=0, 160 | bias=False, 161 | ), 162 | ScaledTanh(self.tanh_coef), 163 | # nn.Tanh(), 164 | # nn.Conv2d( 165 | # self.hid_size, 166 | # self.periods_forward, 167 | # kernel_size=1, 168 | # stride=1, 169 | # padding=0, 170 | # bias=False, 171 | # ), 172 | ) 173 | 174 | self.final_classify = nn.Sequential( 175 | ConvBlock( 176 | self.hid_size, 177 | self.num_class, 178 | kernel_size=3, 179 | stride=1, 180 | padding=1, 181 | dilation=1, 182 | groups=1, 183 | ), 184 | nn.Sigmoid(), 185 | ) 186 | 187 | self.register_buffer( 188 | "prev_state_h", 189 | torch.zeros( 190 | self.batch_size, 191 | self.hid_size, 192 | self.n_cells_hor, 193 | self.n_cells_ver, 194 | requires_grad=False, 195 | ), 196 | ) 197 | self.register_buffer( 198 | "prev_state_c", 199 | torch.zeros( 200 | self.batch_size, 201 | self.hid_size, 202 | self.n_cells_hor, 203 | self.n_cells_ver, 204 | requires_grad=False, 205 | ), 206 | ) 207 | 208 | self.mode = mode 209 | # loss 210 | if self.mode == "regression": 211 | self.criterion = nn.MSELoss() 212 | self.loss_name = "MSE" 213 | else: 214 | self.criterion = nn.CrossEntropyLoss() 215 | self.loss_name = "CrossEntropy" 216 | 217 | def forward(self, x: torch.Tensor): 218 | prev_c = self.prev_state_c 219 | prev_h = self.prev_state_h 220 | x = self.dropout(x) 221 | x_emb = self.embedding(x) 222 | if x_emb.shape[0] < self.batch_size: 223 | x_emb = torch.nn.functional.pad( 224 | x_emb, pad=(0,0,0,0,0,0,0,self.batch_size - x_emb.shape[0]), value=0 225 | ) 226 | x_and_h = torch.cat([prev_h, x_emb], dim=1) 227 | 228 | f_i = self.f_t(x_and_h) 229 | i_i = self.i_t(x_and_h) 230 | c_i = self.c_t(x_and_h) 231 | o_i = self.o_t(x_and_h) 232 | 233 | # print("prev_c", prev_c.shape) 234 | # print("f_i", f_i.shape) 235 | # print("i_i", i_i.shape) 236 | # print("c_i", c_i.shape) 237 | 238 | next_c = prev_c * f_i + i_i * c_i 239 | next_h = torch.tanh(next_c) * o_i 240 | 241 | assert prev_h.shape == next_h.shape 242 | assert prev_c.shape == next_c.shape 243 | 244 | if self.mode == "regression": 245 | prediction = self.final_conv(next_h) 246 | elif self.mode == "classification": 247 | prediction = self.final_classify(next_h) 248 | self.prev_state_c = next_c 249 | self.prev_state_h = next_h 250 | 251 | return prediction 252 | 253 | def step(self, batch: Any): 254 | x, y = batch 255 | preds = self.forward(x) 256 | if y.shape[0] < self.batch_size: 257 | y = torch.nn.functional.pad( 258 | y, pad=(0,0,0,0,0,0,0,self.batch_size - y.shape[0]), value=0 259 | ) 260 | # checking last (forward) value of target 261 | loss = self.criterion(preds, y[:, -1, :, :]) 262 | return loss, preds, y[:, -1, :, :] 263 | 264 | def rolling_step(self, batch: Any): 265 | x, y = batch 266 | # x -> B*Hist*W*H or B*(Hist*Feat)*W*H 267 | # pdsi is first feature in tensor 268 | x = x[:, : self.history_length, :, :] 269 | rolling = torch.mean(x, dim=1) 270 | rolling_forecast = rolling[:, None, :, :] 271 | 272 | for i in range(1, self.periods_forward): 273 | x = torch.cat((x[:, 1:, :, :], rolling[:, None, :, :]), dim=1) 274 | rolling = torch.mean(x, dim=1) 275 | rolling_forecast = torch.cat( 276 | (rolling_forecast, rolling[:, None, :, :]), dim=1 277 | ) 278 | 279 | return rolling_forecast 280 | 281 | def class_baseline(self, batch: Any): 282 | x, y = batch 283 | # return most frequent class along history_dim 284 | x_binned = torch.bucketize(x, self.boundaries) 285 | most_freq_values, most_freq_indices = torch.mode(x_binned, dim=1) 286 | return most_freq_values 287 | 288 | def on_after_backward(self) -> None: 289 | self.prev_state_c.detach_() 290 | self.prev_state_h.detach_() 291 | 292 | def training_step(self, batch: Any, batch_idx: int): 293 | loss, preds, targets = self.step(batch) 294 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 295 | 296 | # we can return here dict with any tensors 297 | # and then read it in some callback or in `training_epoch_end()`` below 298 | # remember to always return loss from `training_step()` or else backpropagation will fail! 299 | return {"loss": loss, "preds": preds, "targets": targets} 300 | 301 | def training_epoch_end(self, outputs: List[Any]): 302 | # `outputs` is a list of dicts returned from `training_step()` 303 | all_targets = outputs[0]["targets"] 304 | all_preds = outputs[0]["preds"] 305 | 306 | for i in range(1, len(outputs)): 307 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 308 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 309 | all_preds = torch.softmax(all_preds, dim=1) 310 | all_preds = all_preds[:, 1, :, :] 311 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled(all_targets, all_preds) 312 | # log metrics 313 | if self.mode == "classification": 314 | self.log( 315 | "train/f1_median", 316 | torch.median(f1_table), 317 | on_epoch=True, 318 | prog_bar=True, 319 | ) 320 | self.log( 321 | "train/ap_median", 322 | torch.median(ap_table), 323 | on_epoch=True, 324 | prog_bar=True, 325 | ) 326 | self.log( 327 | "train/rocauc_median", 328 | torch.median(rocauc_table), 329 | on_epoch=True, 330 | prog_bar=True, 331 | ) 332 | self.log( 333 | "train/accuracy_median", 334 | torch.median(acc_table), 335 | on_epoch=True, 336 | prog_bar=True, 337 | ) 338 | 339 | # log metrics 340 | # r2table = rsquared(all_targets, all_preds, mode="mean") 341 | # self.log("train/R2_std", np.std(r2table), on_epoch=True, prog_bar=True) 342 | # self.log("train/R2", np.median(r2table), on_epoch=True, prog_bar=True) 343 | # self.log("train/R2_min", np.min(r2table), on_epoch=True, prog_bar=True) 344 | # self.log("train/R2_max", np.max(r2table), on_epoch=True, prog_bar=True) 345 | 346 | def validation_step(self, batch: Any, batch_idx: int): 347 | loss, preds, targets = self.step(batch) 348 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 349 | 350 | return {"loss": loss, "preds": preds, "targets": targets} 351 | 352 | def validation_epoch_end(self, outputs: List[Any]): 353 | all_targets = outputs[0]["targets"] 354 | all_preds = outputs[0]["preds"] 355 | 356 | for i in range(1, len(outputs)): 357 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 358 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 359 | 360 | all_preds = torch.softmax(all_preds, dim=1) 361 | all_preds = all_preds[:, 1, :, :] 362 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled(all_targets, all_preds) 363 | # log metrics 364 | if self.mode == "classification": 365 | self.log( 366 | "val/f1_median", 367 | torch.median(f1_table), 368 | on_epoch=True, 369 | prog_bar=True, 370 | ) 371 | self.log( 372 | "val/ap_median", 373 | torch.median(ap_table), 374 | on_epoch=True, 375 | prog_bar=True, 376 | ) 377 | self.log( 378 | "val/rocauc_median", 379 | torch.median(rocauc_table), 380 | on_epoch=True, 381 | prog_bar=True, 382 | ) 383 | self.log( 384 | "val/accuracy_median", 385 | torch.median(acc_table), 386 | on_epoch=True, 387 | prog_bar=True, 388 | ) 389 | 390 | 391 | # log metrics 392 | # r2table = rsquared(all_targets, all_preds, mode="mean") 393 | # self.log("val/R2_std", np.std(r2table), on_epoch=True, prog_bar=True) 394 | # self.log("val/R2", np.median(r2table), on_epoch=True, prog_bar=True) 395 | # self.log("val/R2_min", np.min(r2table), on_epoch=True, prog_bar=True) 396 | # self.log("val/R2_max", np.max(r2table), on_epoch=True, prog_bar=True) 397 | 398 | def test_step(self, batch: Any, batch_idx: int): 399 | loss, preds, targets = self.step(batch) 400 | if self.mode == "regression": 401 | baseline = self.rolling_step(batch) 402 | elif self.mode == "classification": 403 | baseline = self.class_baseline(batch) 404 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True) 405 | 406 | return {"loss": loss, "preds": preds, "targets": targets, "baseline": baseline} 407 | 408 | def test_epoch_end(self, outputs: List[Any]): 409 | all_targets = outputs[0]["targets"] 410 | all_baselines = outputs[0]["baseline"] 411 | all_preds = outputs[0]["preds"] 412 | 413 | for i in range(1, len(outputs)): 414 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0) 415 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0) 416 | all_baselines = torch.cat((all_baselines, outputs[i]["baseline"]), 0) 417 | 418 | # remove padded values 419 | init_len = all_baselines.shape[0] 420 | all_preds = all_preds[:init_len,:,:,:] 421 | all_targets = all_targets[:init_len,:,:] 422 | 423 | all_preds = torch.softmax(all_preds, dim=1) 424 | # log confusion matrix 425 | preds_for_cm = torch.argmax(all_preds, dim=1) 426 | # self.logger.experiment[0].log_confusion_matrix(torch.flatten(all_targets), torch.flatten(preds_for_cm)) 427 | # probability of first class 428 | all_preds = all_preds[:, 1, :, :] 429 | 430 | self.saved_predictions = all_preds 431 | self.saved_targets = all_targets 432 | 433 | # global baseline 434 | all_global_baselines = self.global_avg.to("cuda:0") 435 | all_global_baselines = all_global_baselines.unsqueeze(0).repeat( 436 | len(all_targets), 1, 1 437 | ) 438 | all_global_baselines = torch.where(all_global_baselines > 0,100.0, 0.0) 439 | all_global_baselines = all_global_baselines.double() 440 | # all zeros baseline - no drought 441 | all_zeros = torch.zeros( 442 | all_preds.shape[0], all_preds.shape[1], all_preds.shape[2], dtype=torch.long, 443 | ).to("cuda:0") 444 | all_zeros = all_zeros.double() 445 | rocauc_table_zeros, ap_table_zeros, f1_table_zeros, acc_table_zeros, thr = metrics_celled( 446 | all_targets, all_zeros, "test" 447 | ) 448 | rocauc_table_global, ap_table_global, f1_table_global, acc_table_global, thr = metrics_celled( 449 | all_targets, all_global_baselines, "test" 450 | ) 451 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled( 452 | all_targets, all_preds, "test" 453 | ) 454 | # log metrics 455 | if self.mode == "classification": 456 | self.log( 457 | "test/convlstm/f1_median", 458 | torch.median(f1_table), 459 | on_epoch=True, 460 | prog_bar=True, 461 | ) 462 | self.log( 463 | "test/convlstm/ap_median", 464 | torch.median(ap_table), 465 | on_epoch=True, 466 | prog_bar=True, 467 | ) 468 | self.log( 469 | "test/convlstm/rocauc_median", 470 | torch.median(rocauc_table), 471 | on_epoch=True, 472 | prog_bar=True, 473 | ) 474 | self.log( 475 | "test/convlstm/accuracy_median", 476 | torch.median(acc_table), 477 | on_epoch=True, 478 | prog_bar=True, 479 | ) 480 | self.log( 481 | "test/global/f1_median", 482 | torch.median(f1_table_global), 483 | on_epoch=True, 484 | prog_bar=True, 485 | ) 486 | self.log( 487 | "test/global/ap_median", 488 | torch.median(ap_table_global), 489 | on_epoch=True, 490 | prog_bar=True, 491 | ) 492 | self.log( 493 | "test/global/rocauc_median", 494 | torch.median(rocauc_table_global), 495 | on_epoch=True, 496 | prog_bar=True, 497 | ) 498 | self.log( 499 | "test/global/accuracy_median", 500 | torch.median(acc_table_global), 501 | on_epoch=True, 502 | prog_bar=True, 503 | ) 504 | self.log( 505 | "test/zeros/f1_median", 506 | torch.median(f1_table_zeros), 507 | on_epoch=True, 508 | prog_bar=True, 509 | ) 510 | self.log( 511 | "test/zeros/ap_median", 512 | torch.median(ap_table_zeros), 513 | on_epoch=True, 514 | prog_bar=True, 515 | ) 516 | self.log( 517 | "test/zeros/rocauc_median", 518 | torch.median(rocauc_table_zeros), 519 | on_epoch=True, 520 | prog_bar=True, 521 | ) 522 | self.log( 523 | "test/zeros/accuracy_median", 524 | torch.median(acc_table_zeros), 525 | on_epoch=True, 526 | prog_bar=True, 527 | ) 528 | 529 | rocauc_path = make_heatmap(rocauc_table, filename="rocauc_spatial.png") 530 | torch.save(rocauc_table, "rocauc_table.pt") 531 | self.logger.experiment[0].log_image(rocauc_path) 532 | 533 | # log metrics 534 | # test_r2table = rsquared(all_targets, all_preds, mode="full") 535 | # self.log("test/R2_std", np.std(test_r2table), on_epoch=True, prog_bar=True) 536 | # self.log( 537 | # "test/R2_median", np.median(test_r2table), on_epoch=True, prog_bar=True 538 | # ) 539 | # self.log("test/R2_min", np.min(test_r2table), on_epoch=True, prog_bar=True) 540 | # self.log("test/R2_max", np.max(test_r2table), on_epoch=True, prog_bar=True) 541 | if self.mode == "regression": 542 | self.log( 543 | "test/baseline_MSE", 544 | self.criterion(all_baselines, all_targets), 545 | on_epoch=True, 546 | prog_bar=True, 547 | ) 548 | 549 | def on_epoch_end(self): 550 | pass 551 | 552 | def configure_optimizers(self): 553 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 554 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 555 | 556 | See examples here: 557 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers 558 | """ 559 | return torch.optim.Adam( 560 | params=self.parameters(), 561 | lr=self.hparams.lr, 562 | weight_decay=self.hparams.weight_decay, 563 | ) 564 | --------------------------------------------------------------------------------