├── model ├── __pycache__ │ ├── seq2seq.cpython-39.pyc │ ├── Encode2Decode.cpython-39.pyc │ ├── convLSTM_cell.cpython-39.pyc │ └── sa_convLSTM_cell.cpython-39.pyc ├── convLSTM_cell.py ├── Encode2Decode.py ├── seq2seq.py └── sa_convLSTM_cell.py ├── utils ├── __pycache__ │ ├── utils.cpython-39.pyc │ └── earlystopping.cpython-39.pyc ├── utils.py └── earlystopping.py ├── data_loader ├── __pycache__ │ └── data_loader_MovingMNIST.cpython-39.pyc └── data_loader_MovingMNIST.py ├── README.md ├── Min_train.py └── Min_test.py /model/__pycache__/seq2seq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/model/__pycache__/seq2seq.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/Encode2Decode.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/model/__pycache__/Encode2Decode.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/convLSTM_cell.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/model/__pycache__/convLSTM_cell.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/earlystopping.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/utils/__pycache__/earlystopping.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/sa_convLSTM_cell.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/model/__pycache__/sa_convLSTM_cell.cpython-39.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/data_loader_MovingMNIST.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinNamgung/sa_convlstm/HEAD/data_loader/__pycache__/data_loader_MovingMNIST.cpython-39.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sa_convlstm 2 | Base-line model 3 | 4 | Contact: namgu007@umn.edu 5 | 6 | Implement SA-convLSTM from: Self-Attention ConvLSTM for Spatiotemporal Prediction (https://doi.org/10.1609/aaai.v34i07.6819) 7 | 8 | 9 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import compare_psnr as psnr_metric 3 | from skimage.measure import compare_ssim as ssim_metric 4 | 5 | 6 | def mse(gt, y_hat): 7 | mse = np.square(gt - y_hat).sum() 8 | return mse/10000 9 | -------------------------------------------------------------------------------- /utils/earlystopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.val_loss_min = np.Inf 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, val_loss, model): 30 | 31 | score = -val_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(val_loss, model) 36 | elif score < self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | else: 42 | self.best_score = score 43 | self.save_checkpoint(val_loss, model) 44 | self.counter = 0 45 | 46 | def save_checkpoint(self, val_loss, model): 47 | '''Saves model when validation loss decrease.''' 48 | if self.verbose: 49 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 50 | torch.save(model.state_dict(), self.path) 51 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /model/convLSTM_cell.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ConvLSTMCell(nn.Module): 6 | 7 | def __init__(self, in_channels, h_channels, kernel_size, bias=True): 8 | """ 9 | Initialize ConvLSTM cell. 10 | Parameters 11 | ---------- 12 | input_dim: int 13 | Number of channels of input tensor. 14 | hidden_dim: int 15 | Number of channels of hidden state. 16 | kernel_size: (int, int) 17 | Size of the convolutional kernel. 18 | bias: bool 19 | Whether or not to add the bias. 20 | """ 21 | 22 | super(ConvLSTMCell, self).__init__() 23 | 24 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | self.h_channels = h_channels 26 | padding = kernel_size[0] // 2, kernel_size[1] // 2 27 | 28 | self.conv = nn.Conv2d(in_channels=in_channels + h_channels, 29 | out_channels=4 * h_channels, 30 | kernel_size=kernel_size, 31 | padding=padding, 32 | bias=bias) 33 | 34 | def forward(self, input_tensor, cur_state): 35 | 36 | h_cur, c_cur = cur_state 37 | 38 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 39 | 40 | combined_conv = self.conv(combined) 41 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.h_channels, dim=1) 42 | i = torch.sigmoid(cc_i) 43 | f = torch.sigmoid(cc_f) 44 | o = torch.sigmoid(cc_o) 45 | g = torch.tanh(cc_g) 46 | 47 | c_next = f * c_cur + i * g 48 | h_next = o * torch.tanh(c_next) 49 | 50 | return h_next, c_next 51 | 52 | def init_hidden(self, batch_size, image_size): 53 | height, width = image_size 54 | 55 | return (torch.zeros(batch_size, self.h_channels, height, width).to(self.device), 56 | torch.zeros(batch_size, self.h_channels, height, width).to(self.device)) 57 | 58 | # return (torch.zeros(batch_size, self.h_channels, height, width, device=self.conv.weight.device), 59 | # torch.zeros(batch_size, self.h_channels, height, width, device=self.conv.weight.device)) 60 | 61 | -------------------------------------------------------------------------------- /model/Encode2Decode.py: -------------------------------------------------------------------------------- 1 | from model.sa_convLSTM_cell import SA_Convlstm_cell 2 | 3 | import torch 4 | import torch.nn as nn 5 | import random 6 | 7 | ### 8 | # 9 | # Author: Min Namgung 10 | # Contact: namgu007@umn.edu 11 | # 12 | # ### 13 | 14 | class Encode2Decode(nn.Module): # self-attention convlstm for spatiotemporal prediction model 15 | def __init__(self, params): 16 | super(Encode2Decode, self).__init__() 17 | # hyperparams 18 | self.batch_size = params['batch_size'] 19 | self.img_size = params['img_size'] 20 | self.cells, self.bns, self.decoderCells = [], [], [] 21 | self.n_layers = params['n_layers'] 22 | self.input_window_size = params['input_window_size'] 23 | self.output_window_size = params['output_dim'] 24 | 25 | # Written By Min 26 | self.img_encode = nn.Sequential( 27 | nn.Conv2d(in_channels=params['input_dim'], kernel_size=1, stride=1, padding=0, 28 | out_channels=params['hidden_dim']), 29 | nn.LeakyReLU(0.1), 30 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, 31 | out_channels=params['hidden_dim']), 32 | nn.LeakyReLU(0.1), 33 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, 34 | out_channels=params['hidden_dim']), 35 | nn.LeakyReLU(0.1) 36 | ) 37 | 38 | self.img_decode = nn.Sequential( 39 | nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1, 40 | out_channels=params['hidden_dim']), 41 | nn.LeakyReLU(0.1), 42 | nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1, 43 | out_channels=params['hidden_dim']), 44 | nn.LeakyReLU(0.1), 45 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=1, stride=1, padding=0, 46 | out_channels=params['input_dim']) 47 | ) 48 | 49 | 50 | for i in range(params['n_layers']): 51 | params['input_dim'] == params['hidden_dim'] if i == 0 else params['hidden_dim'] 52 | params['hidden_dim'] == params['hidden_dim'] 53 | self.cells.append(SA_Convlstm_cell(params)) 54 | self.bns.append(nn.LayerNorm((params['hidden_dim'], 16, 16))) # Use layernorm 55 | 56 | 57 | 58 | self.cells = nn.ModuleList(self.cells) 59 | 60 | self.bns = nn.ModuleList(self.bns) 61 | self.decoderCells = nn.ModuleList(self.decoderCells) 62 | 63 | # Linear 64 | self.decoder_predict = nn.Conv2d(in_channels=params['hidden_dim'], 65 | out_channels=1, 66 | kernel_size=(1, 1), 67 | padding=(0, 0)) 68 | 69 | def forward(self, x, y, teacher_forcing_rate=0.5, hidden=None): 70 | if hidden == None: 71 | hidden = self.init_hidden(batch_size=self.batch_size, img_size=self.img_size) 72 | 73 | b, seq_len, x_c, h, w = x.size() 74 | _, horizon, y_c, h, w = y.size() 75 | 76 | predict_temp_de = [] 77 | 78 | in_x = min(x_c, y_c) 79 | # lag_y = torch.cat([x[:, -1:, :in_x, :, :], y[:, :-1, :in_x, :, :]], dim=1) 80 | 81 | frames = torch.cat([x, y], dim=1) 82 | 83 | 84 | for t in range(19): 85 | 86 | if t < self.input_window_size or random.random() < teacher_forcing_rate: 87 | x = frames[:, t, :, :, :] 88 | else: 89 | x = out 90 | 91 | x = self.img_encode(x) 92 | 93 | for i, cell in enumerate(self.cells): 94 | 95 | if i == 0: 96 | out, hidden[i] = cell(x, hidden[i]) 97 | out = self.bns[i](out) 98 | 99 | else: 100 | out, hidden[i] = cell(x, hidden[i]) 101 | out = self.bns[i](out) 102 | 103 | # out = self.decoder_predict(out) 104 | out = self.img_decode(out) 105 | predict_temp_de.append(out) 106 | 107 | predict_temp_de = torch.stack(predict_temp_de, dim=1) 108 | 109 | predict_temp_de = predict_temp_de[:, 9:, :, :, :] 110 | 111 | return predict_temp_de 112 | 113 | 114 | def init_hidden(self, batch_size, img_size): 115 | states = [] 116 | for i in range(self.n_layers): 117 | states.append(self.cells[i].init_hidden(batch_size, img_size)) 118 | 119 | return states -------------------------------------------------------------------------------- /model/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | 5 | from model.convLSTM_cell import ConvLSTMCell 6 | 7 | 8 | ### 9 | # 10 | # Author: Min Namgung 11 | # Contact: namgu007@umn.edu 12 | # 13 | # ### 14 | 15 | class EncoderDecoderConvLSTM(nn.Module): 16 | def __init__(self, params): 17 | super(EncoderDecoderConvLSTM, self).__init__() 18 | 19 | self.in_chan = params['hidden_dim'] 20 | self.h_chan = params['hidden_dim'] 21 | self.out_chan = params['hidden_dim'] 22 | self.input_window_size = params['input_window_size'] 23 | self.n_layers = params['n_layers'] 24 | self.img_size = params['img_size'] 25 | self.batch_size = params['batch_size'] 26 | self.device = params['device'] 27 | self.cells = [] 28 | self.h_t, self.c_t = [], [] 29 | 30 | for i in range(self.n_layers): 31 | self.cells.append(ConvLSTMCell(in_channels=self.in_chan, 32 | h_channels=self.h_chan, 33 | kernel_size=(3, 3), 34 | bias=True)) 35 | # self.bns.append(nn.LayerNorm((params['hidden_dim'], 16, 16))) # Use layernorm 36 | 37 | self.cells = nn.ModuleList(self.cells) 38 | 39 | self.img_encode = nn.Sequential( 40 | nn.Conv2d(in_channels=params['input_dim'], kernel_size=1, stride=1, padding=0, 41 | out_channels=params['hidden_dim']), 42 | nn.LeakyReLU(0.1), 43 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, 44 | out_channels=params['hidden_dim']), 45 | nn.LeakyReLU(0.1), 46 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, 47 | out_channels=params['hidden_dim']), 48 | nn.LeakyReLU(0.1) 49 | ) 50 | 51 | # Prediction layer 52 | self.img_decode = nn.Sequential( 53 | nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1, 54 | out_channels=params['hidden_dim']), 55 | nn.LeakyReLU(0.1), 56 | nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1, 57 | out_channels=params['hidden_dim']), 58 | nn.LeakyReLU(0.1), 59 | nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=1, stride=1, padding=0, 60 | out_channels=params['input_dim']) 61 | ) 62 | 63 | # Linear 64 | self.decoder_predict = nn.Conv2d(in_channels=params['hidden_dim'], 65 | out_channels=params['hidden_dim'], 66 | kernel_size=(1, 1), 67 | padding=(0, 0)) 68 | 69 | def forward(self, x, y, teacher_forcing_rate=0.5, hidden=None): 70 | # Init hidden weight 71 | if hidden == None: 72 | hidden = self.init_hidden(batch_size=self.batch_size, img_size=self.img_size) 73 | 74 | 75 | b, seq_len, x_c, h, w = x.size() 76 | _, future_seq, y_c, _, _ = y.size() 77 | 78 | frames = torch.cat([x, y], dim=1) 79 | predict_temp_de = [] 80 | 81 | # Seq2seq 82 | for t in range(19): 83 | 84 | if t < self.input_window_size or random.random() < teacher_forcing_rate: 85 | x = frames[:, t, :, :, :] 86 | else: 87 | x = out 88 | 89 | x = self.img_encode(x) 90 | 91 | for i, cell in enumerate(self.cells): 92 | 93 | if i == 0: 94 | # hid = cell(input_tensor=x, cur_state=[hid[0], hid[1]]) 95 | hidden[i] = cell(input_tensor=x, cur_state=hidden[i]) 96 | out = self.decoder_predict(hidden[i][0]) 97 | 98 | else: 99 | hidden[i] = cell(input_tensor=x, cur_state=hidden[i]) 100 | out = self.decoder_predict(hidden[i][0]) 101 | 102 | out = self.img_decode(out) 103 | predict_temp_de.append(out) 104 | 105 | predict_temp_de = torch.stack(predict_temp_de, dim=1) 106 | final = predict_temp_de[:, 9:, :, :, :] 107 | 108 | return final 109 | 110 | 111 | def init_hidden(self, batch_size, img_size): 112 | states = [] 113 | for i in range(self.n_layers): 114 | states.append(self.cells[i].init_hidden(batch_size, img_size)) 115 | 116 | return states 117 | 118 | -------------------------------------------------------------------------------- /model/sa_convLSTM_cell.py: -------------------------------------------------------------------------------- 1 | # Github ref: https://github.com/hyona-yu/SA-convlstm/blob/main/convlstm_att_hnver2.py 2 | 3 | import torch 4 | import numpy as np 5 | from torch import optim 6 | import torch.nn as nn 7 | 8 | 9 | 10 | class self_attention_memory_module(nn.Module): # SAM 11 | def __init__(self, input_dim, hidden_dim, device): 12 | super().__init__() 13 | # h(hidden): layer q, k, v 14 | # m(memory): layer k2, v2 15 | # layer z, m are for layer after concat(attention_h, attention_m) 16 | 17 | # layer_q, k, v are for h (hidden) layer 18 | # Layer_ k2, v2 are for m (memory) layer 19 | # Layer_z, m are using after concatinating attention_h and attention_m layer 20 | 21 | self.layer_q = nn.Conv2d(input_dim, hidden_dim, 1) 22 | self.layer_k = nn.Conv2d(input_dim, hidden_dim, 1) 23 | self.layer_k2 = nn.Conv2d(input_dim, hidden_dim, 1) 24 | self.layer_v = nn.Conv2d(input_dim, input_dim, 1) 25 | self.layer_v2 = nn.Conv2d(input_dim, input_dim, 1) 26 | self.layer_z = nn.Conv2d(input_dim * 2, input_dim * 2, 1) 27 | self.layer_m = nn.Conv2d(input_dim * 3, input_dim * 3, 1) 28 | self.hidden_dim = hidden_dim 29 | self.input_dim = input_dim 30 | 31 | def forward(self, h, m): 32 | batch_size, channel, H, W = h.shape 33 | # feature aggregation 34 | ##### hidden h attention ##### 35 | K_h = self.layer_k(h) 36 | Q_h = self.layer_q(h) 37 | K_h = K_h.view(batch_size, self.hidden_dim, H * W) 38 | Q_h = Q_h.view(batch_size, self.hidden_dim, H * W) 39 | Q_h = Q_h.transpose(1, 2) 40 | 41 | A_h = torch.softmax(torch.bmm(Q_h, K_h), dim=-1) # batch_size, H*W, H*W 42 | 43 | V_h = self.layer_v(h) 44 | V_h = V_h.view(batch_size, self.input_dim, H * W) 45 | Z_h = torch.matmul(A_h, V_h.permute(0, 2, 1)) 46 | 47 | ###### memory m attention ##### 48 | K_m = self.layer_k2(m) 49 | V_m = self.layer_v2(m) 50 | K_m = K_m.view(batch_size, self.hidden_dim, H * W) 51 | V_m = V_m.view(batch_size, self.input_dim, H * W) 52 | A_m = torch.softmax(torch.bmm(Q_h, K_m), dim=-1) 53 | V_m = self.layer_v2(m) 54 | V_m = V_m.view(batch_size, self.input_dim, H * W) 55 | Z_m = torch.matmul(A_m, V_m.permute(0, 2, 1)) 56 | Z_h = Z_h.transpose(1, 2).view(batch_size, self.input_dim, H, W) 57 | Z_m = Z_m.transpose(1, 2).view(batch_size, self.input_dim, H, W) 58 | 59 | ### Z_h & Z_m (from attention) then, concat then computation #### 60 | W_z = torch.cat([Z_h, Z_m], dim=1) 61 | Z = self.layer_z(W_z) 62 | ## Memory Updating (Ref: SA-ConvLSTM) 63 | combined = self.layer_m(torch.cat([Z, h], dim=1)) # 3 * input_dim 64 | mo, mg, mi = torch.split(combined, self.input_dim, dim=1) 65 | ### (Ref: SA-ConvLSTM) 66 | mi = torch.sigmoid(mi) 67 | new_m = (1 - mi) * m + mi * torch.tanh(mg) 68 | new_h = torch.sigmoid(mo) * new_m 69 | 70 | return new_h, new_m 71 | 72 | 73 | 74 | class SA_Convlstm_cell(nn.Module): 75 | def __init__(self, params): 76 | super().__init__() 77 | # hyperparrams 78 | # self.input_channels = params['input_dim'] # Use Input_channle with the hidden_dim, due to Conv2d image encode increase hidden_dim from 1 to 64 79 | self.input_channels = params['hidden_dim'] 80 | self.hidden_dim = params['hidden_dim'] 81 | self.kernel_size = params['kernel_size'] 82 | self.padding = params['padding'] 83 | self.device = params['device'] 84 | self.attention_layer = self_attention_memory_module(params['hidden_dim'], params['att_hidden_dim'], self.device) # 32, 16 85 | self.conv2d = nn.Sequential( 86 | nn.Conv2d(in_channels=self.input_channels + self.hidden_dim, out_channels=4 * self.hidden_dim, 87 | kernel_size=self.kernel_size, padding=self.padding) 88 | , nn.GroupNorm(4 * self.hidden_dim, 4 * self.hidden_dim)) # (num_groups, num_channels) 89 | 90 | 91 | def forward(self, x, hidden): 92 | h, c, m = hidden 93 | 94 | combined = torch.cat([x, h], dim=1) # (batch_size, input_dim + hidden_dim, img_size[0], img_size[1]) 95 | 96 | combined_conv = self.conv2d(combined) # (batch_size, 4 * hidden_dim, img_size[0], img_size[1]) 97 | i, f, o, g = torch.split(combined_conv, self.hidden_dim, dim=1) 98 | i = torch.sigmoid(i) 99 | f = torch.sigmoid(f) 100 | o = torch.sigmoid(o) 101 | g = torch.tanh(g) 102 | c_next = f * c + i * g 103 | h_next = o * torch.tanh(c_next) 104 | #Finish typical Convlstm above in the forward() 105 | #Attention below 106 | h_next, m_next = self.attention_layer(h_next, m) 107 | 108 | return h_next, (h_next, c_next, m_next) 109 | 110 | def init_hidden(self, batch_size, img_size): # h, c, m initalize 111 | h, w = img_size 112 | 113 | return (torch.zeros(batch_size, self.hidden_dim, h, w).to(self.device), 114 | torch.zeros(batch_size, self.hidden_dim, h, w).to(self.device), 115 | torch.zeros(batch_size, self.hidden_dim, h, w).to(self.device)) -------------------------------------------------------------------------------- /Min_train.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import numpy as np 4 | from torch import optim 5 | import torch.nn as nn 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | import random 8 | from torch.utils.data import Subset 9 | 10 | from utils.earlystopping import EarlyStopping 11 | from data_loader.data_loader_MovingMNIST import MovingMNIST 12 | from model.Encode2Decode import Encode2Decode 13 | from model.seq2seq import EncoderDecoderConvLSTM 14 | 15 | 16 | def split_train_val(dataset): 17 | idx = [i for i in range(len(dataset))] 18 | 19 | random.seed(1234) 20 | random.shuffle(idx) 21 | 22 | num_train = int(0.8 * len(idx)) 23 | num_val = int(0.2 * len(idx)) 24 | 25 | train_idx = idx[:num_train] 26 | val_idx = idx[num_train:] 27 | 28 | train_dataset = Subset(dataset, train_idx) 29 | val_dataset = Subset(dataset, val_idx) 30 | 31 | print(f'train index: {len(train_idx)}') 32 | print(f'val index: {len(val_idx)}') 33 | 34 | return train_dataset, val_dataset 35 | 36 | 37 | def reshape_patch(images, patch_size): 38 | bs = images.size(0) 39 | nc = images.size(1) 40 | height = images.size(2) 41 | weight = images.size(3) 42 | x = images.reshape(bs, nc, int(height / patch_size), patch_size, int(weight / patch_size), patch_size) 43 | x = x.transpose(2, 5) 44 | x = x.transpose(4, 5) 45 | x = x.reshape(bs, nc * patch_size * patch_size, int(height / patch_size), int(weight / patch_size)) 46 | 47 | return x 48 | 49 | 50 | def reshape_patch_back(images, patch_size): 51 | bs = images.size(0) 52 | nc = int(images.size(1) / (patch_size * patch_size)) 53 | height = images.size(2) 54 | weight = images.size(3) 55 | x = images.reshape(bs, nc, patch_size, patch_size, height, weight) 56 | x = x.transpose(4, 5) 57 | x = x.transpose(2, 5) 58 | x = x.reshape(bs, nc, height * patch_size, weight * patch_size) 59 | 60 | return x 61 | 62 | 63 | class Model(): 64 | def __init__(self, params, loading_path = None, set_device=4): 65 | if params['model_cell'] == 'sa_convlstm': 66 | self.model = Encode2Decode(params).to(params['device']) 67 | else: # params['model_cell'] == 'convlstm': 68 | self.model = EncoderDecoderConvLSTM(params).to(params['device']) 69 | self.loss = params['loss'] 70 | if self.loss == 'SSIM': 71 | # self.criterion = SSIM().to(device) 72 | self.criterion = nn.MSELoss() 73 | elif self.loss == 'L2': 74 | self.criterion = nn.MSELoss() 75 | else: 76 | self.criterion = nn.L1Loss() 77 | self.output = params['output_dim'] 78 | self.device = params['device'] 79 | self.optim = optim.Adam(self.model.parameters(), lr=params['lr']) 80 | 81 | def train(self, train_dataset, valid_dataset, epochs, path): 82 | min_loss = 1e9 83 | n_total_steps = len(train_dataset) 84 | train_losses = [] 85 | 86 | 87 | early_stopping = EarlyStopping(patience=20, verbose=True) 88 | avg_train_losses = [] 89 | avg_valid_losses = [] 90 | 91 | for i in range(epochs): 92 | losses, val_losses = [], [] 93 | self.model.train() 94 | epoch_loss = 0.0 95 | val_epoch_loss = 0.0 96 | 97 | for ite, data in enumerate(train_dataset): 98 | x, y = data 99 | 100 | x_train = x.to(self.device) 101 | y_train = y.to(self.device) 102 | self.optim.zero_grad() 103 | pred_train = self.model(x_train, y_train) 104 | loss = self.criterion(pred_train, y_train) 105 | loss.backward() 106 | self.optim.step() 107 | 108 | train_losses.append(loss.item()) 109 | print(f'epoch {i + 1} / {epochs}, step {ite + 1}/{n_total_steps}, encode + decode loss = {loss.item():.4f}') 110 | 111 | epoch_loss += loss.item() 112 | 113 | with torch.no_grad(): 114 | self.model.eval() 115 | for _, data in enumerate(valid_dataset): 116 | x, y = data 117 | x_val = x.to(self.device) 118 | y_val = y.to(self.device) 119 | pred_val = self.model(x_val, y_val, teacher_forcing_rate = 0) 120 | loss = self.criterion(pred_val, y_val) 121 | val_losses.append(loss.item()) 122 | val_epoch_loss += loss.item() 123 | 124 | train_loss = np.average(train_losses) 125 | valid_loss = np.average(val_losses) 126 | avg_train_losses.append(train_loss) 127 | avg_valid_losses.append(valid_loss) 128 | 129 | print('{}th epochs train loss {}, valid loss {}'.format(i, np.mean(train_loss), np.mean(valid_loss))) 130 | 131 | torch.save(self.model.state_dict(), path) 132 | 133 | # Uncomment for applying early stopping 134 | early_stopping(valid_loss, self.model) 135 | if early_stopping.early_stop: 136 | print("Early stopping") 137 | break 138 | 139 | plt.plot(avg_train_losses, '-o') 140 | plt.plot(avg_valid_losses, '-o') 141 | plt.xlabel('epoch') 142 | plt.ylabel('losses') 143 | plt.legend(['Train', 'Validation']) 144 | plt.title('(MSE) Avg Train vs Validation Losses') 145 | plt.savefig('./results/npy_file_save/test_new_trainer.png') 146 | plt.clf() 147 | 148 | 149 | if __name__ == '__main__': 150 | 151 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 152 | 153 | root = '/panfs/roc/groups/14/yaoyi/namgu007/weather4cast-master/sa-convlstm-movingmnist' 154 | dataset = MovingMNIST(root, train=True) 155 | 156 | train_dataset, val_dataset = split_train_val(dataset) 157 | 158 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0) 159 | valid_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0) 160 | 161 | 162 | torch.manual_seed(42) 163 | BATCH_SIZE = 8 164 | img_size = (16, 16) 165 | new_size = (16, 16) 166 | strides = img_size 167 | input_window_size, output = 10, 10 168 | epochs = 2 169 | lr = 1e-3 170 | hid_dim = 64 # 16 171 | loss = 'L2' 172 | att_hid_dim = 64 173 | n_layers = 4 174 | bias = True 175 | 176 | # 1) sa-convlstm 177 | # 2) convlstm (else for now) 178 | model_name = 'sa_convlstm' 179 | 180 | params = {'input_dim': 1, 'batch_size': BATCH_SIZE, 'padding': 1, 'lr': lr, 'device': device, 181 | 'att_hidden_dim': att_hid_dim, 'kernel_size': 3, 'img_size': img_size, 'hidden_dim': hid_dim, 182 | 'n_layers': n_layers, 'output_dim': output, 'input_window_size': input_window_size, 'loss': loss, 183 | 'model_cell': model_name, 'bias': bias} 184 | 185 | print(f'Moving Mnist Image size (64 to 64) by processing reducing image to {img_size}') 186 | print('data has been loaded!') 187 | print('This is Train.py') 188 | print(f'Model name: {model_name}') 189 | 190 | model = Model(params) 191 | 192 | model.train(train_dataloader, valid_dataloader, epochs = epochs, path = './results/model_save/model_{}_{}to{}_BS{}_{}epochs_{}layers_{}atthid_{}loss_{}hid.pt'.format(model_name, input_window_size, output, BATCH_SIZE, epochs, n_layers, att_hid_dim, loss, hid_dim)) 193 | 194 | -------------------------------------------------------------------------------- /Min_test.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import numpy as np 5 | from torch import optim 6 | import torch.nn as nn 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | import random 9 | from torch.utils.data import Subset 10 | 11 | # from data_loader import MovingMNIST 12 | 13 | import utils.utils as utils 14 | from skimage.metrics import structural_similarity as ssim 15 | from data_loader.data_loader_MovingMNIST import MovingMNIST 16 | from model.Encode2Decode import Encode2Decode 17 | from model.seq2seq import EncoderDecoderConvLSTM 18 | 19 | 20 | def split_train_val(dataset): 21 | idx = [i for i in range(len(dataset))] 22 | # print(f'idx: {len(dataset)}') 23 | # idx = [i for i in range(30)] 24 | 25 | random.seed(1234) 26 | random.shuffle(idx) 27 | 28 | num_train = int(0.8 * len(idx)) 29 | num_val = int(0.2 * len(idx)) 30 | # num_test = int(0.2 * len(idx)) 31 | # num_val = int(0.2 * num_train) 32 | 33 | train_idx = idx[:num_train] 34 | # test_idx = idx[num_train: (num_train + num_test)] 35 | val_idx = idx[num_train:] 36 | 37 | train_dataset = Subset(dataset, train_idx) 38 | # test_dataset = Subset(dataset, test_idx) 39 | val_dataset = Subset(dataset, val_idx) 40 | 41 | return train_dataset, val_dataset 42 | 43 | 44 | def reshape_patch(images, patch_size): 45 | bs = images.size(0) 46 | nc = images.size(1) 47 | height = images.size(2) 48 | weight = images.size(3) 49 | x = images.reshape(bs, nc, int(height / patch_size), patch_size, int(weight / patch_size), patch_size) 50 | 51 | x = x.transpose(2, 5) 52 | x = x.transpose(4, 5) 53 | x = x.reshape(bs, nc * patch_size * patch_size, int(height / patch_size), int(weight / patch_size)) 54 | 55 | return x 56 | 57 | class Model(): 58 | def __init__(self, params, loading_path = None, set_device=4): 59 | if params['model_cell'] == 'sa_convlstm': 60 | self.model = Encode2Decode(params).to(params['device']) 61 | else: # params['model_cell'] == 'convlstm': 62 | self.model = EncoderDecoderConvLSTM(params).to(params['device']) 63 | self.loss = params['loss'] 64 | if self.loss == 'SSIM': 65 | # self.criterion = SSIM().to(device) 66 | self.criterion = nn.MSELoss() 67 | elif self.loss == 'L2': 68 | self.criterion = nn.MSELoss() 69 | else: 70 | self.criterion = nn.L1Loss() 71 | self.output = params['output_dim'] 72 | self.device = params['device'] 73 | self.optim = optim.Adam(self.model.parameters(), lr=params['lr']) 74 | 75 | def evaluate(self, train_dataset, path): 76 | prediction = [] 77 | test_groundtruth = [] 78 | 79 | self.model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) 80 | # self.model.load_state_dict(torch.load(path)) 81 | 82 | train_pred, train_gt = [], [] 83 | 84 | with torch.no_grad(): 85 | self.model.eval() 86 | for i, data in enumerate(train_dataset): 87 | x, y = data 88 | x = x.to(self.device) 89 | y = y.to(self.device) * 255.0 90 | pred = self.model(x, y) * 255.0 91 | train_pred.append(pred.cpu().data.numpy()) 92 | train_gt.append(y.cpu().data.numpy()) 93 | 94 | train_pred = np.concatenate(train_pred) 95 | train_gt = np.concatenate(train_gt) 96 | 97 | path_pred = './npy_file_save/train_pred_Feb20.npy' 98 | path_gt = './npy_file_save/train_gt_Feb20.npy' 99 | 100 | np.save(path_pred, train_pred) 101 | np.save(path_gt, train_gt) 102 | 103 | return prediction 104 | 105 | 106 | def evaluate_test(self, test_dataset, path): 107 | prediction = [] 108 | test_groundtruth = [] 109 | self.model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) 110 | # self.model.load_state_dict(torch.load(path)) 111 | 112 | 113 | test_pred, test_gt = [], [] 114 | 115 | with torch.no_grad(): 116 | self.model.eval() 117 | for i, data in enumerate(test_dataset): 118 | x, y = data 119 | x = x.to(self.device) 120 | y = y.to(self.device) 121 | pred = self.model(x, y, teacher_forcing_rate = 0) 122 | test_pred.append(pred.cpu().data.numpy()) 123 | test_gt.append(y.cpu().data.numpy()) 124 | 125 | test_pred = np.concatenate(test_pred) 126 | test_gt = np.concatenate(test_gt) 127 | 128 | mse = utils.mse(test_gt, test_pred) 129 | print('TEST Data loader - MSE = {:.6f}'.format(mse)) 130 | 131 | # Frame-wise comparison in MSE and SSIM 132 | 133 | overall_mse = 0 134 | overall_ssim = 0 135 | frame_mse = np.zeros(test_gt.shape[1]) 136 | frame_ssim = np.zeros(test_gt.shape[1]) 137 | 138 | for i in range(test_gt.shape[1]): 139 | for j in range(test_gt.shape[0]): 140 | 141 | mse_ = np.square(test_gt[j,i] - test_pred[j,i]).sum() 142 | test_gt_img = np.squeeze(test_gt[j,i]) 143 | test_pred_img = np.squeeze(test_pred[j,i]) 144 | ssim_ = ssim(test_gt_img, test_pred_img) 145 | 146 | overall_mse += mse_ 147 | overall_ssim += ssim_ 148 | frame_mse[i] += mse_ 149 | frame_ssim[i] += ssim_ 150 | 151 | overall_mse /= 10 152 | overall_ssim /= 10 153 | frame_mse /= 1000 154 | frame_ssim /= 1000 155 | print(f'overall_mse.shape {overall_mse}') 156 | print(f'overall_ssim.shape {overall_ssim}') 157 | print(f'frame_mse.shape {frame_mse}') 158 | print(f'frame_ssim.shape {frame_ssim}') 159 | 160 | 161 | 162 | path_pred = './results/npy_file_save/saconvlstm_test_pred_speedpt5.npy' 163 | path_gt = './results/npy_file_save/saconvlstm_test_gt_speedpt5.npy' 164 | 165 | np.save(path_pred, test_pred) 166 | np.save(path_gt, test_gt) 167 | 168 | return prediction 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | 174 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 175 | 176 | root = '/panfs/roc/groups/14/yaoyi/namgu007/weather4cast-master/sa-convlstm-movingmnist' 177 | dataset = MovingMNIST(root, train=False) # For test dataset 178 | 179 | dataset2 = MovingMNIST(root, train=True) # For train/val dataset 180 | train_dataset, val_dataset = split_train_val(dataset2) 181 | 182 | 183 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0) 184 | valid_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0) 185 | test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0) # Shuffle = False 186 | 187 | 188 | 189 | torch.manual_seed(42) 190 | BATCH_SIZE = 8 191 | img_size = (16, 16) 192 | new_size = (16, 16) 193 | strides = img_size 194 | input_window_size, output = 10, 10 195 | epochs = 1 196 | lr = 1e-3 197 | hid_dim = 16 # 16 198 | loss = 'L2' 199 | att_hid_dim = 64 200 | n_layers = 4 201 | bias = True 202 | 203 | # 1) sa_convlstm 204 | # 2) convlstm (else for now) 205 | model_name = 'sa_convlstm' 206 | 207 | params = {'input_dim': 1, 'batch_size': BATCH_SIZE, 'padding': 1, 'lr': lr, 'device': device, 208 | 'att_hidden_dim': att_hid_dim, 'kernel_size': 3, 'img_size': img_size, 'hidden_dim': hid_dim, 209 | 'n_layers': n_layers, 'output_dim': output, 'input_window_size': input_window_size, 'loss': loss, 210 | 'model_cell': model_name, 'bias': bias} 211 | 212 | print(f'Moving Mnist Image size (64 to 64) by processing reducing image to {img_size}') 213 | print('data has been loaded!') 214 | print('This is Test.py') 215 | print(f'Model name: {model_name}') 216 | 217 | model = Model(params) 218 | 219 | # path = 'weather4cast-master/sa-convlstm-movingmnist/reconstruction/results/model_save/model_sa_convlstm_10to10_BS8_1000epochs_4layers_64atthid_L2loss_64hid.pt' 220 | 221 | # path = 'checkpoint.pt' # SA-Convlstm 222 | path = 'model_sa_convlstm_10to10_BS8_200epochs_4layers_64atthid_L2loss_16hid.pt' 223 | prediction_test = model.evaluate_test(test_dataloader, path) -------------------------------------------------------------------------------- /data_loader/data_loader_MovingMNIST.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import numpy as np 8 | import torch 9 | import codecs 10 | from scipy.signal import convolve2d 11 | import random 12 | from torch.utils.data import Subset 13 | 14 | 15 | class MovingMNIST(data.Dataset): 16 | """`MovingMNIST `_ Dataset. 17 | Args: 18 | root (string): Root directory of dataset where ``processed/training.pt`` 19 | and ``processed/test.pt`` exist. 20 | train (bool, optional): If True, creates dataset from ``training.pt``, 21 | otherwise from ``test.pt``. 22 | split (int, optional): Train/test split size. Number defines how many samples 23 | belong to test set. 24 | download (bool, optional): If true, downloads the dataset from the internet and 25 | puts it in root directory. If dataset is already downloaded, it is not 26 | downloaded again. 27 | transform (callable, optional): A function/transform that takes in an PIL image 28 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 29 | target_transform (callable, optional): A function/transform that takes in an PIL 30 | image and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | """ 32 | urls = [ 33 | 'https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz' 34 | ] 35 | raw_folder = 'raw' 36 | processed_folder = 'processed' 37 | training_file = 'moving_mnist_train.pt' 38 | test_file = 'moving_mnist_test.pt' 39 | 40 | def __init__( 41 | self, 42 | root, 43 | train=True, 44 | split=1000, 45 | transform=None, 46 | target_transform=None, 47 | download=False, # You can set download=True for the first running 48 | seq_len=10, 49 | horizon=10, 50 | crop_size=None, 51 | downsample_size=None 52 | ): 53 | self.root = os.path.expanduser(root) 54 | self.transform = transform 55 | self.target_transform = target_transform 56 | self.split = split 57 | self.train = train # training set or test set 58 | self.seq_len = seq_len 59 | self.horizon = horizon 60 | self.crop_size = crop_size 61 | self.downsample_size = downsample_size 62 | if download: 63 | self.download() 64 | if not self._check_exists(): 65 | raise RuntimeError('Dataset not found.' + 66 | ' You can use download=True to download it') 67 | if self.train: 68 | self.train_data = torch.load( 69 | os.path.join(self.root, self.processed_folder, self.training_file)) 70 | else: 71 | self.test_data = torch.load( 72 | os.path.join(self.root, self.processed_folder, self.test_file)) 73 | 74 | def crop_image(self, img): 75 | """ img: [T, C, H, W]""" 76 | 77 | if self.crop_size is not None: 78 | return img[..., :self.crop_size, :self.crop_size] 79 | else: 80 | return img 81 | 82 | def downsample_image(self, img): 83 | """ img: [T, C, H, W]""" 84 | 85 | if self.downsample_size is not None: 86 | T, C, H, W = img.shape 87 | 88 | assert ( 89 | H // self.downsample_size != 0 or W // self.downsample_size != 0 90 | ), "downsampling rate cannot be divided by image size" 91 | 92 | # h = H // self.downsample_size 93 | # out = img.reshape(T, C, -1, self.downsample_size, h, 94 | # self.downsample_size).sum((-1, -3)) / self.downsample_size ** 2 95 | 96 | # return out 97 | return img[..., ::self.downsample_size, ::self.downsample_size] 98 | # kernel = np.ones((self.downsample_size, self.downsample_size)) 99 | # out = convolve2d(img, kernel, mode='valid') 100 | # return out[::self.downsample_size, ::self.downsample_size] / self.downsample_size ** 2 101 | else: 102 | return img 103 | 104 | def __getitem__(self, index): 105 | """ 106 | Args: 107 | index (int): Index 108 | Returns: 109 | tuple: (seq, target) where sampled sequences are splitted into a seq 110 | and target part 111 | """ 112 | 113 | # need to iterate over time 114 | def _transform_time(data): 115 | new_data = None 116 | for i in range(data.size(0)): 117 | img = Image.fromarray(data[i].numpy(), mode='L') 118 | new_data = self.transform(img) if new_data is None else torch.cat([self.transform(img), new_data], 119 | dim=0) 120 | return new_data 121 | 122 | if self.train: 123 | seq = self.train_data[index, :self.seq_len] 124 | target = self.train_data[index, self.seq_len: (self.seq_len + self.horizon)] 125 | else: 126 | seq = self.test_data[index, :self.seq_len] 127 | target = self.test_data[index, self.seq_len:(self.seq_len + self.horizon)] 128 | if self.transform is not None: 129 | seq = _transform_time(seq) 130 | if self.target_transform is not None: 131 | target = _transform_time(target) 132 | 133 | seq = seq.unsqueeze(1) # adding channel dimension 134 | target = target.unsqueeze(1) 135 | 136 | return seq / 255.0, target / 255.0 137 | 138 | def __len__(self): 139 | if self.train: 140 | return len(self.train_data) 141 | else: 142 | return len(self.test_data) 143 | 144 | def _check_exists(self): 145 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 146 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 147 | 148 | def download(self): 149 | """Download the Moving MNIST data if it doesn't exist in processed_folder already.""" 150 | from six.moves import urllib 151 | import gzip 152 | if self._check_exists(): 153 | return 154 | # download files 155 | try: 156 | os.makedirs(os.path.join(self.root, self.raw_folder)) 157 | os.makedirs(os.path.join(self.root, self.processed_folder)) 158 | except OSError as e: 159 | if e.errno == errno.EEXIST: 160 | pass 161 | else: 162 | raise 163 | for url in self.urls: 164 | print('Downloading ' + url) 165 | data = urllib.request.urlopen(url) 166 | filename = url.rpartition('/')[2] 167 | file_path = os.path.join(self.root, self.raw_folder, filename) 168 | with open(file_path, 'wb') as f: 169 | f.write(data.read()) 170 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 171 | gzip.GzipFile(file_path) as zip_f: 172 | out_f.write(zip_f.read()) 173 | os.unlink(file_path) 174 | # process and save as torch files 175 | print('Processing...') 176 | training_set = torch.from_numpy( 177 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split] 178 | ) 179 | test_set = torch.from_numpy( 180 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:] 181 | ) 182 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 183 | torch.save(training_set, f) 184 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 185 | torch.save(test_set, f) 186 | print('Done!') 187 | 188 | def __repr__(self): 189 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 190 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 191 | tmp = 'train' if self.train is True else 'test' 192 | fmt_str += ' Train/test: {}\n'.format(tmp) 193 | fmt_str += ' Root Location: {}\n'.format(self.root) 194 | tmp = ' Transforms (if any): ' 195 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 196 | tmp = ' Target Transforms (if any): ' 197 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 198 | return fmt_str --------------------------------------------------------------------------------