├── src
├── __init__.py
├── models
│ ├── __init__.py
│ ├── components
│ │ ├── __init__.py
│ │ └── conv_block.py
│ ├── .DS_Store
│ ├── conv1d_module.py
│ └── rcnn_module.py
├── datamodules
│ ├── __init__.py
│ ├── components
│ │ └── __init__.py
│ ├── .DS_Store
│ └── weather_datamodule.py
├── .DS_Store
├── vendor
│ └── __init__.py
├── utils
│ ├── training_utils.py
│ ├── data_utils.py
│ ├── metrics.py
│ ├── plotting.py
│ └── __init__.py
├── testing_pipeline.py
└── training_pipeline.py
├── configs
├── local
│ └── .gitkeep
├── callbacks
│ ├── none.yaml
│ └── default.yaml
├── .DS_Store
├── trainer
│ ├── ddp.yaml
│ └── default.yaml
├── debug
│ ├── test_only.yaml
│ ├── overfit.yaml
│ ├── step.yaml
│ ├── profiler.yaml
│ ├── limit_batches.yaml
│ └── default.yaml
├── logger
│ ├── csv.yaml
│ ├── many_loggers.yaml
│ ├── comet.yaml
│ ├── tensorboard.yaml
│ ├── mlflow.yaml
│ ├── neptune.yaml
│ └── wandb.yaml
├── log_dir
│ ├── debug.yaml
│ ├── default.yaml
│ └── evaluation.yaml
├── model
│ ├── conv1d.yaml
│ ├── vit.yaml
│ └── rcnn.yaml
├── datamodule
│ └── weather.yaml
├── experiment
│ └── example.yaml
├── hparams_search
│ └── optuna.yaml
├── train.yaml
└── test.yaml
├── .DS_Store
├── tests
└── .DS_Store
├── .gitignore
├── scripts
└── schedule.sh
├── train.py
├── Dockerfile
├── requirements.txt
├── test.py
├── README.md
├── test_ensemble.py
└── preprocess.py
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/local/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/datamodules/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/.DS_Store
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/.DS_Store
--------------------------------------------------------------------------------
/tests/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/tests/.DS_Store
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/configs/.DS_Store
--------------------------------------------------------------------------------
/src/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/models/.DS_Store
--------------------------------------------------------------------------------
/src/datamodules/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VGrabar/Weather-Prediction-NN/HEAD/src/datamodules/.DS_Store
--------------------------------------------------------------------------------
/src/vendor/__init__.py:
--------------------------------------------------------------------------------
1 | # use this folder for storing third party code that cannot be installed using pip/conda
2 |
--------------------------------------------------------------------------------
/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | gpus: 4
5 | strategy: ddp
6 | sync_batchnorm: True
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **.csv
2 | **.sh
3 | **.ipynb
4 | **.npy
5 | **.png
6 | **.DS_Store
7 | *.pyc
8 | .ipynb_checkpoints/
9 | data/
10 | logs/
11 |
--------------------------------------------------------------------------------
/configs/debug/test_only.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs only test epoch
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | train: False
9 | test: True
10 |
--------------------------------------------------------------------------------
/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "."
6 | name: "csv/"
7 | prefix: ""
8 |
--------------------------------------------------------------------------------
/configs/debug/overfit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # overfits to 3 batches
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 20
10 | overfit_batches: 3
11 |
--------------------------------------------------------------------------------
/configs/debug/step.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs 1 train, 1 validation and 1 test step
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | fast_dev_run: true
10 |
--------------------------------------------------------------------------------
/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | # - comet.yaml
5 | - csv.yaml
6 | # - mlflow.yaml
7 | # - neptune.yaml
8 | - tensorboard.yaml
9 | - wandb.yaml
10 |
--------------------------------------------------------------------------------
/scripts/schedule.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Shedule execution of many runs
3 | # Run from root folder with: bash scripts/schedule.sh
4 |
5 | python train.py trainer.max_epochs=5
6 |
7 | python train.py trainer.max_epochs=10 logger=csv
8 |
--------------------------------------------------------------------------------
/configs/log_dir/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | run:
5 | dir: logs/debugs/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6 | sweep:
7 | dir: logs/debugs/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8 | subdir: ${hydra.job.num}
9 |
--------------------------------------------------------------------------------
/configs/debug/profiler.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs with execution time profiling
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 1
10 | profiler: "simple"
11 | # profiler: "advanced"
12 | # profiler: "pytorch"
13 |
--------------------------------------------------------------------------------
/configs/log_dir/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | run:
5 | dir: logs/experiments/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6 | sweep:
7 | dir: logs/experiments/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8 | subdir: ${hydra.job.num}
9 |
--------------------------------------------------------------------------------
/configs/log_dir/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | run:
5 | dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6 | sweep:
7 | dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8 | subdir: ${hydra.job.num}
9 |
--------------------------------------------------------------------------------
/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: pytorch_lightning.loggers.comet.CometLogger
5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6 | project_name: "binary_classifier_weather_convlstm"
7 | experiment_name: ${name}
8 |
--------------------------------------------------------------------------------
/configs/model/conv1d.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.conv1d_module.Conv1dModule
2 |
3 | n_cells_hor: ${n_cells_hor}
4 | n_cells_ver: ${n_cells_ver}
5 | history_length: ${history_length}
6 | periods_forward: ${periods_forward}
7 | batch_size: ${batch_size}
8 | lr: 0.003
9 | weight_decay: 0
10 |
11 |
--------------------------------------------------------------------------------
/configs/debug/limit_batches.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # uses only 1% of the training data and 5% of validation/test data
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 3
10 | limit_train_batches: 0.01
11 | limit_val_batches: 0.05
12 | limit_test_batches: 0.05
13 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | gpus: 1
4 |
5 | min_epochs: 1
6 | max_epochs: ${num_epochs}
7 |
8 | # number of validation steps to execute at the beginning of the training
9 | # num_sanity_val_steps: 0
10 |
11 | # ckpt path
12 | resume_from_checkpoint: null
13 |
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # https://www.tensorflow.org/tensorboard/
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "tensorboard/"
6 | name: null
7 | version: ${name}
8 | log_graph: False
9 | default_hp_metric: True
10 | prefix: ""
11 |
--------------------------------------------------------------------------------
/configs/model/vit.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.vit_module.ViTModule
2 |
3 | embed_dim: 16
4 | hidden_dim: 32
5 | num_channels: 1
6 | num_heads: 1
7 | num_layers: 1
8 | patch_size: 10
9 | num_patches: 5
10 | dropout: 0.3
11 | n_cells_hor: 200
12 | n_cells_ver: 250
13 | batch_size: 1
14 | lr: 0.003
15 | weight_decay: 0
16 |
--------------------------------------------------------------------------------
/configs/logger/mlflow.yaml:
--------------------------------------------------------------------------------
1 | # https://mlflow.org
2 |
3 | mlflow:
4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
5 | experiment_name: ${name}
6 | tracking_uri: ${original_work_dir}/logs/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
7 | tags: null
8 | prefix: ""
9 | artifact_location: null
10 |
--------------------------------------------------------------------------------
/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6 | project_name: your_name/template-tests
7 | close_after_fit: True
8 | offline_mode: False
9 | experiment_name: ${name}
10 | experiment_id: null
11 | prefix: ""
12 |
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
5 | project: "template-tests"
6 | # name: ${name}
7 | save_dir: "."
8 | offline: False # set True to store all logs only locally
9 | id: null # pass correct id to resume experiment!
10 | # entity: "" # set to name of your wandb team
11 | log_model: False
12 | prefix: ""
13 | job_type: "train"
14 | group: ""
15 | tags: []
16 |
--------------------------------------------------------------------------------
/configs/model/rcnn.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.rcnn_module.RCNNModule
2 |
3 | mode: ${mode} # "classification" or "regression"
4 | embedding_size: 16
5 | hidden_state_size: 32
6 | kernel_size: 3
7 | groups: 1
8 | dilation: 1
9 | n_cells_hor: 100
10 | n_cells_ver: 100
11 | history_length: ${history_length}
12 | periods_forward: ${periods_forward}
13 | batch_size: ${batch_size}
14 | num_of_additional_features: ${num_of_additional_features}
15 | num_classes: ${num_classes}
16 | boundaries: ${boundaries}
17 | values_range: 10
18 | dropout: 0.0
19 | lr: 0.05
20 | weight_decay: 0
21 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | import hydra
3 | from omegaconf import DictConfig
4 |
5 | # load environment variables from `.env` file if it exists
6 | # recursively searches for `.env` in all folders starting from work dir
7 | dotenv.load_dotenv(override=True)
8 |
9 |
10 | @hydra.main(config_path="configs/", config_name="train.yaml")
11 | def main(config: DictConfig):
12 |
13 | # Imports can be nested inside @hydra.main to optimize tab completion
14 | # https://github.com/facebookresearch/hydra/issues/934
15 | from src import utils
16 | from src.training_pipeline import train
17 |
18 | # Applies optional utilities
19 | utils.extras(config)
20 |
21 | # Train model
22 | return train(config)
23 |
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
--------------------------------------------------------------------------------
/configs/datamodule/weather.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodules.weather_datamodule.WeatherDataModule
2 |
3 | mode: ${mode}
4 | data_dir: ${data_dir} # data_dir is specified in config.yaml
5 | dataset_name: ${dataset_name}
6 | left_border: 0
7 | down_border: 0
8 | right_border: 350
9 | up_border: 350
10 | time_col: "date"
11 | event_col: "val"
12 | x_col: "x"
13 | y_col: "y"
14 | train_val_test_split: [0.7, 0.3, 0.3]
15 | periods_forward: ${periods_forward}
16 | history_length: ${history_length}
17 | data_start: 0
18 | data_len: 1000
19 | feature_to_predict: ${feature_to_predict}
20 | num_of_additional_features: ${num_of_additional_features}
21 | additional_features: ${additional_features}
22 | boundaries: ${boundaries}
23 | patch_size: 8
24 | normalize: True
25 | batch_size: ${batch_size}
26 | num_workers: 0
27 | pin_memory: False
28 |
--------------------------------------------------------------------------------
/src/utils/training_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def make_mask(lengths, size):
5 | mask = torch.arange(size).unsqueeze(0).repeat(len(lengths), 1)
6 | mask = mask < lengths.unsqueeze(1)
7 | return mask
8 |
9 |
10 | def smape(pred, targ):
11 | return torch.mean(2 * torch.abs(pred - targ) / (torch.abs(pred) + torch.abs(targ)))
12 |
13 |
14 | def normalize(seq_x, seq_y):
15 | with torch.no_grad():
16 | norm_consts = torch.max(seq_x[:, 0, :], dim=1).values.unsqueeze(1)
17 | seq_x[:, 0, :] /= norm_consts
18 | seq_y[:, :, 0] /= norm_consts
19 | return seq_x, seq_y, norm_consts
20 |
21 |
22 | def denormalize(seq_x, seq_y, out, norm_consts):
23 | with torch.no_grad():
24 | seq_x[:, 0, :] *= norm_consts
25 | seq_y *= norm_consts
26 | out *= norm_consts
27 | return seq_x, seq_y, out
28 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Pull base image
2 | FROM ubuntu:20.04
3 |
4 | # Set environment variables
5 | ENV PYTHONDONTWRITEBYTECODE 1
6 | ENV PYTHONUNBUFFERED 1
7 |
8 | ENV TZ=Europe/Moscow
9 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
10 |
11 | # Set work directory
12 | WORKDIR /app
13 |
14 | # Install dependencies
15 | COPY requirements.txt /app
16 | RUN apt-get update && apt-get install -y python3 python3-pip libgdal-dev git vim
17 | RUN pip install -r requirements.txt
18 | RUN pip install --force-reinstall torch==1.10.1+cu113 --extra-index-url https://download.pytorch.org/whl/
19 | # Install GDAL (for preprocessing)
20 | ARG CPLUS_INCLUDE_PATH=/usr/include/gdal
21 | ARG C_INCLUDE_PATH=/usr/include/gdal
22 | RUN pip3 install gdal==$(gdal-config --version)
23 | # Create folders
24 | RUN mkdir -p data/raw data/preprocessed data/celled
25 | # Copy project
26 | COPY . /app/
27 |
--------------------------------------------------------------------------------
/configs/experiment/example.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /datamodule: mnist.yaml
8 | - override /model: mnist.yaml
9 | - override /callbacks: default.yaml
10 | - override /logger: null
11 | - override /trainer: default.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | # name of the run determines folder name in logs
17 | name: "simple_dense_net"
18 |
19 | seed: 12345
20 |
21 | trainer:
22 | min_epochs: 10
23 | max_epochs: 10
24 | gradient_clip_val: 0.5
25 |
26 | model:
27 | lr: 0.002
28 | net:
29 | lin1_size: 128
30 | lin2_size: 256
31 | lin3_size: 64
32 |
33 | datamodule:
34 | batch_size: 64
35 |
36 | logger:
37 | wandb:
38 | tags: ["mnist", "${name}"]
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # --------- pytorch --------- #
2 | torch==1.10.1
3 | pytorch-lightning==1.5.1
4 | torchmetrics>=0.7.0
5 |
6 | # --------- hydra --------- #
7 | hydra>=2.5
8 | hydra-core>=1.3.0
9 | hydra-colorlog>=1.2.0
10 | hydra-optuna-sweeper>=1.1.0
11 |
12 | # --------- loggers --------- #
13 | # wandb
14 | # tensorboard
15 | comet-ml>=3.33
16 |
17 | # --------- linters --------- #
18 | pre-commit # hooks for applying linters on commit
19 | black # code formatting
20 | isort # import sorting
21 | flake8 # code analysis
22 | nbstripout # remove output from jupyter notebooks
23 |
24 | # --------- others --------- #
25 | python-dotenv # loading env variables from .env file
26 | rich # beautiful text formatting in3 terminal
27 | pytest # tests
28 | sh # for running bash commands in some tests
29 | pudb # debugger
30 | seaborn>=0.10.1 # plotting utils
31 | tqdm
32 |
--------------------------------------------------------------------------------
/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | defaults:
7 | - override /log_dir: debug.yaml
8 |
9 | trainer:
10 | max_epochs: 1
11 | gpus: 0 # debuggers don't like gpus
12 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
13 | track_grad_norm: 2 # track gradient norm with loggers
14 |
15 | datamodule:
16 | num_workers: 0 # debuggers don't like multiprocessing
17 | pin_memory: False # disable gpu memory pin
18 |
19 | # sets level of all command line loggers to 'DEBUG'
20 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
21 | hydra:
22 | verbose: True
23 |
24 | # use this to set level of only chosen command line loggers to 'DEBUG':
25 | # verbose: [src.train, src.utils]
26 |
27 | # config is already printed by hydra when `hydra/verbose: True`
28 | print_config: False
29 |
--------------------------------------------------------------------------------
/src/models/components/conv_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class ConvBlock(nn.Module):
5 | def __init__(
6 | self,
7 | in_channels,
8 | out_channels,
9 | kernel_size,
10 | stride=1,
11 | padding=1,
12 | dilation=1,
13 | groups=1,
14 | ):
15 | super(ConvBlock, self).__init__()
16 |
17 | self.CONV = nn.Conv2d(
18 | in_channels,
19 | out_channels,
20 | kernel_size=kernel_size,
21 | stride=stride,
22 | padding=padding,
23 | dilation=dilation,
24 | groups=groups,
25 | bias=False,
26 | )
27 |
28 | self.BNORM = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=False)
29 |
30 | self.MAXPOOL = nn.MaxPool2d(3, stride=1, padding=1, dilation=1)
31 |
32 | def forward(self, x):
33 | x = self.CONV(x)
34 | x = self.BNORM(x)
35 | x = self.MAXPOOL(x)
36 |
37 | return x
38 |
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
3 | monitor: "val/rocauc_median" # name of the logged metric which determines when model is improving
4 | mode: "max" # "max" means higher metric value is better, can be also "min"
5 | save_top_k: 1 # save k best models (determined by above metric)
6 | save_last: True # additionaly always save model from last epoch
7 | verbose: False
8 | dirpath: "checkpoints/"
9 | filename: "epoch_{epoch:03d}"
10 | auto_insert_metric_name: False
11 |
12 | early_stopping:
13 | _target_: pytorch_lightning.callbacks.EarlyStopping
14 | monitor: "val/rocauc_median" # name of the logged metric which determines when model is improving
15 | mode: "max" # "max" means higher metric value is better, can be also "min"
16 | patience: 50 # how many validation epochs of not improving until training stops
17 | min_delta: 0.02 # minimum change in the monitored metric needed to qualify as an improvement
18 |
19 | model_summary:
20 | _target_: pytorch_lightning.callbacks.RichModelSummary
21 | max_depth: -1
22 |
23 | rich_progress_bar:
24 | _target_: pytorch_lightning.callbacks.RichProgressBar
25 |
--------------------------------------------------------------------------------
/src/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pathlib
6 | import torch
7 |
8 | import tqdm
9 |
10 |
11 | def create_celled_data(
12 | data_path,
13 | dataset_name,
14 | time_col: str = "time",
15 | event_col: str = "val",
16 | x_col: str = "x",
17 | y_col: str = "y",
18 | ):
19 | data_path = pathlib.Path(
20 | data_path,
21 | dataset_name,
22 | )
23 |
24 | df = pd.read_csv(data_path)
25 | df.sort_values(by=[time_col], inplace=True)
26 | df = df[[event_col, x_col, y_col, time_col]]
27 |
28 | indicies = range(df.shape[0])
29 | start_date = int(df[time_col][indicies[0]])
30 | finish_date = int(df[time_col][indicies[-1]])
31 | n_cells_hor = df[x_col].max() - df[x_col].min() + 1
32 | n_cells_ver = df[y_col].max() - df[y_col].min() + 1
33 | celled_data = torch.zeros([finish_date - start_date + 1, n_cells_hor, n_cells_ver])
34 |
35 | for i in tqdm.tqdm(indicies):
36 | x = int(df[x_col][i])
37 | y = int(df[y_col][i])
38 | celled_data[int(df[time_col][i]) - start_date, x, y] = df[event_col][i]
39 |
40 | return celled_data
41 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | import hydra
3 | import glob
4 | from pathlib import Path
5 | import torch
6 | import os
7 | from omegaconf import DictConfig
8 |
9 | # load environment variables from `.env` file if it exists
10 | # recursively searches for `.env` in all folders starting from work dir
11 | dotenv.load_dotenv(override=True)
12 | from src import utils
13 |
14 | log = utils.get_logger(__name__)
15 |
16 | @hydra.main(config_path="configs/", config_name="test.yaml")
17 | def main(config: DictConfig):
18 |
19 | # Imports can be nested inside @hydra.main to optimize tab completion
20 | # https://github.com/facebookresearch/hydra/issues/934
21 | from src import utils
22 | from src.utils import metrics
23 | from src.testing_pipeline import test
24 |
25 | # Applies optional utilities
26 | utils.extras(config)
27 | # Evaluate model
28 | chkpts = []
29 | os.chdir("/Weather-Prediction-NN")
30 | path = config.ckpt_folder
31 | print(path)
32 | for ck in Path(path).rglob("*.ckpt"):
33 | if not "last" in str(ck):
34 | chkpts.append(ck)
35 |
36 | print(chkpts)
37 | config.ckpt_path = chkpts[-1]
38 |
39 | return test(config)
40 |
41 |
42 | if __name__ == "__main__":
43 |
44 |
45 | main()
46 |
--------------------------------------------------------------------------------
/configs/hparams_search/optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python train.py -m hparams_search=mnist_optuna experiment=example
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/R2"
12 |
13 | # here we define Optuna hyperparameter search
14 | # it optimizes for value returned from function with @hydra.main decorator
15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16 | hydra:
17 | sweeper:
18 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
19 |
20 | # storage URL to persist optimization results
21 | # for example, you can use SQLite if you set 'sqlite:///example.db'
22 | storage: null
23 |
24 | # name of the study to persist optimization results
25 | study_name: null
26 |
27 | # number of parallel workers
28 | n_jobs: 1
29 |
30 | # 'minimize' or 'maximize' the objective
31 | direction: maximize
32 |
33 | # total number of runs that will be executed
34 | n_trials: 25
35 |
36 | # choose Optuna hyperparameter sampler
37 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
38 | sampler:
39 | _target_: optuna.samplers.TPESampler
40 | seed: 23
41 | n_startup_trials: 10 # number of random sampling runs before optimization starts
42 |
43 | # define range of hyperparameters
44 | search_space:
45 | #datamodule.batch_size:
46 | # type: categorical
47 | # choices: [32, 64, 128]
48 | model.lr:
49 | type: float
50 | low: 0.0001
51 | high: 0.2
52 | #model.embedding_size:
53 | # type: categorical
54 | # choices: [16, 32, 64, 128, 256]
55 | #model.hidden_state_size:
56 | # type: categorical
57 | # choices: [16, 32, 64, 128, 256]
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConvLSTM model for Weather Forecasting
2 |
3 | PyTorch Lightning implementation of drought forecasting (classification) model (Convolutional LSTM). Classification is based on [PDSI index](https://en.wikipedia.org/wiki/Palmer_drought_index), and its corresponding bins.
4 |
5 |
6 |
7 | We solve binary classification problem, where threshold for a drought could be adjusted in config file.
8 |
9 | ## Docker container launch
10 |
11 | First, build an image
12 |
13 | ```
14 | docker build . -t=
15 | ```
16 | Then run a container with required parameters
17 |
18 | ```
19 | docker run --mount type=bind,source=/local_path/Droughts/,destination=/Droughts/ -p : --memory=64g --cpuset-cpus="0-7" --gpus '"device=0"' -it --rm --name=
20 | ```
21 |
22 | ## Preprocessing ##
23 |
24 | Input is geospatial monthly data, downloaded as .tif from public sources (e.g. from Google Earth Engine) and put into "data/raw" folder. Naming convention is "region_feature.tif". Please run
25 |
26 | ```
27 | python3 preprocess.py --region region_name --band feature_name --endyear last_year_of_data --endmonth last_month_of_data
28 | ```
29 |
30 | Results (both as .csv and .npy files) could be found in "data/preprocessed" folder.
31 |
32 | ## Training ##
33 |
34 | To train model - first, change configs of datamodule and network (if necessary), edit necessary parameters (e.g. data path in train.yaml) - and then run
35 | ```
36 | python3 train.py --config==train.yaml
37 | ```
38 |
39 | Experiments results can be tracked via Comet ML (please add your token to logger config file or export it as enviromental variable)
40 |
41 | ## Inference ##
42 |
43 | To run model on test dataset, calculate metrics and save predictions - first, change configs of datamodule and network (if necessary), add path to model checkpoint - and then run
44 | ```
45 | python3 test.py --config==test.yaml
46 |
47 |
--------------------------------------------------------------------------------
/test_ensemble.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | import hydra
3 | import glob
4 | from pathlib import Path
5 | import torch
6 | import os
7 | from omegaconf import DictConfig
8 |
9 | # load environment variables from `.env` file if it exists
10 | # recursively searches for `.env` in all folders starting from work dir
11 | dotenv.load_dotenv(override=True)
12 | from src import utils
13 |
14 | log = utils.get_logger(__name__)
15 |
16 | @hydra.main(config_path="configs/", config_name="test.yaml")
17 | def main(config: DictConfig):
18 |
19 | # Imports can be nested inside @hydra.main to optimize tab completion
20 | # https://github.com/facebookresearch/hydra/issues/934
21 | from src import utils
22 | from src.utils import metrics
23 | from src.testing_pipeline import test
24 |
25 | # Applies optional utilities
26 | utils.extras(config)
27 | all_preds = []
28 | os.chdir("/Weather-Prediction-NN")
29 | # Evaluate model
30 | chkpts = []
31 | path = config.ckpt_folder
32 | for ck in Path(path).rglob("*.ckpt"):
33 | if not "last" in str(ck):
34 | chkpts.append(ck)
35 | for c in chkpts:
36 | config.ckpt_path = c
37 | preds, all_targets = test(config)
38 | all_preds.append(preds)
39 |
40 | all_preds = torch.stack((all_preds))
41 | all_preds = torch.mean(all_preds, dim=0)
42 | rocauc_table, ap_table, f1_table = metrics.metrics_celled(all_targets, all_preds)
43 | res_rocauc = torch.median(rocauc_table)
44 | res_ap = torch.median(ap_table)
45 | res_f1 = torch.median(f1_table)
46 | log.info(f"test_ensemble_median_rocauc: {res_rocauc}")
47 | log.info(f"test_ensemble_median_ap: {res_ap}")
48 | log.info(f"test_ensemble_median_f1: {res_f1}")
49 | with open("ens.txt", "a") as f:
50 | f.write(config.ckpt_folder + "\n")
51 | f.write("median_rocauc: " + str(res_rocauc) + "\n")
52 | f.write("\n")
53 | f.write("median_ap: " + str(res_ap) + "\n")
54 | f.write("\n")
55 | f.write("median_f1: " + str(res_f1) + "\n")
56 | f.write("\n")
57 | return
58 |
59 |
60 | if __name__ == "__main__":
61 |
62 |
63 | main()
64 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchmetrics.classification import (
4 | AUROC,
5 | AveragePrecision,
6 | ROC,
7 | )
8 | from torcheval.metrics.functional import binary_f1_score, binary_accuracy
9 |
10 |
11 | def metrics_celled(all_targets, all_preds, mode: str = "train"):
12 | rocauc_table = torch.zeros(all_preds.shape[1], all_preds.shape[2])
13 | rocauc = AUROC(task="binary", num_classes=1)
14 | rocauc_table = torch.tensor(
15 | [
16 | [
17 | rocauc(all_preds[:, x, y], all_targets[:, x, y])
18 | for x in range(all_preds.shape[1])
19 | ]
20 | for y in range(all_preds.shape[2])
21 | ]
22 | )
23 | rocauc_table = torch.nan_to_num(rocauc_table, nan=0.0)
24 |
25 | ap_table = torch.zeros(all_preds.shape[1], all_preds.shape[2])
26 | acc_table = torch.zeros(all_preds.shape[1], all_preds.shape[2])
27 | f1_table = torch.zeros(all_preds.shape[1], all_preds.shape[2])
28 | thresholds = torch.zeros(all_preds.shape[1], all_preds.shape[2])
29 |
30 | if mode == "test":
31 | ap = AveragePrecision(task="binary")
32 | roc = ROC(task="binary")
33 | for x in range(all_preds.shape[1]):
34 | for y in range(all_preds.shape[2]):
35 | ap_table[x][y] = ap(all_preds[:, x, y], all_targets[:, x, y])
36 | fpr, tpr, thr = roc(all_preds[:, x, y], all_targets[:, x, y])
37 | j_stat = tpr - fpr
38 | ind = torch.argmax(j_stat).item()
39 | thresholds[x][y] = thr[ind]
40 | acc_table[x][y] = binary_accuracy(
41 | all_preds[:, x, y], all_targets[:, x, y], threshold=thresholds[x][y]
42 | )
43 | f1_table[x][y] = binary_f1_score(
44 | all_preds[:, x, y], all_targets[:, x, y], threshold=thresholds[x][y]
45 | )
46 |
47 | ap_table = torch.nan_to_num(ap_table, nan=0.0)
48 | f1_table = torch.nan_to_num(f1_table, nan=0.0)
49 | acc_table = torch.nan_to_num(acc_table, nan=0.0)
50 |
51 | return rocauc_table, ap_table, f1_table, acc_table, thresholds
52 |
--------------------------------------------------------------------------------
/src/testing_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import hydra
5 | from omegaconf import DictConfig
6 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
7 | from pytorch_lightning.loggers import LightningLoggerBase
8 |
9 | from src import utils
10 |
11 | log = utils.get_logger(__name__)
12 |
13 |
14 | def test(config: DictConfig) -> None:
15 | """Contains minimal example of the testing pipeline. Evaluates given checkpoint on a testset.
16 |
17 | Args:
18 | config (DictConfig): Configuration composed by Hydra.
19 |
20 | Returns:
21 | None
22 | """
23 |
24 | # Set seed for random number generators in pytorch, numpy and python.random
25 | if config.get("seed"):
26 | seed_everything(config.seed, workers=True)
27 |
28 | # Convert relative ckpt path to absolute path if necessary
29 | if not os.path.isabs(config.ckpt_path):
30 | config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path)
31 |
32 | # Init lightning datamodule
33 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
34 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
35 | datamodule.setup()
36 | config.model.n_cells_hor = datamodule.h
37 | config.model.n_cells_ver = datamodule.w
38 |
39 | # Init lightning model
40 | log.info(f"Instantiating model <{config.model._target_}>")
41 | model: LightningModule = hydra.utils.instantiate(config.model)
42 | model.global_avg = datamodule.data_test.global_avg
43 |
44 | # Init lightning loggers
45 | logger: List[LightningLoggerBase] = []
46 | if "logger" in config:
47 | for _, lg_conf in config.logger.items():
48 | if "_target_" in lg_conf:
49 | log.info(f"Instantiating logger <{lg_conf._target_}>")
50 | logger.append(hydra.utils.instantiate(lg_conf))
51 |
52 | # Init lightning trainer
53 | log.info(f"Instantiating trainer <{config.trainer._target_}>")
54 | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger)
55 |
56 | # Log hyperparameters
57 | if trainer.logger:
58 | trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path})
59 |
60 | log.info("Starting testing!")
61 | trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path)
62 |
63 | return
64 | #return model.saved_predictions, model.saved_targets
65 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - datamodule: weather.yaml
7 | - model: rcnn.yaml
8 | - callbacks: default.yaml
9 | - logger: comet.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
10 | - trainer: default.yaml
11 | - log_dir: default.yaml
12 |
13 | # experiment configs allow for version control of specific configurations
14 | # e.g. best hyperparameters for each combination of model and datamodule
15 | - experiment: null
16 |
17 | # debugging config (enable through command line, e.g. `python train.py debug=default)
18 | - debug: null
19 |
20 | # config for hyperparameter optimization
21 | - hparams_search: null # optuna.yaml
22 |
23 | # optional local config for machine/user specific settings
24 | # it's optional since it doesn't need to exist and is excluded from version control
25 | - optional local: default.yaml
26 |
27 | # enable color logging
28 | - override hydra/hydra_logging: colorlog
29 | - override hydra/job_logging: colorlog
30 |
31 | # path to original working directory
32 | # hydra hijacks working directory by changing it to the new log directory
33 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
34 | original_work_dir: ${hydra:runtime.cwd}
35 |
36 | # path to folder with data
37 | data_dir: ${original_work_dir}/data/
38 | dataset_name: "pdsi_Belarus_with_neighb.csv"
39 | batch_size: 8
40 | num_epochs: 10
41 |
42 | # pretty print config at the start of the run using Rich library
43 | print_config: True
44 |
45 | # disable python warnings if they annoy you
46 | ignore_warnings: True
47 |
48 | # set False to skip model training
49 | train: True
50 |
51 | # evaluate on test set, using best model weights achieved during training
52 | # lightning chooses best weights based on the metric specified in checkpoint callback
53 | test: True
54 |
55 | # seed for random number generators in pytorch, numpy and python.random
56 | seed: null
57 |
58 | # parameters of the model that are shared with dataloader
59 | history_length: 9
60 | periods_forward: 1
61 | mode: "classification" # "classification" or "regression"
62 | num_classes: 2 # for classifier
63 | boundaries: [-2] # for classifier
64 | feature_to_predict: "pdsi"
65 | num_of_additional_features: 0
66 | additional_features: [] #["pr", "pet", "tmmn", "tmmx"]
67 |
68 | # default name for the experiment, determines logging folder path
69 | # (you can overwrite this name in experiment configs)
70 | name: acc_newdimensions_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward}
71 |
--------------------------------------------------------------------------------
/src/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import seaborn as sns
6 | import torch
7 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8 | from matplotlib.figure import Figure
9 | from sklearn.metrics import confusion_matrix
10 |
11 |
12 | def make_heatmap(table, filename="rocauc_spatial.png", size=(8, 6)):
13 | fig = Figure(figsize=size, frameon=True)
14 | canvas = FigureCanvas(fig)
15 | ax = fig.add_subplot(111)
16 | ax = sns.heatmap(table, vmin=0.0, vmax=1.0)
17 | ax.set_xticklabels([])
18 | ax.set_yticklabels([])
19 |
20 | full_path = os.path.expanduser(filename)
21 | ax.figure.savefig(full_path)
22 | ax.cla()
23 |
24 | return full_path
25 |
26 |
27 | def make_cf_matrix(
28 | targets, preds, thresholds, filename: str = "cf_matrix.png", size=(8, 6)
29 | ):
30 | targets = torch.flatten(targets).cpu().numpy()
31 | for x in range(preds.shape[1]):
32 | for y in range(preds.shape[2]):
33 | preds[:,x,y] = torch.bucketize(preds[:,x,y], torch.Tensor([thresholds[x][y]]).cuda())
34 | preds = torch.flatten(preds).cpu().numpy()
35 |
36 | cf_matrix = confusion_matrix(targets, preds)
37 | fig = Figure(figsize=size, frameon=True)
38 | ax = fig.add_subplot(111)
39 | ax = sns.heatmap(
40 | cf_matrix / np.sum(cf_matrix), annot=True, fmt=".2%", cmap="Blues", cbar=False
41 | )
42 |
43 | full_path = os.path.expanduser(filename)
44 | ax.figure.savefig(full_path)
45 | ax.cla()
46 |
47 | return full_path
48 |
49 |
50 | def make_pred_vs_target_plot(
51 | preds,
52 | targets,
53 | title="Comparison",
54 | size=(8, 6),
55 | xlabel=None,
56 | xlabel_rotate=45,
57 | ylabel=None,
58 | ylabel_rotate=0,
59 | filename="forecasts.png",
60 | ):
61 | fig = Figure(figsize=size, frameon=False)
62 | canvas = FigureCanvas(fig)
63 | ax = fig.add_subplot(111)
64 | x_length = targets.shape[1]
65 | y_length = targets.shape[2]
66 | x_random = random.choice(list(range(x_length)))
67 | y_random = random.choice(list(range(y_length)))
68 | targets = targets.cpu()
69 | targets = torch.mean(targets, dim=[1, 2])
70 | preds = preds.cpu()
71 | preds = torch.mean(preds, dim=[1, 2])
72 | time_periods = np.arange(0, targets.shape[0])
73 | ax.plot(time_periods, targets, "g-", label="actual")
74 | ax.plot(time_periods, preds, "b--", label="predictions")
75 | ax.legend()
76 |
77 | if xlabel:
78 | ax.set_xlabel(xlabel)
79 | if ylabel:
80 | ax.set_ylabel(ylabel)
81 | if title:
82 | ax.set_title(title)
83 |
84 | fig.tight_layout()
85 | filename = os.path.expanduser(filename)
86 | fig.savefig(filename)
87 | return fig
88 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | from osgeo import gdal
2 | import numpy as np
3 | import pandas as pd
4 | import tqdm
5 | import argparse
6 | import os
7 |
8 |
9 | parser = argparse.ArgumentParser(description="preprocess tif files")
10 | parser.add_argument("--region", type=str, help="name of region")
11 | parser.add_argument(
12 | "--band", type=str, default="pdsi", help="name of variable to process"
13 | )
14 | parser.add_argument("--endyear", type=int, default=2020, help="last year of data")
15 | parser.add_argument(
16 | "--endmonth", type=int, default=1, help="last month of data, from 1 to 12"
17 | )
18 | args = parser.parse_args()
19 | region = args.region
20 | feature = args.band
21 | endyear = args.endyear
22 | endmonth = args.endmonth
23 |
24 | save_path = "data/preprocessed/"
25 | print(f"region {region}")
26 | print(f"band {feature}")
27 |
28 |
29 | ds = gdal.Open("data/raw/" + region + "_" + feature + ".tif")
30 | print(f"number of months {ds.RasterCount}")
31 | print(f"x dim {ds.RasterXSize}")
32 | print(f"y dim {ds.RasterYSize}")
33 |
34 | num_of_months = ds.RasterCount
35 | xsize = ds.RasterXSize
36 | ysize = ds.RasterYSize
37 | all_data = np.zeros((num_of_months, 1, ysize, xsize))
38 |
39 | curr_month = endmonth
40 | curr_year = endyear
41 | total_df = pd.DataFrame(columns=["y", "x", "value", "date"])
42 |
43 |
44 | for i in tqdm.tqdm(range(num_of_months - 1, 1, -1)):
45 | if curr_month == 0:
46 | curr_month = 12
47 | curr_year -= 1
48 |
49 | curr_date = str(curr_year) + "-" + str(curr_month)
50 | band = ds.GetRasterBand(i)
51 | data = band.ReadAsArray()
52 | # terraclim features need to be normalized
53 | if feature == "pdsi":
54 | data = data / 100
55 | elif feature == "pet" or feature == "tmmn" or feature == "tmmx":
56 | data = data / 10
57 | all_data[i][0] = data
58 |
59 | df_row = (
60 | pd.DataFrame(data, columns=list(range(xsize)))
61 | .reset_index()
62 | .melt(id_vars="index")
63 | .rename(columns={"index": "y", "variable": "x"})
64 | )
65 | df_row["date"] = curr_date
66 | total_df = pd.concat([total_df, df_row])
67 |
68 | curr_month -= 1
69 |
70 | np.save(save_path + region + "_" + feature + ".npy", all_data)
71 | total_df.to_csv(save_path + region + "_" + feature + ".csv")
72 | print(f"{region} global stats")
73 | print(f"mean: {np.mean(all_data)}")
74 | print(f"std: {np.std(all_data)}")
75 | num_of_channels = 1
76 | global_means = np.zeros((1, num_of_channels, 1, 1))
77 | global_stds = np.zeros((1, num_of_channels, 1, 1))
78 | global_means[0, 0, 0, 0] = np.mean(all_data)
79 | global_stds[0, 0, 0, 0] = np.std(all_data)
80 | np.save(save_path + region + "_" + feature + "_global_means.npy", global_means)
81 | np.save(save_path + region + "_" + feature + "_global_stds.npy", global_stds)
82 |
--------------------------------------------------------------------------------
/configs/test.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - datamodule: weather.yaml
7 | - model: rcnn.yaml
8 | - callbacks: null
9 | - logger: comet.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
10 | - trainer: default.yaml
11 | - log_dir: evaluation.yaml
12 |
13 | # experiment configs allow for version control of specific configurations
14 | # e.g. best hyperparameters for each combination of model and datamodule
15 | - experiment: null
16 |
17 | # debugging config (enable through command line, e.g. `python train.py debug=default)
18 | - debug: null
19 |
20 | # config for hyperparameter optimization
21 | - hparams_search: null # optuna.yaml
22 |
23 | # optional local config for machine/user specific settings
24 | # it's optional since it doesn't need to exist and is excluded from version control
25 | - optional local: default.yaml
26 |
27 | # enable color logging
28 | - override hydra/hydra_logging: colorlog
29 | - override hydra/job_logging: colorlog
30 |
31 | # path to original working directory
32 | # hydra hijacks working directory by changing it to the new log directory
33 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
34 | original_work_dir: ${hydra:runtime.cwd}
35 |
36 | # path to folder with data
37 | data_dir: ${original_work_dir}/data/
38 | dataset_name: "pdsi_Belarus_with_neighb.csv"
39 | batch_size: 1
40 | num_epochs: 100
41 |
42 | # pretty print config at the start of the run using Rich library
43 | print_config: True
44 |
45 | # disable python warnings if they annoy you
46 | ignore_warnings: True
47 |
48 | # set False to skip model training
49 | train: True
50 |
51 | # evaluate on test set, using best model weights achieved during training
52 | # lightning chooses best weights based on the metric specified in checkpoint callback
53 | test: True
54 |
55 | # seed for random number generators in pytorch, numpy and python.random
56 | seed: null
57 |
58 | # parameters of the model that are shared with dataloader
59 | n_cells_hor: 66
60 | n_cells_ver: 123
61 | mode: "classification" # "classification" or "regression"
62 | num_classes: 2 # for classifier
63 | boundaries: [-2] # for classifier
64 | feature_to_predict: "pdsi"
65 | num_of_additional_features: 0
66 | additional_features: [] #["pr", "pet", "tmmn", "tmmx"]
67 |
68 | # default name for the experiment, determines logging folder path
69 | # (you can overwrite this name in experiment configs)
70 | history_length: 6
71 | periods_forward: 1
72 | name: GlobalBaseline_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward}
73 |
74 | # checkpoints
75 | ckpt_folder: logs/experiments/runs/recalc_newdimensions_${dataset_name}_${model._target_}_history_${history_length}_forward_${periods_forward}
76 | ckpt_path: null
77 |
--------------------------------------------------------------------------------
/src/training_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Optional
3 |
4 | import hydra
5 | from omegaconf import DictConfig
6 | from pytorch_lightning import (
7 | Callback,
8 | LightningDataModule,
9 | LightningModule,
10 | Trainer,
11 | seed_everything,
12 | )
13 | from pytorch_lightning.loggers import LightningLoggerBase
14 |
15 | from src import utils
16 |
17 | log = utils.get_logger(__name__)
18 |
19 |
20 | def train(config: DictConfig) -> Optional[float]:
21 | """Contains the training pipeline. Can additionally evaluate model on a testset, using best
22 | weights achieved during training.
23 |
24 | Args:
25 | config (DictConfig): Configuration composed by Hydra.
26 |
27 | Returns:
28 | Optional[float]: Metric score for hyperparameter optimization.
29 | """
30 |
31 | # Set seed for random number generators in pytorch, numpy and python.random
32 | if config.get("seed"):
33 | seed_everything(config.seed, workers=True)
34 |
35 | # Convert relative ckpt path to absolute path if necessary
36 | ckpt_path = config.trainer.get("resume_from_checkpoint")
37 | if ckpt_path and not os.path.isabs(ckpt_path):
38 | config.trainer.resume_from_checkpoint = os.path.join(
39 | hydra.utils.get_original_cwd(), ckpt_path
40 | )
41 |
42 |
43 | # Init lightning datamodule
44 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
45 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
46 | datamodule.setup()
47 | config.model.n_cells_hor = datamodule.h
48 | config.model.n_cells_ver = datamodule.w
49 |
50 | # Init lightning model
51 | log.info(f"Instantiating model <{config.model._target_}>")
52 | model: LightningModule = hydra.utils.instantiate(config.model)
53 |
54 | # Init lightning callbacks
55 | callbacks: List[Callback] = []
56 | if "callbacks" in config:
57 | for _, cb_conf in config.callbacks.items():
58 | if "_target_" in cb_conf:
59 | log.info(f"Instantiating callback <{cb_conf._target_}>")
60 | callbacks.append(hydra.utils.instantiate(cb_conf))
61 |
62 | # Init lightning loggers
63 | logger: List[LightningLoggerBase] = []
64 | if "logger" in config:
65 | for _, lg_conf in config.logger.items():
66 | if "_target_" in lg_conf:
67 | log.info(f"Instantiating logger <{lg_conf._target_}>")
68 | logger.append(hydra.utils.instantiate(lg_conf))
69 |
70 | # Init lightning trainer
71 | log.info(f"Instantiating trainer <{config.trainer._target_}>")
72 | trainer: Trainer = hydra.utils.instantiate(
73 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
74 | )
75 |
76 | # Send some parameters from config to all lightning loggers
77 | log.info("Logging hyperparameters!")
78 | utils.log_hyperparameters(
79 | config=config,
80 | model=model,
81 | datamodule=datamodule,
82 | trainer=trainer,
83 | callbacks=callbacks,
84 | logger=logger,
85 | )
86 |
87 | # Train the model
88 | if config.get("train"):
89 | log.info("Starting training!")
90 | trainer.fit(model=model, datamodule=datamodule)
91 |
92 | # Get metric score for hyperparameter optimization
93 | optimized_metric = config.get("optimized_metric")
94 | if optimized_metric and optimized_metric not in trainer.callback_metrics:
95 | raise Exception(
96 | "Metric for hyperparameter optimization not found! "
97 | "Make sure the `optimized_metric` in `hparams_search` config is correct!"
98 | )
99 | score = trainer.callback_metrics.get(optimized_metric)
100 |
101 | # Test the model
102 | if config.get("test"):
103 | ckpt_path = "best"
104 | if not config.get("train") or config.trainer.get("fast_dev_run"):
105 | ckpt_path = None
106 | log.info("Starting testing!")
107 | model.global_avg = datamodule.data_test.global_avg
108 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
109 |
110 | # Make sure everything closed properly
111 | log.info("Finalizing!")
112 | utils.finish(
113 | config=config,
114 | model=model,
115 | datamodule=datamodule,
116 | trainer=trainer,
117 | callbacks=callbacks,
118 | logger=logger,
119 | )
120 |
121 | # Print path to best checkpoint
122 | if not config.trainer.get("fast_dev_run") and config.get("train"):
123 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")
124 |
125 | # Return metric score for hyperparameter optimization
126 | return score
127 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | from typing import List, Sequence
4 |
5 | import pytorch_lightning as pl
6 | import rich.syntax
7 | import rich.tree
8 | from omegaconf import DictConfig, OmegaConf
9 | from pytorch_lightning.utilities import rank_zero_only
10 |
11 |
12 | def get_logger(name=__name__) -> logging.Logger:
13 | """Initializes multi-GPU-friendly python command line logger."""
14 |
15 | logger = logging.getLogger(name)
16 |
17 | # this ensures all logging levels get marked with the rank zero decorator
18 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
19 | for level in (
20 | "debug",
21 | "info",
22 | "warning",
23 | "error",
24 | "exception",
25 | "fatal",
26 | "critical",
27 | ):
28 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
29 |
30 | return logger
31 |
32 |
33 | log = get_logger(__name__)
34 |
35 |
36 | def extras(config: DictConfig) -> None:
37 | """Applies optional utilities, controlled by config flags.
38 |
39 | Utilities:
40 | - Ignoring python warnings
41 | - Rich config printing
42 | """
43 |
44 | # disable python warnings if
45 | if config.get("ignore_warnings"):
46 | log.info("Disabling python warnings! ")
47 | warnings.filterwarnings("ignore")
48 |
49 | # pretty print config tree using Rich library if
50 | if config.get("print_config"):
51 | log.info("Printing config tree with Rich! ")
52 | print_config(config, resolve=True)
53 |
54 |
55 | @rank_zero_only
56 | def print_config(
57 | config: DictConfig,
58 | print_order: Sequence[str] = (
59 | "datamodule",
60 | "model",
61 | "callbacks",
62 | "logger",
63 | "trainer",
64 | ),
65 | resolve: bool = True,
66 | ) -> None:
67 | """Prints content of DictConfig using Rich library and its tree structure.
68 |
69 | Args:
70 | config (DictConfig): Configuration composed by Hydra.
71 | print_order (Sequence[str], optional): Determines in what order config components are printed.
72 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
73 | """
74 |
75 | style = "dim"
76 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
77 |
78 | quee = []
79 |
80 | for field in print_order:
81 | quee.append(field) if field in config else log.info(
82 | f"Field '{field}' not found in config"
83 | )
84 |
85 | for field in config:
86 | if field not in quee:
87 | quee.append(field)
88 |
89 | for field in quee:
90 | branch = tree.add(field, style=style, guide_style=style)
91 |
92 | config_group = config[field]
93 | if isinstance(config_group, DictConfig):
94 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
95 | else:
96 | branch_content = str(config_group)
97 |
98 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
99 |
100 | rich.print(tree)
101 |
102 | with open("config_tree.log", "w") as file:
103 | rich.print(tree, file=file)
104 |
105 |
106 | @rank_zero_only
107 | def log_hyperparameters(
108 | config: DictConfig,
109 | model: pl.LightningModule,
110 | datamodule: pl.LightningDataModule,
111 | trainer: pl.Trainer,
112 | callbacks: List[pl.Callback],
113 | logger: List[pl.loggers.LightningLoggerBase],
114 | ) -> None:
115 | """Controls which config parts are saved by Lightning loggers.
116 |
117 | Additionaly saves:
118 | - number of model parameters
119 | """
120 |
121 | if not trainer.logger:
122 | return
123 |
124 | hparams = {}
125 |
126 | # choose which parts of hydra config will be saved to loggers
127 | hparams["model"] = config["model"]
128 |
129 | # save number of model parameters
130 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
131 | hparams["model/params/trainable"] = sum(
132 | p.numel() for p in model.parameters() if p.requires_grad
133 | )
134 | hparams["model/params/non_trainable"] = sum(
135 | p.numel() for p in model.parameters() if not p.requires_grad
136 | )
137 |
138 | hparams["datamodule"] = config["datamodule"]
139 | hparams["trainer"] = config["trainer"]
140 |
141 | if "seed" in config:
142 | hparams["seed"] = config["seed"]
143 | if "callbacks" in config:
144 | hparams["callbacks"] = config["callbacks"]
145 |
146 | # send hparams to all loggers
147 | trainer.logger.log_hyperparams(hparams)
148 |
149 |
150 | def finish(
151 | config: DictConfig,
152 | model: pl.LightningModule,
153 | datamodule: pl.LightningDataModule,
154 | trainer: pl.Trainer,
155 | callbacks: List[pl.Callback],
156 | logger: List[pl.loggers.LightningLoggerBase],
157 | ) -> None:
158 | """Makes sure everything closed properly."""
159 |
160 | # without this sweeps with wandb logger might crash!
161 | for lg in logger:
162 | if isinstance(lg, pl.loggers.wandb.WandbLogger):
163 | import wandb
164 |
165 | wandb.finish()
166 |
--------------------------------------------------------------------------------
/src/models/conv1d_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning import LightningModule
7 | from torch.autograd import Variable
8 |
9 | from src.models.components.conv_block import ConvBlock
10 | from src.utils.metrics import rmse, rsquared, smape
11 | from sklearn.metrics import r2_score
12 | from src.utils.plotting import make_heatmap, make_pred_vs_target_plot
13 |
14 | import seaborn as sns
15 | import matplotlib.pyplot as plt
16 |
17 |
18 |
19 | class Conv1dModule(LightningModule):
20 | """Example of LightningModule for MNIST classification.
21 |
22 | A LightningModule organizes your PyTorch code into 5 sections:
23 | - Computations (init).
24 | - Train loop (training_step)
25 | - Validation loop (validation_step)
26 | - Test loop (test_step)
27 | - Optimizers (configure_optimizers)
28 |
29 | Read the docs:
30 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
31 | """
32 |
33 | def __init__(
34 | self,
35 | n_cells_hor: int = 200,
36 | n_cells_ver: int = 250,
37 | history_length: int = 1,
38 | periods_forward: int = 1,
39 | batch_size: int = 1,
40 | lr: float = 0.003,
41 | weight_decay: float = 0.0,
42 | ):
43 | super(Conv1dModule, self).__init__()
44 |
45 | # this line allows to access init params with 'self.hparams' attribute
46 | # it also ensures init params will be stored in ckpt
47 | self.save_hyperparameters(logger=False)
48 |
49 | self.n_cells_hor = n_cells_hor
50 | self.n_cells_ver = n_cells_ver
51 | self.history_length = history_length
52 | self.periods_forward = periods_forward
53 | self.batch_size = batch_size
54 | self.lr = lr
55 | self.weight_decay = weight_decay
56 |
57 |
58 | self.conv1x1 = nn.Conv2d(
59 | self.history_length,
60 | self.periods_forward,
61 | kernel_size=1,
62 | stride=1,
63 | padding=0,
64 | bias=False,
65 | )
66 |
67 | # loss
68 | self.criterion = nn.MSELoss()
69 |
70 | def forward(self, x: torch.Tensor):
71 |
72 | prediction = self.conv1x1(x)
73 |
74 | return prediction
75 |
76 | def step(self, batch: Any):
77 | x, y = batch
78 | preds = self.forward(x)
79 | loss = self.criterion(preds, y)
80 | return loss, preds, y
81 |
82 | def training_step(self, batch: Any, batch_idx: int):
83 | loss, preds, targets = self.step(batch)
84 |
85 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
86 |
87 | # we can return here dict with any tensors
88 | # and then read it in some callback or in `training_epoch_end()`` below
89 | # remember to always return loss from `training_step()` or else backpropagation will fail!
90 | return {"loss": loss, "preds": preds, "targets": targets}
91 |
92 | def training_epoch_end(self, outputs: List[Any]):
93 | # `outputs` is a list of dicts returned from `training_step()`
94 | all_preds = outputs[0]["preds"]
95 | all_targets = outputs[0]["targets"]
96 |
97 | for i in range(1, len(outputs)):
98 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
99 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
100 |
101 | # log metrics
102 | r2table = rsquared(all_targets, all_preds, mode="mean")
103 | self.log("train/R2_std", np.std(r2table), on_epoch=True, prog_bar=True)
104 | self.log("train/R2", np.median(r2table), on_epoch=True, prog_bar=True)
105 | self.log("train/R2_min", np.min(r2table), on_epoch=True, prog_bar=True)
106 | self.log("train/R2_max", np.max(r2table), on_epoch=True, prog_bar=True)
107 | self.log("train/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True)
108 |
109 | def validation_step(self, batch: Any, batch_idx: int):
110 | loss, preds, targets = self.step(batch)
111 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
112 |
113 | return {"loss": loss, "preds": preds, "targets": targets}
114 |
115 | def validation_epoch_end(self, outputs: List[Any]):
116 | all_preds = outputs[0]["preds"]
117 | all_targets = outputs[0]["targets"]
118 |
119 | for i in range(1, len(outputs)):
120 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
121 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
122 |
123 | # log metrics
124 | r2table = rsquared(all_targets, all_preds, mode="mean")
125 | self.log("val/R2_std", np.std(r2table), on_epoch=True, prog_bar=True)
126 | self.log("val/R2", np.median(r2table), on_epoch=True, prog_bar=True)
127 | self.log("val/R2_min", np.min(r2table), on_epoch=True, prog_bar=True)
128 | self.log("val/R2_max", np.max(r2table), on_epoch=True, prog_bar=True)
129 | self.log("val/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True)
130 |
131 |
132 | def test_step(self, batch: Any, batch_idx: int):
133 | loss, preds, targets = self.step(batch)
134 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
135 |
136 | return {"loss": loss, "preds": preds, "targets": targets}
137 |
138 | def test_epoch_end(self, outputs: List[Any]):
139 | all_preds = outputs[0]["preds"]
140 | all_targets = outputs[0]["targets"]
141 |
142 | for i in range(1, len(outputs)):
143 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
144 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
145 |
146 | print(all_preds.shape)
147 | print(all_targets.shape)
148 |
149 | # log metrics
150 | test_r2table = rsquared(all_targets, all_preds, mode="full")
151 | self.log("test/R2_std", np.std(test_r2table), on_epoch=True, prog_bar=True)
152 | self.log(
153 | "test/R2_median", np.median(test_r2table), on_epoch=True, prog_bar=True
154 | )
155 | self.log("test/R2_min", np.min(test_r2table), on_epoch=True, prog_bar=True)
156 | self.log("test/R2_max", np.max(test_r2table), on_epoch=True, prog_bar=True)
157 | self.log("test/MSE", rmse(all_targets, all_preds), on_epoch=True, prog_bar=True)
158 |
159 | # log graphs
160 | mse_conv = []
161 | r2_conv = []
162 |
163 | for i in range(1, self.periods_forward+1):
164 |
165 | preds_i = all_preds[:, :i, :, :]
166 | targets_i = all_targets[:, :i, :, :]
167 | mse_conv.append(rmse(preds_i, targets_i).item())
168 | mean_preds = torch.mean(preds_i, axis=(2, 3))
169 | mean_targets = torch.mean(targets_i, axis=(2, 3))
170 | r2_conv.append(r2_score(mean_targets.cpu().numpy(), mean_preds.cpu().numpy()))
171 |
172 | h = [i for i in range(1, self.periods_forward+1)]
173 | fig1 = plt.figure(figsize=(7, 7))
174 | ax = fig1.add_subplot(1, 1, 1)
175 | sns.lineplot(x=h, y=mse_conv, ax = ax)
176 | ax.legend(['conv1d'])
177 | ax.set_xlabel("horizon (in months)")
178 | ax.set_title("MSE")
179 | fig1.savefig("conv_mse.png")
180 |
181 | fig2 = plt.figure(figsize=(7, 7))
182 | ax = fig2.add_subplot(1, 1, 1)
183 | sns.lineplot(x=h, y=r2_conv, ax = ax)
184 | ax.legend(['conv1d'])
185 | ax.set_xlabel("horizon (in months)")
186 | ax.set_title("R2")
187 | fig2.savefig("r2_mse.png")
188 |
189 |
190 | def on_epoch_end(self):
191 | pass
192 |
193 | def configure_optimizers(self):
194 | """Choose what optimizers and learning-rate schedulers to use in your optimization.
195 | Normally you'd need one. But in the case of GANs or similar you might have multiple.
196 |
197 | See examples here:
198 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
199 | """
200 | return torch.optim.Adam(
201 | params=self.parameters(),
202 | lr=self.hparams.lr,
203 | weight_decay=self.hparams.weight_decay,
204 | )
205 |
--------------------------------------------------------------------------------
/src/datamodules/weather_datamodule.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from typing import Optional, Tuple, List
3 |
4 | import torch
5 | from pytorch_lightning import LightningDataModule
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 | from src.utils.data_utils import create_celled_data
9 | from src import utils
10 |
11 | log = utils.get_logger(__name__)
12 |
13 |
14 | class Dataset_RNN(Dataset):
15 | """
16 | Simple Torch Dataset for many-to-many RNN
17 | celled_data: source of data,
18 | start_date: start date index,
19 | end_date: end date index,
20 | periods_forward: number of future periods for a target,
21 | history_length: number of past periods for an input,
22 | transforms: input data manipulations
23 | """
24 |
25 | def __init__(
26 | self,
27 | celled_data: torch.Tensor,
28 | celled_features_list: List[torch.Tensor],
29 | start_date: int,
30 | end_date: int,
31 | periods_forward: int,
32 | history_length: int,
33 | boundaries: Optional[List[None]],
34 | mode,
35 | normalize: bool,
36 | moments: Optional[List[None]],
37 | global_avg: Optional[List[None]],
38 | ):
39 | self.data = celled_data[start_date:end_date, :, :]
40 | self.features = [
41 | feature[start_date:end_date, :, :]
42 | for feature in celled_features_list
43 | ]
44 | self.periods_forward = periods_forward
45 | self.history_length = history_length
46 | self.mode = mode
47 | self.target = self.data
48 | # bins for pdsi
49 | self.boundaries = boundaries
50 | if self.mode == "classification":
51 | # 1 is for drought
52 | self.target = 1 - torch.bucketize(self.target, self.boundaries)
53 | if len(global_avg) > 0:
54 | self.global_avg = global_avg
55 | else:
56 | self.global_avg, _ = torch.mode(self.target, dim=0)
57 | # normalization
58 | if moments:
59 | self.moments = moments
60 | if normalize:
61 | self.data = (self.data - self.moments[0][0]) / (
62 | self.moments[0][1] - self.moments[0][0]
63 | )
64 | for i in range(1, len(self.moments)):
65 | self.features[i - 1] = (
66 | self.features[i - 1] - self.moments[i][0]
67 | ) / (self.moments[i][1] - self.moments[i][0])
68 | else:
69 | self.moments = []
70 | if normalize:
71 | self.data = (self.data - torch.min(self.data)) / (
72 | torch.max(self.data) - torch.min(self.data)
73 | )
74 | self.moments.append((torch.min(self.data), torch.max(self.data)))
75 | for i in range(len(self.features)):
76 | self.features[i] = (
77 | self.features[i] - torch.min(self.features[i])
78 | ) / (torch.max(self.features[i]) - torch.min(self.features[i]))
79 | self.moments.append(
80 | (torch.min(self.features[i]), torch.max(self.features[i]))
81 | )
82 |
83 | def __len__(self):
84 | return len(self.data) - self.periods_forward - self.history_length
85 |
86 | def __getitem__(self, idx):
87 | input_tensor = self.data[idx : idx + self.history_length]
88 | for feature in self.features:
89 | input_tensor = torch.cat(
90 | (input_tensor, feature[idx : idx + self.history_length]), dim=0
91 | )
92 |
93 | target = self.target[
94 | idx + self.history_length : idx + self.history_length + self.periods_forward
95 | ]
96 |
97 | return (
98 | input_tensor,
99 | target,
100 | )
101 |
102 |
103 | class WeatherDataModule(LightningDataModule):
104 | """LightningDataModule for Weather dataset.
105 |
106 | A DataModule implements 5 key methods:
107 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
108 | - setup (things to do on every accelerator in distributed mode)
109 | - train_dataloader (the training dataloader)
110 | - val_dataloader (the validation dataloader(s))
111 | - test_dataloader (the test dataloader(s))
112 |
113 | This allows you to share a full dataset without explaining how to download,
114 | split, transform and process the data.
115 |
116 | Read the docs:
117 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
118 | """
119 |
120 | def __init__(
121 | self,
122 | mode: str = "regression",
123 | data_dir: str = "data",
124 | dataset_name: str = "dataset_name",
125 | left_border: int = 0,
126 | down_border: int = 0,
127 | right_border: int = 2000,
128 | up_border: int = 2500,
129 | time_col: str = "time",
130 | event_col: str = "value",
131 | x_col: str = "x",
132 | y_col: str = "y",
133 | train_val_test_split: Tuple[float] = (0.8, 0.1, 0.1),
134 | periods_forward: int = 1,
135 | history_length: int = 1,
136 | data_start: int = 0,
137 | data_len: int = 100,
138 | feature_to_predict: str = "pdsi",
139 | num_of_additional_features: int = 0,
140 | additional_features: Optional[List[str]] = None,
141 | boundaries: Optional[List[str]] = None,
142 | patch_size: int = 8,
143 | normalize: bool = False,
144 | batch_size: int = 64,
145 | num_workers: int = 0,
146 | pin_memory: bool = False,
147 | ):
148 | super().__init__()
149 |
150 | # this line allows to access init params with 'self.hparams' attribute
151 | self.save_hyperparameters(logger=False)
152 | self.mode = mode
153 | self.data_dir = data_dir
154 | self.dataset_name = dataset_name
155 | self.left_border = left_border
156 | self.right_border = right_border
157 | self.down_border = down_border
158 | self.up_border = up_border
159 | self.h = 0
160 | self.w = 0
161 | self.time_col = time_col
162 | self.event_col = event_col
163 | self.x_col = x_col
164 | self.y_col = y_col
165 |
166 | self.data_train: Optional[Dataset] = None
167 | self.data_val: Optional[Dataset] = None
168 | self.data_test: Optional[Dataset] = None
169 | self.train_val_test_split = train_val_test_split
170 | self.periods_forward = periods_forward
171 | self.history_length = history_length
172 | self.data_start = data_start
173 | self.data_len = data_len
174 | self.feature_to_predict = feature_to_predict
175 | self.num_of_features = num_of_additional_features + 1
176 | self.additional_features = additional_features
177 | self.boundaries = torch.Tensor(boundaries)
178 | self.patch_size = patch_size
179 | self.normalize = normalize
180 |
181 | self.batch_size = batch_size
182 | self.num_workers = num_workers
183 | self.pin_memory = pin_memory
184 |
185 | def prepare_data(self):
186 | """Download data if needed.
187 |
188 | This method is called only from a single GPU.
189 | Do not use it to assign state (self.x = y).
190 | """
191 | celled_data_path = pathlib.Path(self.data_dir, "celled", self.dataset_name)
192 | if not celled_data_path.is_file():
193 | celled_data = create_celled_data(
194 | self.data_dir,
195 | self.dataset_name,
196 | self.time_col,
197 | self.event_col,
198 | self.x_col,
199 | self.y_col,
200 | )
201 | log.info(f"Original dataset shape: {celled_data.shape}")
202 | torch.save(celled_data, celled_data_path)
203 |
204 | data_dir_geo = self.dataset_name.split(self.feature_to_predict)[1]
205 | for feature in self.additional_features:
206 | celled_feature_path = pathlib.Path(
207 | self.data_dir, "celled", feature + data_dir_geo
208 | )
209 | if not celled_feature_path.is_file():
210 | celled_feature = create_celled_data(
211 | self.data_dir,
212 | feature + data_dir_geo,
213 | self.time_col,
214 | self.event_col,
215 | self.x_col,
216 | self.y_col,
217 | )
218 | torch.save(celled_feature, celled_feature_path)
219 |
220 | def setup(self, stage: Optional[str] = None):
221 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
222 |
223 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
224 | so be careful not to execute the random split twice! The `stage` can be used to
225 | differentiate whether it's called before trainer.fit()` or `trainer.test()`.
226 | """
227 |
228 | # load datasets only if they're not loaded already
229 | if not self.data_train and not self.data_val and not self.data_test:
230 | celled_data_path = pathlib.Path(self.data_dir, "celled", self.dataset_name)
231 | celled_data = torch.load(celled_data_path)
232 | # make borders divisible by patch size
233 | h = celled_data.shape[1]
234 | self.h = h - h % self.patch_size
235 | w = celled_data.shape[2]
236 | self.w = w - w % self.patch_size
237 | celled_data = celled_data[:, :self.h, :self.w]
238 | celled_data = celled_data[
239 | self.data_start : self.data_start + self.data_len,
240 | self.down_border : self.up_border,
241 | self.left_border : self.right_border,
242 | ]
243 | # loading features
244 | celled_features_list = []
245 | data_dir_geo = self.dataset_name.split(self.feature_to_predict)[1]
246 | for feature in self.additional_features:
247 | celled_feature_path = pathlib.Path(
248 | self.data_dir, "celled", feature + data_dir_geo
249 | )
250 | celled_feature = torch.load(celled_feature_path)
251 | # make borders divisible by patch size
252 | h = celled_feature.shape[1]
253 | h = h - h % self.patch_size
254 | w = celled_feature.shape[2]
255 | w = w - w % self.patch_size
256 | celled_feature = celled_feature[:, :h, :w]
257 | celled_feature = celled_feature[
258 | self.data_start : self.data_start + self.data_len,
259 | self.down_border : self.up_border,
260 | self.left_border : self.right_border,
261 | ]
262 | celled_features_list.append(celled_feature)
263 |
264 | train_start = 0
265 | train_end = int(self.train_val_test_split[0] * celled_data.shape[0])
266 | self.data_train = Dataset_RNN(
267 | celled_data,
268 | celled_features_list,
269 | train_start,
270 | train_end,
271 | self.periods_forward,
272 | self.history_length,
273 | self.boundaries,
274 | self.mode,
275 | self.normalize,
276 | [],
277 | [],
278 | )
279 | # valid_end = int(
280 | # (self.train_val_test_split[0] + self.train_val_test_split[1])
281 | # * celled_data.shape[0]
282 | # )
283 | valid_end = celled_data.shape[0]
284 | self.data_val = Dataset_RNN(
285 | celled_data,
286 | celled_features_list,
287 | train_end - self.history_length,
288 | valid_end,
289 | self.periods_forward,
290 | self.history_length,
291 | self.boundaries,
292 | self.mode,
293 | self.normalize,
294 | self.data_train.moments,
295 | self.data_train.global_avg,
296 | )
297 | test_end = celled_data.shape[0]
298 | self.data_test = Dataset_RNN(
299 | celled_data,
300 | celled_features_list,
301 | train_end - self.history_length,
302 | test_end,
303 | self.periods_forward,
304 | self.history_length,
305 | self.boundaries,
306 | self.mode,
307 | self.normalize,
308 | self.data_train.moments,
309 | self.data_train.global_avg,
310 | )
311 | log.info(f"train dataset shape {self.data_train.data.shape}")
312 | log.info(f"val dataset shape {self.data_val.data.shape}")
313 | log.info(f"test dataset shape {self.data_test.data.shape}")
314 |
315 | def train_dataloader(self):
316 | return DataLoader(
317 | self.data_train,
318 | batch_size=self.batch_size,
319 | num_workers=self.num_workers,
320 | pin_memory=self.pin_memory,
321 | shuffle=False,
322 | )
323 |
324 | def val_dataloader(self):
325 | return DataLoader(
326 | self.data_val,
327 | batch_size=self.batch_size,
328 | num_workers=self.num_workers,
329 | pin_memory=self.pin_memory,
330 | shuffle=False,
331 | )
332 |
333 | def test_dataloader(self):
334 | return DataLoader(
335 | self.data_test,
336 | batch_size=self.batch_size,
337 | num_workers=self.num_workers,
338 | pin_memory=self.pin_memory,
339 | shuffle=False,
340 | )
341 |
--------------------------------------------------------------------------------
/src/models/rcnn_module.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Any, List
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from pytorch_lightning import LightningModule
8 | from sklearn.metrics import r2_score, roc_auc_score
9 | from torch.autograd import Variable
10 |
11 | from src.models.components.conv_block import ConvBlock
12 | from src.utils.metrics import metrics_celled
13 | from src.utils.plotting import make_heatmap, make_cf_matrix
14 |
15 |
16 | class ScaledTanh(nn.Module):
17 | def __init__(self, coef: int = 10):
18 | super().__init__()
19 | self.c = coef
20 |
21 | def forward(self, x):
22 | output = torch.mul(torch.tanh(x), self.c)
23 | return output
24 |
25 |
26 | class RCNNModule(LightningModule):
27 | """Example of LightningModule for MNIST classification.
28 |
29 | A LightningModule organizes your PyTorch code into 5 sections:
30 | - Computations (init).
31 | - Train loop (training_step)
32 | - Validation loop (validation_step)
33 | - Test loop (test_step)
34 | - Optimizers (configure_optimizers)
35 |
36 | Read the docs:
37 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
38 | """
39 |
40 | def __init__(
41 | self,
42 | mode: str = "regression",
43 | embedding_size: int = 16,
44 | hidden_state_size: int = 32,
45 | kernel_size: int = 3,
46 | groups: int = 1,
47 | dilation: int = 1,
48 | n_cells_hor: int = 200,
49 | n_cells_ver: int = 250,
50 | history_length: int = 1,
51 | periods_forward: int = 1,
52 | batch_size: int = 1,
53 | num_of_additional_features: int = 0,
54 | boundaries: List[int] = [-2],
55 | num_classes: int = 2,
56 | values_range: int = 10,
57 | dropout: float = 0.0,
58 | lr: float = 0.003,
59 | weight_decay: float = 0.0,
60 | ):
61 | super(self.__class__, self).__init__()
62 |
63 | # this line allows to access init params with 'self.hparams' attribute
64 | # it also ensures init params will be stored in ckpt
65 | self.save_hyperparameters(logger=False)
66 |
67 | self.n_cells_hor = n_cells_hor
68 | self.n_cells_ver = n_cells_ver
69 | self.history_length = history_length
70 | self.periods_forward = periods_forward
71 | self.batch_size = batch_size
72 | self.lr = lr
73 | self.weight_decay = weight_decay
74 |
75 | self.num_of_features = num_of_additional_features + 1
76 | self.tanh_coef = values_range
77 | # number of bins for pdsi
78 | self.dropout = torch.nn.Dropout2d(p=dropout)
79 | self.num_class = num_classes
80 | self.boundaries = torch.tensor(boundaries).cuda()
81 |
82 | self.emb_size = embedding_size
83 | self.hid_size = hidden_state_size
84 | self.kernel_size = kernel_size
85 | self.dilation = dilation
86 | self.groups = groups
87 |
88 | self.global_avg = None
89 | self.saved_predictions = None
90 | self.saved_targets = None
91 |
92 | self.embedding = nn.Sequential(
93 | ConvBlock(
94 | self.num_of_features * history_length,
95 | self.emb_size,
96 | self.kernel_size,
97 | stride=1,
98 | padding=self.kernel_size // 2,
99 | dilation=self.dilation,
100 | groups=self.groups,
101 | ),
102 | nn.ReLU(),
103 | ConvBlock(
104 | self.emb_size,
105 | self.emb_size,
106 | self.kernel_size,
107 | stride=1,
108 | padding=self.kernel_size // 2,
109 | ),
110 | )
111 |
112 | self.f_t = nn.Sequential(
113 | ConvBlock(
114 | self.hid_size + self.emb_size,
115 | self.hid_size,
116 | self.kernel_size,
117 | stride=1,
118 | padding=self.kernel_size // 2,
119 | ),
120 | nn.Sigmoid(),
121 | )
122 | self.i_t = nn.Sequential(
123 | ConvBlock(
124 | self.hid_size + self.emb_size,
125 | self.hid_size,
126 | self.kernel_size,
127 | stride=1,
128 | padding=self.kernel_size // 2,
129 | ),
130 | nn.Sigmoid(),
131 | )
132 | self.c_t = nn.Sequential(
133 | ConvBlock(
134 | self.hid_size + self.emb_size,
135 | self.hid_size,
136 | self.kernel_size,
137 | stride=1,
138 | padding=self.kernel_size // 2,
139 | ),
140 | nn.Tanh(),
141 | )
142 | self.o_t = nn.Sequential(
143 | ConvBlock(
144 | self.hid_size + self.emb_size,
145 | self.hid_size,
146 | self.kernel_size,
147 | stride=1,
148 | padding=self.kernel_size // 2,
149 | ),
150 | nn.Sigmoid(),
151 | )
152 |
153 | self.final_conv = nn.Sequential(
154 | nn.Conv2d(
155 | self.hid_size,
156 | self.periods_forward,
157 | kernel_size=1,
158 | stride=1,
159 | padding=0,
160 | bias=False,
161 | ),
162 | ScaledTanh(self.tanh_coef),
163 | # nn.Tanh(),
164 | # nn.Conv2d(
165 | # self.hid_size,
166 | # self.periods_forward,
167 | # kernel_size=1,
168 | # stride=1,
169 | # padding=0,
170 | # bias=False,
171 | # ),
172 | )
173 |
174 | self.final_classify = nn.Sequential(
175 | ConvBlock(
176 | self.hid_size,
177 | self.num_class,
178 | kernel_size=3,
179 | stride=1,
180 | padding=1,
181 | dilation=1,
182 | groups=1,
183 | ),
184 | nn.Sigmoid(),
185 | )
186 |
187 | self.register_buffer(
188 | "prev_state_h",
189 | torch.zeros(
190 | self.batch_size,
191 | self.hid_size,
192 | self.n_cells_hor,
193 | self.n_cells_ver,
194 | requires_grad=False,
195 | ),
196 | )
197 | self.register_buffer(
198 | "prev_state_c",
199 | torch.zeros(
200 | self.batch_size,
201 | self.hid_size,
202 | self.n_cells_hor,
203 | self.n_cells_ver,
204 | requires_grad=False,
205 | ),
206 | )
207 |
208 | self.mode = mode
209 | # loss
210 | if self.mode == "regression":
211 | self.criterion = nn.MSELoss()
212 | self.loss_name = "MSE"
213 | else:
214 | self.criterion = nn.CrossEntropyLoss()
215 | self.loss_name = "CrossEntropy"
216 |
217 | def forward(self, x: torch.Tensor):
218 | prev_c = self.prev_state_c
219 | prev_h = self.prev_state_h
220 | x = self.dropout(x)
221 | x_emb = self.embedding(x)
222 | if x_emb.shape[0] < self.batch_size:
223 | x_emb = torch.nn.functional.pad(
224 | x_emb, pad=(0,0,0,0,0,0,0,self.batch_size - x_emb.shape[0]), value=0
225 | )
226 | x_and_h = torch.cat([prev_h, x_emb], dim=1)
227 |
228 | f_i = self.f_t(x_and_h)
229 | i_i = self.i_t(x_and_h)
230 | c_i = self.c_t(x_and_h)
231 | o_i = self.o_t(x_and_h)
232 |
233 | # print("prev_c", prev_c.shape)
234 | # print("f_i", f_i.shape)
235 | # print("i_i", i_i.shape)
236 | # print("c_i", c_i.shape)
237 |
238 | next_c = prev_c * f_i + i_i * c_i
239 | next_h = torch.tanh(next_c) * o_i
240 |
241 | assert prev_h.shape == next_h.shape
242 | assert prev_c.shape == next_c.shape
243 |
244 | if self.mode == "regression":
245 | prediction = self.final_conv(next_h)
246 | elif self.mode == "classification":
247 | prediction = self.final_classify(next_h)
248 | self.prev_state_c = next_c
249 | self.prev_state_h = next_h
250 |
251 | return prediction
252 |
253 | def step(self, batch: Any):
254 | x, y = batch
255 | preds = self.forward(x)
256 | if y.shape[0] < self.batch_size:
257 | y = torch.nn.functional.pad(
258 | y, pad=(0,0,0,0,0,0,0,self.batch_size - y.shape[0]), value=0
259 | )
260 | # checking last (forward) value of target
261 | loss = self.criterion(preds, y[:, -1, :, :])
262 | return loss, preds, y[:, -1, :, :]
263 |
264 | def rolling_step(self, batch: Any):
265 | x, y = batch
266 | # x -> B*Hist*W*H or B*(Hist*Feat)*W*H
267 | # pdsi is first feature in tensor
268 | x = x[:, : self.history_length, :, :]
269 | rolling = torch.mean(x, dim=1)
270 | rolling_forecast = rolling[:, None, :, :]
271 |
272 | for i in range(1, self.periods_forward):
273 | x = torch.cat((x[:, 1:, :, :], rolling[:, None, :, :]), dim=1)
274 | rolling = torch.mean(x, dim=1)
275 | rolling_forecast = torch.cat(
276 | (rolling_forecast, rolling[:, None, :, :]), dim=1
277 | )
278 |
279 | return rolling_forecast
280 |
281 | def class_baseline(self, batch: Any):
282 | x, y = batch
283 | # return most frequent class along history_dim
284 | x_binned = torch.bucketize(x, self.boundaries)
285 | most_freq_values, most_freq_indices = torch.mode(x_binned, dim=1)
286 | return most_freq_values
287 |
288 | def on_after_backward(self) -> None:
289 | self.prev_state_c.detach_()
290 | self.prev_state_h.detach_()
291 |
292 | def training_step(self, batch: Any, batch_idx: int):
293 | loss, preds, targets = self.step(batch)
294 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
295 |
296 | # we can return here dict with any tensors
297 | # and then read it in some callback or in `training_epoch_end()`` below
298 | # remember to always return loss from `training_step()` or else backpropagation will fail!
299 | return {"loss": loss, "preds": preds, "targets": targets}
300 |
301 | def training_epoch_end(self, outputs: List[Any]):
302 | # `outputs` is a list of dicts returned from `training_step()`
303 | all_targets = outputs[0]["targets"]
304 | all_preds = outputs[0]["preds"]
305 |
306 | for i in range(1, len(outputs)):
307 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
308 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
309 | all_preds = torch.softmax(all_preds, dim=1)
310 | all_preds = all_preds[:, 1, :, :]
311 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled(all_targets, all_preds)
312 | # log metrics
313 | if self.mode == "classification":
314 | self.log(
315 | "train/f1_median",
316 | torch.median(f1_table),
317 | on_epoch=True,
318 | prog_bar=True,
319 | )
320 | self.log(
321 | "train/ap_median",
322 | torch.median(ap_table),
323 | on_epoch=True,
324 | prog_bar=True,
325 | )
326 | self.log(
327 | "train/rocauc_median",
328 | torch.median(rocauc_table),
329 | on_epoch=True,
330 | prog_bar=True,
331 | )
332 | self.log(
333 | "train/accuracy_median",
334 | torch.median(acc_table),
335 | on_epoch=True,
336 | prog_bar=True,
337 | )
338 |
339 | # log metrics
340 | # r2table = rsquared(all_targets, all_preds, mode="mean")
341 | # self.log("train/R2_std", np.std(r2table), on_epoch=True, prog_bar=True)
342 | # self.log("train/R2", np.median(r2table), on_epoch=True, prog_bar=True)
343 | # self.log("train/R2_min", np.min(r2table), on_epoch=True, prog_bar=True)
344 | # self.log("train/R2_max", np.max(r2table), on_epoch=True, prog_bar=True)
345 |
346 | def validation_step(self, batch: Any, batch_idx: int):
347 | loss, preds, targets = self.step(batch)
348 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
349 |
350 | return {"loss": loss, "preds": preds, "targets": targets}
351 |
352 | def validation_epoch_end(self, outputs: List[Any]):
353 | all_targets = outputs[0]["targets"]
354 | all_preds = outputs[0]["preds"]
355 |
356 | for i in range(1, len(outputs)):
357 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
358 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
359 |
360 | all_preds = torch.softmax(all_preds, dim=1)
361 | all_preds = all_preds[:, 1, :, :]
362 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled(all_targets, all_preds)
363 | # log metrics
364 | if self.mode == "classification":
365 | self.log(
366 | "val/f1_median",
367 | torch.median(f1_table),
368 | on_epoch=True,
369 | prog_bar=True,
370 | )
371 | self.log(
372 | "val/ap_median",
373 | torch.median(ap_table),
374 | on_epoch=True,
375 | prog_bar=True,
376 | )
377 | self.log(
378 | "val/rocauc_median",
379 | torch.median(rocauc_table),
380 | on_epoch=True,
381 | prog_bar=True,
382 | )
383 | self.log(
384 | "val/accuracy_median",
385 | torch.median(acc_table),
386 | on_epoch=True,
387 | prog_bar=True,
388 | )
389 |
390 |
391 | # log metrics
392 | # r2table = rsquared(all_targets, all_preds, mode="mean")
393 | # self.log("val/R2_std", np.std(r2table), on_epoch=True, prog_bar=True)
394 | # self.log("val/R2", np.median(r2table), on_epoch=True, prog_bar=True)
395 | # self.log("val/R2_min", np.min(r2table), on_epoch=True, prog_bar=True)
396 | # self.log("val/R2_max", np.max(r2table), on_epoch=True, prog_bar=True)
397 |
398 | def test_step(self, batch: Any, batch_idx: int):
399 | loss, preds, targets = self.step(batch)
400 | if self.mode == "regression":
401 | baseline = self.rolling_step(batch)
402 | elif self.mode == "classification":
403 | baseline = self.class_baseline(batch)
404 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
405 |
406 | return {"loss": loss, "preds": preds, "targets": targets, "baseline": baseline}
407 |
408 | def test_epoch_end(self, outputs: List[Any]):
409 | all_targets = outputs[0]["targets"]
410 | all_baselines = outputs[0]["baseline"]
411 | all_preds = outputs[0]["preds"]
412 |
413 | for i in range(1, len(outputs)):
414 | all_preds = torch.cat((all_preds, outputs[i]["preds"]), 0)
415 | all_targets = torch.cat((all_targets, outputs[i]["targets"]), 0)
416 | all_baselines = torch.cat((all_baselines, outputs[i]["baseline"]), 0)
417 |
418 | # remove padded values
419 | init_len = all_baselines.shape[0]
420 | all_preds = all_preds[:init_len,:,:,:]
421 | all_targets = all_targets[:init_len,:,:]
422 |
423 | all_preds = torch.softmax(all_preds, dim=1)
424 | # log confusion matrix
425 | preds_for_cm = torch.argmax(all_preds, dim=1)
426 | # self.logger.experiment[0].log_confusion_matrix(torch.flatten(all_targets), torch.flatten(preds_for_cm))
427 | # probability of first class
428 | all_preds = all_preds[:, 1, :, :]
429 |
430 | self.saved_predictions = all_preds
431 | self.saved_targets = all_targets
432 |
433 | # global baseline
434 | all_global_baselines = self.global_avg.to("cuda:0")
435 | all_global_baselines = all_global_baselines.unsqueeze(0).repeat(
436 | len(all_targets), 1, 1
437 | )
438 | all_global_baselines = torch.where(all_global_baselines > 0,100.0, 0.0)
439 | all_global_baselines = all_global_baselines.double()
440 | # all zeros baseline - no drought
441 | all_zeros = torch.zeros(
442 | all_preds.shape[0], all_preds.shape[1], all_preds.shape[2], dtype=torch.long,
443 | ).to("cuda:0")
444 | all_zeros = all_zeros.double()
445 | rocauc_table_zeros, ap_table_zeros, f1_table_zeros, acc_table_zeros, thr = metrics_celled(
446 | all_targets, all_zeros, "test"
447 | )
448 | rocauc_table_global, ap_table_global, f1_table_global, acc_table_global, thr = metrics_celled(
449 | all_targets, all_global_baselines, "test"
450 | )
451 | rocauc_table, ap_table, f1_table, acc_table, thr = metrics_celled(
452 | all_targets, all_preds, "test"
453 | )
454 | # log metrics
455 | if self.mode == "classification":
456 | self.log(
457 | "test/convlstm/f1_median",
458 | torch.median(f1_table),
459 | on_epoch=True,
460 | prog_bar=True,
461 | )
462 | self.log(
463 | "test/convlstm/ap_median",
464 | torch.median(ap_table),
465 | on_epoch=True,
466 | prog_bar=True,
467 | )
468 | self.log(
469 | "test/convlstm/rocauc_median",
470 | torch.median(rocauc_table),
471 | on_epoch=True,
472 | prog_bar=True,
473 | )
474 | self.log(
475 | "test/convlstm/accuracy_median",
476 | torch.median(acc_table),
477 | on_epoch=True,
478 | prog_bar=True,
479 | )
480 | self.log(
481 | "test/global/f1_median",
482 | torch.median(f1_table_global),
483 | on_epoch=True,
484 | prog_bar=True,
485 | )
486 | self.log(
487 | "test/global/ap_median",
488 | torch.median(ap_table_global),
489 | on_epoch=True,
490 | prog_bar=True,
491 | )
492 | self.log(
493 | "test/global/rocauc_median",
494 | torch.median(rocauc_table_global),
495 | on_epoch=True,
496 | prog_bar=True,
497 | )
498 | self.log(
499 | "test/global/accuracy_median",
500 | torch.median(acc_table_global),
501 | on_epoch=True,
502 | prog_bar=True,
503 | )
504 | self.log(
505 | "test/zeros/f1_median",
506 | torch.median(f1_table_zeros),
507 | on_epoch=True,
508 | prog_bar=True,
509 | )
510 | self.log(
511 | "test/zeros/ap_median",
512 | torch.median(ap_table_zeros),
513 | on_epoch=True,
514 | prog_bar=True,
515 | )
516 | self.log(
517 | "test/zeros/rocauc_median",
518 | torch.median(rocauc_table_zeros),
519 | on_epoch=True,
520 | prog_bar=True,
521 | )
522 | self.log(
523 | "test/zeros/accuracy_median",
524 | torch.median(acc_table_zeros),
525 | on_epoch=True,
526 | prog_bar=True,
527 | )
528 |
529 | rocauc_path = make_heatmap(rocauc_table, filename="rocauc_spatial.png")
530 | torch.save(rocauc_table, "rocauc_table.pt")
531 | self.logger.experiment[0].log_image(rocauc_path)
532 |
533 | # log metrics
534 | # test_r2table = rsquared(all_targets, all_preds, mode="full")
535 | # self.log("test/R2_std", np.std(test_r2table), on_epoch=True, prog_bar=True)
536 | # self.log(
537 | # "test/R2_median", np.median(test_r2table), on_epoch=True, prog_bar=True
538 | # )
539 | # self.log("test/R2_min", np.min(test_r2table), on_epoch=True, prog_bar=True)
540 | # self.log("test/R2_max", np.max(test_r2table), on_epoch=True, prog_bar=True)
541 | if self.mode == "regression":
542 | self.log(
543 | "test/baseline_MSE",
544 | self.criterion(all_baselines, all_targets),
545 | on_epoch=True,
546 | prog_bar=True,
547 | )
548 |
549 | def on_epoch_end(self):
550 | pass
551 |
552 | def configure_optimizers(self):
553 | """Choose what optimizers and learning-rate schedulers to use in your optimization.
554 | Normally you'd need one. But in the case of GANs or similar you might have multiple.
555 |
556 | See examples here:
557 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
558 | """
559 | return torch.optim.Adam(
560 | params=self.parameters(),
561 | lr=self.hparams.lr,
562 | weight_decay=self.hparams.weight_decay,
563 | )
564 |
--------------------------------------------------------------------------------