├── .DS_Store ├── .gitignore ├── Benchmark └── IrradianceNet.py ├── Dataset ├── .DS_Store └── dataset.py ├── Denoiser.drawio ├── README.md ├── SHADECast ├── .DS_Store ├── Blocks │ ├── .DS_Store │ ├── AFNO.py │ ├── ResBlock3D.py │ ├── TimeStep.py │ └── attention.py ├── Models │ ├── .DS_Store │ ├── Diffusion │ │ ├── .DS_Store │ │ ├── DiffusionModel.py │ │ ├── ema.py │ │ └── utils.py │ ├── Nowcaster │ │ └── Nowcast.py │ ├── Sampler │ │ ├── .DS_Store │ │ ├── PLMS.py │ │ └── utils.py │ ├── UNet │ │ ├── .DS_Store │ │ ├── UNet.py │ │ └── utils.py │ └── VAE │ │ └── VariationalAutoEncoder.py └── Training │ ├── Nowcast_training │ ├── IrradianceNetTraining_pl.py │ ├── IrradianceNettrainingconf.yml │ ├── NowcasterTraining_pl.py │ └── Nowcastertrainingconf.yml │ ├── SHADECastTraining.py │ ├── SHADECastTrainingconf.yml │ └── VAE_training │ ├── VAETraining_pl.py │ └── VAEtrainingconf.yml ├── Test ├── Test_IrrNet.py └── Test_SHADECast.py ├── compute_metrics.py ├── requirements.txt ├── utils.py └── validation_utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Benchmark/IrradianceNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | import torch 4 | import logging 5 | import pytorch_lightning as pl 6 | 7 | 8 | def make_layers(block): 9 | layers = [] 10 | for layer_name, v in block.items(): 11 | if 'pool' in layer_name: 12 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) 13 | layers.append((layer_name, layer)) 14 | elif 'deconv' in layer_name: 15 | transposeConv2d = nn.ConvTranspose2d(in_channels=v[0], 16 | out_channels=v[1], 17 | kernel_size=v[2], 18 | stride=v[3], 19 | padding=v[4]) 20 | layers.append((layer_name, transposeConv2d)) 21 | if 'relu' in layer_name: 22 | layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) 23 | elif 'leaky' in layer_name: 24 | layers.append(('leaky_' + layer_name, 25 | nn.LeakyReLU(negative_slope=0.2, inplace=True))) 26 | elif 'conv' in layer_name: 27 | conv2d = nn.Conv2d(in_channels=v[0], 28 | out_channels=v[1], 29 | kernel_size=v[2], 30 | stride=v[3], 31 | padding=v[4]) 32 | layers.append((layer_name, conv2d)) 33 | if 'relu' in layer_name: 34 | layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) 35 | elif 'leaky' in layer_name: 36 | layers.append(('leaky_' + layer_name, 37 | nn.LeakyReLU(negative_slope=0.2, inplace=True))) 38 | else: 39 | raise NotImplementedError 40 | return nn.Sequential(OrderedDict(layers)) 41 | 42 | 43 | class CLSTM_cell(nn.Module): 44 | """ConvLSTMCell 45 | """ 46 | 47 | def __init__(self, shape, input_channels, filter_size, num_features, seq_len=8, device='cpu'): 48 | super(CLSTM_cell, self).__init__() 49 | 50 | self.shape = shape # H, W 51 | self.input_channels = input_channels 52 | self.filter_size = filter_size 53 | self.device = device 54 | self.num_features = num_features 55 | # in this way the output has the same size 56 | self.padding = (filter_size - 1) // 2 57 | self.conv = nn.Sequential( 58 | nn.Conv2d(self.input_channels + self.num_features, 59 | 4 * self.num_features, self.filter_size, 1, 60 | self.padding), 61 | nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features) # best for regression 62 | ) 63 | 64 | self.seq_len = seq_len 65 | 66 | def forward(self, inputs=None, hidden_state=None): 67 | if hidden_state is None: 68 | hx = torch.zeros(inputs.size(1), self.num_features, self.shape[0], 69 | self.shape[1]).to(self.device) 70 | cx = torch.zeros(inputs.size(1), self.num_features, self.shape[0], 71 | self.shape[1]).to(self.device) 72 | else: 73 | hx, cx = hidden_state 74 | output_inner = [] 75 | for index in range(self.seq_len): 76 | if inputs is None: 77 | x = torch.zeros(hx.size(0), self.input_channels, self.shape[0], 78 | self.shape[1]).to(self.device) 79 | else: 80 | x = inputs[index, ...] 81 | 82 | combined = torch.cat((x, hx), 1) 83 | gates = self.conv(combined) # gates: S, num_features*4, H, W 84 | 85 | # it should return 4 tensors: i,f,g,o 86 | ingate, forgetgate, cellgate, outgate = torch.split( 87 | gates, self.num_features, dim=1) 88 | ingate = torch.sigmoid(ingate) 89 | forgetgate = torch.sigmoid(forgetgate) 90 | cellgate = torch.tanh(cellgate) 91 | outgate = torch.sigmoid(outgate) 92 | 93 | cy = (forgetgate * cx) + (ingate * cellgate) 94 | hy = outgate * torch.tanh(cy) 95 | output_inner.append(hy) 96 | hx = hy 97 | cx = cy 98 | return torch.stack(output_inner), (hy, cy) 99 | 100 | 101 | def convlstm_encoder_params(in_chan=7, image_size=128, device='cpu'): 102 | size_l1 = image_size 103 | size_l2 = image_size - (image_size // 4) 104 | size_l3 = image_size - (image_size // 2) 105 | size_l4 = size_l1 - size_l2 106 | 107 | convlstm_encoder_params = [ 108 | [ 109 | OrderedDict({'conv1_leaky_1': [in_chan, size_l4, 3, 1, 1]}), # [1, 32, 3, 1, 1] 110 | OrderedDict({'conv2_leaky_1': [size_l3, size_l3, 3, 2, 1]}), 111 | OrderedDict({'conv3_leaky_1': [size_l2, size_l2, 3, 2, 1]}), 112 | ], 113 | [ 114 | CLSTM_cell(shape=(size_l1, size_l1), input_channels=size_l4, filter_size=5, num_features=size_l3, 115 | seq_len=4, device=device), 116 | CLSTM_cell(shape=(size_l3, size_l3), input_channels=size_l3, filter_size=5, num_features=size_l2, 117 | seq_len=4, device=device), 118 | CLSTM_cell(shape=(size_l4, size_l4), input_channels=size_l2, filter_size=5, num_features=size_l1, 119 | seq_len=4, device=device) 120 | ] 121 | ] 122 | return convlstm_encoder_params 123 | 124 | 125 | def convlstm_decoder_params(seq_len, image_size=128, device='cpu'): 126 | size_l1 = image_size 127 | size_l2 = image_size - (image_size // 4) 128 | size_l3 = image_size - (image_size // 2) 129 | size_l4 = size_l1 - size_l2 130 | 131 | convlstm_decoder_params = [ 132 | [ 133 | OrderedDict({'deconv1_leaky_1': [size_l1, size_l1, 4, 2, 1]}), 134 | OrderedDict({'deconv2_leaky_1': [size_l2, size_l2, 4, 2, 1]}), 135 | OrderedDict({ 136 | 'conv3_leaky_1': [size_l3, size_l4, 3, 1, 1], 137 | 'conv4_leaky_1': [size_l4, 1, 1, 1, 0] 138 | }), 139 | ], 140 | [ 141 | CLSTM_cell(shape=(size_l4, size_l4), input_channels=size_l1, filter_size=5, num_features=size_l1, 142 | seq_len=4, device=device), 143 | CLSTM_cell(shape=(size_l3, size_l3), input_channels=size_l1, filter_size=5, num_features=size_l2, 144 | seq_len=4, device=device), 145 | CLSTM_cell(shape=(size_l1, size_l1), input_channels=size_l2, filter_size=5, num_features=size_l3, 146 | seq_len=4, device=device) 147 | ] 148 | ] 149 | return convlstm_decoder_params 150 | 151 | 152 | class Encoder(nn.Module): 153 | def __init__(self, subnets, rnns): 154 | super().__init__() 155 | assert len(subnets) == len(rnns) 156 | self.blocks = len(subnets) 157 | 158 | for index, (params, rnn) in enumerate(zip(subnets, rnns), 1): 159 | # index sign from 1 160 | setattr(self, 'stage' + str(index), make_layers(params)) 161 | setattr(self, 'rnn' + str(index), rnn) 162 | 163 | def forward_by_stage(self, inputs, subnet, rnn): 164 | seq_number, batch_size, input_channel, height, width = inputs.size() 165 | inputs = torch.reshape(inputs, (-1, input_channel, height, width)) 166 | inputs = subnet(inputs) 167 | inputs = torch.reshape(inputs, (seq_number, batch_size, inputs.size(1), 168 | inputs.size(2), inputs.size(3))) 169 | outputs_stage, state_stage = rnn(inputs, None) 170 | return outputs_stage, state_stage 171 | 172 | def forward(self, inputs): 173 | inputs = inputs.transpose(0, 1) # to S,B,1,64,64 174 | hidden_states = [] 175 | logging.debug(inputs.size()) 176 | for i in range(1, self.blocks + 1): 177 | inputs, state_stage = self.forward_by_stage( 178 | inputs, getattr(self, 'stage' + str(i)), 179 | getattr(self, 'rnn' + str(i))) 180 | hidden_states.append(state_stage) 181 | return tuple(hidden_states) 182 | 183 | 184 | class Decoder(nn.Module): 185 | def __init__(self, subnets, rnns, seq_len): 186 | super().__init__() 187 | assert len(subnets) == len(rnns) 188 | 189 | self.blocks = len(subnets) 190 | self.seq_len = seq_len 191 | 192 | for index, (params, rnn) in enumerate(zip(subnets, rnns)): 193 | setattr(self, 'rnn' + str(self.blocks - index), rnn) 194 | setattr(self, 'stage' + str(self.blocks - index), 195 | make_layers(params)) 196 | 197 | def forward_by_stage(self, inputs, state, subnet, rnn): 198 | inputs, state_stage = rnn(inputs, state) # , seq_len=8 199 | seq_number, batch_size, input_channel, height, width = inputs.size() 200 | inputs = torch.reshape(inputs, (-1, input_channel, height, width)) 201 | inputs = subnet(inputs) 202 | inputs = torch.reshape(inputs, (seq_number, batch_size, inputs.size(1), 203 | inputs.size(2), inputs.size(3))) 204 | return inputs 205 | 206 | # input: 5D S*B*C*H*W 207 | 208 | def forward(self, hidden_states): 209 | inputs = self.forward_by_stage(None, hidden_states[-1], 210 | getattr(self, 'stage3'), 211 | getattr(self, 'rnn3')) 212 | for i in list(range(1, self.blocks))[::-1]: 213 | inputs = self.forward_by_stage(inputs, hidden_states[i - 1], 214 | getattr(self, 'stage' + str(i)), 215 | getattr(self, 'rnn' + str(i))) 216 | inputs = inputs.transpose(0, 1) # to B,S,1,64,64 217 | return inputs 218 | 219 | 220 | class ConvLSTM_patch(nn.Module): 221 | 222 | def __init__(self, seq_len, in_chan=7, image_size=128, device='cpu'): 223 | super(ConvLSTM_patch, self).__init__() 224 | encoder_params = convlstm_encoder_params(in_chan, image_size, device=device) 225 | decoder_params = convlstm_decoder_params(seq_len, image_size, device=device) 226 | 227 | self.encoder = Encoder(encoder_params[0], encoder_params[1]) 228 | self.decoder = Decoder(decoder_params[0], decoder_params[1], seq_len=seq_len) 229 | 230 | def forward(self, x, future_seq=10): 231 | x = x.permute(0, 1, 4, 2, 3) 232 | state = self.encoder(x) 233 | output = self.decoder(state) 234 | 235 | return output 236 | 237 | 238 | class IrradianceNet(pl.LightningModule): 239 | def __init__(self, model, opt_patience): 240 | super().__init__() 241 | self.model = model 242 | self.opt_patience = opt_patience 243 | 244 | def forward(self, x): 245 | x = x.permute(0, 2, 3, 4, 1) 246 | y_pred1 = self.model(x).permute(0, 1, 3, 4, 2) 247 | y_pred2 = self.model(y_pred1).permute(0, 1, 3, 4, 2) 248 | y_pred = torch.concat((y_pred1, y_pred2), axis=1).permute(0, 4, 1, 2, 3) 249 | return y_pred 250 | 251 | def _loss(self, batch): 252 | x, y = batch 253 | y_pred = self.forward(x) 254 | return (y - y_pred).square().mean() 255 | 256 | def training_step(self, batch, batch_idx): 257 | loss = self._loss(batch) 258 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 259 | self.log('train_loss', loss, **log_params) 260 | return loss 261 | 262 | @torch.no_grad() 263 | def val_test_step(self, batch, batch_idx, split="val"): 264 | loss = self._loss(batch) 265 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 266 | self.log(f"{split}_loss", loss, **log_params) 267 | 268 | def validation_step(self, batch, batch_idx): 269 | self.val_test_step(batch, batch_idx, split="val") 270 | 271 | def test_step(self, batch, batch_idx): 272 | self.val_test_step(batch, batch_idx, split="test") 273 | 274 | def configure_optimizers(self): 275 | optimizer = torch.optim.AdamW( 276 | self.parameters(), lr=0.002, 277 | betas=(0.5, 0.9), weight_decay=1e-3 278 | ) 279 | reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( 280 | optimizer, patience=self.opt_patience, factor=0.5, verbose=True 281 | ) 282 | 283 | optimizer_spec = { 284 | "optimizer": optimizer, 285 | "lr_scheduler": { 286 | "scheduler": reduce_lr, 287 | "monitor": "val_loss", 288 | "frequency": 1, 289 | }, 290 | } 291 | return optimizer_spec 292 | -------------------------------------------------------------------------------- /Dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/Dataset/.DS_Store -------------------------------------------------------------------------------- /Dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from utils import open_pkl, save_pkl 2 | from torch.utils.data import Dataset 3 | import pytorch_lightning as pl 4 | import torch 5 | import os 6 | import numpy as np 7 | from torch.nn import AvgPool3d 8 | 9 | 10 | class KIDataset(Dataset): 11 | def __init__(self, 12 | data_path, 13 | coordinate_data_path, 14 | n, 15 | length=None, 16 | return_all=False, 17 | forecast=False, 18 | validation=False, 19 | return_t=False, 20 | **kwargs): 21 | super().__init__() 22 | self.data_path = data_path 23 | self.coordinate_data_path = coordinate_data_path 24 | self.return_all = return_all 25 | self.forecast = forecast 26 | self.validation = validation 27 | self.return_t = return_t 28 | f = os.listdir(self.data_path) 29 | self.filenames = [] 30 | self.n = n 31 | if length is None: 32 | self.filenames += f 33 | else: 34 | while length > len(self.filenames): 35 | self.filenames += f 36 | self.filenames = self.filenames[:length] 37 | self.nitems = len(self.filenames) 38 | if self.validation: 39 | np.random.seed(0) 40 | self.seeds = np.random.randint(0, 1000000, self.nitems) 41 | if self.return_t: 42 | self.t_lst = np.random.randint(0, 1000, self.nitems) 43 | self.norm_method = kwargs['norm_method'] if 'norm_method' in kwargs else 'rescaling' 44 | if self.norm_method == 'normalization': 45 | self.a, self.b = kwargs['mean'], kwargs['std'] 46 | elif self.norm_method == 'rescaling': 47 | self.a, self.b = kwargs['min'], kwargs['max'] 48 | 49 | def to_tensor(self, x): 50 | return torch.FloatTensor(x) 51 | 52 | def __getitem__(self, idx): 53 | item_idx = self.filenames[idx] 54 | if self.validation: 55 | seed = self.seeds[idx] 56 | np.random.seed(seed) 57 | if self.return_t: 58 | t = int(self.t_lst[idx]) 59 | coord_idx = int(item_idx.split('_')[1].split('.')[0]) 60 | item_dict = open_pkl(self.data_path + item_idx) 61 | 62 | starting_idx = np.random.choice(item_dict['starting_idx'], 1, replace=False)[0] 63 | if self.validation: 64 | print(idx, starting_idx) 65 | seq = np.array(item_dict['ki_maps'])[starting_idx:starting_idx + self.n] 66 | seq = seq.reshape(1, *seq.shape) 67 | 68 | if self.norm_method == 'normalization': 69 | seq = (seq - self.a) / self.b 70 | elif self.norm_method == 'rescaling': 71 | seq = 2 * ((seq - self.a) / (self.b - self.a)) - 1 72 | 73 | if self.return_all: 74 | lon = np.array(open_pkl(self.coordinate_data_path + '{}_lon.pkl'.format(coord_idx))) 75 | lat = np.array(open_pkl(self.coordinate_data_path + '{}_lat.pkl'.format(coord_idx))) 76 | alt = np.array(open_pkl(self.coordinate_data_path + '{}_alt.pkl'.format(coord_idx))) 77 | lon = 2 * ((lon - 0) / (90 - 0)) - 1 78 | lat = 2 * ((lat - 0) / (90 - 0)) - 1 79 | alt = 2 * ((alt - (-13)) / (4294 - 0)) - 1 80 | lon = lon.reshape(1, 1, *lon.shape) 81 | lat = lat.reshape(1, 1, *lat.shape) 82 | alt = alt.reshape(1, 1, *alt.shape) 83 | c = np.concatenate((alt, lon, lat), axis=0) 84 | if self.forecast: 85 | return self.to_tensor(seq[:, :4]), self.to_tensor(seq[:, 4:]), self.to_tensor(c) 86 | 87 | else: 88 | return self.to_tensor(seq), self.to_tensor(c) 89 | else: 90 | if self.forecast: 91 | if self.return_t: 92 | return self.to_tensor(seq[:, :4]), self.to_tensor(seq[:, 4:]), t 93 | else: 94 | return self.to_tensor(seq[:, :4]), self.to_tensor(seq[:, 4:]) 95 | else: 96 | return self.to_tensor(seq) 97 | 98 | def __len__(self): 99 | return self.nitems -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenerativeNowcasting 2 | Code repository for "Extending intraday solar forecast horizons with deep generative models". 3 | Check out our ArXiv paper: https://arxiv.org/abs/2312.11966 4 | 5 | SHADECast is a solar irradiance nowcasting model based on a latent diffusion model (LDM) and precipitation nowcasting model (LDCast). 6 | Tutorial on training and running SHADECast coming soon! 7 | -------------------------------------------------------------------------------- /SHADECast/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Blocks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/Blocks/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Blocks/AFNO.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/afno.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class AFNO3D(nn.Module): 11 | def __init__( 12 | self, hidden_size, num_blocks=8, sparsity_threshold=0.01, 13 | hard_thresholding_fraction=1, hidden_size_factor=1, res_mult=1 14 | ): 15 | super().__init__() 16 | assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" 17 | 18 | self.hidden_size = hidden_size 19 | self.sparsity_threshold = sparsity_threshold 20 | self.num_blocks = num_blocks 21 | self.block_size = self.hidden_size // self.num_blocks 22 | self.hard_thresholding_fraction = hard_thresholding_fraction 23 | self.hidden_size_factor = hidden_size_factor 24 | self.scale = 0.02 25 | self.res_mult = res_mult 26 | 27 | self.w1 = nn.Parameter( 28 | self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) 29 | self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) 30 | self.w2 = nn.Parameter( 31 | self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) 32 | self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 33 | 34 | def forward(self, x): 35 | bias = x 36 | 37 | dtype = x.dtype 38 | x = x.float() 39 | B, D, H, W, C = x.shape 40 | 41 | x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") 42 | x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) 43 | 44 | o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], 45 | device=x.device) 46 | o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], 47 | device=x.device) 48 | o2_real = torch.zeros(x.shape, device=x.device) 49 | o2_imag = torch.zeros(x.shape, device=x.device) 50 | 51 | total_modes = H // 2 + 1 52 | kept_modes = int(total_modes * self.hard_thresholding_fraction) 53 | 54 | o1_real[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = F.relu( 55 | torch.einsum('...bi,bio->...bo', 56 | x[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].real, self.w1[0]) - 57 | torch.einsum('...bi,bio->...bo', 58 | x[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].imag, self.w1[1]) + 59 | self.b1[0] 60 | ) 61 | 62 | o1_imag[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = F.relu( 63 | torch.einsum('...bi,bio->...bo', 64 | x[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].imag, self.w1[0]) + 65 | torch.einsum('...bi,bio->...bo', 66 | x[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].real, self.w1[1]) + 67 | self.b1[1] 68 | ) 69 | 70 | o2_real[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = ( 71 | torch.einsum('...bi,bio->...bo', 72 | o1_real[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], 73 | self.w2[0]) - 74 | torch.einsum('...bi,bio->...bo', 75 | o1_imag[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], 76 | self.w2[1]) + 77 | self.b2[0] 78 | ) 79 | 80 | o2_imag[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = ( 81 | torch.einsum('...bi,bio->...bo', 82 | o1_imag[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], 83 | self.w2[0]) + 84 | torch.einsum('...bi,bio->...bo', 85 | o1_real[:, :, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], 86 | self.w2[1]) + 87 | self.b2[1] 88 | ) 89 | 90 | x = torch.stack([o2_real, o2_imag], dim=-1) 91 | x = F.softshrink(x, lambd=self.sparsity_threshold) 92 | x = torch.view_as_complex(x) 93 | x = x.reshape(B, D, H, W // 2 + 1, C) 94 | x = torch.fft.irfftn(x, s=(D, H*self.res_mult, W*self.res_mult), dim=(1, 2, 3), norm="ortho") 95 | x = x.type(dtype) 96 | if self.res_mult>1: 97 | return x 98 | else: 99 | return x + bias 100 | 101 | 102 | class Mlp(nn.Module): 103 | def __init__( 104 | self, 105 | in_features, hidden_features=None, out_features=None, 106 | act_layer=nn.GELU, drop=0.0 107 | ): 108 | super().__init__() 109 | out_features = out_features or in_features 110 | hidden_features = hidden_features or in_features 111 | self.fc1 = nn.Linear(in_features, hidden_features) 112 | self.act = act_layer() 113 | self.fc2 = nn.Linear(hidden_features, out_features) 114 | self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() 115 | 116 | def forward(self, x): 117 | x = self.fc1(x) 118 | x = self.act(x) 119 | x = self.drop(x) 120 | x = self.fc2(x) 121 | x = self.drop(x) 122 | return x 123 | 124 | 125 | class AFNOBlock3d(nn.Module): 126 | def __init__( 127 | self, 128 | dim, 129 | mlp_ratio=4., 130 | drop=0., 131 | act_layer=nn.GELU, 132 | norm_layer=nn.LayerNorm, 133 | double_skip=True, 134 | num_blocks=8, 135 | sparsity_threshold=0.01, 136 | hard_thresholding_fraction=1.0, 137 | data_format="channels_last", 138 | mlp_out_features=None, 139 | afno_res_mult=1, 140 | 141 | ): 142 | super().__init__() 143 | self.norm_layer = norm_layer 144 | self.afno_res_mult = afno_res_mult 145 | self.norm1 = norm_layer(dim) 146 | self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, 147 | hard_thresholding_fraction, res_mult=afno_res_mult) 148 | self.norm2 = norm_layer(dim) 149 | mlp_hidden_dim = int(dim * mlp_ratio) 150 | self.mlp = Mlp( 151 | in_features=dim, out_features=mlp_out_features, 152 | hidden_features=mlp_hidden_dim, 153 | act_layer=act_layer, drop=drop 154 | ) 155 | self.double_skip = double_skip 156 | self.channels_first = (data_format == "channels_first") 157 | 158 | def forward(self, x): 159 | if self.channels_first: 160 | # AFNO natively uses a channels-last data format 161 | x = x.permute(0, 2, 3, 4, 1) 162 | 163 | residual = x 164 | x = self.norm1(x) 165 | x = self.filter(x) 166 | if self.afno_res_mult > 1: 167 | residual = F.interpolate(residual, x.shape[2:]) 168 | if self.double_skip: 169 | x = x + residual 170 | residual = x 171 | 172 | x = self.norm2(x) 173 | x = self.mlp(x) 174 | x = x + residual 175 | 176 | if self.channels_first: 177 | x = x.permute(0, 4, 1, 2, 3) 178 | 179 | return x 180 | 181 | 182 | class AFNOCrossAttentionBlock3d(nn.Module): 183 | """ AFNO 3D Block with channel mixing from two sources. 184 | """ 185 | 186 | def __init__( 187 | self, 188 | dim, 189 | context_dim, 190 | mlp_ratio=2., 191 | drop=0., 192 | act_layer=nn.GELU, 193 | norm_layer=nn.Identity, 194 | double_skip=True, 195 | num_blocks=8, 196 | sparsity_threshold=0.01, 197 | hard_thresholding_fraction=1.0, 198 | data_format="channels_last", 199 | timesteps=None 200 | ): 201 | super().__init__() 202 | 203 | self.norm1 = norm_layer(dim) 204 | self.norm2 = norm_layer(dim + context_dim) 205 | mlp_hidden_dim = int((dim + context_dim) * mlp_ratio) 206 | self.pre_proj = nn.Linear(dim + context_dim, dim + context_dim) 207 | self.filter = AFNO3D(dim + context_dim, num_blocks, sparsity_threshold, 208 | hard_thresholding_fraction) 209 | self.mlp = Mlp( 210 | in_features=dim + context_dim, 211 | out_features=dim, 212 | hidden_features=mlp_hidden_dim, 213 | act_layer=act_layer, drop=drop 214 | ) 215 | self.channels_first = (data_format == "channels_first") 216 | 217 | def forward(self, x, y): 218 | if self.channels_first: 219 | # AFNO natively uses a channels-last order 220 | x = x.permute(0, 2, 3, 4, 1) 221 | y = y.permute(0, 2, 3, 4, 1) 222 | 223 | xy = torch.concat((self.norm1(x), y), axis=-1) 224 | xy = self.pre_proj(xy) + xy 225 | xy = self.filter(self.norm2(xy)) + xy # AFNO filter 226 | x = self.mlp(xy) + x # feed-forward 227 | 228 | if self.channels_first: 229 | x = x.permute(0, 4, 1, 2, 3) 230 | 231 | return x -------------------------------------------------------------------------------- /SHADECast/Blocks/ResBlock3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/resnet.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.utils.parametrizations import spectral_norm as sn 8 | from utils import activation, normalization 9 | 10 | 11 | class ResBlock3D(nn.Module): 12 | def __init__( 13 | self, in_channels, out_channels, resample=None, 14 | resample_factor=(1, 1, 1), kernel_size=(3, 3, 3), 15 | act='swish', norm='group', norm_kwargs=None, 16 | spectral_norm=False, 17 | **kwargs 18 | ): 19 | super().__init__(**kwargs) 20 | if in_channels != out_channels: 21 | self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) 22 | else: 23 | self.proj = nn.Identity() 24 | 25 | padding = tuple(k // 2 for k in kernel_size) 26 | if resample == "down": 27 | self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) 28 | self.conv1 = nn.Conv3d(in_channels, out_channels, 29 | kernel_size=kernel_size, stride=resample_factor, 30 | padding=padding) 31 | self.conv2 = nn.Conv3d(out_channels, out_channels, 32 | kernel_size=kernel_size, padding=padding) 33 | elif resample == "up": 34 | self.resample = nn.Upsample( 35 | scale_factor=resample_factor, mode='trilinear') 36 | self.conv1 = nn.ConvTranspose3d(in_channels, out_channels, 37 | kernel_size=kernel_size, padding=padding) 38 | output_padding = tuple( 39 | 2 * p + s - k for (p, s, k) in zip(padding, resample_factor, kernel_size) 40 | ) 41 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, 42 | kernel_size=kernel_size, stride=resample_factor, 43 | padding=padding, output_padding=output_padding) 44 | else: 45 | self.resample = nn.Identity() 46 | self.conv1 = nn.Conv3d(in_channels, out_channels, 47 | kernel_size=kernel_size, padding=padding) 48 | self.conv2 = nn.Conv3d(out_channels, out_channels, 49 | kernel_size=kernel_size, padding=padding) 50 | 51 | if isinstance(act, str): 52 | act = (act, act) 53 | self.act1 = activation(act_type=act[0]) 54 | self.act2 = activation(act_type=act[1]) 55 | 56 | if norm_kwargs is None: 57 | norm_kwargs = {} 58 | self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) 59 | self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) 60 | if spectral_norm: 61 | self.conv1 = sn(self.conv1) 62 | self.conv2 = sn(self.conv2) 63 | if not isinstance(self.proj, nn.Identity): 64 | self.proj = sn(self.proj) 65 | 66 | def forward(self, x): 67 | x_in = self.resample(self.proj(x)) 68 | x = self.norm1(x) 69 | x = self.act1(x) 70 | x = self.conv1(x) 71 | x = self.norm2(x) 72 | x = self.act2(x) 73 | x = self.conv2(x) 74 | return x + x_in 75 | 76 | 77 | class ResBlock2D(nn.Module): 78 | def __init__( 79 | self, in_channels, out_channels, resample=None, 80 | resample_factor=(1, 1), kernel_size=(3, 3), 81 | act='swish', norm='group', norm_kwargs=None, 82 | spectral_norm=False, 83 | **kwargs 84 | ): 85 | super().__init__(**kwargs) 86 | if in_channels != out_channels: 87 | self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) 88 | else: 89 | self.proj = nn.Identity() 90 | 91 | padding = tuple(k // 2 for k in kernel_size) 92 | if resample == "down": 93 | self.resample = nn.AvgPool2d(resample_factor, ceil_mode=True) 94 | self.conv1 = nn.Conv2d(in_channels, out_channels, 95 | kernel_size=kernel_size, stride=resample_factor, 96 | padding=padding) 97 | self.conv2 = nn.Conv2d(out_channels, out_channels, 98 | kernel_size=kernel_size, padding=padding) 99 | elif resample == "up": 100 | self.resample = nn.Upsample( 101 | scale_factor=resample_factor, mode='trilinear') 102 | self.conv1 = nn.ConvTranspose3d(in_channels, out_channels, 103 | kernel_size=kernel_size, padding=padding) 104 | output_padding = tuple( 105 | 2 * p + s - k for (p, s, k) in zip(padding, resample_factor, kernel_size) 106 | ) 107 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, 108 | kernel_size=kernel_size, stride=resample_factor, 109 | padding=padding, output_padding=output_padding) 110 | else: 111 | self.resample = nn.Identity() 112 | self.conv1 = nn.Conv2d(in_channels, out_channels, 113 | kernel_size=kernel_size, padding=padding) 114 | self.conv2 = nn.Conv2d(out_channels, out_channels, 115 | kernel_size=kernel_size, padding=padding) 116 | 117 | if isinstance(act, str): 118 | act = (act, act) 119 | self.act1 = activation(act_type=act[0]) 120 | self.act2 = activation(act_type=act[1]) 121 | 122 | if norm_kwargs is None: 123 | norm_kwargs = {} 124 | self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) 125 | self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) 126 | if spectral_norm: 127 | self.conv1 = sn(self.conv1) 128 | self.conv2 = sn(self.conv2) 129 | if not isinstance(self.proj, nn.Identity): 130 | self.proj = sn(self.proj) 131 | 132 | def forward(self, x): 133 | x_in = self.resample(self.proj(x)) 134 | x = self.norm1(x) 135 | x = self.act1(x) 136 | x = self.conv1(x) 137 | x = self.norm2(x) 138 | x = self.act2(x) 139 | x = self.conv2(x) 140 | return x + x_in 141 | -------------------------------------------------------------------------------- /SHADECast/Blocks/TimeStep.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch.nn as nn 3 | from SHADECast.Blocks.AFNO import AFNOCrossAttentionBlock3d 4 | 5 | class TimestepBlock(nn.Module): 6 | """ 7 | Any module where forward() takes timestep embeddings as a second argument. 8 | """ 9 | 10 | @abstractmethod 11 | def forward(self, x, emb): 12 | """ 13 | Apply the module to `x` given `emb` timestep embeddings. 14 | """ 15 | 16 | 17 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 18 | """ 19 | A sequential module that passes timestep embeddings to the children that 20 | support it as an extra input. 21 | """ 22 | 23 | def forward(self, x, emb, context=None): 24 | for layer in self: 25 | if isinstance(layer, TimestepBlock): 26 | x = layer(x, emb) 27 | elif isinstance(layer, AFNOCrossAttentionBlock3d): 28 | img_shape = tuple(x.shape[-2:]) 29 | x = layer(x, context[img_shape]) 30 | else: 31 | x = layer(x) 32 | return x -------------------------------------------------------------------------------- /SHADECast/Blocks/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/attention.py 3 | """ 4 | 5 | import math 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class TemporalAttention(nn.Module): 12 | def __init__( 13 | self, channels, context_channels=None, 14 | head_dim=32, num_heads=8 15 | ): 16 | super().__init__() 17 | self.channels = channels 18 | if context_channels is None: 19 | context_channels = channels 20 | self.context_channels = context_channels 21 | self.head_dim = head_dim 22 | self.num_heads = num_heads 23 | self.inner_dim = head_dim * num_heads 24 | self.attn_scale = self.head_dim ** -0.5 25 | if channels % num_heads: 26 | raise ValueError("channels must be divisible by num_heads") 27 | self.KV = nn.Linear(context_channels, self.inner_dim * 2) 28 | self.Q = nn.Linear(channels, self.inner_dim) 29 | self.proj = nn.Linear(self.inner_dim, channels) 30 | 31 | def forward(self, x, y=None): 32 | if y is None: 33 | y = x 34 | 35 | (K, V) = self.KV(y).chunk(2, dim=-1) 36 | (B, Dk, H, W, C) = K.shape 37 | shape = (B, Dk, H, W, self.num_heads, self.head_dim) 38 | K = K.reshape(shape) 39 | V = V.reshape(shape) 40 | 41 | Q = self.Q(x) 42 | (B, Dq, H, W, C) = Q.shape 43 | shape = (B, Dq, H, W, self.num_heads, self.head_dim) 44 | Q = Q.reshape(shape) 45 | 46 | K = K.permute((0, 2, 3, 4, 5, 1)) # K^T 47 | V = V.permute((0, 2, 3, 4, 1, 5)) 48 | Q = Q.permute((0, 2, 3, 4, 1, 5)) 49 | 50 | attn = torch.matmul(Q, K) * self.attn_scale 51 | attn = F.softmax(attn, dim=-1) 52 | y = torch.matmul(attn, V) 53 | y = y.permute((0, 4, 1, 2, 3, 5)) 54 | y = y.reshape((B, Dq, H, W, C)) 55 | y = self.proj(y) 56 | return y 57 | 58 | 59 | class TemporalTransformer(nn.Module): 60 | def __init__(self, 61 | channels, 62 | mlp_dim_mul=1, 63 | **kwargs 64 | ): 65 | super().__init__() 66 | self.attn1 = TemporalAttention(channels, **kwargs) 67 | self.attn2 = TemporalAttention(channels, **kwargs) 68 | self.norm1 = nn.LayerNorm(channels) 69 | self.norm2 = nn.LayerNorm(channels) 70 | self.norm3 = nn.LayerNorm(channels) 71 | self.mlp = MLP(channels, dim_mul=mlp_dim_mul) 72 | 73 | def forward(self, x, y): 74 | x = self.attn1(self.norm1(x)) + x # self attention 75 | x = self.attn2(self.norm2(x), y) + x # cross attention 76 | return self.mlp(self.norm3(x)) + x # feed-forward 77 | 78 | 79 | class MLP(nn.Sequential): 80 | def __init__(self, dim, dim_mul=4): 81 | inner_dim = dim * dim_mul 82 | sequence = [ 83 | nn.Linear(dim, inner_dim), 84 | nn.SiLU(), 85 | nn.Linear(inner_dim, dim) 86 | ] 87 | super().__init__(*sequence) 88 | 89 | 90 | def positional_encoding(position, dims, add_dims=()): 91 | div_term = torch.exp( 92 | torch.arange(0, dims, 2, device=position.device) * 93 | (-math.log(10000.0) / dims) 94 | ) 95 | if position.ndim == 1: 96 | arg = position[:, None] * div_term[None, :] 97 | else: 98 | arg = position[:, :, None] * div_term[None, None, :] 99 | 100 | pos_enc = torch.concat( 101 | [torch.sin(arg), torch.cos(arg)], 102 | dim=-1 103 | ) 104 | if add_dims: 105 | for dim in add_dims: 106 | pos_enc = pos_enc.unsqueeze(dim) 107 | return pos_enc 108 | -------------------------------------------------------------------------------- /SHADECast/Models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/Models/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Models/Diffusion/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/Models/Diffusion/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Models/Diffusion/DiffusionModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py 3 | 4 | The original file acknowledges: 5 | https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 6 | https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py 7 | https://github.com/CompVis/taming-transformers 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | import pytorch_lightning as pl 14 | from contextlib import contextmanager 15 | from functools import partial 16 | 17 | 18 | from SHADECast.Models.Diffusion.utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding 19 | from SHADECast.Models.Diffusion.ema import LitEma 20 | 21 | 22 | class LatentDiffusion(pl.LightningModule): 23 | def __init__(self, 24 | model, 25 | autoencoder, 26 | context_encoder=None, 27 | timesteps=1000, 28 | beta_schedule="linear", 29 | loss_type="l2", 30 | use_ema=True, 31 | lr=1e-4, 32 | lr_warmup=0, 33 | linear_start=1e-4, 34 | linear_end=2e-2, 35 | cosine_s=8e-3, 36 | parameterization="eps", # all assuming fixed variance schedules 37 | opt_patience=5, 38 | get_t=False, 39 | **kwargs 40 | ): 41 | super().__init__() 42 | self.model = model 43 | self.autoencoder = autoencoder.requires_grad_(False) 44 | self.conditional = (context_encoder is not None) 45 | self.context_encoder = context_encoder 46 | self.lr = lr 47 | self.lr_warmup = lr_warmup 48 | self.opt_patience = opt_patience 49 | self.get_t = get_t 50 | assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' 51 | self.parameterization = parameterization 52 | 53 | self.use_ema = use_ema 54 | if self.use_ema: 55 | self.model_ema = LitEma(self.model) 56 | 57 | self.register_schedule( 58 | beta_schedule=beta_schedule, timesteps=timesteps, 59 | linear_start=linear_start, linear_end=linear_end, 60 | cosine_s=cosine_s 61 | ) 62 | 63 | self.loss_type = loss_type 64 | 65 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 66 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 67 | 68 | betas = make_beta_schedule( 69 | beta_schedule, timesteps, 70 | linear_start=linear_start, linear_end=linear_end, 71 | cosine_s=cosine_s 72 | ) 73 | alphas = 1. - betas 74 | alphas_cumprod = np.cumprod(alphas, axis=0) 75 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 76 | 77 | timesteps, = betas.shape 78 | self.num_timesteps = int(timesteps) 79 | self.linear_start = linear_start 80 | self.linear_end = linear_end 81 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 82 | 83 | to_torch = partial(torch.tensor, dtype=torch.float32) 84 | 85 | self.register_buffer('betas', to_torch(betas)) 86 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 87 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 88 | 89 | # calculations for diffusion q(x_t | x_{t-1}) and others 90 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 91 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 92 | 93 | @contextmanager 94 | def ema_scope(self, context=None): 95 | if self.use_ema: 96 | self.model_ema.store(self.model.parameters()) 97 | self.model_ema.copy_to(self.model) 98 | if context is not None: 99 | print(f"{context}: Switched to EMA weights") 100 | try: 101 | yield None 102 | finally: 103 | if self.use_ema: 104 | self.model_ema.restore(self.model.parameters()) 105 | if context is not None: 106 | print(f"{context}: Restored training weights") 107 | 108 | def apply_model(self, x_noisy, t, cond=None, return_ids=False): 109 | # if self.conditional: 110 | # cond = self.context_encoder(cond) 111 | with self.ema_scope(): 112 | return self.model(x_noisy, t, context=cond) 113 | 114 | def q_sample(self, x_start, t, noise=None): 115 | if noise is None: 116 | noise = torch.randn_like(x_start) 117 | return ( 118 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 119 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 120 | ) 121 | 122 | def get_loss(self, pred, target): 123 | if self.loss_type == 'l1': 124 | loss = (target - pred).abs() 125 | elif self.loss_type == 'l2': 126 | loss = torch.nn.functional.mse_loss(target, pred, reduction='none') 127 | else: 128 | raise NotImplementedError("unknown loss type '{loss_type}'") 129 | return loss.mean() 130 | 131 | def p_losses(self, x_start, t, noise=None, context=None): 132 | if noise is None: 133 | noise = torch.randn_like(x_start) 134 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 135 | model_out = self.model(x_noisy, t, context=context) 136 | if self.parameterization == "eps": 137 | target = noise 138 | yhat = x_noisy - model_out 139 | elif self.parameterization == "x0": 140 | target = x_start 141 | yhat = model_out 142 | else: 143 | raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") 144 | return self.get_loss(model_out, target) 145 | 146 | def forward(self, x, t=None, *args, **kwargs): 147 | if t is None: 148 | t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() 149 | return self.p_losses(x, t, *args, **kwargs) 150 | 151 | def shared_step(self, batch, t=None): 152 | if len(batch) == 2: 153 | (x, y) = batch 154 | context = self.context_encoder(x) if self.conditional else None 155 | elif len(batch) == 3: 156 | (x, y, c) = batch 157 | context = self.context_encoder(x, c) if self.conditional else None 158 | loss = self(self.autoencoder.encode(y)[0], t=t, context=context) 159 | return loss 160 | 161 | 162 | def training_step(self, batch, batch_idx): 163 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 164 | loss = self.shared_step(batch) 165 | self.log("train_loss", loss, **log_params) 166 | return loss 167 | 168 | @torch.no_grad() 169 | def validation_step(self, batch, batch_idx): 170 | if self.get_t: 171 | t = torch.tensor(batch[-1], dtype=torch.int8) 172 | batch = batch[:-1] 173 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 174 | loss = self.shared_step(batch, t) 175 | 176 | with self.ema_scope(): 177 | loss_ema = self.shared_step(batch, t) 178 | self.log("val_loss", loss, **log_params) 179 | self.log("val_loss_ema", loss_ema, **log_params) 180 | 181 | def on_train_batch_end(self, *args, **kwargs): 182 | if self.use_ema: 183 | self.model_ema(self.model) 184 | 185 | def configure_optimizers(self): 186 | optimizer = torch.optim.AdamW(self.parameters(), 187 | lr=self.lr, 188 | betas=(0.5, 0.9), 189 | weight_decay=1e-3) 190 | reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( 191 | optimizer, patience=self.opt_patience, factor=0.25, verbose=True 192 | ) 193 | return { 194 | "optimizer": optimizer, 195 | "lr_scheduler": { 196 | "scheduler": reduce_lr, 197 | "monitor": "val_loss_ema", 198 | "frequency": 1, 199 | }, 200 | } 201 | 202 | # def on_before_optimizer_step(self, optimizer): 203 | # # Compute the 2-norm for each layer 204 | # # If using mixed precision, the gradients are already unscaled here 205 | # norms = grad_norm(self.layer, norm_type=2) 206 | # self.log_dict(norms) 207 | 208 | def optimizer_step( 209 | self, 210 | epoch, 211 | batch_idx, 212 | optimizer, 213 | optimizer_idx, 214 | optimizer_closure, 215 | **kwargs 216 | ): 217 | if self.trainer.global_step < self.lr_warmup: 218 | lr_scale = (self.trainer.global_step + 1) / self.lr_warmup 219 | for pg in optimizer.param_groups: 220 | pg['lr'] = lr_scale * self.lr 221 | 222 | super().optimizer_step( 223 | epoch, batch_idx, optimizer, 224 | optimizer_idx, optimizer_closure, 225 | **kwargs 226 | ) 227 | 228 | 229 | -------------------------------------------------------------------------------- /SHADECast/Models/Diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /SHADECast/Models/Diffusion/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | import numpy as np 14 | from einops import repeat 15 | 16 | 17 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 18 | if schedule == "linear": 19 | betas = ( 20 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 21 | ) 22 | 23 | elif schedule == "cosine": 24 | timesteps = ( 25 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 26 | ) 27 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 28 | alphas = torch.cos(alphas).pow(2) 29 | alphas = alphas / alphas[0] 30 | betas = 1 - alphas[1:] / alphas[:-1] 31 | betas = np.clip(betas, a_min=0, a_max=0.999) 32 | 33 | elif schedule == "sqrt_linear": 34 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 35 | elif schedule == "sqrt": 36 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 37 | else: 38 | raise ValueError(f"schedule '{schedule}' unknown.") 39 | return betas.numpy() 40 | 41 | 42 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 43 | if ddim_discr_method == 'uniform': 44 | c = num_ddpm_timesteps // num_ddim_timesteps 45 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 46 | elif ddim_discr_method == 'quad': 47 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 48 | else: 49 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 50 | 51 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 52 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 53 | steps_out = ddim_timesteps + 1 54 | if verbose: 55 | print(f'Selected timesteps for ddim sampler: {steps_out}') 56 | return steps_out 57 | 58 | 59 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 60 | # select alphas for computing the variance schedule 61 | alphas = alphacums[ddim_timesteps] 62 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 63 | 64 | # according the the formula provided in https://arxiv.org/abs/2010.02502 65 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 66 | if verbose: 67 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 68 | print(f'For the chosen value of eta, which is {eta}, ' 69 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 70 | return sigmas, alphas, alphas_prev 71 | 72 | 73 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 74 | """ 75 | Create a beta schedule that discretizes the given alpha_t_bar function, 76 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 77 | :param num_diffusion_timesteps: the number of betas to produce. 78 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 79 | produces the cumulative product of (1-beta) up to that 80 | part of the diffusion process. 81 | :param max_beta: the maximum beta to use; use values lower than 1 to 82 | prevent singularities. 83 | """ 84 | betas = [] 85 | for i in range(num_diffusion_timesteps): 86 | t1 = i / num_diffusion_timesteps 87 | t2 = (i + 1) / num_diffusion_timesteps 88 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 89 | return np.array(betas) 90 | 91 | 92 | def extract_into_tensor(a, t, x_shape): 93 | b, *_ = t.shape 94 | out = a.gather(-1, t) 95 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 96 | 97 | 98 | def checkpoint(func, inputs, params, flag): 99 | """ 100 | Evaluate a function without caching intermediate activations, allowing for 101 | reduced memory at the expense of extra compute in the backward pass. 102 | :param func: the function to evaluate. 103 | :param inputs: the argument sequence to pass to `func`. 104 | :param params: a sequence of parameters `func` depends on but does not 105 | explicitly take as arguments. 106 | :param flag: if False, disable gradient checkpointing. 107 | """ 108 | if flag: 109 | args = tuple(inputs) + tuple(params) 110 | return CheckpointFunction.apply(func, len(inputs), *args) 111 | else: 112 | return func(*inputs) 113 | 114 | 115 | class CheckpointFunction(torch.autograd.Function): 116 | @staticmethod 117 | def forward(ctx, run_function, length, *args): 118 | ctx.run_function = run_function 119 | ctx.input_tensors = list(args[:length]) 120 | ctx.input_params = list(args[length:]) 121 | 122 | with torch.no_grad(): 123 | output_tensors = ctx.run_function(*ctx.input_tensors) 124 | return output_tensors 125 | 126 | @staticmethod 127 | def backward(ctx, *output_grads): 128 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 129 | with torch.enable_grad(): 130 | # Fixes a bug where the first op in run_function modifies the 131 | # Tensor storage in place, which is not allowed for detach()'d 132 | # Tensors. 133 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 134 | output_tensors = ctx.run_function(*shallow_copies) 135 | input_grads = torch.autograd.grad( 136 | output_tensors, 137 | ctx.input_tensors + ctx.input_params, 138 | output_grads, 139 | allow_unused=True, 140 | ) 141 | del ctx.input_tensors 142 | del ctx.input_params 143 | del output_tensors 144 | return (None, None) + input_grads 145 | 146 | 147 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 148 | """ 149 | Create sinusoidal timestep embeddings. 150 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 151 | These may be fractional. 152 | :param dim: the dimension of the output. 153 | :param max_period: controls the minimum frequency of the embeddings. 154 | :return: an [N x dim] Tensor of positional embeddings. 155 | """ 156 | if not repeat_only: 157 | half = dim // 2 158 | freqs = torch.exp( 159 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 160 | ).to(device=timesteps.device) 161 | args = timesteps[:, None].float() * freqs[None] 162 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 163 | if dim % 2: 164 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 165 | else: 166 | embedding = repeat(timesteps, 'b -> b d', d=dim) 167 | return embedding 168 | 169 | 170 | def zero_module(module): 171 | """ 172 | Zero out the parameters of a module and return it. 173 | """ 174 | for p in module.parameters(): 175 | p.detach().zero_() 176 | return module 177 | 178 | 179 | def scale_module(module, scale): 180 | """ 181 | Scale the parameters of a module and return it. 182 | """ 183 | for p in module.parameters(): 184 | p.detach().mul_(scale) 185 | return module 186 | 187 | 188 | def mean_flat(tensor): 189 | """ 190 | Take the mean over all non-batch dimensions. 191 | """ 192 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 193 | 194 | 195 | class GroupNorm32(nn.GroupNorm): 196 | def forward(self, x): 197 | return super().forward(x.float()).type(x.dtype) 198 | 199 | 200 | def normalization(channels): 201 | """ 202 | Make a standard normalization layer. 203 | :param channels: number of input channels. 204 | :return: an nn.Module for normalization. 205 | """ 206 | return nn.Identity() #GroupNorm32(32, channels) 207 | 208 | 209 | def noise_like(shape, device, repeat=False): 210 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 211 | noise = lambda: torch.randn(shape, device=device) 212 | return repeat_noise() if repeat else noise() 213 | 214 | 215 | def conv_nd(dims, *args, **kwargs): 216 | """ 217 | Create a 1D, 2D, or 3D convolution module. 218 | """ 219 | if dims == 1: 220 | return nn.Conv1d(*args, **kwargs) 221 | elif dims == 2: 222 | return nn.Conv2d(*args, **kwargs) 223 | elif dims == 3: 224 | return nn.Conv3d(*args, **kwargs) 225 | raise ValueError(f"unsupported dimensions: {dims}") 226 | 227 | 228 | def linear(*args, **kwargs): 229 | """ 230 | Create a linear module. 231 | """ 232 | return nn.Linear(*args, **kwargs) 233 | 234 | 235 | def avg_pool_nd(dims, *args, **kwargs): 236 | """ 237 | Create a 1D, 2D, or 3D average pooling module. 238 | """ 239 | if dims == 1: 240 | return nn.AvgPool1d(*args, **kwargs) 241 | elif dims == 2: 242 | return nn.AvgPool2d(*args, **kwargs) 243 | elif dims == 3: 244 | return nn.AvgPool3d(*args, **kwargs) 245 | raise ValueError(f"unsupported dimensions: {dims}") -------------------------------------------------------------------------------- /SHADECast/Models/Nowcaster/Nowcast.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import pytorch_lightning as pl 6 | 7 | from SHADECast.Blocks.attention import TemporalTransformer, positional_encoding 8 | from SHADECast.Blocks.AFNO import AFNOBlock3d 9 | from SHADECast.Blocks.ResBlock3D import ResBlock3D 10 | import numpy as np 11 | 12 | 13 | class Nowcaster(pl.LightningModule): 14 | def __init__(self, nowcast_net, opt_patience, loss_type='l1'): 15 | super().__init__() 16 | self.nowcast_net = nowcast_net 17 | self.opt_patience = opt_patience 18 | self.loss_type = loss_type 19 | 20 | def forward(self, x): 21 | return self.nowcast_net(x) 22 | 23 | def _loss(self, batch): 24 | x, y = batch 25 | 26 | if self.loss_type == 'l1': 27 | y_pred = self.forward(x) 28 | return (y - y_pred).abs().mean() 29 | 30 | elif self.loss_type == 'l2': 31 | y_pred = self.forward(x) 32 | return (y - y_pred).square().mean() 33 | 34 | elif self.loss_type == 'latent': 35 | y, _ = self.nowcast_net.autoencoder.encode(y) 36 | x = self.nowcast_net.latent_forward(x) 37 | y_pred = self.nowcast_net.out_proj(x) 38 | return (y - y_pred).abs().mean() 39 | else: 40 | AssertionError('Loss type must be "l1" or "l2"') 41 | 42 | def training_step(self, batch, batch_idx): 43 | loss = self._loss(batch) 44 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 45 | self.log('train_loss', loss, **log_params) 46 | return loss 47 | 48 | @torch.no_grad() 49 | def val_test_step(self, batch, batch_idx, split="val"): 50 | loss = self._loss(batch) 51 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 52 | self.log(f"{split}_loss", loss, **log_params) 53 | 54 | def validation_step(self, batch, batch_idx): 55 | self.val_test_step(batch, batch_idx, split="val") 56 | 57 | def test_step(self, batch, batch_idx): 58 | self.val_test_step(batch, batch_idx, split="test") 59 | 60 | def configure_optimizers(self): 61 | optimizer = torch.optim.AdamW( 62 | self.parameters(), lr=1e-3, 63 | betas=(0.5, 0.9), weight_decay=1e-3 64 | ) 65 | reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( 66 | optimizer, patience=self.opt_patience, factor=0.25, verbose=True 67 | ) 68 | 69 | optimizer_spec = { 70 | "optimizer": optimizer, 71 | "lr_scheduler": { 72 | "scheduler": reduce_lr, 73 | "monitor": "val_loss", 74 | "frequency": 1, 75 | }, 76 | } 77 | return optimizer_spec 78 | 79 | 80 | class FusionBlock3d(nn.Module): 81 | def __init__(self, dim, size_ratios, dim_out=None, afno_fusion=False): 82 | super().__init__() 83 | 84 | N_sources = len(size_ratios) 85 | if not isinstance(dim, collections.abc.Sequence): 86 | dim = (dim,) * N_sources 87 | if dim_out is None: 88 | dim_out = dim[0] 89 | 90 | self.scale = nn.ModuleList() 91 | for (i, size_ratio) in enumerate(size_ratios): 92 | if size_ratio == 1: 93 | scale = nn.Identity() 94 | else: 95 | scale = [] 96 | while size_ratio > 1: 97 | scale.append(nn.ConvTranspose3d( 98 | dim[i], dim_out if size_ratio == 2 else dim[i], 99 | kernel_size=(1, 3, 3), stride=(1, 2, 2), 100 | padding=(0, 1, 1), output_padding=(0, 1, 1) 101 | )) 102 | size_ratio //= 2 103 | scale = nn.Sequential(*scale) 104 | self.scale.append(scale) 105 | 106 | self.afno_fusion = afno_fusion 107 | 108 | if self.afno_fusion: 109 | if N_sources > 1: 110 | self.fusion = nn.Sequential( 111 | nn.Linear(sum(dim), sum(dim)), 112 | AFNOBlock3d(dim * N_sources, mlp_ratio=2), 113 | nn.Linear(sum(dim), dim_out) 114 | ) 115 | else: 116 | self.fusion = nn.Identity() 117 | 118 | def resize_proj(self, x, i): 119 | x = x.permute(0, 4, 1, 2, 3) 120 | x = self.scale[i](x) 121 | x = x.permute(0, 2, 3, 4, 1) 122 | return x 123 | 124 | def forward(self, x): 125 | x = [self.resize_proj(xx, i) for (i, xx) in enumerate(x)] 126 | if self.afno_fusion: 127 | x = torch.concat(x, axis=-1) 128 | x = self.fusion(x) 129 | else: 130 | x = sum(x) 131 | return x 132 | 133 | 134 | class AFNONowcastNetBase(nn.Module): 135 | def __init__( 136 | self, 137 | autoencoder, 138 | embed_dim=128, 139 | embed_dim_out=None, 140 | analysis_depth=4, 141 | forecast_depth=4, 142 | input_steps=1, 143 | output_steps=2, 144 | train_autoenc=False 145 | ): 146 | super().__init__() 147 | 148 | self.train_autoenc = train_autoenc 149 | self.embed_dim = embed_dim 150 | self.embed_dim_out = embed_dim_out 151 | self.output_steps = output_steps 152 | self.input_steps = input_steps 153 | 154 | # encoding + analysis for each input 155 | ae = autoencoder.requires_grad_(train_autoenc) 156 | self.autoencoder = ae 157 | 158 | self.proj = nn.Conv3d(ae.hidden_width, embed_dim, kernel_size=1) 159 | 160 | self.analysis = nn.Sequential( 161 | *(AFNOBlock3d(embed_dim) for _ in range(analysis_depth)) 162 | ) 163 | 164 | # temporal transformer 165 | self.use_temporal_transformer = input_steps != output_steps 166 | if self.use_temporal_transformer: 167 | self.temporal_transformer = TemporalTransformer(embed_dim) 168 | 169 | # # data fusion 170 | # self.fusion = FusionBlock3d(embed_dim, input_size_ratios, 171 | # afno_fusion=afno_fusion, dim_out=embed_dim_out) 172 | 173 | # forecast 174 | self.forecast = nn.Sequential( 175 | *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth)) 176 | ) 177 | 178 | def add_pos_enc(self, x, t): 179 | if t.shape[1] != x.shape[1]: 180 | # this can happen if x has been compressed 181 | # by the autoencoder in the time dimension 182 | ds_factor = t.shape[1] // x.shape[1] 183 | t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:, 0, :] 184 | 185 | pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2, 3)) 186 | return x + pos_enc 187 | 188 | def forward(self, x): 189 | # (x, t_relative) = list(zip(*x)) 190 | 191 | # encoding + analysis for each input 192 | # def process_input(i): 193 | x = self.autoencoder.encode(x)[0] 194 | x = self.proj(x) 195 | x = x.permute(0, 2, 3, 4, 1) 196 | x = self.analysis(x) 197 | if self.use_temporal_transformer: 198 | # add positional encoding 199 | t = torch.arange(0, self.input_steps, device=x.device) 200 | expand_shape = x.shape[:1] + (-1,) + x.shape[2:] 201 | pos_enc_output = positional_encoding( 202 | t, 203 | self.embed_dim, add_dims=(0, 2, 3) 204 | ) 205 | pe_out = pos_enc_output.expand(*expand_shape) 206 | x = x + pe_out 207 | 208 | # transform to output shape and coordinates 209 | pos_enc_output = positional_encoding( 210 | torch.arange(self.input_steps, self.output_steps + 1, device=x.device), 211 | self.embed_dim, add_dims=(0, 2, 3) 212 | ) 213 | pe_out = pos_enc_output.expand(*expand_shape) 214 | x = self.temporal_transformer(pe_out, x) 215 | 216 | x = self.forecast(x) 217 | return x.permute(0, 4, 1, 2, 3) # to channels-first order 218 | 219 | 220 | class AFNONowcastNet(AFNONowcastNetBase): 221 | def __init__(self, autoencoder, **kwargs): 222 | super().__init__(autoencoder, **kwargs) 223 | 224 | self.output_autoencoder = autoencoder.requires_grad_( 225 | self.train_autoenc) 226 | self.out_proj = nn.Conv3d( 227 | self.embed_dim_out, autoencoder.hidden_width, kernel_size=1 228 | ) 229 | 230 | def latent_forward(self, x): 231 | x = super().forward(x) 232 | return x 233 | 234 | def forward(self, x): 235 | x = self.latent_forward(x) 236 | x = self.out_proj(x) 237 | return self.output_autoencoder.decode(x) 238 | 239 | 240 | class AFNONowcastNetCascade(nn.Module): 241 | def __init__(self, 242 | nowcast_net, 243 | cascade_depth=4, 244 | train_net=False): 245 | super().__init__() 246 | self.cascade_depth = cascade_depth 247 | self.nowcast_net = nowcast_net 248 | for p in self.nowcast_net.parameters(): 249 | p.requires_grad = train_net 250 | self.resnet = nn.ModuleList() 251 | ch = self.nowcast_net.embed_dim_out 252 | self.cascade_dims = [ch] 253 | for i in range(cascade_depth - 1): 254 | ch_out = 2 * ch 255 | self.cascade_dims.append(ch_out) 256 | self.resnet.append( 257 | ResBlock3D(ch, ch_out, kernel_size=(1, 3, 3), norm=None) 258 | ) 259 | ch = ch_out 260 | 261 | def forward(self, x): 262 | x = self.nowcast_net.latent_forward(x) 263 | img_shape = tuple(x.shape[-2:]) 264 | cascade = {img_shape: x} 265 | for i in range(self.cascade_depth - 1): 266 | x = F.avg_pool3d(x, (1, 2, 2)) 267 | x = self.resnet[i](x) 268 | img_shape = tuple(x.shape[-2:]) 269 | cascade[img_shape] = x 270 | return cascade -------------------------------------------------------------------------------- /SHADECast/Models/Sampler/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/Models/Sampler/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Models/Sampler/PLMS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from Models.Sampler.utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 5 | 6 | 7 | """ 8 | From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py 9 | """ 10 | 11 | 12 | """SAMPLING ONLY.""" 13 | 14 | import torch 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | from Models.Diffusion.utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 19 | 20 | 21 | class PLMSSampler: 22 | def __init__(self, model, schedule="linear", **kwargs): 23 | self.model = model 24 | self.ddpm_num_timesteps = model.num_timesteps 25 | self.schedule = schedule 26 | 27 | def register_buffer(self, name, attr): 28 | #if type(attr) == torch.Tensor: 29 | # if attr.device != torch.device("cuda"): 30 | # attr = attr.to(torch.device("cuda")) 31 | setattr(self, name, attr) 32 | 33 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 34 | if ddim_eta != 0: 35 | raise ValueError('ddim_eta must be 0 for PLMS') 36 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 37 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 38 | alphas_cumprod = self.model.alphas_cumprod 39 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 40 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 41 | 42 | self.register_buffer('betas', to_torch(self.model.betas)) 43 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 44 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 45 | 46 | # calculations for diffusion q(x_t | x_{t-1}) and others 47 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 48 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 49 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 50 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 51 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 52 | 53 | # ddim sampling parameters 54 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 55 | ddim_timesteps=self.ddim_timesteps, 56 | eta=ddim_eta,verbose=verbose) 57 | self.register_buffer('ddim_sigmas', ddim_sigmas) 58 | self.register_buffer('ddim_alphas', ddim_alphas) 59 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 60 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 61 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 62 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 63 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 64 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 65 | 66 | @torch.no_grad() 67 | def sample(self, 68 | S, 69 | batch_size, 70 | shape, 71 | conditioning=None, 72 | callback=None, 73 | normals_sequence=None, 74 | img_callback=None, 75 | quantize_x0=False, 76 | eta=0., 77 | mask=None, 78 | x0=None, 79 | temperature=1., 80 | noise_dropout=0., 81 | score_corrector=None, 82 | corrector_kwargs=None, 83 | verbose=True, 84 | x_T=None, 85 | log_every_t=100, 86 | unconditional_guidance_scale=1., 87 | unconditional_conditioning=None, 88 | progbar=True, 89 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 90 | **kwargs 91 | ): 92 | """ 93 | if conditioning is not None: 94 | if isinstance(conditioning, dict): 95 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 96 | if cbs != batch_size: 97 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 98 | else: 99 | if conditioning.shape[0] != batch_size: 100 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 101 | """ 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 104 | # sampling 105 | size = (batch_size,) + shape 106 | print(f'Data shape for PLMS sampling is {size}') 107 | 108 | samples, intermediates = self.plms_sampling(conditioning, size, 109 | callback=callback, 110 | img_callback=img_callback, 111 | quantize_denoised=quantize_x0, 112 | mask=mask, x0=x0, 113 | ddim_use_original_steps=False, 114 | noise_dropout=noise_dropout, 115 | temperature=temperature, 116 | score_corrector=score_corrector, 117 | corrector_kwargs=corrector_kwargs, 118 | x_T=x_T, 119 | log_every_t=log_every_t, 120 | unconditional_guidance_scale=unconditional_guidance_scale, 121 | unconditional_conditioning=unconditional_conditioning, 122 | progbar=progbar 123 | ) 124 | return samples, intermediates 125 | 126 | @torch.no_grad() 127 | def plms_sampling(self, cond, shape, 128 | x_T=None, ddim_use_original_steps=False, 129 | callback=None, timesteps=None, quantize_denoised=False, 130 | mask=None, x0=None, img_callback=None, log_every_t=100, 131 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 132 | unconditional_guidance_scale=1., unconditional_conditioning=None, progbar=True): 133 | device = self.model.betas.device 134 | b = shape[0] 135 | if x_T is None: 136 | img = torch.randn(shape, device=device) 137 | else: 138 | img = x_T 139 | 140 | if timesteps is None: 141 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 142 | elif timesteps is not None and not ddim_use_original_steps: 143 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 144 | timesteps = self.ddim_timesteps[:subset_end] 145 | 146 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 147 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 148 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 149 | print(f"Running PLMS Sampling with {total_steps} timesteps") 150 | 151 | iterator = time_range 152 | if progbar: 153 | iterator = tqdm(iterator, desc='PLMS Sampler', total=total_steps) 154 | old_eps = [] 155 | 156 | for i, step in enumerate(iterator): 157 | index = total_steps - i - 1 158 | ts = torch.full((b,), step, device=device, dtype=torch.long) 159 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 160 | 161 | if mask is not None: 162 | assert x0 is not None 163 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 164 | img = img_orig * mask + (1. - mask) * img 165 | 166 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 167 | quantize_denoised=quantize_denoised, temperature=temperature, 168 | noise_dropout=noise_dropout, score_corrector=score_corrector, 169 | corrector_kwargs=corrector_kwargs, 170 | unconditional_guidance_scale=unconditional_guidance_scale, 171 | unconditional_conditioning=unconditional_conditioning, 172 | old_eps=old_eps, t_next=ts_next) 173 | img, pred_x0, e_t = outs 174 | old_eps.append(e_t) 175 | if len(old_eps) >= 4: 176 | old_eps.pop(0) 177 | if callback: callback(i) 178 | if img_callback: img_callback(pred_x0, i) 179 | 180 | if index % log_every_t == 0 or index == total_steps - 1: 181 | intermediates['x_inter'].append(img) 182 | intermediates['pred_x0'].append(pred_x0) 183 | 184 | return img, intermediates 185 | 186 | @torch.no_grad() 187 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 188 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 189 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 190 | b, *_, device = *x.shape, x.device 191 | 192 | def get_model_output(x, t ,c): 193 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 194 | e_t = self.model.apply_model(x, t, c) 195 | else: 196 | x_in = torch.cat([x] * 2) 197 | t_in = torch.cat([t] * 2) 198 | c_in = torch.cat([unconditional_conditioning, c]) 199 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 200 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 201 | 202 | if score_corrector is not None: 203 | assert self.model.parameterization == "eps" 204 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 205 | 206 | return e_t 207 | 208 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 209 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 210 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 211 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 212 | 213 | def get_x_prev_and_pred_x0(x, e_t, index): 214 | # select parameters corresponding to the currently considered timestep 215 | param_shape = (b,) + (1,)*(x.ndim-1) 216 | a_t = torch.full(param_shape, alphas[index], device=device) 217 | a_prev = torch.full(param_shape, alphas_prev[index], device=device) 218 | sigma_t = torch.full(param_shape, sigmas[index], device=device) 219 | sqrt_one_minus_at = torch.full(param_shape, sqrt_one_minus_alphas[index],device=device) 220 | 221 | # current prediction for x_0 222 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 223 | if quantize_denoised: 224 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 225 | # direction pointing to x_t 226 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 227 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 228 | if noise_dropout > 0.: 229 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 230 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 231 | return x_prev, pred_x0 232 | 233 | e_t = get_model_output(x, t, c) 234 | if len(old_eps) == 0: 235 | # Pseudo Improved Euler (2nd order) 236 | x_prev, pred_x0 = get_x_prev_and_pred_x0(x, e_t, index) 237 | e_t_next = get_model_output(x_prev, t_next, c) 238 | e_t_prime = (e_t + e_t_next) / 2 239 | elif len(old_eps) == 1: 240 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 241 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 242 | elif len(old_eps) == 2: 243 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 244 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 245 | elif len(old_eps) >= 3: 246 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 247 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 248 | 249 | x_prev, pred_x0 = get_x_prev_and_pred_x0(x, e_t_prime, index) 250 | 251 | return x_prev, pred_x0, e_t -------------------------------------------------------------------------------- /SHADECast/Models/Sampler/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import torch 12 | import numpy as np 13 | 14 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 15 | if ddim_discr_method == 'uniform': 16 | c = num_ddpm_timesteps // num_ddim_timesteps 17 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 18 | elif ddim_discr_method == 'quad': 19 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 20 | else: 21 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 22 | 23 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 24 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 25 | steps_out = ddim_timesteps + 1 26 | if verbose: 27 | print(f'Selected timesteps for ddim sampler: {steps_out}') 28 | return steps_out 29 | 30 | 31 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 32 | # select alphas for computing the variance schedule 33 | alphas = alphacums[ddim_timesteps] 34 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 35 | 36 | # according to the formula provided in https://arxiv.org/abs/2010.02502 37 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 38 | if verbose: 39 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 40 | print(f'For the chosen value of eta, which is {eta}, ' 41 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 42 | return sigmas, alphas, alphas_prev 43 | 44 | 45 | def noise_like(shape, device, repeat=False): 46 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 47 | noise = lambda: torch.randn(shape, device=device) 48 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /SHADECast/Models/UNet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnergyWeatherAI/GenerativeNowcasting/8542c84d782a949083360b0817c535b0457d7dcc/SHADECast/Models/UNet/.DS_Store -------------------------------------------------------------------------------- /SHADECast/Models/UNet/UNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/unet.py 3 | 4 | """ 5 | 6 | from abc import abstractmethod 7 | from functools import partial 8 | import math 9 | from typing import Iterable 10 | 11 | import numpy as np 12 | import torch as th 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from SHADECast.Models.UNet.utils import ( 17 | checkpoint, 18 | conv_nd, 19 | linear, 20 | avg_pool_nd, 21 | zero_module, 22 | normalization, 23 | timestep_embedding, 24 | ) 25 | from SHADECast.Blocks.AFNO import AFNOCrossAttentionBlock3d 26 | SpatialTransformer = type(None) 27 | 28 | 29 | class TimestepBlock(nn.Module): 30 | """ 31 | Any module where forward() takes timestep embeddings as a second argument. 32 | """ 33 | 34 | @abstractmethod 35 | def forward(self, x, emb): 36 | """ 37 | Apply the module to `x` given `emb` timestep embeddings. 38 | """ 39 | 40 | 41 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 42 | """ 43 | A sequential module that passes timestep embeddings to the children that 44 | support it as an extra input. 45 | """ 46 | 47 | def forward(self, x, emb, context=None): 48 | for layer in self: 49 | if isinstance(layer, TimestepBlock): 50 | x = layer(x, emb) 51 | elif isinstance(layer, AFNOCrossAttentionBlock3d): 52 | img_shape = tuple(x.shape[-2:]) 53 | x = layer(x, context[img_shape]) 54 | else: 55 | x = layer(x) 56 | return x 57 | 58 | 59 | class Upsample(nn.Module): 60 | """ 61 | An upsampling layer with an optional convolution. 62 | :param channels: channels in the inputs and outputs. 63 | :param use_conv: a bool determining if a convolution is applied. 64 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 65 | upsampling occurs in the inner-two dimensions. 66 | """ 67 | 68 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 69 | super().__init__() 70 | self.channels = channels 71 | self.out_channels = out_channels or channels 72 | self.use_conv = use_conv 73 | self.dims = dims 74 | if use_conv: 75 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 76 | 77 | def forward(self, x): 78 | assert x.shape[1] == self.channels 79 | if self.dims == 3: 80 | x = F.interpolate( 81 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 82 | ) 83 | else: 84 | x = F.interpolate(x, scale_factor=2, mode="nearest") 85 | if self.use_conv: 86 | x = self.conv(x) 87 | return x 88 | 89 | 90 | class Downsample(nn.Module): 91 | """ 92 | A downsampling layer with an optional convolution. 93 | :param channels: channels in the inputs and outputs. 94 | :param use_conv: a bool determining if a convolution is applied. 95 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 96 | downsampling occurs in the inner-two dimensions. 97 | """ 98 | 99 | def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): 100 | super().__init__() 101 | self.channels = channels 102 | self.out_channels = out_channels or channels 103 | self.use_conv = use_conv 104 | self.dims = dims 105 | stride = 2 if dims != 3 else (1, 2, 2) 106 | if use_conv: 107 | self.op = conv_nd( 108 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 109 | ) 110 | else: 111 | assert self.channels == self.out_channels 112 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 113 | 114 | def forward(self, x): 115 | assert x.shape[1] == self.channels 116 | return self.op(x) 117 | 118 | 119 | class ResBlock(TimestepBlock): 120 | """ 121 | A residual block that can optionally change the number of channels. 122 | :param channels: the number of input channels. 123 | :param emb_channels: the number of timestep embedding channels. 124 | :param dropout: the rate of dropout. 125 | :param out_channels: if specified, the number of out channels. 126 | :param use_conv: if True and out_channels is specified, use a spatial 127 | convolution instead of a smaller 1x1 convolution to change the 128 | channels in the skip connection. 129 | :param dims: determines if the signal is 1D, 2D, or 3D. 130 | :param use_checkpoint: if True, use gradient checkpointing on this module. 131 | :param up: if True, use this block for upsampling. 132 | :param down: if True, use this block for downsampling. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | channels, 138 | emb_channels, 139 | dropout, 140 | out_channels=None, 141 | use_conv=False, 142 | use_scale_shift_norm=False, 143 | dims=2, 144 | use_checkpoint=False, 145 | up=False, 146 | down=False, 147 | ): 148 | super().__init__() 149 | self.channels = channels 150 | self.emb_channels = emb_channels 151 | self.dropout = dropout 152 | self.out_channels = out_channels or channels 153 | self.use_conv = use_conv 154 | self.use_checkpoint = use_checkpoint 155 | self.use_scale_shift_norm = use_scale_shift_norm 156 | 157 | self.in_layers = nn.Sequential( 158 | normalization(channels), 159 | nn.SiLU(), 160 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 161 | ) 162 | 163 | self.updown = up or down 164 | 165 | if up: 166 | self.h_upd = Upsample(channels, False, dims) 167 | self.x_upd = Upsample(channels, False, dims) 168 | elif down: 169 | self.h_upd = Downsample(channels, False, dims) 170 | self.x_upd = Downsample(channels, False, dims) 171 | else: 172 | self.h_upd = self.x_upd = nn.Identity() 173 | 174 | self.emb_layers = nn.Sequential( 175 | nn.SiLU(), 176 | linear( 177 | emb_channels, 178 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 179 | ), 180 | ) 181 | self.out_layers = nn.Sequential( 182 | normalization(self.out_channels), 183 | nn.SiLU(), 184 | nn.Dropout(p=dropout), 185 | zero_module( 186 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 187 | ), 188 | ) 189 | 190 | if self.out_channels == channels: 191 | self.skip_connection = nn.Identity() 192 | elif use_conv: 193 | self.skip_connection = conv_nd( 194 | dims, channels, self.out_channels, 3, padding=1 195 | ) 196 | else: 197 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 198 | 199 | def forward(self, x, emb): 200 | """ 201 | Apply the block to a Tensor, conditioned on a timestep embedding. 202 | :param x: an [N x C x ...] Tensor of features. 203 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 204 | :return: an [N x C x ...] Tensor of outputs. 205 | """ 206 | return checkpoint( 207 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 208 | ) 209 | 210 | 211 | def _forward(self, x, emb): 212 | if self.updown: 213 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 214 | h = in_rest(x) 215 | h = self.h_upd(h) 216 | x = self.x_upd(x) 217 | h = in_conv(h) 218 | else: 219 | h = self.in_layers(x) 220 | emb_out = self.emb_layers(emb).type(h.dtype) 221 | while len(emb_out.shape) < len(h.shape): 222 | emb_out = emb_out[..., None] 223 | if self.use_scale_shift_norm: 224 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 225 | scale, shift = th.chunk(emb_out, 2, dim=1) 226 | h = out_norm(h) * (1 + scale) + shift 227 | h = out_rest(h) 228 | else: 229 | h = h + emb_out 230 | h = self.out_layers(h) 231 | return self.skip_connection(x) + h 232 | 233 | 234 | class UNetModel(nn.Module): 235 | """ 236 | The full UNet model with attention and timestep embedding. 237 | :param in_channels: channels in the input Tensor. 238 | :param model_channels: base channel count for the model. 239 | :param out_channels: channels in the output Tensor. 240 | :param num_res_blocks: number of residual blocks per downsample. 241 | :param attention_resolutions: a collection of downsample rates at which 242 | attention will take place. May be a set, list, or tuple. 243 | For example, if this contains 4, then at 4x downsampling, attention 244 | will be used. 245 | :param dropout: the dropout probability. 246 | :param channel_mult: channel multiplier for each level of the UNet. 247 | :param conv_resample: if True, use learned convolutions for upsampling and 248 | downsampling. 249 | :param dims: determines if the signal is 1D, 2D, or 3D. 250 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 251 | :param num_heads: the number of attention heads in each attention layer. 252 | :param num_heads_channels: if specified, ignore num_heads and instead use 253 | a fixed channel width per attention head. 254 | :param num_heads_upsample: works with num_heads to set a different number 255 | of heads for upsampling. Deprecated. 256 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 257 | :param resblock_updown: use residual blocks for up/downsampling. 258 | 259 | """ 260 | 261 | def __init__( 262 | self, 263 | model_channels, 264 | in_channels=1, 265 | out_channels=1, 266 | num_res_blocks=2, 267 | attention_resolutions=(1, 2, 4), 268 | context_ch=128, 269 | dropout=0, 270 | channel_mult=(1, 2, 4, 4), 271 | conv_resample=True, 272 | dims=3, 273 | use_checkpoint=False, 274 | use_fp16=False, 275 | num_heads=-1, 276 | num_head_channels=-1, 277 | num_heads_upsample=-1, 278 | use_scale_shift_norm=False, 279 | resblock_updown=False, 280 | legacy=True, 281 | num_timesteps=1 282 | ): 283 | super().__init__() 284 | 285 | if num_heads_upsample == -1: 286 | num_heads_upsample = num_heads 287 | 288 | if num_heads == -1: 289 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 290 | 291 | if num_head_channels == -1: 292 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 293 | 294 | self.in_channels = in_channels 295 | self.model_channels = model_channels 296 | self.out_channels = out_channels 297 | self.num_res_blocks = num_res_blocks 298 | self.attention_resolutions = attention_resolutions 299 | self.dropout = dropout 300 | self.channel_mult = channel_mult 301 | self.conv_resample = conv_resample 302 | self.use_checkpoint = use_checkpoint 303 | self.dtype = th.float16 if use_fp16 else th.float32 304 | self.num_heads = num_heads 305 | self.num_head_channels = num_head_channels 306 | self.num_heads_upsample = num_heads_upsample 307 | timesteps = th.arange(1, num_timesteps+1) 308 | 309 | time_embed_dim = model_channels * 4 310 | self.time_embed = nn.Sequential( 311 | linear(model_channels, time_embed_dim), 312 | nn.SiLU(), 313 | linear(time_embed_dim, time_embed_dim), 314 | ) 315 | 316 | self.input_blocks = nn.ModuleList( 317 | [ 318 | TimestepEmbedSequential( 319 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 320 | ) 321 | ] 322 | ) 323 | self._feature_size = model_channels 324 | input_block_chans = [model_channels] 325 | ch = model_channels 326 | ds = 1 327 | for level, mult in enumerate(channel_mult): 328 | for _ in range(num_res_blocks): 329 | layers = [ 330 | ResBlock( 331 | ch, 332 | time_embed_dim, 333 | dropout, 334 | out_channels=mult * model_channels, 335 | dims=dims, 336 | use_checkpoint=use_checkpoint, 337 | use_scale_shift_norm=use_scale_shift_norm, 338 | ) 339 | ] 340 | ch = mult * model_channels 341 | if ds in attention_resolutions: 342 | if num_head_channels == -1: 343 | dim_head = ch // num_heads 344 | else: 345 | num_heads = ch // num_head_channels 346 | dim_head = num_head_channels 347 | if legacy: 348 | dim_head = num_head_channels 349 | layers.append( 350 | AFNOCrossAttentionBlock3d( 351 | ch, context_dim=context_ch[level], num_blocks=num_heads, 352 | data_format="channels_first", timesteps=timesteps 353 | ) 354 | ) 355 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 356 | self._feature_size += ch 357 | input_block_chans.append(ch) 358 | if level != len(channel_mult) - 1: 359 | out_ch = ch 360 | self.input_blocks.append( 361 | TimestepEmbedSequential( 362 | ResBlock( 363 | ch, 364 | time_embed_dim, 365 | dropout, 366 | out_channels=out_ch, 367 | dims=dims, 368 | use_checkpoint=use_checkpoint, 369 | use_scale_shift_norm=use_scale_shift_norm, 370 | down=True, 371 | ) 372 | if resblock_updown 373 | else Downsample( 374 | ch, conv_resample, dims=dims, out_channels=out_ch 375 | ) 376 | ) 377 | ) 378 | ch = out_ch 379 | input_block_chans.append(ch) 380 | ds *= 2 381 | self._feature_size += ch 382 | 383 | if num_head_channels == -1: 384 | dim_head = ch // num_heads 385 | else: 386 | num_heads = ch // num_head_channels 387 | dim_head = num_head_channels 388 | if legacy: 389 | dim_head = num_head_channels 390 | self.middle_block = TimestepEmbedSequential( 391 | ResBlock( 392 | ch, 393 | time_embed_dim, 394 | dropout, 395 | dims=dims, 396 | use_checkpoint=use_checkpoint, 397 | use_scale_shift_norm=use_scale_shift_norm, 398 | ), 399 | AFNOCrossAttentionBlock3d( 400 | ch, context_dim=context_ch[-1], num_blocks=num_heads, 401 | data_format="channels_first", timesteps=timesteps 402 | ), 403 | ResBlock( 404 | ch, 405 | time_embed_dim, 406 | dropout, 407 | dims=dims, 408 | use_checkpoint=use_checkpoint, 409 | use_scale_shift_norm=use_scale_shift_norm, 410 | ), 411 | ) 412 | self._feature_size += ch 413 | 414 | self.output_blocks = nn.ModuleList([]) 415 | for level, mult in list(enumerate(channel_mult))[::-1]: 416 | for i in range(num_res_blocks + 1): 417 | ich = input_block_chans.pop() 418 | layers = [ 419 | ResBlock( 420 | ch + ich, 421 | time_embed_dim, 422 | dropout, 423 | out_channels=model_channels * mult, 424 | dims=dims, 425 | use_checkpoint=use_checkpoint, 426 | use_scale_shift_norm=use_scale_shift_norm, 427 | ) 428 | ] 429 | ch = model_channels * mult 430 | if ds in attention_resolutions: 431 | if num_head_channels == -1: 432 | dim_head = ch // num_heads 433 | else: 434 | num_heads = ch // num_head_channels 435 | dim_head = num_head_channels 436 | if legacy: 437 | #num_heads = 1 438 | dim_head = num_head_channels 439 | layers.append( 440 | AFNOCrossAttentionBlock3d( 441 | ch, context_dim=context_ch[level], num_blocks=num_heads, 442 | data_format="channels_first", timesteps=timesteps 443 | ) 444 | ) 445 | if level and i == num_res_blocks: 446 | out_ch = ch 447 | layers.append( 448 | ResBlock( 449 | ch, 450 | time_embed_dim, 451 | dropout, 452 | out_channels=out_ch, 453 | dims=dims, 454 | use_checkpoint=use_checkpoint, 455 | use_scale_shift_norm=use_scale_shift_norm, 456 | up=True, 457 | ) 458 | if resblock_updown 459 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 460 | ) 461 | ds //= 2 462 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 463 | self._feature_size += ch 464 | 465 | self.out = nn.Sequential( 466 | normalization(ch), 467 | nn.SiLU(), 468 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 469 | ) 470 | 471 | def forward(self, x, timesteps=None, context=None): 472 | """ 473 | Apply the model to an input batch. 474 | :param x: an [N x C x ...] Tensor of inputs. 475 | :param timesteps: a 1-D batch of timesteps. 476 | :param context: conditioning plugged in via crossattn 477 | :return: an [N x C x ...] Tensor of outputs. 478 | """ 479 | hs = [] 480 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 481 | emb = self.time_embed(t_emb) 482 | 483 | h = x.type(self.dtype) 484 | for module in self.input_blocks: 485 | h = module(h, emb, context) 486 | hs.append(h) 487 | h = self.middle_block(h, emb, context) 488 | for module in self.output_blocks: 489 | h = th.cat([h, hs.pop()], dim=1) 490 | h = module(h, emb, context) 491 | h = h.type(x.dtype) 492 | return self.out(h) 493 | -------------------------------------------------------------------------------- /SHADECast/Models/UNet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__( 9 | self, 10 | embedding_dim: tuple, 11 | dropout: float = 0.1, 12 | max_len: int = 1000, 13 | apply_dropout: bool = False, 14 | ): 15 | """Section 3.5 of attention is all you need paper. 16 | 17 | Extended slicing method is used to fill even and odd position of sin, cos with increment of 2. 18 | Ex, `[sin, cos, sin, cos, sin, cos]` for `embedding_dim = 6`. 19 | 20 | `max_len` is equivalent to number of noise steps or patches. `embedding_dim` must same as image 21 | embedding dimension of the model. 22 | 23 | Args: 24 | embedding_dim: `d_model` in given positional encoding formula. 25 | dropout: Dropout amount. 26 | max_len: Number of embeddings to generate. Here, equivalent to total noise steps. 27 | """ 28 | super(PositionalEncoding, self).__init__() 29 | self.dropout = nn.Dropout(p=dropout) 30 | self.apply_dropout = apply_dropout 31 | 32 | pos_encoding = torch.zeros(max_len, embedding_dim) 33 | position = torch.arange(start=0, end=max_len).unsqueeze(1) 34 | div_term = torch.exp(-math.log(10000.0) * torch.arange(0, embedding_dim, 2).float() / embedding_dim) 35 | 36 | pos_encoding[:, 0::2] = torch.sin(position * div_term) 37 | pos_encoding[:, 1::2] = torch.cos(position * div_term) 38 | self.register_buffer(name='pos_encoding', tensor=pos_encoding, persistent=False) 39 | 40 | def forward(self, t: torch.Tensor) -> torch.Tensor: 41 | """Get precalculated positional embedding at timestep t. Outputs same as video implementation 42 | code but embeddings are in [sin, cos, sin, cos] format instead of [sin, sin, cos, cos] in that code. 43 | Also batch dimension is added to final output. 44 | """ 45 | positional_encoding = self.pos_encoding[t].squeeze(1) 46 | if self.apply_dropout: 47 | return self.dropout(positional_encoding) 48 | return positional_encoding 49 | 50 | 51 | class DoubleConv(nn.Module): 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | out_channels: int, 56 | mid_channels: int = None, 57 | residual: bool = False 58 | ): 59 | """Double convolutions as applied in the unet paper architecture. 60 | """ 61 | super(DoubleConv, self).__init__() 62 | self.residual = residual 63 | if not mid_channels: 64 | mid_channels = out_channels 65 | 66 | self.double_conv = nn.Sequential( 67 | nn.Conv3d( 68 | in_channels=in_channels, out_channels=mid_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=False 69 | ), 70 | nn.GroupNorm(num_groups=1, num_channels=mid_channels), 71 | nn.GELU(), 72 | nn.Conv3d( 73 | in_channels=mid_channels, out_channels=out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), 74 | bias=False, 75 | ), 76 | nn.GroupNorm(num_groups=1, num_channels=out_channels), 77 | ) 78 | 79 | def forward(self, x: torch.Tensor) -> torch.Tensor: 80 | if self.residual: 81 | return F.gelu(x + self.double_conv(x)) 82 | 83 | return self.double_conv(x) 84 | 85 | 86 | class Down(nn.Module): 87 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 88 | super(Down, self).__init__() 89 | self.maxpool_conv = nn.Sequential( 90 | nn.MaxPool3d(kernel_size=(1, 2, 2)), 91 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 92 | DoubleConv(in_channels=in_channels, out_channels=out_channels), 93 | ) 94 | 95 | self.emb_layer = nn.Sequential( 96 | nn.SiLU(), 97 | nn.Linear(in_features=emb_dim, out_features=out_channels), 98 | ) 99 | 100 | def forward(self, x: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 101 | """Downsamples input tensor, calculates embedding and adds embedding channel wise. 102 | 103 | If, `x.shape == [4, 64, 64, 64]` and `out_channels = 128`, then max_conv outputs [4, 128, 32, 32] by 104 | downsampling in h, w and outputting specified amount of feature maps/channels. 105 | 106 | `t_embedding` is embedding of timestep of shape [batch, time_dim]. It is passed through embedding layer 107 | to output channel dimentsion equivalent to channel dimension of x tensor, so they can be summbed elementwise. 108 | 109 | Since emb_layer output needs to be summed its output is also `emb.shape == [4, 128]`. It needs to be converted 110 | to 4D tensor, [4, 128, 1, 1]. Then the channel dimension is duplicated in all of `H x W` dimension to get 111 | shape of [4, 128, 32, 32]. 128D vector is sample for each pixel position is image. Now the emb_layer output 112 | is summed with max_conv output. 113 | """ 114 | x = self.maxpool_conv(x) 115 | emb = self.emb_layer(t_embedding) 116 | emb = emb.view(emb.shape[0], emb.shape[1], 1, 1, 1).repeat(1, 1, x.shape[-3], x.shape[-2], x.shape[-1]) 117 | return x + emb 118 | 119 | 120 | class Up(nn.Module): 121 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 122 | super(Up, self).__init__() 123 | self.up = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True) 124 | self.conv = nn.Sequential( 125 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 126 | DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2), 127 | ) 128 | 129 | self.emb_layer = nn.Sequential( 130 | nn.SiLU(), 131 | nn.Linear(in_features=emb_dim, out_features=out_channels), 132 | ) 133 | 134 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 135 | x = self.up(x) 136 | x = torch.cat([x_skip, x], dim=1) 137 | x = self.conv(x) 138 | emb = self.emb_layer(t_embedding) 139 | emb = emb.view(emb.shape[0], emb.shape[1], 1, 1, 1).repeat(1, 1, x.shape[-3], x.shape[-2], x.shape[-1]) 140 | return x + emb 141 | 142 | 143 | # adopted from 144 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 145 | # and 146 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 147 | # and 148 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 149 | # 150 | # thanks! 151 | 152 | import os 153 | import math 154 | import torch 155 | import torch.nn as nn 156 | import numpy as np 157 | from einops import repeat 158 | 159 | 160 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 161 | if schedule == "linear": 162 | betas = ( 163 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 164 | ) 165 | 166 | elif schedule == "cosine": 167 | timesteps = ( 168 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 169 | ) 170 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 171 | alphas = torch.cos(alphas).pow(2) 172 | alphas = alphas / alphas[0] 173 | betas = 1 - alphas[1:] / alphas[:-1] 174 | betas = np.clip(betas, a_min=0, a_max=0.999) 175 | 176 | elif schedule == "sqrt_linear": 177 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 178 | elif schedule == "sqrt": 179 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 180 | else: 181 | raise ValueError(f"schedule '{schedule}' unknown.") 182 | return betas.numpy() 183 | 184 | 185 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 186 | if ddim_discr_method == 'uniform': 187 | c = num_ddpm_timesteps // num_ddim_timesteps 188 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 189 | elif ddim_discr_method == 'quad': 190 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 191 | else: 192 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 193 | 194 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 195 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 196 | steps_out = ddim_timesteps + 1 197 | if verbose: 198 | print(f'Selected timesteps for ddim sampler: {steps_out}') 199 | return steps_out 200 | 201 | 202 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 203 | # select alphas for computing the variance schedule 204 | alphas = alphacums[ddim_timesteps] 205 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 206 | 207 | # according the the formula provided in https://arxiv.org/abs/2010.02502 208 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 209 | if verbose: 210 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 211 | print(f'For the chosen value of eta, which is {eta}, ' 212 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 213 | return sigmas, alphas, alphas_prev 214 | 215 | 216 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 217 | """ 218 | Create a beta schedule that discretizes the given alpha_t_bar function, 219 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 220 | :param num_diffusion_timesteps: the number of betas to produce. 221 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 222 | produces the cumulative product of (1-beta) up to that 223 | part of the diffusion process. 224 | :param max_beta: the maximum beta to use; use values lower than 1 to 225 | prevent singularities. 226 | """ 227 | betas = [] 228 | for i in range(num_diffusion_timesteps): 229 | t1 = i / num_diffusion_timesteps 230 | t2 = (i + 1) / num_diffusion_timesteps 231 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 232 | return np.array(betas) 233 | 234 | 235 | def extract_into_tensor(a, t, x_shape): 236 | b, *_ = t.shape 237 | out = a.gather(-1, t) 238 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 239 | 240 | 241 | def checkpoint(func, inputs, params, flag): 242 | """ 243 | Evaluate a function without caching intermediate activations, allowing for 244 | reduced memory at the expense of extra compute in the backward pass. 245 | :param func: the function to evaluate. 246 | :param inputs: the argument sequence to pass to `func`. 247 | :param params: a sequence of parameters `func` depends on but does not 248 | explicitly take as arguments. 249 | :param flag: if False, disable gradient checkpointing. 250 | """ 251 | if flag: 252 | args = tuple(inputs) + tuple(params) 253 | return CheckpointFunction.apply(func, len(inputs), *args) 254 | else: 255 | return func(*inputs) 256 | 257 | 258 | class CheckpointFunction(torch.autograd.Function): 259 | @staticmethod 260 | def forward(ctx, run_function, length, *args): 261 | ctx.run_function = run_function 262 | ctx.input_tensors = list(args[:length]) 263 | ctx.input_params = list(args[length:]) 264 | 265 | with torch.no_grad(): 266 | output_tensors = ctx.run_function(*ctx.input_tensors) 267 | return output_tensors 268 | 269 | @staticmethod 270 | def backward(ctx, *output_grads): 271 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 272 | with torch.enable_grad(): 273 | # Fixes a bug where the first op in run_function modifies the 274 | # Tensor storage in place, which is not allowed for detach()'d 275 | # Tensors. 276 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 277 | output_tensors = ctx.run_function(*shallow_copies) 278 | input_grads = torch.autograd.grad( 279 | output_tensors, 280 | ctx.input_tensors + ctx.input_params, 281 | output_grads, 282 | allow_unused=True, 283 | ) 284 | del ctx.input_tensors 285 | del ctx.input_params 286 | del output_tensors 287 | return (None, None) + input_grads 288 | 289 | 290 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 291 | """ 292 | Create sinusoidal timestep embeddings. 293 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 294 | These may be fractional. 295 | :param dim: the dimension of the output. 296 | :param max_period: controls the minimum frequency of the embeddings. 297 | :return: an [N x dim] Tensor of positional embeddings. 298 | """ 299 | if not repeat_only: 300 | half = dim // 2 301 | freqs = torch.exp( 302 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 303 | ).to(device=timesteps.device) 304 | args = timesteps[:, None].float() * freqs[None] 305 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 306 | if dim % 2: 307 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 308 | else: 309 | embedding = repeat(timesteps, 'b -> b d', d=dim) 310 | return embedding 311 | 312 | 313 | def zero_module(module): 314 | """ 315 | Zero out the parameters of a module and return it. 316 | """ 317 | for p in module.parameters(): 318 | p.detach().zero_() 319 | return module 320 | 321 | 322 | def scale_module(module, scale): 323 | """ 324 | Scale the parameters of a module and return it. 325 | """ 326 | for p in module.parameters(): 327 | p.detach().mul_(scale) 328 | return module 329 | 330 | 331 | def mean_flat(tensor): 332 | """ 333 | Take the mean over all non-batch dimensions. 334 | """ 335 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 336 | 337 | 338 | class GroupNorm32(nn.GroupNorm): 339 | def forward(self, x): 340 | return super().forward(x.float()).type(x.dtype) 341 | 342 | 343 | def normalization(channels): 344 | """ 345 | Make a standard normalization layer. 346 | :param channels: number of input channels. 347 | :return: an nn.Module for normalization. 348 | """ 349 | return nn.Identity() #GroupNorm32(32, channels) 350 | 351 | 352 | def noise_like(shape, device, repeat=False): 353 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 354 | noise = lambda: torch.randn(shape, device=device) 355 | return repeat_noise() if repeat else noise() 356 | 357 | 358 | def conv_nd(dims, *args, **kwargs): 359 | """ 360 | Create a 1D, 2D, or 3D convolution module. 361 | """ 362 | if dims == 1: 363 | return nn.Conv1d(*args, **kwargs) 364 | elif dims == 2: 365 | return nn.Conv2d(*args, **kwargs) 366 | elif dims == 3: 367 | return nn.Conv3d(*args, **kwargs) 368 | raise ValueError(f"unsupported dimensions: {dims}") 369 | 370 | 371 | def linear(*args, **kwargs): 372 | """ 373 | Create a linear module. 374 | """ 375 | return nn.Linear(*args, **kwargs) 376 | 377 | 378 | def avg_pool_nd(dims, *args, **kwargs): 379 | """ 380 | Create a 1D, 2D, or 3D average pooling module. 381 | """ 382 | if dims == 1: 383 | return nn.AvgPool1d(*args, **kwargs) 384 | elif dims == 2: 385 | return nn.AvgPool2d(*args, **kwargs) 386 | elif dims == 3: 387 | return nn.AvgPool3d(*args, **kwargs) 388 | raise ValueError(f"unsupported dimensions: {dims}") 389 | -------------------------------------------------------------------------------- /SHADECast/Models/VAE/VariationalAutoEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | import numpy as np 5 | from SHADECast.Blocks.ResBlock3D import ResBlock3D 6 | from utils import sample_from_standard_normal, kl_from_standard_normal 7 | 8 | 9 | class Encoder(nn.Sequential): 10 | def __init__(self, in_dim=1, levels=2, min_ch=64, max_ch=64): 11 | sequence = [] 12 | channels = np.hstack((in_dim, np.arange(1, (levels + 1)) * min_ch)) 13 | channels[channels > max_ch] = max_ch 14 | for i in range(levels): 15 | in_channels = int(channels[i]) 16 | out_channels = int(channels[i + 1]) 17 | res_kernel_size = (3, 3, 3) if i == 0 else (1, 3, 3) 18 | res_block = ResBlock3D( 19 | in_channels, out_channels, 20 | kernel_size=res_kernel_size, 21 | norm_kwargs={"num_groups": 1} 22 | ) 23 | sequence.append(res_block) 24 | downsample = nn.Conv3d(out_channels, out_channels, 25 | kernel_size=(2, 2, 2), stride=(2, 2, 2)) 26 | sequence.append(downsample) 27 | 28 | super().__init__(*sequence) 29 | 30 | 31 | class Decoder(nn.Sequential): 32 | def __init__(self, in_dim=1, levels=2, min_ch=64, max_ch=64): 33 | sequence = [] 34 | channels = np.hstack((in_dim, np.arange(1, (levels + 1)) * min_ch)) 35 | channels[channels > max_ch] = max_ch 36 | for i in reversed(list(range(levels))): 37 | in_channels = int(channels[i + 1]) 38 | out_channels = int(channels[i]) 39 | upsample = nn.ConvTranspose3d(in_channels, in_channels, 40 | kernel_size=(2, 2, 2), stride=(2, 2, 2)) 41 | sequence.append(upsample) 42 | res_kernel_size = (3, 3, 3) if (i == 0) else (1, 3, 3) 43 | res_block = ResBlock3D( 44 | in_channels, out_channels, 45 | kernel_size=res_kernel_size, 46 | norm_kwargs={"num_groups": 1} 47 | ) 48 | sequence.append(res_block) 49 | 50 | super().__init__(*sequence) 51 | 52 | 53 | class VAE(pl.LightningModule): 54 | def __init__(self, 55 | encoder, 56 | decoder, 57 | kl_weight, 58 | encoded_channels, 59 | hidden_width, 60 | opt_patience, 61 | **kwargs): 62 | super().__init__(**kwargs) 63 | self.save_hyperparameters(ignore=['encoder', 'decoder']) 64 | self.encoder = encoder 65 | self.decoder = decoder 66 | self.hidden_width = hidden_width 67 | self.opt_patience = opt_patience 68 | self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, 69 | kernel_size=1) 70 | self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, 71 | kernel_size=1) 72 | 73 | self.log_var = nn.Parameter(torch.zeros(size=())) 74 | self.kl_weight = kl_weight 75 | 76 | def encode(self, x): 77 | h = self.encoder(x) 78 | (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) 79 | return mean, log_var 80 | 81 | def decode(self, z): 82 | z = self.to_decoder(z) 83 | dec = self.decoder(z) 84 | return dec 85 | 86 | def forward(self, x, sample_posterior=True): 87 | (mean, log_var) = self.encode(x) 88 | if sample_posterior: 89 | z = sample_from_standard_normal(mean, log_var) 90 | else: 91 | z = mean 92 | dec = self.decode(z) 93 | return dec, mean, log_var 94 | 95 | def _loss(self, batch): 96 | x = batch 97 | 98 | (y_pred, mean, log_var) = self.forward(x) 99 | 100 | rec_loss = (x - y_pred).abs().mean() 101 | kl_loss = kl_from_standard_normal(mean, log_var) 102 | 103 | total_loss = (1 - self.kl_weight) * rec_loss + self.kl_weight * kl_loss 104 | 105 | return total_loss, rec_loss, kl_loss 106 | 107 | def training_step(self, batch, batch_idx): 108 | loss = self._loss(batch)[0] 109 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 110 | self.log('train_loss', loss, **log_params) 111 | return loss 112 | 113 | @torch.no_grad() 114 | def val_test_step(self, batch, batch_idx, split="val"): 115 | (total_loss, rec_loss, kl_loss) = self._loss(batch) 116 | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True, "sync_dist": True} 117 | self.log(f"{split}_loss", total_loss, **log_params) 118 | self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params) 119 | self.log(f"{split}_kl_loss", kl_loss, **log_params) 120 | 121 | def validation_step(self, batch, batch_idx): 122 | self.val_test_step(batch, batch_idx, split="val") 123 | 124 | def test_step(self, batch, batch_idx): 125 | self.val_test_step(batch, batch_idx, split="test") 126 | 127 | def configure_optimizers(self): 128 | optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, 129 | betas=(0.5, 0.9), weight_decay=1e-3) 130 | reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( 131 | optimizer, patience=self.opt_patience, factor=0.25, verbose=True 132 | ) 133 | return { 134 | "optimizer": optimizer, 135 | "lr_scheduler": { 136 | "scheduler": reduce_lr, 137 | "monitor": "val_rec_loss", 138 | "frequency": 1, 139 | }, 140 | } 141 | -------------------------------------------------------------------------------- /SHADECast/Training/Nowcast_training/IrradianceNetTraining_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchinfo import summary 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from yaml import load, Loader 6 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 7 | # from pytorch_lightning import seed_everything 8 | # from Models.VAE.VariationalAutoEncoder import Encoder, Decoder, VAE 9 | # from Models.Nowcaster.Nowcast import AFNONowcastNet, Nowcaster 10 | from Benchmark.IrradianceNet import ConvLSTM_patch, IrradianceNet 11 | from Dataset.dataset import KIDataset 12 | from utils import save_pkl 13 | 14 | 15 | def get_dataloader(data_path, 16 | coordinate_data_path, 17 | n=12, 18 | min=0.05, 19 | max=1.2, 20 | length=100, 21 | norm_method='rescaling', 22 | num_workers=24, 23 | batch_size=64, 24 | shuffle=True, 25 | validation=False): 26 | dataset = KIDataset(data_path=data_path, 27 | n=n, 28 | min=min, 29 | max=max, 30 | length=length, 31 | norm_method=norm_method, 32 | coordinate_data_path=coordinate_data_path, 33 | return_all=False, 34 | forecast=True, 35 | validation=validation) 36 | dataloader = DataLoader(dataset, 37 | num_workers=num_workers, 38 | batch_size=batch_size, 39 | shuffle=shuffle) 40 | return dataloader 41 | 42 | 43 | def train(config, distributed=True): 44 | if distributed: 45 | num_nodes = int(os.environ['SLURM_NNODES']) 46 | rank = int(os.environ['SLURM_NODEID']) 47 | print(rank, num_nodes) 48 | else: 49 | rank = 0 50 | num_nodes = 1 51 | 52 | ID = config['ID'] 53 | save_pkl(config['Checkpoint']['dirpath'] + ID + '_config.pkl', config) 54 | if rank == 0: 55 | print(config) 56 | 57 | 58 | nowcaster_config = config['Nowcaster'] 59 | if rank == 0: 60 | print(nowcaster_config) 61 | 62 | nowcast_net = ConvLSTM_patch(in_chan=1, image_size=128, device='cuda', seq_len=8) 63 | irradiance_net = IrradianceNet(nowcast_net, 64 | opt_patience=nowcaster_config['opt_patience']) 65 | 66 | if rank == 0: 67 | print('All models built') 68 | summary(irradiance_net) 69 | 70 | ckpt_config = config['Checkpoint'] 71 | checkpoint_callback = ModelCheckpoint( 72 | monitor=ckpt_config['monitor'], 73 | dirpath=ckpt_config['dirpath'], 74 | filename=ID + '_' + ckpt_config['filename'], 75 | save_top_k=ckpt_config['save_top_k'], 76 | every_n_epochs=ckpt_config['every_n_epochs'] 77 | ) 78 | 79 | early_stop_callback = EarlyStopping(monitor=ckpt_config['monitor'], 80 | patience=config['EarlyStopping']['patience']) 81 | 82 | tr_config = config['Trainer'] 83 | trainer = pl.Trainer( 84 | default_root_dir=ckpt_config['dirpath'], 85 | accelerator=tr_config['accelerator'], 86 | devices=tr_config['devices'], 87 | num_nodes=num_nodes, 88 | max_epochs=tr_config['max_epochs'], 89 | callbacks=[checkpoint_callback, early_stop_callback], 90 | strategy=tr_config['strategy'], 91 | precision=tr_config['precision'], 92 | enable_progress_bar=(rank == 0), 93 | accumulate_grad_batches=tr_config['accumulate_grad_batches'] 94 | # deterministic=False 95 | ) 96 | if rank == 0: 97 | print('Trainer built') 98 | data_config = config['Dataset'] 99 | train_dataloader = get_dataloader(data_path=data_config['data_path'] + 'TrainingSet/KI/', 100 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 101 | n=data_config['n_in']+data_config['n_out'], 102 | min=data_config['min'], 103 | max=data_config['max'], 104 | length=data_config['train_length'], 105 | num_workers=data_config['num_workers'], 106 | norm_method=data_config['norm_method'], 107 | batch_size=data_config['batch_size'], 108 | shuffle=True, 109 | validation=False) 110 | 111 | val_dataloader = get_dataloader(data_path=data_config['data_path'] + 'ValidationSet/KI/', 112 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 113 | n=data_config['n_in']+data_config['n_out'], 114 | min=data_config['min'], 115 | max=data_config['max'], 116 | length=data_config['val_length'], 117 | num_workers=data_config['num_workers'], 118 | norm_method=data_config['norm_method'], 119 | batch_size=data_config['batch_size'], 120 | shuffle=False, 121 | validation=True) 122 | if rank == 0: 123 | print('Training started') 124 | resume_training = tr_config['resume_training'] 125 | if resume_training is None: 126 | trainer.fit(irradiance_net, train_dataloader, val_dataloader) 127 | else: 128 | trainer.fit(irradiance_net, train_dataloader, val_dataloader, 129 | ckpt_path=resume_training) 130 | 131 | 132 | if __name__ == '__main__': 133 | with open('/scratch/snx3000/acarpent/GenerativeNowcasting/SHADECast/Training/Nowcast_training/IrradianceNettrainingconf.yml', 134 | 'r') as o: 135 | config = load(o, Loader) 136 | 137 | # seed_everything(0, workers=0) 138 | 139 | train(config) 140 | -------------------------------------------------------------------------------- /SHADECast/Training/Nowcast_training/IrradianceNettrainingconf.yml: -------------------------------------------------------------------------------- 1 | ID: 'IN2' 2 | Dataset: 3 | data_path: '/scratch/snx3000/acarpent/HelioMontDataset/' 4 | n_in: 4 5 | n_out: 8 6 | batch_size: 10 7 | num_workers: 24 8 | train_length: 9 | val_length: 10 | norm_method: 'rescaling' 11 | min: 0.05 12 | max: 1.2 13 | mean: 0.6 14 | std: 0.3 15 | 16 | 17 | Nowcaster: 18 | opt_patience: 5 19 | 20 | EarlyStopping: 21 | patience: 10 22 | 23 | Checkpoint: 24 | dirpath: '/scratch/snx3000/acarpent/Logs/SHADECast/NowcasterTrainingLogs/' 25 | filename: '{epoch}-{val_loss:.5f}' 26 | monitor: 'val_loss' 27 | every_n_epochs: 1 28 | save_top_k: 3 29 | 30 | Trainer: 31 | accelerator: 'gpu' 32 | precision: 16 33 | devices: 1 34 | max_epochs: 1000 35 | strategy: 'ddp' 36 | accumulate_grad_batches: 1 37 | resume_training: '/scratch/snx3000/acarpent/Logs/SHADECast/NowcasterTrainingLogs/IN2_epoch=66-val_loss=0.08655.ckpt' 38 | -------------------------------------------------------------------------------- /SHADECast/Training/Nowcast_training/NowcasterTraining_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchinfo import summary 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from yaml import load, Loader 6 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 7 | # from pytorch_lightning import seed_everything 8 | from Models.VAE.VariationalAutoEncoder import Encoder, Decoder, VAE 9 | from Models.Nowcaster.Nowcast import AFNONowcastNet, Nowcaster 10 | from Dataset.dataset import KIDataset 11 | from utils import save_pkl 12 | 13 | 14 | def get_dataloader(data_path, 15 | coordinate_data_path, 16 | n=12, 17 | min=0.05, 18 | max=1.2, 19 | length=None, 20 | norm_method='rescaling', 21 | num_workers=24, 22 | batch_size=64, 23 | shuffle=True, 24 | validation=False): 25 | dataset = KIDataset(data_path=data_path, 26 | n=n, 27 | min=min, 28 | max=max, 29 | length=length, 30 | norm_method=norm_method, 31 | coordinate_data_path=coordinate_data_path, 32 | return_all=False, 33 | forecast=True, 34 | validation=validation) 35 | dataloader = DataLoader(dataset, 36 | num_workers=num_workers, 37 | batch_size=batch_size, 38 | shuffle=shuffle) 39 | return dataloader 40 | 41 | 42 | 43 | def train(config, distributed=True): 44 | if distributed: 45 | num_nodes = int(os.environ['SLURM_NNODES']) 46 | rank = int(os.environ['SLURM_NODEID']) 47 | print(rank, num_nodes) 48 | else: 49 | rank = 0 50 | num_nodes = 1 51 | 52 | ID = config['ID'] 53 | save_pkl(config['Checkpoint']['dirpath'] + ID + '_config.pkl', config) 54 | if rank == 0: 55 | print(config) 56 | 57 | encoder_config = config['Encoder'] 58 | encoder = Encoder(in_dim=encoder_config['in_dim'], 59 | levels=encoder_config['levels'], 60 | min_ch=encoder_config['min_ch'], 61 | max_ch=encoder_config['max_ch']) 62 | if rank == 0: 63 | print('Encoder built') 64 | 65 | decoder_config = config['Decoder'] 66 | decoder = Decoder(in_dim=decoder_config['in_dim'], 67 | levels=decoder_config['levels'], 68 | min_ch=decoder_config['min_ch'], 69 | max_ch=decoder_config['max_ch']) 70 | if rank == 0: 71 | print('Decoder built') 72 | 73 | vae_config = config['VAE'] 74 | if vae_config['path'] is not None: 75 | vae = VAE.load_from_checkpoint(vae_config['path'], 76 | encoder=encoder, decoder=decoder) 77 | train_autoencoder = False 78 | else: 79 | vae = VAE(encoder, 80 | decoder, 81 | kl_weight=vae_config['kl_weight'], 82 | encoded_channels=encoder_config['max_ch'], 83 | hidden_width=vae_config['hidden_width']) 84 | train_autoencoder = True 85 | if rank == 0: 86 | print('VAE built') 87 | 88 | nowcaster_config = config['Nowcaster'] 89 | if rank == 0: 90 | print(nowcaster_config) 91 | 92 | nowcast_net = AFNONowcastNet(vae, 93 | train_autoenc=train_autoencoder, 94 | embed_dim=nowcaster_config['embed_dim'], 95 | embed_dim_out=nowcaster_config['embed_dim'], 96 | analysis_depth=nowcaster_config['analysis_depth'], 97 | forecast_depth=nowcaster_config['forecast_depth'], 98 | input_steps=nowcaster_config['input_steps'], 99 | output_steps=nowcaster_config['output_steps']) 100 | nowcaster = Nowcaster(nowcast_net=nowcast_net, 101 | opt_patience=nowcaster_config['opt_patience'], 102 | loss_type=nowcaster_config['loss_type']) 103 | 104 | if rank == 0: 105 | print('All models built') 106 | summary(nowcaster) 107 | 108 | ckpt_config = config['Checkpoint'] 109 | checkpoint_callback = ModelCheckpoint( 110 | monitor=ckpt_config['monitor'], 111 | dirpath=ckpt_config['dirpath'], 112 | filename=ID + '_' + ckpt_config['filename'], 113 | save_top_k=ckpt_config['save_top_k'], 114 | every_n_epochs=ckpt_config['every_n_epochs'] 115 | ) 116 | 117 | early_stop_callback = EarlyStopping(monitor=ckpt_config['monitor'], 118 | patience=config['EarlyStopping']['patience']) 119 | 120 | tr_config = config['Trainer'] 121 | trainer = pl.Trainer( 122 | default_root_dir=ckpt_config['dirpath'], 123 | accelerator=tr_config['accelerator'], 124 | devices=tr_config['devices'], 125 | num_nodes=num_nodes, 126 | max_epochs=tr_config['max_epochs'], 127 | callbacks=[checkpoint_callback, early_stop_callback], 128 | strategy=tr_config['strategy'], 129 | precision=tr_config['precision'], 130 | enable_progress_bar=(rank == 0), 131 | accumulate_grad_batches=tr_config['accumulate_grad_batches'] 132 | # deterministic=False 133 | ) 134 | if rank == 0: 135 | print('Trainer built') 136 | data_config = config['Dataset'] 137 | train_dataloader = get_dataloader(data_path=data_config['data_path'] + 'TrainingSet/KI/', 138 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 139 | n=data_config['n_in'] + data_config['n_out'], 140 | min=data_config['min'], 141 | max=data_config['max'], 142 | length=data_config['train_length'], 143 | num_workers=data_config['num_workers'], 144 | norm_method=data_config['norm_method'], 145 | batch_size=data_config['batch_size'], 146 | shuffle=True, 147 | validation=False) 148 | 149 | val_dataloader = get_dataloader(data_path=data_config['data_path'] + 'ValidationSet/KI/', 150 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 151 | n=data_config['n_in'] + data_config['n_out'], 152 | min=data_config['min'], 153 | max=data_config['max'], 154 | length=data_config['val_length'], 155 | num_workers=data_config['num_workers'], 156 | norm_method=data_config['norm_method'], 157 | batch_size=data_config['batch_size'], 158 | shuffle=False, 159 | validation=True) 160 | if rank == 0: 161 | print('Training started') 162 | resume_training = tr_config['resume_training'] 163 | if resume_training is None: 164 | trainer.fit(nowcaster, train_dataloader, val_dataloader) 165 | else: 166 | trainer.fit(nowcaster, train_dataloader, val_dataloader, 167 | ckpt_path=resume_training) 168 | 169 | 170 | if __name__ == '__main__': 171 | with open('Training/Nowcast_training/Nowcastertrainingconf.yml', 172 | 'r') as o: 173 | config = load(o, Loader) 174 | 175 | # seed_everything(0, workers=0) 176 | 177 | train(config) 178 | -------------------------------------------------------------------------------- /SHADECast/Training/Nowcast_training/Nowcastertrainingconf.yml: -------------------------------------------------------------------------------- 1 | ID: 'UN3-VAE3' 2 | Dataset: 3 | data_path: '/scratch/snx3000/acarpent/HelioMontDataset/' 4 | n_in: 4 5 | n_out: 8 6 | batch_size: 20 7 | num_workers: 24 8 | train_length: 9 | val_length: 10 | norm_method: 'rescaling' 11 | min: 0.05 12 | max: 1.2 13 | mean: 0.6 14 | std: 0.3 15 | 16 | Encoder: 17 | in_dim: 1 18 | levels: 2 19 | min_ch: 64 20 | max_ch: 128 21 | 22 | Decoder: # not used at the moment (symmetrical architecture) 23 | in_dim: 1 24 | levels: 2 25 | min_ch: 64 26 | max_ch: 128 27 | 28 | VAE: 29 | kl_weight: 0.01 30 | hidden_width: 32 # arbitrary value 31 | path: '/scratch/snx3000/acarpent/VAETrainingLogs/VAE3-CF=2-l=12_epoch=110-val_rec_loss=0.01699-val_kl_loss=1.88702.ckpt' 32 | 33 | Nowcaster: 34 | embed_dim: 256 35 | forecast_depth: 6 36 | analysis_depth: 6 37 | input_steps: 1 38 | output_steps: 4 39 | opt_patience: 5 40 | loss_type: 'latent' 41 | 42 | EarlyStopping: 43 | patience: 10 44 | 45 | Checkpoint: 46 | dirpath: '/scratch/snx3000/acarpent/NowcasterTrainingLogs/' 47 | filename: '{epoch}-{val_loss:.5f}' 48 | monitor: 'val_loss' 49 | every_n_epochs: 1 50 | save_top_k: 3 51 | 52 | Trainer: 53 | accelerator: 'gpu' 54 | precision: 16 55 | devices: 1 56 | max_epochs: 1000 57 | strategy: 'ddp' 58 | accumulate_grad_batches: 1 59 | resume_training: 60 | -------------------------------------------------------------------------------- /SHADECast/Training/SHADECastTraining.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from torchinfo import summary 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import seed_everything 8 | from torch.utils.data import DataLoader 9 | from yaml import load, Loader 10 | from torchinfo import summary 11 | import os 12 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 13 | 14 | from utils import save_pkl 15 | from Dataset.dataset import KIDataset 16 | from Models.Nowcaster.Nowcast import AFNONowcastNetCascade, Nowcaster, AFNONowcastNet 17 | from Models.VAE.VariationalAutoEncoder import VAE, Encoder, Decoder 18 | from Models.UNet.UNet import UNetModel 19 | from Models.Diffusion.DiffusionModel import LatentDiffusion 20 | 21 | 22 | def get_dataloader(data_path, 23 | coordinate_data_path, 24 | n=12, 25 | min=0.05, 26 | max=1.2, 27 | length=None, 28 | norm_method='rescaling', 29 | num_workers=24, 30 | batch_size=64, 31 | shuffle=True, 32 | validation=False, 33 | return_t=False): 34 | dataset = KIDataset(data_path=data_path, 35 | n=n, 36 | min=min, 37 | max=max, 38 | length=length, 39 | norm_method=norm_method, 40 | coordinate_data_path=coordinate_data_path, 41 | return_all=False, 42 | forecast=True, 43 | validation=validation, 44 | return_t=return_t) 45 | dataloader = DataLoader(dataset, 46 | num_workers=num_workers, 47 | batch_size=batch_size, 48 | shuffle=shuffle) 49 | return dataloader 50 | 51 | 52 | def train(config, distributed=True): 53 | if distributed: 54 | num_nodes = int(os.environ['SLURM_NNODES']) 55 | rank = int(os.environ['SLURM_NODEID']) 56 | print(rank, num_nodes) 57 | else: 58 | rank = 0 59 | num_nodes = 1 60 | 61 | ID = config['ID'] 62 | save_pkl(config['Checkpoint']['dirpath'] + ID + '_config.pkl', config) 63 | if rank == 0: 64 | print(config) 65 | 66 | encoder_config = config['Encoder'] 67 | encoder = Encoder(in_dim=encoder_config['in_dim'], 68 | levels=encoder_config['levels'], 69 | min_ch=encoder_config['min_ch'], 70 | max_ch=encoder_config['max_ch']) 71 | if rank == 0: 72 | print('Encoder built') 73 | 74 | decoder_config = config['Decoder'] 75 | decoder = Decoder(in_dim=decoder_config['in_dim'], 76 | levels=decoder_config['levels'], 77 | min_ch=decoder_config['min_ch'], 78 | max_ch=decoder_config['max_ch']) 79 | if rank == 0: 80 | print('Decoder built') 81 | 82 | vae_config = config['VAE'] 83 | vae = VAE.load_from_checkpoint(vae_config['path'], 84 | encoder=encoder, decoder=decoder, 85 | opt_patience=vae_config['opt_patience']) 86 | if rank == 0: 87 | print('VAE built') 88 | 89 | nowcaster_config = config['Nowcaster'] 90 | print(nowcaster_config['path']) 91 | if nowcaster_config['path'] is None: 92 | nowcast_net = AFNONowcastNet(vae, 93 | train_autoenc=False, 94 | embed_dim=nowcaster_config['embed_dim'], 95 | embed_dim_out=nowcaster_config['embed_dim'], 96 | analysis_depth=nowcaster_config['analysis_depth'], 97 | forecast_depth=nowcaster_config['forecast_depth'], 98 | input_steps=nowcaster_config['input_steps'], 99 | output_steps=nowcaster_config['output_steps'], 100 | # opt_patience=nowcaster_config['opt_patience'], 101 | # loss_type=nowcaster_config['loss_type'] 102 | ) 103 | train_nowcast = True 104 | else: 105 | nowcast_net = AFNONowcastNet(vae, 106 | train_autoenc=False, 107 | embed_dim=nowcaster_config['embed_dim'], 108 | embed_dim_out=nowcaster_config['embed_dim'], 109 | analysis_depth=nowcaster_config['analysis_depth'], 110 | forecast_depth=nowcaster_config['forecast_depth'], 111 | input_steps=nowcaster_config['input_steps'], 112 | output_steps=nowcaster_config['output_steps'], 113 | # opt_patience=nowcaster_config['opt_patience'], 114 | # loss_type=nowcaster_config['loss_type'] 115 | ) 116 | nowcaster = Nowcaster.load_from_checkpoint(nowcaster_config['path'], nowcast_net=nowcast_net, 117 | opt_patience=nowcaster_config['opt_patience'], 118 | loss_type=nowcaster_config['loss_type']) 119 | nowcast_net = nowcaster.nowcast_net 120 | train_nowcast = False 121 | 122 | print('Nowcaster built, train: ', nowcaster_config['path']) 123 | cascade_net = AFNONowcastNetCascade(nowcast_net=nowcast_net, 124 | cascade_depth=nowcaster_config['cascade_depth'], 125 | train_net=train_nowcast) 126 | if rank == 0: 127 | summary(nowcast_net) 128 | # if nowcaster_config['path'] is not None: 129 | # nowcaster = Nowcaster.load_from_checkpoint(nowcaster_config['path'], 130 | # nowcast_net=nowcast_net, 131 | # autoencoder=vae) 132 | if rank == 0: 133 | print('Nowcaster built') 134 | 135 | diffusion_config = config['Diffusion'] 136 | denoiser = UNetModel( 137 | in_channels=vae.hidden_width, 138 | model_channels=diffusion_config['model_channels'], 139 | out_channels=vae.hidden_width, 140 | num_res_blocks=diffusion_config['num_res_blocks'], 141 | attention_resolutions=diffusion_config['attention_resolutions'], 142 | dims=diffusion_config['dims'], 143 | channel_mult=diffusion_config['channel_mult'], 144 | num_heads=8, 145 | num_timesteps=2, 146 | context_ch=cascade_net.cascade_dims) 147 | 148 | ldm = LatentDiffusion(model=denoiser, 149 | autoencoder=vae, 150 | context_encoder=cascade_net, 151 | beta_schedule=diffusion_config['scheduler'], 152 | loss_type="l2", 153 | use_ema=diffusion_config['use_ema'], 154 | lr_warmup=0, 155 | linear_start=1e-4, 156 | linear_end=2e-2, 157 | cosine_s=8e-3, 158 | parameterization='eps', 159 | lr=diffusion_config['lr'], 160 | timesteps=diffusion_config['noise_steps'], 161 | opt_patience=diffusion_config['opt_patience'], 162 | get_t=config['Dataset']['get_t'], 163 | ) 164 | if rank == 0: 165 | print('All models built') 166 | summary(ldm) 167 | 168 | ckpt_config = config['Checkpoint'] 169 | checkpoint_callback = ModelCheckpoint( 170 | monitor=ckpt_config['monitor'], 171 | dirpath=ckpt_config['dirpath'], 172 | filename=ID + '_' + ckpt_config['filename'], 173 | save_top_k=ckpt_config['save_top_k'], 174 | every_n_epochs=ckpt_config['every_n_epochs'] 175 | ) 176 | 177 | early_stop_callback = EarlyStopping(monitor=ckpt_config['monitor'], 178 | patience=config['EarlyStopping']['patience']) 179 | 180 | tr_config = config['Trainer'] 181 | trainer = pl.Trainer( 182 | default_root_dir=ckpt_config['dirpath'], 183 | accelerator=tr_config['accelerator'], 184 | devices=tr_config['devices'], 185 | num_nodes=num_nodes, 186 | max_epochs=tr_config['max_epochs'], 187 | callbacks=[checkpoint_callback, early_stop_callback], 188 | strategy=tr_config['strategy'], 189 | precision=tr_config['precision'], 190 | enable_progress_bar=(rank == 0), 191 | deterministic=False, 192 | accumulate_grad_batches=tr_config['accumulate_grad_batches'] 193 | ) 194 | if rank == 0: 195 | print('Trainer built') 196 | data_config = config['Dataset'] 197 | train_dataloader = get_dataloader(data_path=data_config['data_path'] + 'TrainingSet/KI/', 198 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 199 | n=data_config['n_in'] + data_config['n_out'], 200 | min=data_config['min'], 201 | max=data_config['max'], 202 | length=data_config['train_length'], 203 | num_workers=data_config['num_workers'], 204 | norm_method=data_config['norm_method'], 205 | batch_size=data_config['batch_size'], 206 | shuffle=True, 207 | validation=False) 208 | 209 | val_dataloader = get_dataloader(data_path=data_config['data_path'] + 'ValidationSet/KI/', 210 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 211 | n=data_config['n_in'] + data_config['n_out'], 212 | min=data_config['min'], 213 | max=data_config['max'], 214 | length=data_config['val_length'], 215 | num_workers=data_config['num_workers'], 216 | norm_method=data_config['norm_method'], 217 | batch_size=data_config['batch_size'], 218 | shuffle=False, 219 | validation=True, 220 | return_t=data_config['get_t']) 221 | if rank == 0: 222 | print('Training started') 223 | resume_training = tr_config['resume_training'] 224 | torch.cuda.empty_cache() 225 | if resume_training is None: 226 | trainer.fit(ldm, train_dataloader, val_dataloader) 227 | else: 228 | # if tr_config['resume_training'] is False: 229 | trainer.fit(ldm, train_dataloader, val_dataloader, 230 | ckpt_path=resume_training) 231 | # else: 232 | 233 | 234 | 235 | if __name__ == '__main__': 236 | with open('SHADECastTrainingconf.yml', 237 | 'r') as o: 238 | config = load(o, Loader) 239 | seed = config['seed'] 240 | if seed is not None: 241 | seed_everything(int(seed), workers=True) 242 | 243 | train(config) 244 | -------------------------------------------------------------------------------- /SHADECast/Training/SHADECastTrainingconf.yml: -------------------------------------------------------------------------------- 1 | ID: 'SHADECast' # insert your model ID 2 | seed: 0 3 | Dataset: 4 | data_path: 'path to your dataset' 5 | n_in: 4 6 | n_out: 8 7 | batch_size: 4 8 | num_workers: 24 9 | train_length: 100 10 | val_length: 100 11 | norm_method: 'rescaling' 12 | min: 0.05 13 | max: 1.2 14 | mean: 0.6 15 | std: 0.3 16 | get_t: True 17 | 18 | Encoder: 19 | in_dim: 1 20 | levels: 2 21 | min_ch: 64 22 | max_ch: 128 23 | 24 | Decoder: # not used at the moment (symmetrical architecture) 25 | in_dim: 1 26 | levels: 2 27 | min_ch: 64 28 | max_ch: 128 29 | 30 | VAE: 31 | path: 'path to your pre trained autoencoder' 32 | hidden_width: 32 33 | opt_patience: 5 34 | 35 | Nowcaster: 36 | embed_dim: 256 37 | forecast_depth: 4 38 | analysis_depth: 4 39 | input_steps: 1 40 | output_steps: 2 41 | opt_patience: 5 42 | loss_type: 'latent' 43 | cascade_depth: 3 44 | path: 'path to your pre trained nowcaster otherwise None will train the nowcaster with the diffusion model' 45 | 46 | Diffusion: 47 | model_channels: 256 48 | lr: 0.0001 49 | noise_steps: 1000 50 | scheduler: 'linear' 51 | use_ema: True 52 | opt_patience: 5 53 | num_res_blocks: 2 54 | attention_resolutions: [1, 2] 55 | dims: 3 56 | channel_mult: [1, 2, 2] 57 | 58 | EarlyStopping: 59 | patience: 10 60 | 61 | Checkpoint: 62 | dirpath: 63 | filename: 64 | monitor: 'val_loss_ema' 65 | every_n_epochs: 1 66 | save_top_k: 3 67 | 68 | Trainer: 69 | accelerator: 'gpu' 70 | precision: 16 71 | devices: 1 72 | max_epochs: 1000 73 | strategy: 'ddp' 74 | accumulate_grad_batches: 1 75 | resume_training: -------------------------------------------------------------------------------- /SHADECast/Training/VAE_training/VAETraining_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchinfo import summary 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from yaml import load, Loader 6 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 7 | # from pytorch_lightning import seed_everything 8 | from Models.VAE.VariationalAutoEncoder import Encoder, Decoder, VAE 9 | from Dataset.dataset import KIDataset 10 | from utils import save_pkl 11 | 12 | 13 | def get_dataloader(data_path, 14 | coordinate_data_path, 15 | n=12, 16 | min=0.05, 17 | max=1.2, 18 | length=None, 19 | norm_method='rescaling', 20 | num_workers=24, 21 | batch_size=64, 22 | shuffle=True): 23 | dataset = KIDataset(data_path=data_path, 24 | n=n, 25 | min=min, 26 | max=max, 27 | length=length, 28 | norm_method=norm_method, 29 | coordinate_data_path=coordinate_data_path, 30 | return_all=False) 31 | dataloader = DataLoader(dataset, 32 | num_workers=num_workers, 33 | batch_size=batch_size, 34 | shuffle=shuffle) 35 | return dataloader 36 | 37 | 38 | def train(config): 39 | num_nodes = int(os.environ['SLURM_NNODES']) 40 | rank = int(os.environ['SLURM_NODEID']) 41 | print(rank, num_nodes) 42 | 43 | ID = config['ID'] 44 | save_pkl(config['Checkpoint']['dirpath'] + ID + '_config.pkl', config), 45 | print(config) 46 | 47 | encoder_config = config['Encoder'] 48 | encoder = Encoder(in_dim=encoder_config['in_dim'], 49 | levels=encoder_config['levels'], 50 | min_ch=encoder_config['min_ch'], 51 | max_ch=encoder_config['max_ch']) 52 | print('Encoder built') 53 | 54 | decoder_config = config['Decoder'] 55 | decoder = Decoder(in_dim=decoder_config['in_dim'], 56 | levels=decoder_config['levels'], 57 | min_ch=decoder_config['min_ch'], 58 | max_ch=decoder_config['max_ch']) 59 | print('Decoder built') 60 | 61 | vae_config = config['VAE'] 62 | vae = VAE(encoder, 63 | decoder, 64 | kl_weight=vae_config['kl_weight'], 65 | encoded_channels=encoder_config['max_ch'], 66 | hidden_width=vae_config['hidden_width'], 67 | opt_patience=vae_config['opt_patience']) 68 | print('All models built') 69 | 70 | batch_size = config['Dataset']['batch_size'] 71 | n_steps = config['Dataset']['n_steps'] 72 | if rank == 0: 73 | summary(vae, input_size=(batch_size, 1, n_steps, 128, 128)) 74 | 75 | ckpt_config = config['Checkpoint'] 76 | checkpoint_callback = ModelCheckpoint( 77 | monitor=ckpt_config['monitor'], 78 | dirpath=ckpt_config['dirpath'], 79 | filename=ID + '_' + ckpt_config['filename'], 80 | save_top_k=ckpt_config['save_top_k'], 81 | every_n_epochs=ckpt_config['every_n_epochs'] 82 | ) 83 | 84 | early_stop_callback = EarlyStopping(monitor=ckpt_config['monitor'], 85 | patience=config['EarlyStopping']['patience']) 86 | 87 | tr_config = config['Trainer'] 88 | trainer = pl.Trainer( 89 | default_root_dir=ckpt_config['dirpath'], 90 | accelerator=tr_config['accelerator'], 91 | devices=tr_config['devices'], 92 | num_nodes=num_nodes, 93 | max_epochs=tr_config['max_epochs'], 94 | callbacks=[checkpoint_callback, early_stop_callback], 95 | strategy=tr_config['strategy'], 96 | precision=tr_config['precision'], 97 | enable_progress_bar=(rank == 0), 98 | deterministic=True 99 | ) 100 | 101 | data_config = config['Dataset'] 102 | train_dataloader = get_dataloader(data_path=data_config['data_path'] + 'TrainingSet/KI/', 103 | coordinate_data_path=data_config['data_path']+'CoordinateData/', 104 | n=data_config['n_steps'], 105 | min=data_config['min'], 106 | max=data_config['max'], 107 | length=data_config['train_length'], 108 | num_workers=data_config['num_workers'], 109 | norm_method=data_config['norm_method'], 110 | batch_size=data_config['batch_size'], 111 | shuffle=True) 112 | 113 | val_dataloader = get_dataloader(data_path=data_config['data_path'] + 'ValidationSet/KI/', 114 | coordinate_data_path=data_config['data_path'] + 'CoordinateData/', 115 | n=data_config['n_steps'], 116 | min=data_config['min'], 117 | max=data_config['max'], 118 | length=data_config['val_length'], 119 | num_workers=data_config['num_workers'], 120 | norm_method=data_config['norm_method'], 121 | batch_size=data_config['batch_size'], 122 | shuffle=False) 123 | print('Training started') 124 | 125 | resume_training = tr_config['resume_training'] 126 | if resume_training is None: 127 | trainer.fit(vae, train_dataloader, val_dataloader) 128 | else: 129 | trainer.fit(vae, train_dataloader, val_dataloader, 130 | ckpt_path=resume_training) 131 | 132 | 133 | if __name__ == '__main__': 134 | with open('VAEtrainingconf.yml', 'r') as o: 135 | config = load(o, Loader) 136 | 137 | # seed_everything(0, workers=0) 138 | train(config) 139 | -------------------------------------------------------------------------------- /SHADECast/Training/VAE_training/VAEtrainingconf.yml: -------------------------------------------------------------------------------- 1 | ID: 'VAE7-CF=2-l=12' 2 | Dataset: 3 | data_path: '/scratch/snx3000/acarpent/HelioMontDataset/' 4 | n_steps: 12 5 | batch_size: 32 6 | num_workers: 24 7 | train_length: 8 | val_length: 9 | norm_method: 'rescaling' 10 | min: 0.05 11 | max: 1.2 12 | mean: 0.6 13 | std: 0.3 14 | 15 | Encoder: 16 | in_dim: 1 17 | levels: 2 18 | min_ch: 64 19 | max_ch: 128 20 | 21 | Decoder: # not used at the moment (symmetrical architecture) 22 | in_dim: 1 23 | levels: 2 24 | min_ch: 64 25 | max_ch: 128 26 | 27 | VAE: 28 | kl_weight: 0.01 29 | hidden_width: 32 # arbitrary value 30 | opt_patience: 5 31 | 32 | EarlyStopping: 33 | patience: 10 34 | 35 | Checkpoint: 36 | dirpath: '/scratch/snx3000/acarpent/VAETrainingLogs' 37 | filename: '{epoch}-{val_rec_loss:.5f}-{val_kl_loss:.5f}' 38 | monitor: 'val_rec_loss' 39 | every_n_epochs: 1 40 | save_top_k: 3 41 | 42 | Trainer: 43 | accelerator: 'gpu' 44 | devices: 1 45 | max_epochs: 1000 46 | strategy: 'ddp' 47 | precision: 16 48 | resume_training: 49 | -------------------------------------------------------------------------------- /Test/Test_IrrNet.py: -------------------------------------------------------------------------------- 1 | from validation_utils import get_diffusion_model 2 | from Benchmark.IrradianceNet import ConvLSTM_patch, IrradianceNet 3 | import numpy as np 4 | from utils import save_pkl, open_pkl, get_full_images, get_full_coordinates, remap 5 | import torch 6 | import os 7 | import sys 8 | from yaml import load, Loader 9 | 10 | print(sys.argv) 11 | start = int(sys.argv[1]) 12 | end = int(sys.argv[2]) 13 | test_name = sys.argv[3] 14 | model_config_path = sys.argv[4] 15 | model_path = sys.argv[5] 16 | 17 | def interpolate_yhat(yhat): 18 | yhat = yhat.detach() 19 | yhat[:, 125:131] = np.nan 20 | yhat[:, :, 125:131] = np.nan 21 | 22 | # rows 23 | for t in range(yhat.shape[0]): 24 | row_start_vals = yhat[t][124] 25 | row_end_vals = yhat[t][131] 26 | diff_interpolate = (row_start_vals - row_end_vals) / 7 27 | diff_interpolate = diff_interpolate.unsqueeze(0) # .repeat(6, 1) 28 | diff_interpolate = diff_interpolate.repeat(6, 1) 29 | vals = np.arange(1, 7) 30 | vals = vals[np.newaxis, :] 31 | vals = np.repeat(vals, diff_interpolate.shape[1], axis=0) 32 | 33 | interpol_values = diff_interpolate.detach() * vals.T 34 | interpol_values = row_start_vals.unsqueeze(0).repeat(6, 1) - interpol_values 35 | yhat[t, 125:131] = interpol_values 36 | 37 | col_start_vals = yhat[t][:, 124] 38 | col_end_vals = yhat[t][:, 131] 39 | diff_interpolate = (col_start_vals - col_end_vals) / 7 40 | diff_interpolate = diff_interpolate.unsqueeze(0) # .repeat(6, 1) 41 | diff_interpolate = diff_interpolate.repeat(6, 1) 42 | vals = np.arange(1, 7) 43 | vals = vals[np.newaxis, :] 44 | vals = np.repeat(vals, diff_interpolate.shape[1], axis=0) 45 | 46 | interpol_values = diff_interpolate.detach() * vals.T 47 | interpol_values = col_start_vals.unsqueeze(0).repeat(6, 1) - interpol_values 48 | yhat[t, :, 125:131] = interpol_values.T 49 | return yhat 50 | 51 | def main(): 52 | nowcast_net = ConvLSTM_patch(in_chan=1, image_size=128, device='cuda', seq_len=8) 53 | irradiance_net = IrradianceNet(nowcast_net, 54 | opt_patience=5).to('cuda') 55 | 56 | checkpoint = torch.load(model_path) 57 | irradiance_net.load_state_dict(checkpoint['state_dict']) 58 | model_config = open_pkl(model_config_path) 59 | model_id = model_config['ID'] 60 | 61 | with open('/scratch/snx3000/acarpent/Test_Results/{}/config.yml'.format(test_name), 62 | 'r') as o: 63 | test_config = load(o, Loader) 64 | 65 | data_path='/scratch/snx3000/acarpent/HelioMontDataset/{}/KI/'.format(test_config['dataset_name']) 66 | n_ens = test_config['n_ens'] 67 | ddim_steps = test_config['ddim_steps'] 68 | x_max = test_config['x_max'] 69 | x_min = test_config['x_min'] 70 | y_max = test_config['y_max'] 71 | y_min = test_config['y_min'] 72 | patches_idx = test_config['patches_idx'] 73 | 74 | date_idx_dict = open_pkl('/scratch/snx3000/acarpent/Test_Results/{}/Test_date_idx.pkl'.format(test_name)) 75 | test_days = list(date_idx_dict.keys()) 76 | print(test_days) 77 | forecast_dict = {} 78 | 79 | for date in test_days[start:end]: 80 | full_maps, idx_lst, t = get_full_images(date, data_path=data_path, patches_idx=patches_idx) 81 | # idx = np.random.choice(list(idx_lst), replace=False, size=n_per_day) 82 | idx = date_idx_dict[date] 83 | print(date) 84 | for i in idx: 85 | x = torch.Tensor(full_maps[i:i+4, y_min:y_max, x_min:x_max]) 86 | y = torch.Tensor(full_maps[i+4:i+12, y_min:y_max, x_min:x_max]) 87 | x,y = x.reshape(1,1,*x.shape).to('cuda'), y.reshape(1,1,*y.shape).to('cuda') 88 | 89 | yhat = torch.zeros((8, 256, 256)) 90 | for x_i in [0, 128]: 91 | for y_i in [0, 128]: 92 | yhat[:, x_i:x_i+128, y_i:y_i+128] = irradiance_net(x[:, :, :, x_i:x_i+128, y_i:y_i+128]).detach() 93 | 94 | yhat = interpolate_yhat(yhat) 95 | yhat[yhat<-1] = -1 96 | yhat[yhat>1] = 1 97 | forecast_dict[t[i]] = np.array(yhat.numpy()).astype(np.float32) 98 | save_pkl('/scratch/snx3000/acarpent/Test_Results/{}/{}-forecast_dict_{}.pkl'.format(test_name, model_id, date), forecast_dict) 99 | forecast_dict = {} 100 | print('###################################### SAVED ######################################') 101 | 102 | if __name__ == '__main__': 103 | main() -------------------------------------------------------------------------------- /Test/Test_SHADECast.py: -------------------------------------------------------------------------------- 1 | from validation_utils import get_diffusion_model 2 | from SHADECast.Models.Sampler.PLMS import PLMSSampler 3 | import numpy as np 4 | from utils import save_pkl, open_pkl, get_full_images, get_full_coordinates, remap 5 | import torch 6 | import os 7 | import sys 8 | from yaml import load, Loader 9 | 10 | print(sys.argv) 11 | start = int(sys.argv[1]) 12 | end = int(sys.argv[2]) 13 | test_name = sys.argv[3] 14 | model_config_path = sys.argv[4] 15 | model_path = sys.argv[5] 16 | 17 | def main(): 18 | ldm, model_config = get_diffusion_model(model_config_path, 19 | model_path) 20 | model_id = model_config['ID'] 21 | 22 | with open('/scratch/snx3000/acarpent/Test_Results/{}/config.yml'.format(test_name), 23 | 'r') as o: 24 | test_config = load(o, Loader) 25 | 26 | data_path='/scratch/snx3000/acarpent/HelioMontDataset/{}/KI/'.format(test_config['dataset_name']) 27 | n_ens = test_config['n_ens'] 28 | ddim_steps = test_config['ddim_steps'] 29 | x_max = test_config['x_max'] 30 | x_min = test_config['x_min'] 31 | y_max = test_config['y_max'] 32 | y_min = test_config['y_min'] 33 | patches_idx = test_config['patches_idx'] 34 | 35 | date_idx_dict = open_pkl('/scratch/snx3000/acarpent/Test_Results/{}/Test_date_idx.pkl'.format(test_name)) 36 | test_days = list(date_idx_dict.keys()) 37 | print(test_days) 38 | ldm = ldm.to('cuda') 39 | sampler = PLMSSampler(ldm, verbose=False) 40 | forecast_dict = {} 41 | 42 | print('Testing started') 43 | 44 | for date in test_days[start:end]: 45 | full_maps, idx_lst, t = get_full_images(date, data_path=data_path, patches_idx=patches_idx) 46 | # idx = np.random.choice(list(idx_lst), replace=False, size=n_per_day) 47 | idx = date_idx_dict[date] 48 | print(date) 49 | for i in idx: 50 | x = torch.Tensor(full_maps[i:i+4, y_min:y_max, x_min:x_max]) 51 | y = torch.Tensor(full_maps[i+4:i+12, y_min:y_max, x_min:x_max]) 52 | x,y = x.reshape(1,1,*x.shape).to('cuda'), y.reshape(1,1,*y.shape).to('cuda') 53 | # enc_x, _ = ldm.autoencoder.encode(x) 54 | enc_y, _ = ldm.autoencoder.encode(y) 55 | x = torch.cat([x for _ in range(n_ens)]).to('cuda') 56 | cond = ldm.context_encoder(x) 57 | samples_ddim, _ = sampler.sample(S=ddim_steps, 58 | conditioning=cond, 59 | batch_size=n_ens, 60 | shape=tuple(enc_y.shape[1:]), 61 | verbose=False, 62 | eta=0.) 63 | 64 | yhat = ldm.autoencoder.decode(samples_ddim.to('cuda')) 65 | yhat = yhat.to('cpu').detach().numpy()[:,0] 66 | yhat[yhat<-1] = -1 67 | yhat[yhat>1] = 1 68 | print(yhat.shape) 69 | # ens_members.append(yhat_) 70 | 71 | forecast_dict[t[i]] = np.array(yhat).astype(np.float32) 72 | save_pkl('/scratch/snx3000/acarpent/Test_Results/{}/{}_{}-ddim_forecast_dict_{}.pkl'.format(test_name, model_id, ddim_steps, date), forecast_dict) 73 | forecast_dict = {} 74 | print('###################################### SAVED ######################################') 75 | 76 | if __name__ == '__main__': 77 | main() -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import properscoring as ps 3 | from pysteps.verification.spatialscores import fss 4 | from pysteps.verification.detcatscores import det_cat_fct 5 | from pysteps.verification.probscores import reldiag_init, reldiag_accum 6 | from pysteps.postprocessing import ensemblestats 7 | from pysteps import verification 8 | import torch 9 | 10 | 11 | def compute_picp_pinaw( 12 | yhat_map, 13 | y_map, 14 | sample_size, 15 | ci=0.9): 16 | """ 17 | :param yhat_map: (n_ens, m, n) 18 | :param y_map: (m, n) 19 | :param sample_size: int 20 | :param ci: float 21 | :return: float 22 | """ 23 | nan_mask = ~np.isnan(yhat_map).any(axis=0) 24 | 25 | if sample_size == 'max': 26 | s = np.sum(nan_mask) 27 | else: 28 | if sample_size>np.sum(nan_mask): 29 | s = np.sum(nan_mask) 30 | else: 31 | s = sample_size 32 | 33 | sample_idx = np.random.choice(np.sum(nan_mask), 34 | replace=False, 35 | size=s) 36 | yhat_ = yhat_map[:, nan_mask][:, sample_idx] 37 | y_ = y_map[nan_mask][sample_idx] 38 | ub = np.nanquantile(yhat_, 1 - (1 - ci) / 2, axis=0) 39 | lb = np.nanquantile(yhat_, (1 - ci) / 2, axis=0) 40 | cond = (y_ <= ub) & (y_ >= lb) 41 | return np.sum(cond) / s, np.mean(ub - lb) 42 | 43 | 44 | def compute_CRPS(yhat_map, 45 | y_map): 46 | pred = yhat_map.reshape(yhat_map.shape[0], 47 | yhat_map.shape[1], 48 | -1).T 49 | 50 | obs = y_map.reshape(y_map.shape[0], 51 | -1).T 52 | crps = ps.crps_ensemble(obs, pred).T.reshape(y_map.shape) 53 | return crps 54 | 55 | 56 | def compute_fss(yhat_map, 57 | y_map, 58 | thresh, 59 | scale, 60 | inverse=None): 61 | if inverse is None: 62 | return fss(yhat_map, 63 | y_map, 64 | thr=thresh, 65 | scale=scale) 66 | else: 67 | return fss(inverse-yhat_map, 68 | inverse-y_map, 69 | thr=inverse-thresh, 70 | scale=scale) 71 | 72 | 73 | 74 | def compute_CSI(yhat_map, 75 | y_map, 76 | thresh, 77 | inverse=None): 78 | if inverse is None: 79 | return det_cat_fct(yhat_map, 80 | y_map, 81 | thr=thresh, 82 | scores='CSI')['CSI'] 83 | else: 84 | return det_cat_fct(inverse-yhat_map, 85 | inverse-y_map, 86 | thr=inverse-thresh, 87 | scores='CSI')['CSI'] 88 | 89 | 90 | def compute_rmse(yhat_map, 91 | y_map): 92 | return np.sqrt(np.nanmean((yhat_map - y_map) ** 2)) 93 | 94 | 95 | def compute_bias(yhat_map, 96 | y_map): 97 | diff = yhat_map - y_map 98 | return np.nanmean(diff), np.nanmax(diff), np.nanmin(diff) 99 | 100 | def compute_dist_distance(yhat, y, mmd, idx=None): 101 | # compute mmd distance for two sequences of images images 102 | if idx is not None: 103 | yhat = yhat[:, idx, idx] 104 | y = y[:, idx, idx] 105 | if not isinstance(yhat, torch.Tensor): 106 | yhat = torch.Tensor(yhat) 107 | y = torch.Tensor(y) 108 | mmd_lst = [] 109 | for yhat_, y_ in zip(yhat, y): 110 | mmd_lst.append(mmd(yhat_.view(-1,1), y_.view(-1,1)).detach().numpy()) 111 | return np.array(mmd_lst) 112 | 113 | def compute_ens_dist_distance(ens_yhat, y, mmd, idx=None): 114 | # compute mmd distance for an ensemble of forecasts 115 | d_lst = [] 116 | for yhat in ens_yhat: 117 | d = compute_dist_distance(yhat, y, mmd, idx) 118 | d_lst.append(d) 119 | return np.nanmean(d_lst, axis=0), np.nanstd(d_lst, axis=0) 120 | 121 | def compute_ensemble_metrics(yhat, 122 | real, 123 | metrics=['crps', 'picp-pinaw', 'rmse', 'csi-fss', 'mmd'], 124 | picp_sample_size=1000, 125 | confidence_interval=0.9, 126 | scale_lst=(1, 2, 4, 8, 16, 32, 64), 127 | threshold_lst=(0.3, 0.6, 0.9), 128 | inverse_lst=[1.2, None, None], 129 | mmd_idx=np.arange(0,128,2), 130 | mmd=None, 131 | rankhist_dict={}): 132 | result_dict = {} 133 | 134 | y = real.copy() 135 | y[np.isnan(yhat[0])] = np.nan 136 | # PICP and PINAW 137 | if 'picp-pinaw' in metrics: 138 | picp_pinaw = [compute_picp_pinaw(yhat[:, j], 139 | y[j], 140 | sample_size=picp_sample_size, 141 | ci=confidence_interval) for j in range(len(y))] 142 | picp = np.array(picp_pinaw)[:, 0] 143 | pinaw = np.array(picp_pinaw)[:, 1] 144 | result_dict['picp'] = picp 145 | result_dict['pinaw'] = pinaw 146 | 147 | if 'crps' in metrics: 148 | crps_maps = [compute_CRPS(yhat[:, j], 149 | y[j]) for j in range(len(y))] 150 | result_dict['crps_map'] = crps_maps 151 | result_dict['avg_crps'] = np.nanmean(crps_maps, axis=(1, 2)) 152 | 153 | if 'rmse' in metrics: 154 | rmse = np.sqrt(np.nanmean((np.nanmean(yhat, axis=0)-y)**2, axis=(1,2))) 155 | result_dict['rmse'] = np.array(rmse) 156 | 157 | if 'bias' in metrics: 158 | bias = np.array([compute_bias(np.nanmean(yhat[:, j], axis=0), 159 | y[j]) for j in range(len(y))]) 160 | result_dict['avg_bias'] = bias[:, 0] 161 | result_dict['max_bias'] = bias[:, 1] 162 | result_dict['min_bias'] = bias[:, 2] 163 | 164 | if 'csi' in metrics: 165 | csi_dict = {} 166 | for t,inv, in zip(threshold_lst, inverse_lst): 167 | csi_lst = [] 168 | for yhat_ in yhat: 169 | csi = np.array([compute_CSI(yhat_[j], 170 | y[j], 171 | t, 172 | inverse=inv) for j in range(len(y))]) 173 | csi_lst.append(csi) 174 | csi_dict[t] = (np.nanmean(csi_lst, axis=0), np.nanstd(csi_lst, axis=0)) 175 | result_dict['csi'] = csi_dict 176 | 177 | if 'fss' in metrics: 178 | fss_dict = {} 179 | for t,inv, in zip(threshold_lst, inverse_lst): 180 | fss_dict[t] = {} 181 | for scale in scale_lst: 182 | fss_lst = [] 183 | for yhat_ in yhat: 184 | fs_score = np.array([compute_fss(yhat_[j], 185 | y[j], 186 | t, 187 | inverse=inv, 188 | scale=scale) for j in range(len(y))]) 189 | fss_lst.append(fs_score) 190 | fss_dict[t][scale] = (np.nanmean(fss_lst, axis=0), np.nanstd(fss_lst, axis=0)) 191 | result_dict['fss'] = fss_dict 192 | 193 | if 'mmd' in metrics: 194 | mmd_loss = compute_ens_dist_distance(yhat, y, mmd, mmd_idx) 195 | result_dict['mmd'] = mmd_loss 196 | 197 | if 'rankhist' in metrics: 198 | for step in range(yhat.shape[1]): 199 | verification.rankhist_accum(rankhist_dict[step], yhat[:,step], y[step]) 200 | 201 | if 'spread-skill' in metrics: 202 | rmse = np.sqrt((np.nanmean(yhat, axis=0)-y)**2) 203 | sd = np.std(yhat, axis=0) 204 | skill = np.nanmean(rmse/sd, axis=(1,2)) 205 | result_dict['spread-skill'] = skill 206 | return result_dict 207 | 208 | def compute_rankhist(rankhist_dict, yhat, y): 209 | for step in range(yhat.shape[1]): 210 | verification.rankhist_accum(rankhist_dict[step], yhat[:,step], y[step]) 211 | 212 | def compute_det_metrics(yhat, 213 | y, 214 | scale_lst=(1, 2, 4, 8, 16, 32, 64), 215 | threshold_lst=(0.3, 0.6, 0.9)): 216 | result_dict = {} 217 | rmse = [compute_rmse(yhat[j], 218 | y[j]) for j in range(len(y))] 219 | bias = np.array([compute_bias(yhat[j], 220 | y[j]) for j in range(len(y))]) 221 | result_dict['rmse'] = np.array(rmse) 222 | result_dict['avg_bias'] = bias[:, 0] 223 | result_dict['max_bias'] = bias[:, 1] 224 | result_dict['min_bias'] = bias[:, 2] 225 | 226 | csi_dict = {} 227 | fss_dict = {} 228 | for t in threshold_lst: 229 | csi = np.array([compute_CSI(yhat[j], 230 | y[j], 231 | t) for j in range(len(y))]) 232 | csi_dict[t] = csi 233 | fss_dict[t] = {} 234 | for scale in scale_lst: 235 | fs_score = np.array([compute_fss(yhat[j], 236 | y[j], 237 | t, 238 | scale) for j in range(len(y))]) 239 | fss_dict[t][scale] = fs_score 240 | result_dict['csi'] = csi_dict 241 | result_dict['fss'] = fss_dict 242 | return result_dict 243 | 244 | 245 | def init_reldiagrams(thresh_lst): 246 | reldiag_dict = {} 247 | for t in thresh_lst: 248 | reldiag_dict[t] = reldiag_init(t) 249 | return reldiag_dict 250 | 251 | 252 | def accum_reldiagrams(yhat, 253 | y, 254 | reldiag_dict): 255 | for t in reldiag_dict: 256 | prob = ensemblestats.excprob(yhat, t, ignore_nan=True) 257 | reldiag_accum(reldiag_dict[t], prob, y) 258 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | matplotlib==3.8.2 3 | numpy==1.20.1 4 | properscoring==0.1 5 | pysteps==1.7.4 6 | pytorch_lightning==1.9.0 7 | PyYAML==6.0.1 8 | torch==1.13.1+cu116 9 | torchinfo==1.8.0 10 | tqdm==4.66.1 11 | pickleshare==0.7.5 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | import yaml 4 | import torch 5 | import torch.nn as nn 6 | 7 | ROOT_PATH = '/Users/cea3/Desktop/Projects/GenerativeModels/' 8 | 9 | def open_pkl(path: str): 10 | with open(path, 'rb') as o: 11 | pkl_file = pkl.load(o) 12 | return pkl_file 13 | 14 | 15 | def save_pkl(path: str, obj): 16 | with open(path, 'wb') as o: 17 | pkl.dump(obj, o) 18 | 19 | 20 | def open_yaml(path: str): 21 | with open(path) as o: 22 | yaml_file = yaml.load(o, Loader=yaml.FullLoader) 23 | return yaml_file 24 | 25 | 26 | def activation(act_type="swish"): 27 | act_dict = {"swish": nn.SiLU(), 28 | "gelu": nn.GELU(), 29 | "relu": nn.ReLU(), 30 | "tanh": nn.Tanh()} 31 | if act_type: 32 | if act_type in act_dict: 33 | return act_dict[act_type] 34 | else: 35 | raise NotImplementedError(act_type) 36 | elif not act_type: 37 | return nn.Identity() 38 | 39 | 40 | def normalization(channels, norm_type="group", num_groups=32): 41 | if norm_type == "batch": 42 | return nn.BatchNorm3d(channels) 43 | elif norm_type == "group": 44 | return nn.GroupNorm(num_groups=num_groups, num_channels=channels) 45 | elif (not norm_type) or (norm_type.lower() == 'none'): 46 | return nn.Identity() 47 | else: 48 | raise NotImplementedError(norm_type) 49 | 50 | 51 | def kl_from_standard_normal(mean, log_var): 52 | kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) 53 | return kl.mean() 54 | 55 | 56 | def sample_from_standard_normal(mean, log_var, num=None): 57 | std = (0.5 * log_var).exp() 58 | shape = mean.shape 59 | if num is not None: 60 | # expand channel 1 to create several samples 61 | shape = shape[:1] + (num,) + shape[1:] 62 | mean = mean[:, None, ...] 63 | std = std[:, None, ...] 64 | return mean + std * torch.randn(shape, device=mean.device) 65 | 66 | 67 | # def get_full_images(date, val_data_path, coordinate_data_path, n_patches=18): 68 | 69 | # patches = [] 70 | # times = [] 71 | # coords = [] 72 | # starting_idx_lst = [] 73 | # for i in range(n_patches): 74 | # patch = open_pkl(val_data_path+date+'_'+str(i)+'.pkl') 75 | # lat = open_pkl(coordinate_data_path+str(i)+'_lat.pkl') 76 | # lon = open_pkl(coordinate_data_path+str(i)+'_lon.pkl') 77 | # maps = 2 * ((patch['ki_maps'] - 0.05) / (1.2 - 0.05)) - 1 78 | # patches.append(maps) 79 | # t = 2 * ((patch['sza'] - 0) / (90 - 0)) - 1 80 | # times.append(t) 81 | # lon = 2 * ((lon - 0) / (90 - 0)) - 1 82 | # lat = 2 * ((lat - 0) / (90 - 0)) - 1 83 | # coords.append((lon, lat)) 84 | # starting_idx_lst.append(patch['starting_idx']) 85 | # common_starting_idx_lst = list(set.intersection(*map(set, starting_idx_lst))) 86 | # patches = np.array(patches) 87 | # patches = patches[:, 4:] 88 | # times = np.array(times) 89 | # times = np.nanmean(times[:, 4:], axis=0) 90 | 91 | # full_image = np.empty((patches.shape[1], 128*3, 128*6)) 92 | # full_lat = np.empty((128*3, 128*6)) 93 | # full_lon = np.empty((128*3, 128*6)) 94 | 95 | # k = 0 96 | # for i in range(3): 97 | # for j in range(6): 98 | # full_image[:, 128*i:128*(i+1), 128*j:128*(j+1)] = patches[k] 99 | # full_lat[128*i:128*(i+1), 128*j:128*(j+1)] = coords[k][1] 100 | # full_lon[128*i:128*(i+1), 128*j:128*(j+1)] = coords[k][0] 101 | # k += 1 102 | # return full_image, full_lat, full_lon, times, common_starting_idx_lst 103 | 104 | patch_dict = {0: ((0, 128), (0, 128)), 105 | 1: ((0, 128), (128, 256)), 106 | 2: ((0, 128), (256, 384)), 107 | 3: ((0, 128), (384, 512)), 108 | 4: ((0, 128), (512, 640)), 109 | 5: ((0, 128), (640, 768)), 110 | 6: ((128, 256), (0, 128)), 111 | 7: ((128, 256), (128, 256)), 112 | 8: ((128, 256), (256, 384)), 113 | 9: ((128, 256), (384, 512)), 114 | 10: ((128, 256), (512, 640)), 115 | 11: ((128, 256), (640, 768)), 116 | 12: ((256, 384), (0, 128)), 117 | 13: ((256, 384), (128, 256)), 118 | 14: ((256, 384), (256, 384)), 119 | 15: ((256, 384), (384, 512)), 120 | 16: ((256, 384), (512, 640)), 121 | 17: ((256, 384), (640, 768))} 122 | 123 | def get_full_images(date, 124 | data_path='/scratch/snx3000/acarpent/HelioMontDataset/TestSet/KI/', 125 | patches_idx=np.arange(18)): 126 | 127 | full_maps = np.empty((100, 128*3, 128*6))*np.nan 128 | patches_lst = [] 129 | starting_idx_lst = [] 130 | starting_idx_lst = set(np.arange(100)) 131 | 132 | for p in patches_idx: 133 | patch = open_pkl(data_path+date+'_'+str(p)+'.pkl') 134 | maps = 2 * ((patch['ki_maps'] - 0.05) / (1.2 - 0.05)) - 1 135 | full_maps[:len(maps), patch_dict[p][0][0]:patch_dict[p][0][1], 136 | patch_dict[p][1][0]:patch_dict[p][1][1]] = maps 137 | starting_idx_lst = starting_idx_lst.intersection(set(patch['starting_idx'])) 138 | 139 | 140 | time = patch['time'] 141 | full_maps = full_maps[:len(time)] 142 | x = ~np.isnan(full_maps).all(axis=(0, 2)) 143 | full_maps = full_maps[:, x] 144 | y = ~np.isnan(full_maps).all(axis=(0, 1)) 145 | full_maps = full_maps[:, :, y] 146 | 147 | return full_maps, starting_idx_lst, time 148 | 149 | def get_full_coordinates(data_path='/scratch/snx3000/acarpent/HelioMontDataset/CoordinateData/', 150 | patches_idx=np.arange(18), 151 | normalization=False): 152 | 153 | full_lat = np.empty((128*3, 128*6))*np.nan 154 | full_lon = np.empty((128*3, 128*6))*np.nan 155 | full_alt = np.empty((128*3, 128*6))*np.nan 156 | for p in patches_idx: 157 | lat = open_pkl(data_path+str(p)+'_lat.pkl') 158 | lon = open_pkl(data_path+str(p)+'_lon.pkl') 159 | alt = open_pkl(data_path+str(p)+'_alt.pkl') 160 | full_lat[patch_dict[p][0][0]:patch_dict[p][0][1], 161 | patch_dict[p][1][0]:patch_dict[p][1][1]] = lat 162 | full_lon[patch_dict[p][0][0]:patch_dict[p][0][1], 163 | patch_dict[p][1][0]:patch_dict[p][1][1]] = lon 164 | full_alt[patch_dict[p][0][0]:patch_dict[p][0][1], 165 | patch_dict[p][1][0]:patch_dict[p][1][1]] = alt 166 | 167 | x = ~np.isnan(full_lat).all(axis=(0)) 168 | full_lat = full_lat[:, x] 169 | full_lon = full_lon[:, x] 170 | full_alt = full_alt[:, x] 171 | 172 | y = ~np.isnan(full_lat).all(axis=(1)) 173 | full_lat = full_lat[y, :] 174 | full_lon = full_lon[y, :] 175 | full_alt = full_alt[y, :] 176 | 177 | if normalization: 178 | full_lon = 2 * ((full_lon - 0) / (90 - 0)) - 1 179 | full_lat = 2 * ((full_lat - 0) / (90 - 0)) - 1 180 | full_alt = 2 * ((full_alt - (-13)) / (4294 - 0)) - 1 181 | return full_lat, full_lon, full_alt 182 | 183 | 184 | def compute_prob(arr, thresh, mean=True): 185 | x = arr.copy() 186 | x[x=thresh] = 1 188 | if mean: 189 | return np.nanmean(x, axis=0) 190 | else: 191 | return x 192 | 193 | 194 | def remap(x, max_value=1.2, min_value=0.05): 195 | return ((x+1)/2)*(max_value-min_value) + min_value 196 | 197 | 198 | def nonparametric_cdf_transform(initial_array, target_array, alpha): 199 | # flatten the arrays 200 | arrayshape = initial_array.shape 201 | target_array = target_array.flatten() 202 | initial_array = initial_array.flatten() 203 | # extra_array = extra_array.flatten() 204 | 205 | # rank target values 206 | order = target_array.argsort() 207 | target_ranked = target_array[order] 208 | 209 | # rank initial values order 210 | orderin = initial_array.argsort() 211 | ranks = np.empty(len(initial_array), int) 212 | ranks[orderin] = np.arange(len(initial_array)) 213 | 214 | # # rank extra array 215 | orderex = initial_array.argsort() 216 | extra_ranked = initial_array[orderex] 217 | 218 | # get ranked values from target and rearrange with the initial order 219 | ranked = alpha*extra_ranked + (1-alpha)*target_ranked 220 | output_array = ranked[ranks] 221 | 222 | # reshape to the original array dimensions 223 | output_array = output_array.reshape(arrayshape) 224 | return output_array -------------------------------------------------------------------------------- /validation_utils.py: -------------------------------------------------------------------------------- 1 | from Dataset.dataset import KIDataset 2 | from torch.utils.data import DataLoader 3 | from SHADECast.Models.Nowcaster.Nowcast import AFNONowcastNetCascade, Nowcaster, AFNONowcastNet 4 | 5 | from SHADECast.Models.VAE.VariationalAutoEncoder import VAE, Encoder, Decoder 6 | from SHADECast.Models.UNet.UNet import UNetModel 7 | from SHADECast.Models.Diffusion.DiffusionModel import LatentDiffusion 8 | from utils import open_pkl, save_pkl 9 | 10 | def get_dataloader(data_path, 11 | coordinate_data_path, 12 | n=12, 13 | min=0.05, 14 | max=1.2, 15 | length=None, 16 | norm_method='rescaling', 17 | num_workers=24, 18 | batch_size=64, 19 | shuffle=True, 20 | validation=False): 21 | dataset = KIDataset(data_path=data_path, 22 | n=n, 23 | min=min, 24 | max=max, 25 | length=length, 26 | norm_method=norm_method, 27 | coordinate_data_path=coordinate_data_path, 28 | return_all=False, 29 | forecast=True, 30 | validation=validation) 31 | dataloader = DataLoader(dataset, 32 | num_workers=num_workers, 33 | batch_size=batch_size, 34 | shuffle=shuffle) 35 | return dataloader, dataset 36 | 37 | 38 | def get_diffusion_model(config_path, ldm_path): 39 | 40 | config = open_pkl(config_path) 41 | encoder_config = config['Encoder'] 42 | encoder = Encoder(in_dim=encoder_config['in_dim'], 43 | levels=encoder_config['levels'], 44 | min_ch=encoder_config['min_ch'], 45 | max_ch=encoder_config['max_ch']) 46 | print('Encoder built') 47 | 48 | decoder_config = config['Decoder'] 49 | decoder = Decoder(in_dim=decoder_config['in_dim'], 50 | levels=decoder_config['levels'], 51 | min_ch=decoder_config['min_ch'], 52 | max_ch=decoder_config['max_ch']) 53 | 54 | print('Decoder built') 55 | 56 | vae_config = config['VAE'] 57 | vae = VAE.load_from_checkpoint(vae_config['path'], 58 | encoder=encoder, decoder=decoder, 59 | opt_patience=5) 60 | 61 | print('VAE built') 62 | 63 | nowcaster_config = config['Nowcaster'] 64 | if nowcaster_config['path'] is None: 65 | nowcast_net = AFNONowcastNet(vae, 66 | train_autoenc=False, 67 | embed_dim=nowcaster_config['embed_dim'], 68 | embed_dim_out=nowcaster_config['embed_dim'], 69 | analysis_depth=nowcaster_config['analysis_depth'], 70 | forecast_depth=nowcaster_config['forecast_depth'], 71 | input_steps=nowcaster_config['input_steps'], 72 | output_steps=nowcaster_config['output_steps'], 73 | ) 74 | else: 75 | nowcast_net = AFNONowcastNet(vae, 76 | train_autoenc=False, 77 | embed_dim=nowcaster_config['embed_dim'], 78 | embed_dim_out=nowcaster_config['embed_dim'], 79 | analysis_depth=nowcaster_config['analysis_depth'], 80 | forecast_depth=nowcaster_config['forecast_depth'], 81 | input_steps=nowcaster_config['input_steps'], 82 | output_steps=nowcaster_config['output_steps'], 83 | ) 84 | nowcaster = Nowcaster.load_from_checkpoint(nowcaster_config['path'], nowcast_net=nowcast_net, 85 | opt_patience=nowcaster_config['opt_patience'], 86 | loss_type=nowcaster_config['loss_type']) 87 | nowcast_net = nowcaster.nowcast_net 88 | 89 | cascade_net = AFNONowcastNetCascade(nowcast_net=nowcast_net, 90 | cascade_depth=nowcaster_config['cascade_depth']) 91 | diffusion_config = config['Diffusion'] 92 | denoiser = UNetModel( 93 | in_channels=vae.hidden_width, 94 | model_channels=diffusion_config['model_channels'], 95 | out_channels=vae.hidden_width, 96 | num_res_blocks=diffusion_config['num_res_blocks'], 97 | attention_resolutions=diffusion_config['attention_resolutions'], 98 | dims=diffusion_config['dims'], 99 | channel_mult=diffusion_config['channel_mult'], 100 | num_heads=8, 101 | num_timesteps=2, 102 | context_ch=cascade_net.cascade_dims) 103 | 104 | ldm = LatentDiffusion.load_from_checkpoint(ldm_path, 105 | model=denoiser, 106 | autoencoder=vae, 107 | context_encoder=cascade_net, 108 | beta_schedule=diffusion_config['scheduler'], 109 | loss_type="l2", 110 | use_ema=diffusion_config['use_ema'], 111 | lr_warmup=0, 112 | linear_start=1e-4, 113 | linear_end=2e-2, 114 | cosine_s=8e-3, 115 | parameterization='eps', 116 | lr=diffusion_config['lr'], 117 | timesteps=diffusion_config['noise_steps'], 118 | opt_patience=diffusion_config['opt_patience'] 119 | ) 120 | return ldm, config --------------------------------------------------------------------------------