├── models ├── __init__.py ├── .DS_Store ├── unet_parts.py ├── discriminator.py ├── unet_parts_depthwise_separable.py ├── layers.py ├── regression_SmaAt_UNet.py ├── regression_SmaAt_GNet.py ├── regression_SmaAt_GNet_aleatoric.py ├── unet_precip_regression_lightning.py ├── regression_GA_SmaAt_GNet_mnist.py └── regression_GA_SmaAt_GNet.py ├── utils ├── __init__.py ├── .DS_Store └── dataset_precip.py ├── .DS_Store ├── root.py ├── imgs ├── .DS_Store ├── SmaAt-GNet.png ├── data_prep.png ├── GA-SmaAt-GNet.png └── Attention-PatchGAN.png ├── .gitignore ├── pyproject.toml ├── train_moving_mnist.py ├── train_precip.py ├── README.md ├── test_precip.py ├── grad-cam.py └── requirements.txt /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/.DS_Store -------------------------------------------------------------------------------- /root.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT_DIR = Path(__file__).parent 4 | -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/imgs/.DS_Store -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /imgs/SmaAt-GNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/imgs/SmaAt-GNet.png -------------------------------------------------------------------------------- /imgs/data_prep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/imgs/data_prep.png -------------------------------------------------------------------------------- /imgs/GA-SmaAt-GNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/imgs/GA-SmaAt-GNet.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | data 4 | runs 5 | lightning 6 | checkpoints 7 | .ruff_cache 8 | .mypy_cache 9 | -------------------------------------------------------------------------------- /imgs/Attention-PatchGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EloyReulen/GA-SmaAt-GNet/HEAD/imgs/Attention-PatchGAN.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "smaat-unet" 3 | version = "0.1.0" 4 | description = "Code for the paper `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`" 5 | authors = ["Kevin Trebing "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">3.9.7,<4.0" 10 | tqdm = "^4.65.0" 11 | torch = { version = "^2.0.0+cu118" } # Sadly poetry still installs the CPU version 12 | torchvision = "^0.15.1" 13 | torchsummary = "^1.5.1" 14 | h5py = "^3.8.0" 15 | fastapi = ">=0.80" 16 | lightning = {extras = ["extra"], version = "^2.0.1.post0"} 17 | tensorboard = "^2.13.0" 18 | pytorch-msssim = "^1.0.0" 19 | grad-cam = "^1.5.0" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | black = "^23.3.0" 24 | ruff = "^0.0.262" 25 | mypy = "^1.2.0" 26 | pre-commit = "^3.2.2" 27 | 28 | 29 | [build-system] 30 | requires = ["poetry-core"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | 34 | [tool.black] 35 | line-length = 120 36 | 37 | [tool.mypy] 38 | python_version = "3.9" 39 | ignore_missing_imports = true 40 | 41 | 42 | [tool.ruff] 43 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 44 | select = ["E", "F"] 45 | ignore = [] 46 | 47 | # Allow autofix for all enabled rules (when `--fix`) is provided. 48 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] 49 | unfixable = [] 50 | 51 | # Exclude a variety of commonly ignored directories. 52 | exclude = [ 53 | ".bzr", 54 | ".direnv", 55 | ".eggs", 56 | ".git", 57 | ".hg", 58 | ".mypy_cache", 59 | ".nox", 60 | ".pants.d", 61 | ".pytype", 62 | ".ruff_cache", 63 | ".svn", 64 | ".tox", 65 | ".venv", 66 | "__pypackages__", 67 | "_build", 68 | "buck-out", 69 | "build", 70 | "dist", 71 | "node_modules", 72 | "venv", 73 | ] 74 | 75 | # Same as Black. 76 | line-length = 120 77 | 78 | # Allow unused variables when underscore-prefixed. 79 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 80 | 81 | # Assume Python 3.9. 82 | target-version = "py39" 83 | 84 | [tool.ruff.mccabe] 85 | # Unlike Flake8, default to a complexity level of 10. 86 | max-complexity = 10 87 | -------------------------------------------------------------------------------- /models/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | """Upscaling then double conv""" 41 | 42 | def __init__(self, in_channels, out_channels, bilinear=True): 43 | super().__init__() 44 | 45 | # if bilinear, use the normal convolutions to reduce the number of channels 46 | if bilinear: 47 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 48 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 49 | else: 50 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 51 | self.conv = DoubleConv(in_channels, out_channels) 52 | 53 | def forward(self, x1, x2): 54 | x1 = self.up(x1) 55 | # input is CHW 56 | diffY = x2.size()[2] - x1.size()[2] 57 | diffX = x2.size()[3] - x1.size()[3] 58 | 59 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 60 | # if you have padding issues, see 61 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 62 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 63 | x = torch.cat([x2, x1], dim=1) 64 | return self.conv(x) 65 | 66 | 67 | class OutConv(nn.Module): 68 | def __init__(self, in_channels, out_channels): 69 | super(OutConv, self).__init__() 70 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 71 | 72 | def forward(self, x): 73 | return self.conv(x) 74 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.layers import CBAM 5 | 6 | 7 | # Base code obtained from https://github.com/togheppi/pix2pix/blob/master/model.py 8 | class LargePix2PixDiscriminatorCBAM(nn.Module): 9 | # initializers 10 | def __init__(self, hparams, in_channels=24, d=64): 11 | super(LargePix2PixDiscriminatorCBAM, self).__init__() 12 | reduction_ratio = 16 13 | self.conv1 = nn.Conv2d(in_channels, d, 4, 1, 1) 14 | self.cbam1 = CBAM(d, reduction_ratio=reduction_ratio) 15 | self.conv_down_1 = nn.Conv2d(d, d, 4, 2, 1) 16 | self.cbam_down_1 = CBAM(d, reduction_ratio=reduction_ratio) 17 | 18 | self.conv2 = nn.Conv2d(d, d, 4, 1, 1) 19 | self.conv2_bn = nn.BatchNorm2d(d) 20 | self.cbam2 = CBAM(d, reduction_ratio=reduction_ratio) 21 | self.conv_down_2 = nn.Conv2d(d, d * 2, 4, 2, 1) 22 | self.conv_down_2_bn = nn.BatchNorm2d(d * 2) 23 | self.cbam_down_2 = CBAM(d * 2, reduction_ratio=reduction_ratio) 24 | 25 | self.conv3 = nn.Conv2d(d * 2, d * 2, 4, 1, 1) 26 | self.conv3_bn = nn.BatchNorm2d(d * 2) 27 | self.cbam3 = CBAM(d * 2, reduction_ratio=reduction_ratio) 28 | self.conv_down_3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1) 29 | self.conv_down_3_bn = nn.BatchNorm2d(d * 4) 30 | self.cbam_down_3 = CBAM(d * 4, reduction_ratio=reduction_ratio) 31 | 32 | self.conv4 = nn.Conv2d(d * 4, d * 4, 4, 1, 1) 33 | self.conv4_bn = nn.BatchNorm2d(d * 4) 34 | self.cbam4 = CBAM(d * 4, reduction_ratio=reduction_ratio) 35 | self.conv_down_4 = nn.Conv2d(d * 4, d * 8, 4, 1, 1) 36 | self.conv_down_4_bn = nn.BatchNorm2d(d * 8) 37 | self.cbam_down_4 = CBAM(d * 8, reduction_ratio=reduction_ratio) 38 | 39 | self.conv5 = nn.Conv2d(d * 8, 1, 4, 1, 1) 40 | 41 | # weight_init 42 | def weight_init(self, mean, std): 43 | for m in self._modules: 44 | normal_init(self._modules[m], mean, std) 45 | 46 | # forward method 47 | def forward(self, input, label): 48 | x = torch.cat([input, label], 1) 49 | 50 | x = F.leaky_relu(self.conv1(x), 0.2) 51 | x = self.cbam1(x) 52 | x = F.leaky_relu(self.conv_down_1(x), 0.2) 53 | x = self.cbam_down_1(x) 54 | 55 | x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) 56 | x = self.cbam2(x) 57 | x = F.leaky_relu(self.conv_down_2_bn(self.conv_down_2(x)), 0.2) 58 | x = self.cbam_down_2(x) 59 | 60 | x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) 61 | x = self.cbam3(x) 62 | x = F.leaky_relu(self.conv_down_3_bn(self.conv_down_3(x)), 0.2) 63 | x = self.cbam_down_3(x) 64 | 65 | x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) 66 | x = self.cbam4(x) 67 | x = F.leaky_relu(self.conv_down_4_bn(self.conv_down_4(x)), 0.2) 68 | x = self.cbam_down_4(x) 69 | 70 | x = F.sigmoid(self.conv5(x)) 71 | return x 72 | 73 | 74 | def normal_init(m, mean, std): 75 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): 76 | m.weight.data.normal_(mean, std) 77 | m.bias.data.zero_() 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /utils/dataset_precip.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch.multiprocessing 3 | import h5py 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | class precipitation_maps_masked_h5(Dataset): 8 | def __init__(self, in_file, num_input_images, num_output_images, mode="train", transform=None, use_timestamps=False): 9 | super(precipitation_maps_masked_h5, self).__init__() 10 | # The default sharing strategy is not supporten on mac 11 | torch.multiprocessing.set_sharing_strategy('file_system') 12 | 13 | self.file_name = in_file 14 | self.samples, _, _, _ = h5py.File(self.file_name, "r")[mode]["images"].shape 15 | self.samples = int(self.samples * 1) 16 | self.num_input = num_input_images 17 | self.num_output = num_output_images 18 | 19 | self.mode = mode 20 | self.use_timestamps = use_timestamps 21 | self.transform = transform 22 | self.dataset = None 23 | self.timestamps = None 24 | 25 | def __getitem__(self, index): 26 | # load the file here (load as singleton) 27 | if self.dataset is None: 28 | self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)[self.mode]["images"] 29 | if self.timestamps is None and self.use_timestamps is True: 30 | self.timestamps = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)[self.mode]["timestamps"] 31 | imgs = np.array(self.dataset[index], dtype="float32") 32 | 33 | # add transforms 34 | if self.transform is not None: 35 | imgs = self.transform(imgs) 36 | _, h, w = imgs.shape 37 | # crop 38 | start_row = (h - 64) // 2 39 | start_col = (w - 64) // 2 40 | imgs = imgs[:,start_row:start_row+64, start_col:start_col+64] 41 | 42 | input_imgs = imgs[: self.num_input] 43 | target_imgs = imgs[self.num_input:len(imgs)] 44 | 45 | factor = 52.52 46 | input_imgs_sum = (np.sum(input_imgs, axis=0) * factor)[None,:] 47 | target_imgs_sum = (np.sum(target_imgs, axis=0) * factor)[None,:] 48 | 49 | thresholds = range(1,26,1) 50 | for i, t in enumerate(thresholds): 51 | input_mask = (input_imgs_sum >= t).astype("float32") 52 | input_masks = np.concatenate((input_masks, input_mask), axis=0) if i>0 else input_mask 53 | 54 | target_mask = (target_imgs_sum >= t).astype("float32") 55 | target_masks = np.concatenate((target_masks, target_mask), axis=0) if i>0 else target_mask 56 | 57 | if not self.use_timestamps: 58 | return input_imgs, input_masks, target_imgs, target_masks 59 | else: 60 | month = str(self.timestamps[index][self.num_input][0])[5:8] 61 | season = self.get_season(month) 62 | return input_imgs, input_masks, target_imgs, target_masks, season 63 | 64 | def __len__(self): 65 | return self.samples 66 | 67 | def get_season(self, str): 68 | months = { 69 | 'JAN': (0, 0), 70 | 'FEB': (1, 0), 71 | 'MAR': (2, 1), 72 | 'APR': (3, 1), 73 | 'MAY': (4, 1), 74 | 'JUN': (5, 2), 75 | 'JUL': (6, 2), 76 | 'AUG': (7, 2), 77 | 'SEP': (8, 3), 78 | 'OCT': (9, 3), 79 | 'NOV': (10, 3), 80 | 'DEC': (11, 0) 81 | } 82 | if str in months: 83 | return months[str] 84 | else: 85 | return None # Return None for invalid input 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /train_moving_mnist.py: -------------------------------------------------------------------------------- 1 | from root import ROOT_DIR 2 | 3 | import lightning.pytorch as pl 4 | from lightning.pytorch.callbacks import ( 5 | ModelCheckpoint, 6 | LearningRateMonitor, 7 | EarlyStopping, 8 | ) 9 | from lightning.pytorch import loggers 10 | from lightning.pytorch.tuner import Tuner 11 | import argparse 12 | import torch 13 | from models import unet_precip_regression_lightning as unet_regr 14 | import models.regression_GA_SmaAt_GNet_mnist as gan 15 | 16 | def train_regression(hparams): 17 | if hparams.model == "SmaAt-UNet": 18 | net = unet_regr.SmaAt_UNet(hparams=hparams) 19 | elif hparams.model == "SmaAt-GNet": 20 | net = unet_regr.SmaAt_GNet(hparams=hparams) 21 | elif hparams.model == "SmaAt-GNet-Aleatoric": 22 | net = unet_regr.SmaAt_GNet_aleatoric(hparams=hparams) 23 | elif hparams.model == "GA-SmaAt-GNet": 24 | net = gan.GAN(hparams=hparams) 25 | else: 26 | raise Exception(f"{hparams.model} is not a valid model name") 27 | 28 | default_save_path = hparams.default_save_path 29 | 30 | checkpoint_callback = ModelCheckpoint( 31 | dirpath=default_save_path / net.__class__.__name__, 32 | filename=net.__class__.__name__ + "_rain_threshhold_50_{epoch}-{val_loss:.6f}", 33 | save_top_k=3, 34 | save_last=True, 35 | verbose=False, 36 | monitor="val_loss", 37 | mode="min", 38 | ) 39 | lr_monitor = LearningRateMonitor() 40 | tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__) 41 | 42 | earlystopping_callback = EarlyStopping( 43 | monitor="val_loss", 44 | mode="min", 45 | patience=hparams.es_patience, 46 | ) 47 | trainer = pl.Trainer( 48 | accelerator="mps", 49 | devices=1, 50 | fast_dev_run=hparams.fast_dev_run, 51 | max_epochs=hparams.epochs, 52 | default_root_dir=default_save_path, 53 | logger=tb_logger, 54 | callbacks=[checkpoint_callback, earlystopping_callback, lr_monitor], 55 | val_check_interval=hparams.val_check_interval, 56 | ) 57 | trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | 63 | parser = unet_regr.Precip_regression_base.add_model_specific_args(parser) 64 | 65 | parser.add_argument( 66 | "--dataset_folder", 67 | default=ROOT_DIR / "data" / "precipitation" / "train_test_1998-2022_input-length_12_img-ahead_12_rain-threshhold_50_normalized.h5", 68 | type=str, 69 | ) 70 | parser.add_argument("--batch_size", type=int, default=32) 71 | parser.add_argument("--learning_rate", type=float, default=0.001) 72 | parser.add_argument("--epochs", type=int, default=200) 73 | parser.add_argument("--fast_dev_run", type=bool, default=False) 74 | parser.add_argument("--resume_from_checkpoint", type=str, default=None) 75 | parser.add_argument("--val_check_interval", type=float, default=None) 76 | 77 | args = parser.parse_args() 78 | 79 | # args.fast_dev_run = True 80 | args.n_channels = 10 81 | args.n_masks = 10 82 | args.n_classes = 10 83 | args.n_output_images = 10 84 | 85 | args.gpus = 1 86 | args.model = "GA-SmaAt-GNet" 87 | args.lr_patience = 4 88 | args.es_patience = 15 89 | # GAN options 90 | args.l = 0.01 91 | args.disc_every_n_steps = 2 92 | 93 | # args.val_check_interval = 0.25 94 | args.kernels_per_layer = 2 95 | args.default_save_path = ROOT_DIR / "lightning" / "mnist" / f"{args.model}_batch-{args.batch_size}_v1.0" 96 | 97 | args.dropout=0.5 98 | # The default sharing strategy is not supported on mac 99 | torch.multiprocessing.set_sharing_strategy('file_system') 100 | 101 | # args.resume_from_checkpoint = f"lightning/precip_regression/[filename].ckpt" 102 | 103 | train_regression(args) 104 | 105 | -------------------------------------------------------------------------------- /train_precip.py: -------------------------------------------------------------------------------- 1 | from root import ROOT_DIR 2 | 3 | import lightning.pytorch as pl 4 | from lightning.pytorch.callbacks import ( 5 | ModelCheckpoint, 6 | LearningRateMonitor, 7 | EarlyStopping, 8 | ) 9 | from lightning.pytorch import loggers 10 | from lightning.pytorch.tuner import Tuner 11 | import argparse 12 | import torch 13 | from models import unet_precip_regression_lightning as unet_regr 14 | import models.regression_GA_SmaAt_GNet as gan 15 | 16 | def train_regression(hparams): 17 | if hparams.model == "SmaAt-UNet": 18 | net = unet_regr.SmaAt_UNet(hparams=hparams) 19 | elif hparams.model == "SmaAt-GNet": 20 | net = unet_regr.SmaAt_GNet(hparams=hparams) 21 | elif hparams.model == "SmaAt-GNet-Aleatoric": 22 | net = unet_regr.SmaAt_GNet_aleatoric(hparams=hparams) 23 | elif hparams.model == "GA-SmaAt-GNet": 24 | net = gan.GAN(hparams=hparams) 25 | else: 26 | raise Exception(f"{hparams.model} is not a valid model name") 27 | 28 | default_save_path = hparams.default_save_path 29 | 30 | checkpoint_callback = ModelCheckpoint( 31 | dirpath=default_save_path / net.__class__.__name__, 32 | filename=net.__class__.__name__ + "_rain_threshhold_50_{epoch}-{val_loss:.6f}", 33 | save_top_k=3, 34 | save_last=True, 35 | verbose=False, 36 | monitor="val_loss", 37 | mode="min", 38 | ) 39 | lr_monitor = LearningRateMonitor() 40 | tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__) 41 | 42 | earlystopping_callback = EarlyStopping( 43 | monitor="val_loss", 44 | mode="min", 45 | patience=hparams.es_patience, 46 | ) 47 | trainer = pl.Trainer( 48 | accelerator="mps", 49 | devices=1, 50 | fast_dev_run=hparams.fast_dev_run, 51 | max_epochs=hparams.epochs, 52 | default_root_dir=default_save_path, 53 | logger=tb_logger, 54 | callbacks=[checkpoint_callback, earlystopping_callback, lr_monitor], 55 | val_check_interval=hparams.val_check_interval, 56 | ) 57 | trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | 63 | parser = unet_regr.Precip_regression_base.add_model_specific_args(parser) 64 | 65 | parser.add_argument( 66 | "--dataset_folder", 67 | default=ROOT_DIR / "data" / "precipitation" / "train_test_1998-2022_input-length_12_img-ahead_12_rain-threshhold_50_normalized.h5", 68 | type=str, 69 | ) 70 | parser.add_argument("--batch_size", type=int, default=32) 71 | parser.add_argument("--learning_rate", type=float, default=0.001) 72 | parser.add_argument("--epochs", type=int, default=200) 73 | parser.add_argument("--fast_dev_run", type=bool, default=False) 74 | parser.add_argument("--resume_from_checkpoint", type=str, default=None) 75 | parser.add_argument("--val_check_interval", type=float, default=None) 76 | 77 | args = parser.parse_args() 78 | 79 | # args.fast_dev_run = True 80 | args.n_channels = 12 81 | args.n_masks = 25 82 | args.n_classes = 12 83 | args.n_output_images = 12 84 | 85 | args.gpus = 1 86 | args.model = "GA-SmaAt-GNet" #SmaAt-UNet, SmaAt-GNet, SmaAt-GNet-Aleatoric or GA-SmaAt-GNet 87 | args.lr_patience = 4 88 | args.es_patience = 15 89 | # GAN options 90 | args.l = 1000000 91 | args.disc_every_n_steps = 2 92 | 93 | # args.val_check_interval = 0.25 94 | args.kernels_per_layer = 2 95 | args.dataset_folder = ( 96 | ROOT_DIR / "data" / "precipitation" / "train_test_1998-2022_input-length_12_img-ahead_12_rain-threshhold_50_normalized.h5" 97 | ) 98 | args.dataset = "train" 99 | args.default_save_path = ROOT_DIR / "lightning" / "1998-2022" / f"{args.model}_batch-{args.batch_size}_v1.0" 100 | 101 | args.dropout=0.5 102 | # The default sharing strategy is not supported on mac 103 | torch.multiprocessing.set_sharing_strategy('file_system') 104 | 105 | # args.resume_from_checkpoint = f"lightning/precip_regression/[filename].ckpt" 106 | 107 | train_regression(args) 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GA-SmaAt-GNet 2 | 3 | PyTorch code for the Paper "GA-SmaAt-GNet: Generative Adversarial Small Attention GNet for Extreme Precipitation Nowcasting" [Arxiv-link](https://arxiv.org/abs/2401.09881), [Elsevier-link](https://www.sciencedirect.com/science/article/pii/S0950705124012462). 4 | 5 | ![GAN](imgs/GA-SmaAt-GNet.png) 6 | ![SmaAt-GNet](imgs/SmaAt-GNet.png) 7 | ![Discriminator](imgs/Attention-PatchGAN.png) 8 | 9 | The proposed GA-SmaAt-GNet's generator architecture (SmaAt-GNet) can be found in the model folder in [unet_precip_regression_lightning.py](models/unet_precip_regression_lightning.py). The discriminator architecture can be found in [discriminator.py](models/discriminator.py) 10 | 11 | The code in this GitHub is based on the code of the SmaAt-UNet model which can be found at https://github.com/HansBambel/SmaAt-UNet 12 | 13 | ## Installing dependencies 14 | 15 | This project is using [poetry](https://python-poetry.org/) as dependency management. Therefore, installing the required dependencies is as easy as this: 16 | 17 | ```shell 18 | conda create --name smaat-unet python=3.9 19 | conda activate smaat-unet 20 | poetry install 21 | # Sadly poetry < 1.5 does not allow to install the GPU variant so you need to do that afterwards separately: 22 | pip3 install torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/cu118 23 | ``` 24 | 25 | A [requirements.txt](requirements.txt) is also added from the poetry export. 26 | 27 | The following requirements are needed: 28 | 29 | ``` 30 | tqdm 31 | torch 32 | lightning 33 | tensorboard 34 | torchsummary 35 | h5py 36 | numpy 37 | grad-cam 38 | ``` 39 | 40 | --- 41 | 42 | For the paper we used [Pytorch-Lightning](https://github.com/Lightning-AI/lightning) which simplifies the training process and allows easy additions of loggers and checkpoint creations. 43 | 44 | If you have any questions about the code you can write an email to ereulen.uu@gmail.com and s.mehrkanoon@uu.nl. 45 | 46 | ### Training 47 | 48 | An example training script [train_moving_mnist.py](train_moving_mnist.py) is provided which trains the GA-SmaAt-GNet model on the [Moving MNIST](https://www.cs.toronto.edu/~nitish/unsupervised_video/) dataset. Running this script will automatically download the dataset. 49 | 50 | For training on the precipitation data the [train_precip.py](train_precip.py) file can be used. Make sure that the model you want to train is correctly specified in the model parameter of the script. 51 | The training will save a checkpoint file for the top 3 best epochs in the directory specified with the `default_save_path` variable. 52 | 53 | The [test_precip.py](test_precip.py) can be used to calculate the performance of the trained model on the test set by providing the location of the checkpoint file with the lowest validation loss and model name in the script. The results will be saved as .csv in the specified results folder. 54 | Pretrained checkpoint files of the models discussed in the paper are available upon request. Please write an email to: s.mehrkanoon@uu.nl. 55 | 56 | ### Extreme Precipitation dataset 57 | 58 | The data consists of 25 years precipitation maps in 5-minute intervals from 1998-2022 with a 2.4km grid. 59 | 60 | The dataset is based on radar precipitation maps from the [The Royal Netherlands Meteorological Institute (KNMI)](https://www.knmi.nl/over-het-knmi/about). The original images were cropped as can be seen in the example below: 61 | ![Precip cutout](imgs/data_prep.png) 62 | 63 | If you are interested in the extreme precipitation dataset we used, please write an email to: s.mehrkanoon@uu.nl. 64 | 65 | We normalized the data using [Min-Max normalization](). In order to revert this you need to multiply the images by 52.52; this results in the images showing amount of rain in mm/5min. 66 | 67 | ### Grad-CAM 68 | 69 | We used Grad-CAM to generate activation heatmaps for different parts of GA-SmaAt-GNet to gain more insight into our model's predictions. These activation heatmaps can be generated by running the [grad-cam.py](grad-cam.py) script. The code in this script was obtained from [https://github.com/mathieurenault1/SAR-UNet/blob/master/cam_segmentation_precip.py](https://github.com/mathieurenault1/SAR-UNet/blob/master/cam_segmentation_precip.py) and modified for our model and dataset. 70 | 71 | ### Citation 72 | 73 | ``` 74 | @article{reulen2024ga, 75 | title={Ga-smaat-gnet: Generative adversarial small attention gnet for extreme precipitation nowcasting}, 76 | author={Reulen, Eloy and Shi, Jie and Mehrkanoon, Siamak}, 77 | journal={Knowledge-Based Systems}, 78 | pages={112612}, 79 | year={2024}, 80 | publisher={Elsevier} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /models/unet_parts_depthwise_separable.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | # Base model taken from: https://github.com/milesial/Pytorch-UNet 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.layers import DepthwiseSeparableConv 7 | 8 | 9 | class DoubleConvDS(nn.Module): 10 | """(convolution => [BN] => ReLU) * 2""" 11 | 12 | def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | DepthwiseSeparableConv( 18 | in_channels, 19 | mid_channels, 20 | kernel_size=3, 21 | kernels_per_layer=kernels_per_layer, 22 | padding=1, 23 | ), 24 | nn.BatchNorm2d(mid_channels), 25 | nn.ReLU(inplace=True), 26 | DepthwiseSeparableConv( 27 | mid_channels, 28 | out_channels, 29 | kernel_size=3, 30 | kernels_per_layer=kernels_per_layer, 31 | padding=1, 32 | ), 33 | nn.BatchNorm2d(out_channels), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | def forward(self, x): 38 | return self.double_conv(x) 39 | 40 | 41 | class DownDS(nn.Module): 42 | """Downscaling with maxpool then double conv""" 43 | 44 | def __init__(self, in_channels, out_channels, kernels_per_layer=1): 45 | super().__init__() 46 | self.maxpool_conv = nn.Sequential( 47 | nn.MaxPool2d(2), 48 | DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer), 49 | ) 50 | 51 | def forward(self, x): 52 | return self.maxpool_conv(x) 53 | 54 | 55 | class UpDS(nn.Module): 56 | """Upscaling then double conv""" 57 | 58 | def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1): 59 | super().__init__() 60 | 61 | # if bilinear, use the normal convolutions to reduce the number of channels 62 | if bilinear: 63 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 64 | self.conv = DoubleConvDS( 65 | in_channels, 66 | out_channels, 67 | in_channels // 2, 68 | kernels_per_layer=kernels_per_layer, 69 | ) 70 | else: 71 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 72 | self.conv = DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer) 73 | 74 | def forward(self, x1, x2): 75 | x1 = self.up(x1) 76 | # input is CHW 77 | diffY = x2.size()[2] - x1.size()[2] 78 | diffX = x2.size()[3] - x1.size()[3] 79 | 80 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 81 | # if you have padding issues, see 82 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 83 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 84 | x = torch.cat([x2, x1], dim=1) 85 | return self.conv(x) 86 | 87 | 88 | class OutConv(nn.Module): 89 | def __init__(self, in_channels, out_channels): 90 | super(OutConv, self).__init__() 91 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 92 | 93 | def forward(self, x): 94 | return self.conv(x) 95 | 96 | class ResDoubleConvDS(nn.Module): 97 | def __init__(self, in_channels, out_channels,kernels_per_layer=1): 98 | super(ResDoubleConvDS, self).__init__() 99 | self.doubleconv = nn.Sequential( 100 | DoubleConvDS(in_channels, out_channels,kernels_per_layer=kernels_per_layer) 101 | ) 102 | self.Conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 103 | 104 | def forward(self, x): 105 | x1 = self.Conv_1x1(x) 106 | x2 = self.doubleconv(x) 107 | return x1 + x2 108 | 109 | class ConvDS(nn.Module): 110 | def __init__(self,in_channels, out_channels, kernels_per_layer=1): 111 | super().__init__() 112 | self.conv = nn.Sequential( 113 | DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, kernels_per_layer=kernels_per_layer, padding=1), 114 | nn.BatchNorm2d(out_channels), 115 | nn.ReLU(inplace=True) 116 | ) 117 | 118 | def forward(self, x): 119 | return self.conv(x) 120 | 121 | class UpDS_Simple(nn.Module): 122 | def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1): 123 | super().__init__() 124 | 125 | # if bilinear, use the normal convolutions to reduce the number of channels 126 | if bilinear: 127 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 128 | self.conv = ConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer) 129 | else: 130 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 131 | self.conv = ConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer) 132 | 133 | def forward(self, x): 134 | x = self.up(x) 135 | return self.conv(x) -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 7 | class DepthToSpace(nn.Module): 8 | def __init__(self, block_size): 9 | super().__init__() 10 | self.bs = block_size 11 | 12 | def forward(self, x): 13 | N, C, H, W = x.size() 14 | x = x.view(N, self.bs, self.bs, C // (self.bs**2), H, W) # (N, bs, bs, C//bs^2, H, W) 15 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 16 | x = x.view(N, C // (self.bs**2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 17 | return x 18 | 19 | 20 | class SpaceToDepth(nn.Module): 21 | # Expects the following shape: Batch, Channel, Height, Width 22 | def __init__(self, block_size): 23 | super().__init__() 24 | self.bs = block_size 25 | 26 | def forward(self, x): 27 | N, C, H, W = x.size() 28 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 29 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 30 | x = x.view(N, C * (self.bs**2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 31 | return x 32 | 33 | 34 | class DepthwiseSeparableConv(nn.Module): 35 | def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1): 36 | super(DepthwiseSeparableConv, self).__init__() 37 | # In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer 38 | self.depthwise = nn.Conv2d( 39 | in_channels, 40 | in_channels * kernels_per_layer, 41 | kernel_size=kernel_size, 42 | padding=padding, 43 | groups=in_channels, 44 | ) 45 | self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1) 46 | 47 | def forward(self, x): 48 | x = self.depthwise(x) 49 | x = self.pointwise(x) 50 | return x 51 | 52 | 53 | class DoubleDense(nn.Module): 54 | def __init__(self, in_channels, hidden_neurons, output_channels): 55 | super(DoubleDense, self).__init__() 56 | self.dense1 = nn.Linear(in_channels, out_features=hidden_neurons) 57 | self.dense2 = nn.Linear(in_features=hidden_neurons, out_features=hidden_neurons // 2) 58 | self.dense3 = nn.Linear(in_features=hidden_neurons // 2, out_features=output_channels) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.dense1(x.view(x.size(0), -1))) 62 | out = F.relu(self.dense2(out)) 63 | out = self.dense3(out) 64 | return out 65 | 66 | 67 | class DoubleDSConv(nn.Module): 68 | """(convolution => [BN] => ReLU) * 2""" 69 | 70 | def __init__(self, in_channels, out_channels): 71 | super().__init__() 72 | self.double_ds_conv = nn.Sequential( 73 | DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1), 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU(inplace=True), 76 | DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(out_channels), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | def forward(self, x): 82 | return self.double_ds_conv(x) 83 | 84 | 85 | class Flatten(nn.Module): 86 | def forward(self, x): 87 | return x.view(x.size(0), -1) 88 | 89 | 90 | class ChannelAttention(nn.Module): 91 | def __init__(self, input_channels, reduction_ratio=16): 92 | super(ChannelAttention, self).__init__() 93 | self.input_channels = input_channels 94 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 95 | self.max_pool = nn.AdaptiveMaxPool2d(1) 96 | # https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py 97 | # uses Convolutions instead of Linear 98 | self.MLP = nn.Sequential( 99 | Flatten(), 100 | nn.Linear(input_channels, input_channels // reduction_ratio), 101 | nn.ReLU(), 102 | nn.Linear(input_channels // reduction_ratio, input_channels), 103 | ) 104 | 105 | def forward(self, x): 106 | # Take the input and apply average and max pooling 107 | avg_values = self.avg_pool(x) 108 | max_values = self.max_pool(x) 109 | out = self.MLP(avg_values) + self.MLP(max_values) 110 | scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x) 111 | return scale 112 | 113 | 114 | class SpatialAttention(nn.Module): 115 | def __init__(self, kernel_size=7): 116 | super(SpatialAttention, self).__init__() 117 | assert kernel_size in (3, 7), "kernel size must be 3 or 7" 118 | padding = 3 if kernel_size == 7 else 1 119 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False) 120 | self.bn = nn.BatchNorm2d(1) 121 | 122 | def forward(self, x): 123 | avg_out = torch.mean(x, dim=1, keepdim=True) 124 | max_out, _ = torch.max(x, dim=1, keepdim=True) 125 | out = torch.cat([avg_out, max_out], dim=1) 126 | out = self.conv(out) 127 | out = self.bn(out) 128 | scale = x * torch.sigmoid(out) 129 | return scale 130 | 131 | 132 | class CBAM(nn.Module): 133 | def __init__(self, input_channels, reduction_ratio=16, kernel_size=7): 134 | super(CBAM, self).__init__() 135 | self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio) 136 | self.spatial_att = SpatialAttention(kernel_size=kernel_size) 137 | 138 | def forward(self, x): 139 | out = self.channel_att(x) 140 | out = self.spatial_att(out) 141 | return out 142 | -------------------------------------------------------------------------------- /models/regression_SmaAt_UNet.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn, optim, multiprocessing 3 | import torch 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from torch.utils.data import DataLoader 6 | from utils import dataset_precip 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import math 11 | 12 | 13 | class UNet_base(pl.LightningModule): 14 | @staticmethod 15 | def add_model_specific_args(parent_parser): 16 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 17 | parser.add_argument("--n_channels", type=int, default=12) 18 | parser.add_argument("--n_classes", type=int, default=12) 19 | parser.add_argument("--kernels_per_layer", type=int, default=1) 20 | parser.add_argument("--bilinear", type=bool, default=True) 21 | parser.add_argument("--reduction_ratio", type=int, default=16) 22 | parser.add_argument("--lr_patience", type=int, default=5) 23 | return parser 24 | 25 | def __init__(self, hparams): 26 | super().__init__() 27 | self.save_hyperparameters(hparams) 28 | 29 | def forward(self, x): 30 | pass 31 | 32 | def configure_optimizers(self): 33 | opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 34 | scheduler = { 35 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 36 | opt, mode="min", factor=0.1, patience=self.hparams.lr_patience 37 | ), 38 | "monitor": "val_loss", # Default: val_loss 39 | } 40 | return [opt], [scheduler] 41 | 42 | def loss_func(self, y_pred, y_true): 43 | return nn.functional.mse_loss( 44 | y_pred, y_true, reduction="mean" 45 | ) 46 | 47 | def training_step(self, batch, batchid): 48 | x, _, y, _ = batch 49 | y_pred = self(x) 50 | if batchid % 100 == 0: 51 | display_list = [x[0], y[0], y_pred[0]] 52 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 53 | 54 | f, axarr = plt.subplots(3,12) 55 | f.set_figwidth(20) 56 | for i in range(3): 57 | for j in range(12): 58 | axarr[i, j].imshow(display_list[i][j,:,:].detach().cpu().numpy()) 59 | plt.axis("off") 60 | plt.savefig(self.hparams.default_save_path / "imgs.png") 61 | loss = self.loss_func(y_pred.squeeze(), y) 62 | # logs metrics for each training_step, 63 | # and the average across the epoch, to the progress bar and logger 64 | # *100 for readability in progress bar 65 | self.log("train_loss", loss*100, on_step=True, on_epoch=True, prog_bar=True, logger=True) 66 | return loss 67 | 68 | def validation_step(self, batch, batch_idx): 69 | x, _, y, _ = batch 70 | y_pred = self(x) 71 | mse = self.loss_func(y_pred.squeeze(), y) 72 | self.log("val_loss", mse * 100, prog_bar=True) 73 | 74 | def on_validation_epoch_end(self): 75 | plt.close("all") 76 | 77 | class Precip_regression_base(UNet_base): 78 | @staticmethod 79 | def add_model_specific_args(parent_parser): 80 | parent_parser = UNet_base.add_model_specific_args(parent_parser) 81 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 82 | parser.add_argument("--num_input_images", type=int, default=12) 83 | parser.add_argument("--num_output_images", type=int, default=12) 84 | parser.add_argument("--valid_size", type=float, default=0.1) 85 | parser.n_channels = parser.parse_args().num_input_images 86 | parser.n_classes = 12 87 | return parser 88 | 89 | def __init__(self, hparams): 90 | super(Precip_regression_base, self).__init__(hparams=hparams) 91 | self.train_dataset = None 92 | self.valid_dataset = None 93 | self.train_sampler = None 94 | self.valid_sampler = None 95 | 96 | def prepare_data(self): 97 | # train_transform = transforms.Compose([ 98 | # transforms.RandomHorizontalFlip()] 99 | # ) 100 | train_transform = None 101 | valid_transform = None 102 | precip_dataset = dataset_precip.precipitation_maps_masked_h5 103 | self.train_dataset = precip_dataset( 104 | in_file=self.hparams.dataset_folder, 105 | num_input_images=self.hparams.num_input_images, 106 | num_output_images=self.hparams.num_output_images, 107 | mode=self.hparams.dataset, 108 | transform=train_transform, 109 | ) 110 | self.valid_dataset = precip_dataset( 111 | in_file=self.hparams.dataset_folder, 112 | num_input_images=self.hparams.num_input_images, 113 | num_output_images=self.hparams.num_output_images, 114 | mode=self.hparams.dataset, 115 | transform=valid_transform, 116 | ) 117 | 118 | num_train = len(self.train_dataset) 119 | indices = list(range(num_train)) 120 | split = int(np.floor(self.hparams.valid_size * num_train)) 121 | 122 | np.random.seed(123) 123 | np.random.shuffle(indices) 124 | 125 | train_idx, valid_idx = indices[split:], indices[:split] 126 | self.train_sampler = SubsetRandomSampler(train_idx) 127 | self.valid_sampler = SubsetRandomSampler(valid_idx) 128 | 129 | def train_dataloader(self): 130 | 131 | train_loader = DataLoader( 132 | self.train_dataset, 133 | batch_size=self.hparams.batch_size, 134 | sampler=self.train_sampler, 135 | num_workers=0, 136 | pin_memory=True, 137 | ) 138 | return train_loader 139 | 140 | def val_dataloader(self): 141 | 142 | valid_loader = DataLoader( 143 | self.valid_dataset, 144 | batch_size=self.hparams.batch_size, 145 | sampler=self.valid_sampler, 146 | num_workers=0, 147 | pin_memory=True, 148 | ) 149 | return valid_loader 150 | -------------------------------------------------------------------------------- /models/regression_SmaAt_GNet.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn, optim, multiprocessing 3 | import torch 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from torch.utils.data import DataLoader 6 | from utils import dataset_precip 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import math 11 | 12 | 13 | class UNet_base(pl.LightningModule): 14 | @staticmethod 15 | def add_model_specific_args(parent_parser): 16 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 17 | parser.add_argument("--n_channels", type=int, default=12) 18 | parser.add_argument("--n_classes", type=int, default=1) 19 | parser.add_argument("--kernels_per_layer", type=int, default=1) 20 | parser.add_argument("--bilinear", type=bool, default=True) 21 | parser.add_argument("--reduction_ratio", type=int, default=16) 22 | parser.add_argument("--lr_patience", type=int, default=5) 23 | return parser 24 | 25 | def __init__(self, hparams): 26 | super().__init__() 27 | self.save_hyperparameters(hparams) 28 | 29 | def forward(self, x): 30 | pass 31 | 32 | def configure_optimizers(self): 33 | opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 34 | scheduler = { 35 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 36 | opt, mode="min", factor=0.1, patience=self.hparams.lr_patience 37 | ), 38 | "monitor": "val_loss", # Default: val_loss 39 | } 40 | return [opt], [scheduler] 41 | 42 | def loss_func(self, y_pred, y_true): 43 | return nn.functional.mse_loss( 44 | y_pred, y_true, reduction="mean" 45 | ) 46 | 47 | def training_step(self, batch, batchid): 48 | x, mask, y, _ = batch 49 | y_pred = self(x, mask) 50 | 51 | if batchid % 100 == 0: 52 | # log sampled images 53 | display_list = [x[0], y[0], y_pred[0]] 54 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 55 | 56 | f, axarr = plt.subplots(3,12) 57 | f.set_figwidth(20) 58 | for i in range(3): 59 | for j in range(12): 60 | axarr[i, j].imshow(display_list[i][j,:,:].detach().cpu().numpy()) 61 | plt.axis("off") 62 | plt.savefig(self.hparams.default_save_path / "imgs.png") 63 | loss = self.loss_func(y_pred.squeeze(), y) 64 | # logs metrics for each training_step, 65 | # and the average across the epoch, to the progress bar and logger 66 | # *100 for readability in progress bar 67 | self.log("train_loss", loss*100, on_step=True, on_epoch=True, prog_bar=True, logger=True) 68 | return loss 69 | 70 | def validation_step(self, batch, batch_idx): 71 | x, mask, y, _ = batch 72 | y_pred = self(x, mask) 73 | loss = self.loss_func(y_pred.squeeze(), y) 74 | self.log("val_loss", loss * 100, prog_bar=True) 75 | 76 | def on_validation_epoch_end(self): 77 | plt.close("all") 78 | 79 | class Precip_regression_base_gnet(UNet_base): 80 | @staticmethod 81 | def add_model_specific_args(parent_parser): 82 | parent_parser = UNet_base.add_model_specific_args(parent_parser) 83 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 84 | parser.add_argument("--num_input_images", type=int, default=12) 85 | parser.add_argument("--num_output_images", type=int, default=12) 86 | parser.add_argument("--valid_size", type=float, default=0.1) 87 | parser.n_channels = parser.parse_args().num_input_images 88 | parser.n_classes = 12 89 | return parser 90 | 91 | def __init__(self, hparams): 92 | super(Precip_regression_base_gnet, self).__init__(hparams=hparams) 93 | self.train_dataset = None 94 | self.valid_dataset = None 95 | self.train_sampler = None 96 | self.valid_sampler = None 97 | 98 | def prepare_data(self): 99 | # train_transform = transforms.Compose([ 100 | # transforms.RandomHorizontalFlip()] 101 | # ) 102 | train_transform = None 103 | valid_transform = None 104 | precip_dataset = dataset_precip.precipitation_maps_masked_h5 105 | self.train_dataset = precip_dataset( 106 | in_file=self.hparams.dataset_folder, 107 | num_input_images=self.hparams.num_input_images, 108 | num_output_images=self.hparams.num_output_images, 109 | mode=self.hparams.dataset, 110 | transform=train_transform, 111 | ) 112 | self.valid_dataset = precip_dataset( 113 | in_file=self.hparams.dataset_folder, 114 | num_input_images=self.hparams.num_input_images, 115 | num_output_images=self.hparams.num_output_images, 116 | mode=self.hparams.dataset, 117 | transform=valid_transform, 118 | ) 119 | 120 | num_train = len(self.train_dataset) 121 | indices = list(range(num_train)) 122 | split = int(np.floor(self.hparams.valid_size * num_train)) 123 | 124 | np.random.seed(123) 125 | np.random.shuffle(indices) 126 | 127 | train_idx, valid_idx = indices[split:], indices[:split] 128 | self.train_sampler = SubsetRandomSampler(train_idx) 129 | self.valid_sampler = SubsetRandomSampler(valid_idx) 130 | 131 | def train_dataloader(self): 132 | 133 | train_loader = DataLoader( 134 | self.train_dataset, 135 | batch_size=self.hparams.batch_size, 136 | sampler=self.train_sampler, 137 | num_workers=0, 138 | pin_memory=True, 139 | ) 140 | return train_loader 141 | 142 | def val_dataloader(self): 143 | 144 | valid_loader = DataLoader( 145 | self.valid_dataset, 146 | batch_size=self.hparams.batch_size, 147 | sampler=self.valid_sampler, 148 | num_workers=0, 149 | pin_memory=True, 150 | ) 151 | return valid_loader 152 | -------------------------------------------------------------------------------- /models/regression_SmaAt_GNet_aleatoric.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn, optim, multiprocessing 3 | import torch 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from torch.utils.data import DataLoader 6 | from utils import dataset_precip 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import math 11 | 12 | 13 | class UNet_base(pl.LightningModule): 14 | @staticmethod 15 | def add_model_specific_args(parent_parser): 16 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 17 | parser.add_argument("--n_channels", type=int, default=12) 18 | parser.add_argument("--n_classes", type=int, default=1) 19 | parser.add_argument("--kernels_per_layer", type=int, default=1) 20 | parser.add_argument("--bilinear", type=bool, default=True) 21 | parser.add_argument("--reduction_ratio", type=int, default=16) 22 | parser.add_argument("--lr_patience", type=int, default=5) 23 | return parser 24 | 25 | def __init__(self, hparams): 26 | super().__init__() 27 | self.save_hyperparameters(hparams) 28 | 29 | def forward(self, x): 30 | pass 31 | 32 | def configure_optimizers(self): 33 | opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 34 | scheduler = { 35 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 36 | opt, mode="min", factor=0.1, patience=self.hparams.lr_patience 37 | ), 38 | "monitor": "val_loss", # Default: val_loss 39 | } 40 | return [opt], [scheduler] 41 | 42 | def loss_func(self, y_pred, y_true): 43 | return nn.functional.mse_loss( 44 | y_pred, y_true, reduction="mean" 45 | ) 46 | 47 | def loss_variance(self,y_pred,y_var,y_true): 48 | # Code from https://github.com/pmorerio/dl-uncertainty/blob/master/aleatoric-uncertainty/model.py 49 | loss1 = torch.mean( torch.exp(-y_var) * torch.square(y_pred-y_true)) 50 | loss2 = torch.mean(y_var) 51 | # "From What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?" NIPS 2017 52 | # In practice, we train the network to predict the log variance 53 | loss = .5*(loss1+loss2) 54 | return loss 55 | 56 | def training_step(self, batch, batchid): 57 | x, mask, y, _ = batch 58 | y_pred, y_var = self(x, mask) 59 | 60 | if batchid % 100 == 0: 61 | # log sampled images 62 | var = torch.exp(y_var) 63 | display_list = [x[0], y[0], y_pred[0], var[0]] 64 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 65 | 66 | f, axarr = plt.subplots(4,12) 67 | f.set_figwidth(20) 68 | for i in range(4): 69 | for j in range(12): 70 | axarr[i, j].imshow(display_list[i][j,:,:].detach().cpu().numpy()) 71 | plt.axis("off") 72 | plt.savefig(self.hparams.default_save_path / "imgs.png") 73 | mse = self.loss_func(y_pred.squeeze(), y) 74 | loss = self.loss_variance(y_pred.squeeze(),y_var.squeeze(),y) 75 | # logs metrics for each training_step, 76 | # and the average across the epoch, to the progress bar and logger 77 | self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 78 | self.log("train_mse", mse*100, on_step=True, on_epoch=True, prog_bar=True, logger=True) 79 | return loss 80 | 81 | def validation_step(self, batch, batch_idx): 82 | x, mask, y, _ = batch 83 | y_pred, y_var = self(x, mask) 84 | loss = self.loss_variance(y_pred.squeeze(), y_var.squeeze(), y) 85 | self.log("val_loss", loss, prog_bar=True) 86 | mse = self.loss_func(y_pred.squeeze(), y) 87 | self.log("val_mse", mse * 100, prog_bar=True) 88 | 89 | def on_validation_epoch_end(self): 90 | plt.close("all") 91 | 92 | class Precip_regression_base_gnet_aleatoric(UNet_base): 93 | @staticmethod 94 | def add_model_specific_args(parent_parser): 95 | parent_parser = UNet_base.add_model_specific_args(parent_parser) 96 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 97 | parser.add_argument("--num_input_images", type=int, default=12) 98 | parser.add_argument("--num_output_images", type=int, default=12) 99 | parser.add_argument("--valid_size", type=float, default=0.1) 100 | parser.n_channels = parser.parse_args().num_input_images 101 | parser.n_classes = 12 102 | return parser 103 | 104 | def __init__(self, hparams): 105 | super(Precip_regression_base_gnet_aleatoric, self).__init__(hparams=hparams) 106 | self.train_dataset = None 107 | self.valid_dataset = None 108 | self.train_sampler = None 109 | self.valid_sampler = None 110 | 111 | def prepare_data(self): 112 | # train_transform = transforms.Compose([ 113 | # transforms.RandomHorizontalFlip()] 114 | # ) 115 | train_transform = None 116 | valid_transform = None 117 | precip_dataset = dataset_precip.precipitation_maps_masked_h5 118 | self.train_dataset = precip_dataset( 119 | in_file=self.hparams.dataset_folder, 120 | num_input_images=self.hparams.num_input_images, 121 | num_output_images=self.hparams.num_output_images, 122 | mode=self.hparams.dataset, 123 | transform=train_transform, 124 | ) 125 | self.valid_dataset = precip_dataset( 126 | in_file=self.hparams.dataset_folder, 127 | num_input_images=self.hparams.num_input_images, 128 | num_output_images=self.hparams.num_output_images, 129 | mode=self.hparams.dataset, 130 | transform=valid_transform, 131 | ) 132 | 133 | num_train = len(self.train_dataset) 134 | indices = list(range(num_train)) 135 | split = int(np.floor(self.hparams.valid_size * num_train)) 136 | 137 | np.random.seed(123) 138 | np.random.shuffle(indices) 139 | 140 | train_idx, valid_idx = indices[split:], indices[:split] 141 | self.train_sampler = SubsetRandomSampler(train_idx) 142 | self.valid_sampler = SubsetRandomSampler(valid_idx) 143 | 144 | def train_dataloader(self): 145 | 146 | train_loader = DataLoader( 147 | self.train_dataset, 148 | batch_size=self.hparams.batch_size, 149 | sampler=self.train_sampler, 150 | num_workers=0, 151 | pin_memory=True, 152 | ) 153 | return train_loader 154 | 155 | def val_dataloader(self): 156 | 157 | valid_loader = DataLoader( 158 | self.valid_dataset, 159 | batch_size=self.hparams.batch_size, 160 | sampler=self.valid_sampler, 161 | num_workers=0, 162 | pin_memory=True, 163 | ) 164 | return valid_loader 165 | -------------------------------------------------------------------------------- /test_precip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | from os import path 7 | import pickle 8 | from tqdm import tqdm 9 | import math 10 | from pathlib import Path 11 | 12 | from root import ROOT_DIR 13 | from utils import dataset_precip 14 | from models import unet_precip_regression_lightning as unet_regr 15 | import models.regression_GA_SmaAt_GNet as gan 16 | 17 | 18 | def apply_dropout(m): 19 | if type(m) == nn.Dropout: 20 | m.train() 21 | 22 | def get_metrics(model, model_name, test_dl, denormalize=True, threshold=0.5, k=10): 23 | with torch.no_grad(): 24 | mps = torch.device("mps") 25 | if model_name != "Persistence": 26 | model.eval() # or model.freeze()? 27 | model.apply(apply_dropout) 28 | model.to(mps) 29 | loss_func = nn.functional.mse_loss 30 | 31 | factor = 1 32 | if denormalize: 33 | factor = 52.52 34 | 35 | threshold = threshold 36 | epsilon = 1e-6 37 | 38 | total_tp = 0 39 | total_fp = 0 40 | total_tn = 0 41 | total_fn = 0 42 | 43 | loss_denorm = 0.0 44 | f1 = 0.0 45 | csi = 0.0 46 | uncertainty = 0.0 47 | count = 0 48 | for x, mask, y_true, _ in tqdm(test_dl, leave=False): 49 | count += 1 50 | x = x.to(mps) 51 | mask = mask.to(mps) 52 | y_true = y_true.to(mps).squeeze() 53 | y_true = y_true 54 | y_pred = None 55 | y_preds = [] 56 | 57 | if model_name == "Persistence": 58 | y_pred = x.squeeze()[11].repeat(12, 1, 1) 59 | uncertainty += 0 60 | else: 61 | for _ in range(k): 62 | if model_name == "SmaAt-UNet": 63 | pred = model(x) 64 | else: 65 | pred = model(x,mask) 66 | y_preds.append(pred.squeeze()) 67 | y_preds = torch.stack(y_preds, dim=0) 68 | y_pred = torch.mean(y_preds, dim=0) 69 | uncertainty += torch.mean(torch.var(y_preds, dim=0)).item() 70 | 71 | # denormalize 72 | y_pred_adj = y_pred * factor 73 | y_true_adj = y_true * factor 74 | # calculate loss on denormalized data 75 | loss_denorm += loss_func(y_pred_adj, y_true_adj, reduction="sum") 76 | # sum all output frames 77 | y_pred_adj = torch.sum(y_pred_adj, axis=0) 78 | y_true_adj = torch.sum(y_true_adj, axis=0) 79 | # convert to masks for comparison 80 | y_pred_mask = y_pred_adj > threshold 81 | y_true_mask = y_true_adj > threshold 82 | y_pred_mask = y_pred_mask.cpu() 83 | y_true_mask = y_true_mask.cpu() 84 | 85 | tn, fp, fn, tp = np.bincount(y_true_mask.view(-1) * 2 + y_pred_mask.view(-1), minlength=4) 86 | total_tp += tp 87 | total_fp += fp 88 | total_tn += tn 89 | total_fn += fn 90 | 91 | uncertainty /= len(test_dl) 92 | mse_image = loss_denorm / len(test_dl) 93 | mse_pixel = mse_image / torch.numel(y_true) 94 | # get metrics 95 | precision = total_tp / (total_tp + total_fp + epsilon) 96 | recall = total_tp / (total_tp + total_fn + epsilon) 97 | f1 = 2 * precision * recall / (precision + recall + epsilon) 98 | csi = total_tp / (total_tp + total_fn + total_fp + epsilon) 99 | hss = (total_tp * total_tn - total_fn * total_fp) / ((total_tp + total_fn) * (total_fn + total_tn) + (total_tp + total_fp) * (total_fp + total_tn) + epsilon) 100 | mcc = calculate_mcc(total_tp, total_tn, total_fp, total_fn) 101 | return mse_pixel.item(), f1, csi, hss, mcc, uncertainty 102 | 103 | def calculate_mcc(total_tp, total_tn, total_fp, total_fn): 104 | total_tp = np.array(total_tp, dtype=np.float64) 105 | total_tn = np.array(total_tn, dtype=np.float64) 106 | total_fp = np.array(total_fp, dtype=np.float64) 107 | total_fn = np.array(total_fn, dtype=np.float64) 108 | 109 | numerator = (total_tp * total_tn) - (total_fp * total_fn) 110 | denominator = np.sqrt((total_tp + total_fp) * (total_tp + total_fn) * (total_tn + total_fp) * (total_tn + total_fn)) 111 | mcc = numerator / denominator if denominator != 0 else 0 112 | return mcc 113 | 114 | 115 | def get_model_losses(model_file, model_name, data_file, denormalize): 116 | test_losses = dict() 117 | dataset = dataset_precip.precipitation_maps_masked_h5( 118 | in_file=data_file, 119 | num_input_images=12, 120 | num_output_images=12, 121 | mode="test") 122 | 123 | test_dl = torch.utils.data.DataLoader( 124 | dataset, 125 | batch_size=1, 126 | shuffle=False, 127 | num_workers=0, 128 | pin_memory=True 129 | ) 130 | 131 | # load the model 132 | if model_name == "SmaAt-UNet": 133 | model = unet_regr.SmaAt_UNet 134 | model = model.load_from_checkpoint(f"{model_file}") 135 | elif model_name == "SmaAt-GNet": 136 | model = unet_regr.SmaAt_GNet 137 | model = model.load_from_checkpoint(f"{model_file}") 138 | elif model_name == "GA-SmaAt-GNet": 139 | model = gan.GAN 140 | model = model.load_from_checkpoint(f"{model_file}") 141 | elif model_name == "Persistence": 142 | model = None 143 | else: 144 | raise Exception(f"{model_name} is not a valid model name") 145 | 146 | thresholds = [0.5, 10, 20] 147 | for threshold in thresholds: 148 | print(str(int(threshold*100))) 149 | test_losses[f"binary_{str(int(threshold*100))}"] = [] 150 | 151 | for threshold in thresholds: 152 | losses = get_metrics(model, model_name, test_dl, denormalize, threshold=threshold, k=10) 153 | test_losses[f"binary_{str(int(threshold*100))}"].append([threshold, model_name] + list(losses)) 154 | 155 | 156 | return test_losses 157 | 158 | def losses_to_csv(losses, path): 159 | csv = "threshold, name, mse, f1, csi, hss, mcc, uncertainty\n" 160 | for loss in losses: 161 | row = ",".join(str(l) for l in loss) 162 | csv += row + "\n" 163 | 164 | with open(path,"w+") as f: 165 | f.write(csv) 166 | 167 | return csv 168 | 169 | 170 | if __name__ == '__main__': 171 | denormalize = True 172 | data_file = ( 173 | ROOT_DIR / "data" / "precipitation" / "train_test_1998-2022_input-length_12_img-ahead_12_rain-threshhold_50_normalized.h5" 174 | ) 175 | results_folder = ROOT_DIR / "results" 176 | 177 | model_file = ROOT_DIR / "checkpoints" / "top_models/GA-SmaAt-GNet_rain_threshhold_50_epoch=26-val_loss=0.000288.ckpt" 178 | model_name = "GA-SmaAt-GNet" #Persistence, SmaAt-UNet, SmaAt-GNet or GA-SmaAt-GNet 179 | 180 | test_losses = get_model_losses(model_file, model_name, data_file, denormalize) 181 | 182 | print(losses_to_csv(test_losses['binary_50'], (results_folder / f"{model_name}_res_50.csv"))) 183 | print(losses_to_csv(test_losses['binary_1000'], (results_folder / f"{model_name}_res_1000.csv"))) 184 | print(losses_to_csv(test_losses['binary_2000'], (results_folder / f"{model_name}_res_2000.csv"))) 185 | 186 | -------------------------------------------------------------------------------- /grad-cam.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | from utils import dataset_precip 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | import warnings 10 | from pytorch_grad_cam.utils.image import show_cam_on_image 11 | from pytorch_grad_cam import GradCAM 12 | import models.regression_GA_SmaAt_GNet as gan 13 | from root import ROOT_DIR 14 | 15 | warnings.filterwarnings('ignore') 16 | warnings.simplefilter('ignore') 17 | 18 | 19 | class SemanticSegmentationTarget: 20 | def __init__(self, category, mask, device): 21 | self.category = category 22 | self.mask = torch.from_numpy(mask) 23 | if device == 'cuda': 24 | self.mask = self.mask.cuda() 25 | 26 | def __call__(self, model_output): 27 | return (model_output[self.category, :, :] * self.mask).sum() 28 | 29 | 30 | def load_model(model, model_folder, device): 31 | models = [m for m in os.listdir(model_folder) if ".ckpt" in m] 32 | model_file = models[-1] 33 | model = model.load_from_checkpoint(f"{model_folder}/{model_file}") 34 | model.eval() 35 | model.to(torch.device(device)) 36 | return model 37 | 38 | 39 | def get_segmentation_data(): 40 | dataset_masked = dataset_precip.precipitation_maps_masked_h5( 41 | in_file=data_file, 42 | num_input_images=12, 43 | num_output_images=12, 44 | mode="test", 45 | use_timestamps=False) 46 | 47 | test_dl_masked = torch.utils.data.DataLoader( 48 | dataset_masked, 49 | batch_size=1, 50 | shuffle=False, 51 | num_workers=0, 52 | pin_memory=True 53 | ) 54 | return test_dl_masked 55 | 56 | 57 | def run_cam(model, target_layers, device): 58 | test_dl = get_segmentation_data() 59 | count = 0 60 | for x, masks, y_true, _ in tqdm(test_dl, leave=False): 61 | count += 1 62 | if count < 13871: 63 | continue 64 | x = x.to(torch.device(device)) 65 | masks = masks.to(torch.device(device)) 66 | model = model.to(torch.device(device)) 67 | input = torch.cat((x,masks),dim=1) 68 | output = model(input) 69 | 70 | x = torch.sum(x, dim=1) 71 | output = torch.sum(output, dim=1) 72 | 73 | mask = np.digitize((output[0] * 52.52).detach().cpu().numpy(), np.array([0.5]), right=True) 74 | mask_float = np.float32(mask) 75 | image = torch.stack([x[0], x[0], x[0]], dim=2) 76 | image = image.cpu().numpy() 77 | targets = [SemanticSegmentationTarget(0, mask_float, device)] 78 | cam_image = [] 79 | 80 | for layer in target_layers: 81 | with GradCAM(model=model, target_layers=layer) as cam: 82 | grayscale_cam = cam(input_tensor=input, targets=targets)[0, :] 83 | cam_image.append(show_cam_on_image(image, grayscale_cam, use_rgb=True)) 84 | 85 | # Plot encoder 86 | fig, axes = plt.subplots(5,4, figsize=(8,10)) 87 | # map encoder 88 | # 0 89 | axes[0][0].imshow(cam_image[0]) 90 | axes[0][0].axis("off") 91 | axes[0][1].imshow(cam_image[1]) 92 | axes[0][1].axis("off") 93 | # 1 94 | axes[1][0].imshow(cam_image[2]) 95 | axes[1][0].axis("off") 96 | axes[1][1].imshow(cam_image[3]) 97 | axes[1][1].axis("off") 98 | # 2 99 | axes[2][0].imshow(cam_image[4]) 100 | axes[2][0].axis("off") 101 | axes[2][1].imshow(cam_image[5]) 102 | axes[2][1].axis("off") 103 | # 3 104 | axes[3][0].imshow(cam_image[6]) 105 | axes[3][0].axis("off") 106 | axes[3][1].imshow(cam_image[7]) 107 | axes[3][1].axis("off") 108 | # 4 109 | axes[4][0].imshow(cam_image[8]) 110 | axes[4][0].axis("off") 111 | axes[4][1].imshow(cam_image[9]) 112 | axes[4][1].axis("off") 113 | 114 | # mask encoder 115 | # 0 116 | axes[0][2].imshow(cam_image[10]) 117 | axes[0][2].axis("off") 118 | axes[0][3].imshow(cam_image[11]) 119 | axes[0][3].axis("off") 120 | # 1 121 | axes[1][2].imshow(cam_image[12]) 122 | axes[1][2].axis("off") 123 | axes[1][3].imshow(cam_image[13]) 124 | axes[1][3].axis("off") 125 | # 2 126 | axes[2][2].imshow(cam_image[14]) 127 | axes[2][2].axis("off") 128 | axes[2][3].imshow(cam_image[15]) 129 | axes[2][3].axis("off") 130 | # 3 131 | axes[3][2].imshow(cam_image[16]) 132 | axes[3][2].axis("off") 133 | axes[3][3].imshow(cam_image[17]) 134 | axes[3][3].axis("off") 135 | # 4 136 | axes[4][2].imshow(cam_image[18]) 137 | axes[4][2].axis("off") 138 | axes[4][3].imshow(cam_image[19]) 139 | axes[4][3].axis("off") 140 | plt.tight_layout() 141 | plt.savefig("imgs/enc_grad.png") 142 | plt.show() 143 | 144 | # Plot decoder 145 | fig, axes = plt.subplots(4,1, figsize=(2,7)) 146 | # 0 147 | axes[0].imshow(cam_image[20]) 148 | axes[0].axis("off") 149 | # 1 150 | axes[1].imshow(cam_image[21]) 151 | axes[1].axis("off") 152 | # 2 153 | axes[2].imshow(cam_image[22]) 154 | axes[2].axis("off") 155 | # 3 156 | axes[3].imshow(cam_image[22]) 157 | axes[3].axis("off") 158 | plt.tight_layout() 159 | plt.savefig("imgs/dec_grad.png") 160 | plt.show() 161 | 162 | # Wrapper for GA-SmaAt-GNet that splits 1 input into 2 because gradcam needs 1 unput 163 | class GradCAMWrapper(nn.Module): 164 | def __init__(self, model): 165 | super(GradCAMWrapper, self).__init__() 166 | self.model = model 167 | 168 | def forward(self, input_tensor): 169 | # Split the input tensor into multiple tensors 170 | input_tensor1 = input_tensor[:,:12] 171 | input_tensor2 = input_tensor[:,12:] 172 | 173 | # Forward pass through the original model 174 | output = self.model(input_tensor1, input_tensor2) 175 | 176 | return output 177 | 178 | 179 | if __name__ == '__main__': 180 | data_file = ( 181 | ROOT_DIR / "data" / "precipitation" / "train_test_1998-2022_input-length_12_img-ahead_12_rain-threshhold_50_normalized.h5" 182 | ) 183 | device = 'cpu' 184 | # load the models 185 | ga_smaat_gnet = gan.GAN 186 | ga_smaat_gnet = ga_smaat_gnet.load_from_checkpoint("checkpoints/top_models/GA-SmaAt-GNet_rain_threshhold_50_epoch=26-val_loss=0.000288.ckpt") 187 | model = GradCAMWrapper(ga_smaat_gnet) 188 | print(model) 189 | target_layers = [ 190 | [model.model.generator.inc1], [model.model.generator.cbam11], 191 | [model.model.generator.down11.maxpool_conv[1]], [model.model.generator.cbam12], 192 | [model.model.generator.down12.maxpool_conv[1]], [model.model.generator.cbam13], 193 | [model.model.generator.down13.maxpool_conv[1]], [model.model.generator.cbam14], 194 | [model.model.generator.down14.maxpool_conv[1]], [model.model.generator.cbam15], 195 | [model.model.generator.inc2], [model.model.generator.cbam21], 196 | [model.model.generator.down21.maxpool_conv[1]], [model.model.generator.cbam22], 197 | [model.model.generator.down22.maxpool_conv[1]], [model.model.generator.cbam23], 198 | [model.model.generator.down23.maxpool_conv[1]], [model.model.generator.cbam24], 199 | [model.model.generator.down24.maxpool_conv[1]], [model.model.generator.cbam25], 200 | [model.model.generator.up4.conv], 201 | [model.model.generator.up3.conv], 202 | [model.model.generator.up2.conv], 203 | [model.model.generator.up1.conv], 204 | ] 205 | run_cam(model, target_layers, device) -------------------------------------------------------------------------------- /models/unet_precip_regression_lightning.py: -------------------------------------------------------------------------------- 1 | from models.unet_parts import Down, DoubleConv, Up, OutConv 2 | from models.unet_parts_depthwise_separable import DoubleConvDS, ResDoubleConvDS, UpDS, UpDS_Simple, DownDS 3 | from models.layers import CBAM 4 | from models.regression_SmaAt_UNet import Precip_regression_base 5 | from models.regression_SmaAt_GNet import Precip_regression_base_gnet 6 | from models.regression_SmaAt_GNet_aleatoric import Precip_regression_base_gnet_aleatoric 7 | import torch 8 | import torch.nn as nn 9 | 10 | class SmaAt_UNet(Precip_regression_base): 11 | def __init__(self, hparams): 12 | super(SmaAt_UNet, self).__init__(hparams=hparams) 13 | self.n_channels = self.hparams.n_channels 14 | self.n_classes = self.hparams.n_classes 15 | self.bilinear = self.hparams.bilinear 16 | reduction_ratio = self.hparams.reduction_ratio 17 | kernels_per_layer = self.hparams.kernels_per_layer 18 | dropout_prob = self.hparams.dropout 19 | 20 | self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) 21 | self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio) 22 | self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) 23 | self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio) 24 | self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) 25 | self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio) 26 | self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) 27 | self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio) 28 | factor = 2 if self.bilinear else 1 29 | self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) 30 | self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) 31 | self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 32 | self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 33 | self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 34 | self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer) 35 | 36 | self.outc = OutConv(64, self.n_classes) 37 | 38 | self.dropout = nn.Dropout(p=dropout_prob) 39 | 40 | def forward(self, x): 41 | x1 = self.inc(x) 42 | x1Att = self.cbam1(x1) 43 | x2 = self.down1(x1) 44 | x2Att = self.cbam2(x2) 45 | x3 = self.down2(x2) 46 | x3Att = self.cbam3(x3) 47 | x4 = self.down3(x3) 48 | x4Att = self.cbam4(x4) 49 | x5 = self.down4(x4) 50 | x5Att = self.cbam5(x5) 51 | 52 | x = self.up1(x5Att, x4Att) 53 | x = self.dropout(x) 54 | x = self.up2(x, x3Att) 55 | x = self.dropout(x) 56 | x = self.up3(x, x2Att) 57 | x = self.up4(x, x1Att) 58 | logits = self.outc(x) 59 | return logits 60 | 61 | 62 | class SmaAt_GNet(Precip_regression_base_gnet): 63 | def __init__(self, hparams): 64 | super(Precip_regression_base_gnet, self).__init__(hparams=hparams) 65 | self.n_channels = self.hparams.n_channels 66 | self.n_classes = self.hparams.n_classes 67 | self.n_masks = self.hparams.n_masks 68 | self.bilinear = self.hparams.bilinear 69 | reduction_ratio = self.hparams.reduction_ratio 70 | kernels_per_layer = self.hparams.kernels_per_layer 71 | dropout_prob = self.hparams.dropout 72 | 73 | # map down 74 | self.inc1 = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) 75 | self.cbam11 = CBAM(64, reduction_ratio=reduction_ratio) 76 | self.down11 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) 77 | self.cbam12 = CBAM(128, reduction_ratio=reduction_ratio) 78 | self.down12 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) 79 | self.cbam13 = CBAM(256, reduction_ratio=reduction_ratio) 80 | self.down13 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) 81 | self.cbam14 = CBAM(512, reduction_ratio=reduction_ratio) 82 | factor = 2 if self.bilinear else 1 83 | self.down14 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) 84 | self.cbam15 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) 85 | 86 | # mask down 87 | self.inc2 = DoubleConvDS(self.n_masks, 64, kernels_per_layer=kernels_per_layer) 88 | self.cbam21 = CBAM(64, reduction_ratio=reduction_ratio) 89 | self.down21 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) 90 | self.cbam22 = CBAM(128, reduction_ratio=reduction_ratio) 91 | self.down22 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) 92 | self.cbam23 = CBAM(256, reduction_ratio=reduction_ratio) 93 | self.down23 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) 94 | self.cbam24 = CBAM(512, reduction_ratio=reduction_ratio) 95 | factor = 2 if self.bilinear else 1 96 | self.down24 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) 97 | self.cbam25 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) 98 | # up 99 | self.up1 = UpDS(1024*2, 512*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 100 | self.up2 = UpDS(512*2, 256*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 101 | self.up3 = UpDS(256*2, 128*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 102 | self.up4 = UpDS(128*2, 64*2, self.bilinear, kernels_per_layer=kernels_per_layer) 103 | 104 | self.outc = OutConv(64*2, self.n_classes) 105 | 106 | self.dropout = nn.Dropout(p=dropout_prob) 107 | 108 | def forward(self, x, m): 109 | # down map 110 | x1 = self.inc1(x) 111 | x1Att = self.cbam11(x1) 112 | x2 = self.down11(x1) 113 | x2Att = self.cbam12(x2) 114 | x3 = self.down12(x2) 115 | x3Att = self.cbam13(x3) 116 | x4 = self.down13(x3) 117 | x4Att = self.cbam14(x4) 118 | x5 = self.down14(x4) 119 | x5Att = self.cbam15(x5) 120 | 121 | # down mask 122 | m1 = self.inc2(m) 123 | m1Att = self.cbam21(m1) 124 | m2 = self.down21(m1) 125 | m2Att = self.cbam22(m2) 126 | m3 = self.down22(m2) 127 | m3Att = self.cbam23(m3) 128 | m4 = self.down23(m3) 129 | m4Att = self.cbam24(m4) 130 | m5 = self.down24(m4) 131 | m5Att = self.cbam25(m5) 132 | 133 | # concatenate 134 | x5Att = torch.cat((x5Att, m5Att), dim=1) 135 | x4Att = torch.cat((x4Att, m4Att), dim=1) 136 | x3Att = torch.cat((x3Att, m3Att), dim=1) 137 | x2Att = torch.cat((x2Att, m2Att), dim=1) 138 | x1Att = torch.cat((x1Att, m1Att), dim=1) 139 | 140 | # up 141 | x = self.up1(x5Att, x4Att) 142 | x = self.dropout(x) 143 | x = self.up2(x, x3Att) 144 | x = self.dropout(x) 145 | x = self.up3(x, x2Att) 146 | x = self.up4(x, x1Att) 147 | logits = self.outc(x) 148 | return logits 149 | 150 | class SmaAt_GNet_aleatoric(Precip_regression_base_gnet_aleatoric): 151 | def __init__(self, hparams): 152 | super(Precip_regression_base_gnet_aleatoric, self).__init__(hparams=hparams) 153 | self.n_channels = self.hparams.n_channels 154 | self.n_classes = self.hparams.n_classes 155 | self.n_masks = self.hparams.n_masks 156 | self.bilinear = self.hparams.bilinear 157 | reduction_ratio = self.hparams.reduction_ratio 158 | kernels_per_layer = self.hparams.kernels_per_layer 159 | dropout_prob = self.hparams.dropout 160 | 161 | # map down 162 | self.inc1 = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) 163 | self.cbam11 = CBAM(64, reduction_ratio=reduction_ratio) 164 | self.down11 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) 165 | self.cbam12 = CBAM(128, reduction_ratio=reduction_ratio) 166 | self.down12 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) 167 | self.cbam13 = CBAM(256, reduction_ratio=reduction_ratio) 168 | self.down13 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) 169 | self.cbam14 = CBAM(512, reduction_ratio=reduction_ratio) 170 | factor = 2 if self.bilinear else 1 171 | self.down14 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) 172 | self.cbam15 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) 173 | 174 | # mask down 175 | self.inc2 = DoubleConvDS(self.n_masks, 64, kernels_per_layer=kernels_per_layer) 176 | self.cbam21 = CBAM(64, reduction_ratio=reduction_ratio) 177 | self.down21 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) 178 | self.cbam22 = CBAM(128, reduction_ratio=reduction_ratio) 179 | self.down22 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) 180 | self.cbam23 = CBAM(256, reduction_ratio=reduction_ratio) 181 | self.down23 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) 182 | self.cbam24 = CBAM(512, reduction_ratio=reduction_ratio) 183 | factor = 2 if self.bilinear else 1 184 | self.down24 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) 185 | self.cbam25 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) 186 | # up 187 | self.up1 = UpDS(1024*2, 512*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 188 | self.up2 = UpDS(512*2, 256*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 189 | self.up3 = UpDS(256*2, 128*2 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) 190 | self.up4 = UpDS(128*2, 64*2, self.bilinear, kernels_per_layer=kernels_per_layer) 191 | 192 | self.outc = OutConv(64*2, self.n_classes) 193 | self.outc_var = OutConv(64*2, self.n_classes) 194 | 195 | self.dropout = nn.Dropout(p=dropout_prob) 196 | 197 | def forward(self, x, m): 198 | # down map 199 | x1 = self.inc1(x) 200 | x1Att = self.cbam11(x1) 201 | x2 = self.down11(x1) 202 | x2Att = self.cbam12(x2) 203 | x3 = self.down12(x2) 204 | x3Att = self.cbam13(x3) 205 | x4 = self.down13(x3) 206 | x4Att = self.cbam14(x4) 207 | x5 = self.down14(x4) 208 | x5Att = self.cbam15(x5) 209 | 210 | # down mask 211 | m1 = self.inc2(m) 212 | m1Att = self.cbam21(m1) 213 | m2 = self.down21(m1) 214 | m2Att = self.cbam22(m2) 215 | m3 = self.down22(m2) 216 | m3Att = self.cbam23(m3) 217 | m4 = self.down23(m3) 218 | m4Att = self.cbam24(m4) 219 | m5 = self.down24(m4) 220 | m5Att = self.cbam25(m5) 221 | 222 | # concatenate 223 | x5Att = torch.cat((x5Att, m5Att), dim=1) 224 | x4Att = torch.cat((x4Att, m4Att), dim=1) 225 | x3Att = torch.cat((x3Att, m3Att), dim=1) 226 | x2Att = torch.cat((x2Att, m2Att), dim=1) 227 | x1Att = torch.cat((x1Att, m1Att), dim=1) 228 | 229 | # up 230 | x = self.up1(x5Att, x4Att) 231 | x = self.dropout(x) 232 | x = self.up2(x, x3Att) 233 | x = self.dropout(x) 234 | x = self.up3(x, x2Att) 235 | x = self.up4(x, x1Att) 236 | logits = self.outc(x) 237 | var = self.outc_var(x) 238 | return logits, var 239 | -------------------------------------------------------------------------------- /models/regression_GA_SmaAt_GNet_mnist.py: -------------------------------------------------------------------------------- 1 | from root import ROOT_DIR 2 | import lightning.pytorch as pl 3 | from torch import nn, optim, multiprocessing 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | from torch.utils.data import DataLoader 8 | import torchvision 9 | from utils import dataset_precip 10 | from models.discriminator import * 11 | from models.unet_precip_regression_lightning import SmaAt_UNet, SmaAt_GNet 12 | import argparse 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from matplotlib.lines import Line2D 16 | import math 17 | 18 | 19 | class GAN_base(pl.LightningModule): 20 | @staticmethod 21 | def add_model_specific_args(parent_parser): 22 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 23 | parser.add_argument("--n_channels", type=int, default=10) 24 | parser.add_argument("--n_classes", type=int, default=10) 25 | parser.add_argument("--kernels_per_layer", type=int, default=1) 26 | parser.add_argument("--bilinear", type=bool, default=True) 27 | parser.add_argument("--reduction_ratio", type=int, default=16) 28 | parser.add_argument("--lr_patience", type=int, default=5) 29 | return parser 30 | 31 | def __init__(self, hparams): 32 | super().__init__() 33 | self.save_hyperparameters(hparams) 34 | 35 | self.automatic_optimization = False 36 | 37 | # networks 38 | self.generator = SmaAt_GNet(hparams=hparams) 39 | self.discriminator = LargePix2PixDiscriminatorCBAM(hparams=hparams, in_channels=20) 40 | 41 | self.g_losses = [] 42 | self.d_losses = [] 43 | self.log("val_g_loss",float("inf"),prog_bar=False) 44 | self.log("val_d_loss",float("inf"),prog_bar=False) 45 | 46 | def forward(self, x, m): 47 | return self.generator(x, m) 48 | 49 | def configure_optimizers(self): 50 | lr = self.hparams.learning_rate 51 | 52 | opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr) 53 | opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr) 54 | scheduler_g = { 55 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 56 | opt_g, mode="min", factor=0.1, patience=self.hparams.lr_patience, 57 | 58 | ), 59 | "monitor": "val_g_loss", # Default: val_loss 60 | } 61 | scheduler_d = { 62 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 63 | opt_d, mode="min", factor=0.1, patience=self.hparams.lr_patience 64 | ), 65 | "monitor": "val_d_loss", # Default: val_loss 66 | } 67 | return [opt_g, opt_d], [scheduler_g, scheduler_d] 68 | 69 | def on_validation_epoch_end(self): 70 | plt.close("all") 71 | 72 | def adversarial_loss(self, y_hat, y): 73 | return F.binary_cross_entropy(y_hat, y) 74 | 75 | def loss_func(self, y_pred, y_true): 76 | return nn.functional.mse_loss( 77 | y_pred, y_true, reduction="mean" 78 | ) 79 | 80 | def training_step(self, batch, batchid): 81 | # lamda param for generator loss 82 | l = self.hparams.l 83 | 84 | batch = batch.squeeze().float() 85 | imgs = batch[:,:10,:,:] 86 | # flip horizontally and vertically 87 | imgs_flip = torch.flip(imgs,[2,3]) 88 | tar = batch[:,10:,:,:] 89 | optimizer_g, optimizer_d = self.optimizers() 90 | scheduler_g, scheduler_d = self.lr_schedulers() 91 | 92 | 93 | # how well can it label as real? 94 | valid = torch.ones((imgs.size(0),1,4,4)) 95 | valid = valid.type_as(imgs) 96 | 97 | # how well can it label as fake? 98 | fake = torch.zeros((imgs.size(0),1,4,4)) 99 | fake = fake.type_as(imgs) 100 | 101 | # train discriminator 102 | # Measure discriminator's ability to classify real from generated samples 103 | generated_imgs = self.generator(imgs, imgs_flip) 104 | self.toggle_optimizer(optimizer_d) 105 | 106 | real_loss = self.adversarial_loss(self.discriminator(imgs,tar), valid) 107 | fake_loss = self.adversarial_loss(self.discriminator(imgs, generated_imgs.detach()), fake) 108 | 109 | # discriminator loss is the sum of these 110 | d_loss = (real_loss + fake_loss) 111 | self.log("d_loss", d_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 112 | 113 | # only do backpropagation every n steps 114 | if batchid % self.hparams.disc_every_n_steps == 0: 115 | optimizer_d.zero_grad() 116 | self.manual_backward(d_loss) 117 | optimizer_d.step() 118 | 119 | self.untoggle_optimizer(optimizer_d) 120 | 121 | # Train generator 122 | self.toggle_optimizer(optimizer_g) 123 | optimizer_g.zero_grad() 124 | 125 | generated_imgs = self(imgs, imgs_flip) 126 | 127 | # Generator loss is the combination of adverarial loss and MSE loss 128 | g_loss = self.adversarial_loss(self.discriminator(imgs,generated_imgs), valid) 129 | structural_g_loss = self.loss_func(generated_imgs, tar) 130 | total_g_loss = g_loss + l * structural_g_loss 131 | self.log("g_total", total_g_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 132 | self.log("g_loss", g_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 133 | 134 | 135 | self.manual_backward(total_g_loss) 136 | optimizer_g.step() 137 | self.untoggle_optimizer(optimizer_g) 138 | 139 | if batchid % 100 == 0: 140 | display_list = [imgs[0], tar[0], generated_imgs[0]] 141 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 142 | # Plot generated images 143 | f, axarr = plt.subplots(3,10) 144 | f.set_figwidth(20) 145 | for i in range(3): 146 | for j in range(10): 147 | axarr[i, j].imshow(display_list[i][j,:,:].detach().cpu().numpy()) 148 | plt.axis("off") 149 | plt.savefig(self.hparams.default_save_path / "imgs.png") 150 | 151 | # Plot losses 152 | xs = [x * 10 for x in range(len(self.g_losses))] 153 | xs = xs[10:] 154 | fig, ax1 = plt.subplots() 155 | color = 'tab:red' 156 | ax1.set_xlabel('steps') 157 | ax1.set_ylabel('g_loss', color=color) 158 | g_losses = self.g_losses[10:] 159 | ax1.plot(xs, g_losses, color=color) 160 | ax1.tick_params(axis='y', labelcolor=color) 161 | 162 | ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis 163 | 164 | color = 'tab:blue' 165 | ax2.set_ylabel('d_loss', color=color) # we already handled the x-label with ax1 166 | d_losses = self.d_losses[10:] 167 | ax2.plot(xs, d_losses, color=color) 168 | ax2.tick_params(axis='y', labelcolor=color) 169 | fig.tight_layout() # otherwise the right y-label is slightly clipped 170 | plt.savefig(self.hparams.default_save_path / "losses.png") 171 | 172 | 173 | # Save train losses for plotting 174 | if batchid > 0 and batchid % 10 == 0: 175 | self.g_losses.append(total_g_loss.item()) 176 | self.d_losses.append(d_loss.item()) 177 | 178 | #Update lr schedulers on end of epoch 179 | if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % 1 == 0 and self.trainer.current_epoch > 0: 180 | scheduler_d.step(self.trainer.callback_metrics["val_d_loss"]) 181 | scheduler_g.step(self.trainer.callback_metrics["val_g_loss"]) 182 | 183 | return total_g_loss 184 | 185 | def validation_step(self, batch, batch_idx): 186 | # lamda param for generator loss 187 | l = self.hparams.l 188 | 189 | batch = batch.squeeze().float() 190 | imgs = batch[:,:10,:,:] 191 | # flip horizontally and vertically 192 | imgs_flip = torch.flip(imgs,[2,3]) 193 | tar = batch[:,10:,:,:] 194 | 195 | 196 | # how well can it label as real? 197 | valid = torch.ones((imgs.size(0),1,4,4)) 198 | valid = valid.type_as(imgs) 199 | 200 | # how well can it label as fake? 201 | fake = torch.zeros((imgs.size(0),1,6,6)) 202 | fake = torch.zeros((imgs.size(0),1,4,4)) 203 | fake = fake.type_as(imgs) 204 | 205 | generated_imgs = self.generator(imgs, imgs_flip) 206 | 207 | # train discriminator 208 | # Measure discriminator's ability to classify real from generated samples 209 | real_loss = self.adversarial_loss(self.discriminator(imgs,tar), valid) 210 | fake_loss = self.adversarial_loss(self.discriminator(imgs, generated_imgs.detach()), fake) 211 | 212 | # discriminator loss is the sum of these 213 | d_loss = (real_loss + fake_loss) 214 | self.log("val_d_loss", d_loss, prog_bar=True) 215 | 216 | # Train generator 217 | # Generator loss is the combination of adverarial loss and MSE loss 218 | generated_imgs = self(imgs, imgs_flip) 219 | g_loss = self.adversarial_loss(self.discriminator(imgs,generated_imgs), valid) 220 | structural_g_loss = self.loss_func(generated_imgs, tar) 221 | 222 | total_g_loss = g_loss + l * structural_g_loss 223 | self.log("val_loss",structural_g_loss*100,prog_bar=True) 224 | self.log("val_g_loss",total_g_loss,prog_bar=True) 225 | 226 | 227 | class GAN(GAN_base): 228 | @staticmethod 229 | def add_model_specific_args(parent_parser): 230 | parent_parser = GAN_base.add_model_specific_args(parent_parser) 231 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 232 | parser.add_argument("--num_input_images", type=int, default=10) 233 | parser.add_argument("--num_output_images", type=int, default=10) 234 | parser.add_argument("--valid_size", type=float, default=0.1) 235 | parser.n_channels = parser.parse_args().num_input_images 236 | parser.n_classes = 10 237 | return parser 238 | 239 | def __init__(self, hparams): 240 | super(GAN, self).__init__(hparams=hparams) 241 | self.train_dataset = None 242 | self.valid_dataset = None 243 | self.train_sampler = None 244 | self.valid_sampler = None 245 | 246 | def prepare_data(self): 247 | # train_transform = transforms.Compose([ 248 | # transforms.RandomHorizontalFlip()] 249 | # ) 250 | train_transform = None 251 | valid_transform = None 252 | dataset = torchvision.datasets.MovingMNIST 253 | self.train_dataset = dataset( 254 | root='./data', download=True, transform=train_transform 255 | ) 256 | self.valid_dataset = dataset( 257 | root='./data', download=True, transform=valid_transform 258 | ) 259 | 260 | num_train = len(self.train_dataset) 261 | indices = list(range(num_train)) 262 | split = int(np.floor(self.hparams.valid_size * num_train)) 263 | 264 | np.random.seed(123) 265 | np.random.shuffle(indices) 266 | 267 | train_idx, valid_idx = indices[split:], indices[:split] 268 | self.train_sampler = SubsetRandomSampler(train_idx) 269 | self.valid_sampler = SubsetRandomSampler(valid_idx) 270 | 271 | def train_dataloader(self): 272 | 273 | train_loader = DataLoader( 274 | self.train_dataset, 275 | batch_size=self.hparams.batch_size, 276 | sampler=self.train_sampler, 277 | num_workers=0, 278 | pin_memory=True, 279 | ) 280 | return train_loader 281 | 282 | def val_dataloader(self): 283 | 284 | valid_loader = DataLoader( 285 | self.valid_dataset, 286 | batch_size=self.hparams.batch_size, 287 | sampler=self.valid_sampler, 288 | num_workers=0, 289 | pin_memory=True, 290 | ) 291 | return valid_loader 292 | -------------------------------------------------------------------------------- /models/regression_GA_SmaAt_GNet.py: -------------------------------------------------------------------------------- 1 | from root import ROOT_DIR 2 | import lightning.pytorch as pl 3 | from torch import nn, optim, multiprocessing 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | from torch.utils.data import DataLoader 8 | import torchvision 9 | from utils import dataset_precip 10 | from models.discriminator import * 11 | from models.unet_precip_regression_lightning import SmaAt_UNet, SmaAt_GNet 12 | import argparse 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from matplotlib.lines import Line2D 16 | import math 17 | 18 | 19 | class GAN_base(pl.LightningModule): 20 | @staticmethod 21 | def add_model_specific_args(parent_parser): 22 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 23 | parser.add_argument("--n_channels", type=int, default=12) 24 | parser.add_argument("--n_classes", type=int, default=12) 25 | parser.add_argument("--kernels_per_layer", type=int, default=1) 26 | parser.add_argument("--bilinear", type=bool, default=True) 27 | parser.add_argument("--reduction_ratio", type=int, default=16) 28 | parser.add_argument("--lr_patience", type=int, default=5) 29 | return parser 30 | 31 | def __init__(self, hparams): 32 | super().__init__() 33 | self.save_hyperparameters(hparams) 34 | 35 | self.automatic_optimization = False 36 | 37 | # networks 38 | self.generator = SmaAt_GNet(hparams=hparams) 39 | self.discriminator = LargePix2PixDiscriminatorCBAM(hparams=hparams) 40 | 41 | self.g_losses = [] 42 | self.d_losses = [] 43 | self.log("val_g_loss",float("inf"),prog_bar=False) 44 | self.log("val_d_loss",float("inf"),prog_bar=False) 45 | 46 | def forward(self, x, m): 47 | return self.generator(x, m) 48 | 49 | def configure_optimizers(self): 50 | lr = self.hparams.learning_rate 51 | 52 | opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr) 53 | opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr) 54 | scheduler_g = { 55 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 56 | opt_g, mode="min", factor=0.1, patience=self.hparams.lr_patience, 57 | 58 | ), 59 | "monitor": "val_g_loss", # Default: val_loss 60 | } 61 | scheduler_d = { 62 | "scheduler": optim.lr_scheduler.ReduceLROnPlateau( 63 | opt_d, mode="min", factor=0.1, patience=self.hparams.lr_patience 64 | ), 65 | "monitor": "val_d_loss", # Default: val_loss 66 | } 67 | return [opt_g, opt_d], [scheduler_g, scheduler_d] 68 | 69 | def on_validation_epoch_end(self): 70 | plt.close("all") 71 | 72 | def adversarial_loss(self, y_hat, y): 73 | return F.binary_cross_entropy(y_hat, y) 74 | 75 | def loss_func(self, y_pred, y_true): 76 | return nn.functional.mse_loss( 77 | y_pred, y_true, reduction="mean" 78 | ) 79 | 80 | def training_step(self, batch, batchid): 81 | # lamda param for generator loss 82 | l = self.hparams.l 83 | 84 | imgs, masks_in, tar, masks_true = batch 85 | optimizer_g, optimizer_d = self.optimizers() 86 | scheduler_g, scheduler_d = self.lr_schedulers() 87 | 88 | 89 | # how well can it label as real? 90 | valid = torch.ones((imgs.size(0),1,4,4)) 91 | valid = valid.type_as(imgs) 92 | 93 | # how well can it label as fake? 94 | fake = torch.zeros((imgs.size(0),1,4,4)) 95 | fake = fake.type_as(imgs) 96 | 97 | # train discriminator 98 | # Measure discriminator's ability to classify real from generated samples 99 | generated_imgs = self.generator(imgs, masks_in) 100 | self.toggle_optimizer(optimizer_d) 101 | 102 | real_loss = self.adversarial_loss(self.discriminator(imgs,tar), valid) 103 | fake_loss = self.adversarial_loss(self.discriminator(imgs, generated_imgs.detach()), fake) 104 | 105 | # discriminator loss is the sum of these 106 | d_loss = (real_loss + fake_loss) 107 | self.log("d_loss", d_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 108 | 109 | # only do backpropagation every n steps 110 | if batchid % self.hparams.disc_every_n_steps == 0: 111 | optimizer_d.zero_grad() 112 | self.manual_backward(d_loss) 113 | optimizer_d.step() 114 | 115 | self.untoggle_optimizer(optimizer_d) 116 | 117 | # Train generator 118 | self.toggle_optimizer(optimizer_g) 119 | optimizer_g.zero_grad() 120 | 121 | generated_imgs = self(imgs, masks_in) 122 | 123 | # Generator loss is the combination of adverarial loss and MSE loss 124 | g_loss = self.adversarial_loss(self.discriminator(imgs,generated_imgs), valid) 125 | structural_g_loss = self.loss_func(generated_imgs, tar) 126 | total_g_loss = g_loss + l * structural_g_loss 127 | self.log("g_total", total_g_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 128 | self.log("g_loss", g_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 129 | 130 | 131 | self.manual_backward(total_g_loss) 132 | optimizer_g.step() 133 | self.untoggle_optimizer(optimizer_g) 134 | 135 | if batchid % 100 == 0: 136 | display_list = [imgs[0], tar[0], generated_imgs[0]] 137 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 138 | # Plot generated images 139 | f, axarr = plt.subplots(3,12) 140 | f.set_figwidth(20) 141 | for i in range(3): 142 | for j in range(12): 143 | axarr[i, j].imshow(display_list[i][j,:,:].detach().cpu().numpy()) 144 | plt.axis("off") 145 | plt.savefig(self.hparams.default_save_path / "imgs.png") 146 | 147 | # Plot losses 148 | xs = [x * 10 for x in range(len(self.g_losses))] 149 | xs = xs[10:] 150 | fig, ax1 = plt.subplots() 151 | color = 'tab:red' 152 | ax1.set_xlabel('steps') 153 | ax1.set_ylabel('g_loss', color=color) 154 | g_losses = self.g_losses[10:] 155 | ax1.plot(xs, g_losses, color=color) 156 | ax1.tick_params(axis='y', labelcolor=color) 157 | 158 | ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis 159 | 160 | color = 'tab:blue' 161 | ax2.set_ylabel('d_loss', color=color) # we already handled the x-label with ax1 162 | d_losses = self.d_losses[10:] 163 | ax2.plot(xs, d_losses, color=color) 164 | ax2.tick_params(axis='y', labelcolor=color) 165 | fig.tight_layout() # otherwise the right y-label is slightly clipped 166 | plt.savefig(self.hparams.default_save_path / "losses.png") 167 | 168 | 169 | # Save train losses for plotting 170 | if batchid > 0 and batchid % 10 == 0: 171 | self.g_losses.append(total_g_loss.item()) 172 | self.d_losses.append(d_loss.item()) 173 | 174 | #Update lr schedulers on end of epoch 175 | if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % 1 == 0 and self.trainer.current_epoch > 0: 176 | scheduler_d.step(self.trainer.callback_metrics["val_d_loss"]) 177 | scheduler_g.step(self.trainer.callback_metrics["val_g_loss"]) 178 | 179 | return total_g_loss 180 | 181 | def validation_step(self, batch, batch_idx): 182 | # lamda param for generator loss 183 | l = self.hparams.l 184 | 185 | imgs, masks_in, tar, masks_true = batch 186 | 187 | 188 | # how well can it label as real? 189 | valid = torch.ones((imgs.size(0),1,4,4)) 190 | valid = valid.type_as(imgs) 191 | 192 | # how well can it label as fake? 193 | fake = torch.zeros((imgs.size(0),1,6,6)) 194 | fake = torch.zeros((imgs.size(0),1,4,4)) 195 | fake = fake.type_as(imgs) 196 | 197 | generated_imgs = self.generator(imgs, masks_in) 198 | 199 | # train discriminator 200 | # Measure discriminator's ability to classify real from generated samples 201 | real_loss = self.adversarial_loss(self.discriminator(imgs,tar), valid) 202 | fake_loss = self.adversarial_loss(self.discriminator(imgs, generated_imgs.detach()), fake) 203 | 204 | # discriminator loss is the sum of these 205 | d_loss = (real_loss + fake_loss) 206 | self.log("val_d_loss", d_loss, prog_bar=True) 207 | 208 | # Train generator 209 | # Generator loss is the combination of adverarial loss and MSE loss 210 | generated_imgs = self(imgs, masks_in) 211 | g_loss = self.adversarial_loss(self.discriminator(imgs,generated_imgs), valid) 212 | structural_g_loss = self.loss_func(generated_imgs, tar) 213 | 214 | total_g_loss = g_loss + l * structural_g_loss 215 | self.log("val_loss",structural_g_loss*100,prog_bar=True) 216 | self.log("val_g_loss",total_g_loss,prog_bar=True) 217 | 218 | 219 | class GAN(GAN_base): 220 | @staticmethod 221 | def add_model_specific_args(parent_parser): 222 | parent_parser = GAN_base.add_model_specific_args(parent_parser) 223 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 224 | parser.add_argument("--num_input_images", type=int, default=12) 225 | parser.add_argument("--num_output_images", type=int, default=12) 226 | parser.add_argument("--valid_size", type=float, default=0.1) 227 | parser.n_channels = parser.parse_args().num_input_images 228 | parser.n_classes = 12 229 | return parser 230 | 231 | def __init__(self, hparams): 232 | super(GAN, self).__init__(hparams=hparams) 233 | self.train_dataset = None 234 | self.valid_dataset = None 235 | self.train_sampler = None 236 | self.valid_sampler = None 237 | 238 | def prepare_data(self): 239 | # train_transform = transforms.Compose([ 240 | # transforms.RandomHorizontalFlip()] 241 | # ) 242 | train_transform = None 243 | valid_transform = None 244 | precip_dataset = dataset_precip.precipitation_maps_masked_h5 245 | self.train_dataset = precip_dataset( 246 | in_file=self.hparams.dataset_folder, 247 | num_input_images=self.hparams.num_input_images, 248 | num_output_images=self.hparams.num_output_images, 249 | mode=self.hparams.dataset, 250 | transform=train_transform, 251 | ) 252 | self.valid_dataset = precip_dataset( 253 | in_file=self.hparams.dataset_folder, 254 | num_input_images=self.hparams.num_input_images, 255 | num_output_images=self.hparams.num_output_images, 256 | mode=self.hparams.dataset, 257 | transform=valid_transform, 258 | ) 259 | 260 | num_train = len(self.train_dataset) 261 | indices = list(range(num_train)) 262 | split = int(np.floor(self.hparams.valid_size * num_train)) 263 | 264 | np.random.seed(123) 265 | np.random.shuffle(indices) 266 | 267 | train_idx, valid_idx = indices[split:], indices[:split] 268 | self.train_sampler = SubsetRandomSampler(train_idx) 269 | self.valid_sampler = SubsetRandomSampler(valid_idx) 270 | 271 | def train_dataloader(self): 272 | 273 | train_loader = DataLoader( 274 | self.train_dataset, 275 | batch_size=self.hparams.batch_size, 276 | sampler=self.train_sampler, 277 | num_workers=0, 278 | pin_memory=True, 279 | ) 280 | return train_loader 281 | 282 | def val_dataloader(self): 283 | 284 | valid_loader = DataLoader( 285 | self.valid_dataset, 286 | batch_size=self.hparams.batch_size, 287 | sampler=self.valid_sampler, 288 | num_workers=0, 289 | pin_memory=True, 290 | ) 291 | return valid_loader 292 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 ; python_full_version > "3.9.7" and python_version < "4.0" 2 | aiobotocore==2.4.2 ; python_full_version > "3.9.7" and python_version < "4.0" 3 | aiohttp==3.8.4 ; python_full_version > "3.9.7" and python_version < "4.0" 4 | aioitertools==0.11.0 ; python_full_version > "3.9.7" and python_version < "4.0" 5 | aiosignal==1.3.1 ; python_full_version > "3.9.7" and python_version < "4.0" 6 | altair==4.2.2 ; python_full_version > "3.9.7" and python_version < "4.0" 7 | ansicon==1.89.0 ; python_full_version > "3.9.7" and python_version < "4.0" and platform_system == "Windows" 8 | antlr4-python3-runtime==4.9.3 ; python_full_version > "3.9.7" and python_version < "4.0" 9 | anyio==3.6.2 ; python_full_version > "3.9.7" and python_version < "4.0" 10 | arrow==1.2.3 ; python_full_version > "3.9.7" and python_version < "4.0" 11 | async-timeout==4.0.2 ; python_full_version > "3.9.7" and python_version < "4.0" 12 | attrs==23.1.0 ; python_full_version > "3.9.7" and python_version < "4.0" 13 | beautifulsoup4==4.12.2 ; python_full_version > "3.9.7" and python_version < "4.0" 14 | bleach==6.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" 15 | blessed==1.20.0 ; python_full_version > "3.9.7" and python_version < "4.0" 16 | blinker==1.6.2 ; python_full_version > "3.9.7" and python_version < "4.0" 17 | bokeh==2.4.3 ; python_full_version > "3.9.7" and python_version < "4.0" 18 | botocore==1.27.59 ; python_full_version > "3.9.7" and python_version < "4.0" 19 | cachetools==5.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 20 | certifi==2023.5.7 ; python_full_version > "3.9.7" and python_version < "4.0" 21 | charset-normalizer==3.1.0 ; python_full_version > "3.9.7" and python_version < "4.0" 22 | click==8.1.3 ; python_full_version > "3.9.7" and python_version < "4.0" 23 | colorama==0.4.6 ; python_full_version > "3.9.7" and python_version < "4.0" and platform_system == "Windows" 24 | contourpy==1.0.7 ; python_full_version > "3.9.7" and python_version < "4.0" 25 | croniter==1.3.14 ; python_full_version > "3.9.7" and python_version < "4.0" 26 | cycler==0.11.0 ; python_full_version > "3.9.7" and python_version < "4.0" 27 | dateutils==0.6.12 ; python_full_version > "3.9.7" and python_version < "4.0" 28 | decorator==5.1.1 ; python_full_version > "3.9.7" and python_version < "4.0" 29 | deepdiff==6.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 30 | docker==6.1.1 ; python_full_version > "3.9.7" and python_version < "4.0" 31 | docstring-parser==0.15 ; python_full_version > "3.9.7" and python_version < "4.0" 32 | entrypoints==0.4 ; python_full_version > "3.9.7" and python_version < "4.0" 33 | fastapi==0.88.0 ; python_full_version > "3.9.7" and python_version < "4.0" 34 | filelock==3.12.0 ; python_full_version > "3.9.7" and python_version < "4.0" 35 | fonttools==4.39.3 ; python_full_version > "3.9.7" and python_version < "4.0" 36 | frozenlist==1.3.3 ; python_full_version > "3.9.7" and python_version < "4.0" 37 | fsspec==2022.11.0 ; python_full_version > "3.9.7" and python_version < "4.0" 38 | fsspec[http]==2022.11.0 ; python_full_version > "3.9.7" and python_version < "4.0" 39 | gitdb==4.0.10 ; python_full_version > "3.9.7" and python_version < "4.0" 40 | gitpython==3.1.31 ; python_full_version > "3.9.7" and python_version < "4.0" 41 | google-auth-oauthlib==1.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" 42 | google-auth==2.17.3 ; python_full_version > "3.9.7" and python_version < "4.0" 43 | grad-cam==1.5.0 ; python_full_version > "3.9.7" and python_version < "4.0" 44 | grpcio==1.54.0 ; python_full_version > "3.9.7" and python_version < "4.0" 45 | h11==0.14.0 ; python_full_version > "3.9.7" and python_version < "4.0" 46 | h5py==3.8.0 ; python_full_version > "3.9.7" and python_version < "4.0" 47 | hydra-core==1.3.2 ; python_full_version > "3.9.7" and python_version < "4.0" 48 | idna==3.4 ; python_full_version > "3.9.7" and python_version < "4.0" 49 | importlib-metadata==6.6.0 ; python_full_version > "3.9.7" and python_version < "4.0" 50 | importlib-resources==5.12.0 ; python_full_version > "3.9.7" and python_version < "4.0" 51 | inquirer==3.1.3 ; python_full_version > "3.9.7" and python_version < "4.0" 52 | itsdangerous==2.1.2 ; python_full_version > "3.9.7" and python_version < "4.0" 53 | jinja2==3.1.2 ; python_full_version > "3.9.7" and python_version < "4.0" 54 | jinxed==1.2.0 ; python_full_version > "3.9.7" and python_version < "4.0" and platform_system == "Windows" 55 | jmespath==1.0.1 ; python_full_version > "3.9.7" and python_version < "4.0" 56 | joblib==1.3.2 ; python_full_version > "3.9.7" and python_version < "4.0" 57 | jsonargparse[signatures]==4.21.0 ; python_full_version > "3.9.7" and python_version < "4.0" 58 | jsonschema==4.17.3 ; python_full_version > "3.9.7" and python_version < "4.0" 59 | kiwisolver==1.4.4 ; python_full_version > "3.9.7" and python_version < "4.0" 60 | lightning-api-access==0.0.5 ; python_full_version > "3.9.7" and python_version < "4.0" 61 | lightning-cloud==0.5.34 ; python_full_version > "3.9.7" and python_version < "4.0" 62 | lightning-fabric==2.0.2 ; python_full_version > "3.9.7" and python_version < "4.0" 63 | lightning-utilities==0.8.0 ; python_full_version > "3.9.7" and python_version < "4.0" 64 | lightning[extra]==2.0.2 ; python_full_version > "3.9.7" and python_version < "4.0" 65 | markdown-it-py==2.2.0 ; python_full_version > "3.9.7" and python_version < "4.0" 66 | markdown==3.4.3 ; python_full_version > "3.9.7" and python_version < "4.0" 67 | markupsafe==2.1.2 ; python_full_version > "3.9.7" and python_version < "4.0" 68 | matplotlib==3.7.1 ; python_full_version > "3.9.7" and python_version < "4.0" 69 | mdurl==0.1.2 ; python_full_version > "3.9.7" and python_version < "4.0" 70 | mpmath==1.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 71 | multidict==6.0.4 ; python_full_version > "3.9.7" and python_version < "4.0" 72 | networkx==3.1 ; python_full_version > "3.9.7" and python_version < "4.0" 73 | numpy==1.24.3 ; python_full_version > "3.9.7" and python_version < "4.0" 74 | oauthlib==3.2.2 ; python_full_version > "3.9.7" and python_version < "4.0" 75 | omegaconf==2.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 76 | opencv-python==4.9.0.80 ; python_full_version > "3.9.7" and python_version < "4.0" 77 | ordered-set==4.1.0 ; python_full_version > "3.9.7" and python_version < "4.0" 78 | packaging==23.1 ; python_full_version > "3.9.7" and python_version < "4.0" 79 | pandas==2.0.1 ; python_full_version > "3.9.7" and python_version < "4.0" 80 | panel==0.14.4 ; python_full_version > "3.9.7" and python_version < "4.0" 81 | param==1.13.0 ; python_full_version > "3.9.7" and python_version < "4.0" 82 | pillow==9.5.0 ; python_full_version > "3.9.7" and python_version < "4.0" 83 | protobuf==3.20.3 ; python_full_version > "3.9.7" and python_version < "4.0" 84 | psutil==5.9.5 ; python_full_version > "3.9.7" and python_version < "4.0" 85 | pyarrow==12.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" 86 | pyasn1-modules==0.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 87 | pyasn1==0.5.0 ; python_full_version > "3.9.7" and python_version < "4.0" 88 | pyct==0.5.0 ; python_full_version > "3.9.7" and python_version < "4.0" 89 | pydantic==1.10.7 ; python_full_version > "3.9.7" and python_version < "4.0" 90 | pydeck==0.8.0 ; python_full_version > "3.9.7" and python_version < "4.0" 91 | pygments==2.15.1 ; python_full_version > "3.9.7" and python_version < "4.0" 92 | pyjwt==2.6.0 ; python_full_version > "3.9.7" and python_version < "4.0" 93 | pympler==1.0.1 ; python_full_version > "3.9.7" and python_version < "4.0" 94 | pyparsing==3.0.9 ; python_full_version > "3.9.7" and python_version < "4.0" 95 | pyrsistent==0.19.3 ; python_full_version > "3.9.7" and python_version < "4.0" 96 | python-dateutil==2.8.2 ; python_full_version > "3.9.7" and python_version < "4.0" 97 | python-editor==1.0.4 ; python_full_version > "3.9.7" and python_version < "4.0" 98 | python-multipart==0.0.6 ; python_full_version > "3.9.7" and python_version < "4.0" 99 | pytorch-lightning==2.0.2 ; python_full_version > "3.9.7" and python_version < "4.0" 100 | pytorch-msssim==1.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" 101 | pytz-deprecation-shim==0.1.0.post0 ; python_full_version > "3.9.7" and python_version < "4.0" 102 | pytz==2023.3 ; python_full_version > "3.9.7" and python_version < "4.0" 103 | pyviz-comms==2.2.1 ; python_full_version > "3.9.7" and python_version < "4.0" 104 | pywin32==306 ; python_full_version > "3.9.7" and python_version < "4.0" and sys_platform == "win32" 105 | pyyaml==6.0 ; python_full_version > "3.9.7" and python_version < "4.0" 106 | readchar==4.0.5 ; python_full_version > "3.9.7" and python_version < "4.0" 107 | redis==4.5.5 ; python_full_version > "3.9.7" and python_version < "4.0" 108 | requests-oauthlib==1.3.1 ; python_full_version > "3.9.7" and python_version < "4.0" 109 | requests==2.30.0 ; python_full_version > "3.9.7" and python_version < "4.0" 110 | rich==13.3.5 ; python_full_version > "3.9.7" and python_version < "4.0" 111 | rsa==4.9 ; python_full_version > "3.9.7" and python_version < "4" 112 | s3fs==2022.11.0 ; python_full_version > "3.9.7" and python_version < "4.0" 113 | scikit-learn==1.3.2 ; python_full_version > "3.9.7" and python_version < "4.0" 114 | scipy==1.11.4 ; python_full_version > "3.9.7" and python_version < "4.0" 115 | setuptools==67.7.2 ; python_full_version > "3.9.7" and python_version < "4.0" 116 | six==1.16.0 ; python_full_version > "3.9.7" and python_version < "4.0" 117 | smmap==5.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" 118 | sniffio==1.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 119 | soupsieve==2.4.1 ; python_full_version > "3.9.7" and python_version < "4.0" 120 | starlette==0.22.0 ; python_full_version > "3.9.7" and python_version < "4.0" 121 | starsessions==1.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 122 | streamlit==1.22.0 ; python_full_version > "3.9.7" and python_version < "4.0" 123 | sympy==1.11.1 ; python_full_version > "3.9.7" and python_version < "4.0" 124 | tenacity==8.2.2 ; python_full_version > "3.9.7" and python_version < "4.0" 125 | tensorboard-data-server==0.7.0 ; python_full_version > "3.9.7" and python_version < "4.0" 126 | tensorboard==2.13.0 ; python_full_version > "3.9.7" and python_version < "4.0" 127 | tensorboardx==2.6 ; python_full_version > "3.9.7" and python_version < "4.0" 128 | threadpoolctl==3.2.0 ; python_full_version > "3.9.7" and python_version < "4.0" 129 | toml==0.10.2 ; python_full_version > "3.9.7" and python_version < "4.0" 130 | toolz==0.12.0 ; python_full_version > "3.9.7" and python_version < "4.0" 131 | torch==2.0.1 ; python_full_version > "3.9.7" and python_version < "4.0" 132 | torchmetrics==0.11.4 ; python_full_version > "3.9.7" and python_version < "4.0" 133 | torchsummary==1.5.1 ; python_full_version > "3.9.7" and python_version < "4.0" 134 | torchvision==0.15.2 ; python_full_version > "3.9.7" and python_version < "4.0" 135 | tornado==6.3.1 ; python_full_version > "3.9.7" and python_version < "4.0" 136 | tqdm==4.65.0 ; python_full_version > "3.9.7" and python_version < "4.0" 137 | traitlets==5.9.0 ; python_full_version > "3.9.7" and python_version < "4.0" 138 | ttach==0.0.3 ; python_full_version > "3.9.7" and python_version < "4.0" 139 | typeshed-client==2.3.0 ; python_full_version > "3.9.7" and python_version < "4.0" 140 | typing-extensions==4.5.0 ; python_full_version > "3.9.7" and python_version < "4.0" 141 | tzdata==2023.3 ; python_full_version > "3.9.7" and python_version < "4.0" 142 | tzlocal==4.3 ; python_full_version > "3.9.7" and python_version < "4.0" 143 | urllib3==1.26.15 ; python_full_version > "3.9.7" and python_version < "4.0" 144 | uvicorn==0.22.0 ; python_full_version > "3.9.7" and python_version < "4.0" 145 | validators==0.20.0 ; python_full_version > "3.9.7" and python_version < "4.0" 146 | watchdog==3.0.0 ; python_full_version > "3.9.7" and python_version < "4.0" and platform_system != "Darwin" 147 | wcwidth==0.2.6 ; python_full_version > "3.9.7" and python_version < "4.0" 148 | webencodings==0.5.1 ; python_full_version > "3.9.7" and python_version < "4.0" 149 | websocket-client==1.5.1 ; python_full_version > "3.9.7" and python_version < "4.0" 150 | websockets==11.0.3 ; python_full_version > "3.9.7" and python_version < "4.0" 151 | werkzeug==2.3.4 ; python_full_version > "3.9.7" and python_version < "4.0" 152 | wheel==0.40.0 ; python_full_version > "3.9.7" and python_version < "4.0" 153 | wrapt==1.15.0 ; python_full_version > "3.9.7" and python_version < "4.0" 154 | yarl==1.9.2 ; python_full_version > "3.9.7" and python_version < "4.0" 155 | zipp==3.15.0 ; python_full_version > "3.9.7" and python_version < "4.0" 156 | --------------------------------------------------------------------------------