├── .gitignore ├── LICENSE ├── README.md ├── data └── MovingMNIST.py ├── images ├── epoch0_2500steps.png ├── epoch0_500steps.png └── mnist_gif.gif ├── main.py ├── models ├── ConvLSTMCell.py └── seq2seq_ConvLSTM.py └── utils └── start_tensorboard.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andreas Holm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video-Prediction-using-PyTorch 2 | ![Alt Text](/images/mnist_gif.gif) 3 | Repository for frame prediction on the MovingMNIST dataset using seq2seq ConvLSTM following either of these guides: 4 | 5 | [TDS](https://towardsdatascience.com/video-prediction-using-convlstm-with-pytorch-lightning-27b195fd21a2) 6 | [Github pages](https://holmdk.github.io/2020/04/02/video_prediction.html) 7 | 8 | ## Libraries 9 | Make sure you have the following libraries installed! 10 | 11 | ``` 12 | python=3.6.8 13 | torch=1.1.0 14 | torchvision=0.3.0 15 | pytorch-lightning=0.7.1 16 | matplotlib=3.1.3 17 | tensorboard=1.15.0a20190708 18 | ``` 19 | 20 | ## Getting started 21 | 1. Install the above libraries 22 | 23 | 2. Clone this repo 24 | 25 | ```bash 26 | git clone https://github.com/holmdk/Video-Prediction-using-PyTorch.git 27 | cd ./Video-Prediction-using-PyTorch 28 | ``` 29 | 30 | 3. Run main.py 31 | ```bash 32 | python main.py 33 | ``` 34 | 35 | 4. Navigate to http://localhost:6006/ for visualizing results 36 | 37 | 38 | ## Results 39 | The first row displays our predictions, the second row the ground truth and the third row the absolute error on a pixel-level. The first 8 columns are the input, followed by output in the final 8 columns. This matches the output from the Tensorboard logging. 40 | 41 | After some iterations, we notice that our model is actually generating images of all zeros! This is a common issue people using ConvLSTM reports, however, do not be discouraged! Simply keep training the model, and you should start to see actual and plausible future predictions. 42 | 43 | ### Initial results (500 steps) 44 | ![Initial](/images/epoch0_500steps.png) 45 | 46 | 47 | ### After half an epoch (2500 steps) 48 | Now, we are actually starting to see actual predictions, however blurry they might be. 49 | ![halfepoch](/images/epoch0_2500steps.png) 50 | 51 | ## Todo: 52 | - [ ] Add video of predictions by model 53 | - [ ] Implement other video prediction methods (feel free to contribute!) 54 | - [ ] SVG 55 | - [ ] PredRNN+ 56 | - [ ] E3D 57 | - [ ] MIM 58 | 59 | -------------------------------------------------------------------------------- /data/MovingMNIST.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import numpy as np 3 | from torchvision import datasets, transforms 4 | 5 | # from: https://github.com/edenton/svg/blob/master/data/moving_mnist.py 6 | 7 | class MovingMNIST(object): 8 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 9 | 10 | def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True): 11 | path = data_root 12 | self.seq_len = seq_len 13 | self.num_digits = num_digits 14 | self.image_size = image_size 15 | self.step_length = 0.1 16 | self.digit_size = 32 17 | self.deterministic = deterministic 18 | self.seed_is_set = False # multi threaded loading 19 | self.channels = 1 20 | 21 | self.data = datasets.MNIST( 22 | path, 23 | train=train, 24 | download=True, 25 | transform=transforms.Compose( 26 | [transforms.Scale(self.digit_size), 27 | transforms.ToTensor()])) 28 | 29 | self.N = len(self.data) 30 | 31 | def set_seed(self, seed): 32 | if not self.seed_is_set: 33 | self.seed_is_set = True 34 | np.random.seed(seed) 35 | 36 | def __len__(self): 37 | return self.N 38 | 39 | def __getitem__(self, index): 40 | self.set_seed(index) 41 | image_size = self.image_size 42 | digit_size = self.digit_size 43 | x = np.zeros((self.seq_len, 44 | image_size, 45 | image_size, 46 | self.channels), 47 | dtype=np.float32) 48 | for n in range(self.num_digits): 49 | idx = np.random.randint(self.N) 50 | digit, _ = self.data[idx] 51 | 52 | sx = np.random.randint(image_size - digit_size) 53 | sy = np.random.randint(image_size - digit_size) 54 | dx = np.random.randint(-4, 5) 55 | dy = np.random.randint(-4, 5) 56 | for t in range(self.seq_len): 57 | if sy < 0: 58 | sy = 0 59 | if self.deterministic: 60 | dy = -dy 61 | else: 62 | dy = np.random.randint(1, 5) 63 | dx = np.random.randint(-4, 5) 64 | elif sy >= image_size - 32: 65 | sy = image_size - 32 - 1 66 | if self.deterministic: 67 | dy = -dy 68 | else: 69 | dy = np.random.randint(-4, 0) 70 | dx = np.random.randint(-4, 5) 71 | 72 | if sx < 0: 73 | sx = 0 74 | if self.deterministic: 75 | dx = -dx 76 | else: 77 | dx = np.random.randint(1, 5) 78 | dy = np.random.randint(-4, 5) 79 | elif sx >= image_size - 32: 80 | sx = image_size - 32 - 1 81 | if self.deterministic: 82 | dx = -dx 83 | else: 84 | dx = np.random.randint(-4, 0) 85 | dy = np.random.randint(-4, 5) 86 | 87 | x[t, sy:sy + 32, sx:sx + 32, 0] += digit.numpy().squeeze() 88 | sy += dy 89 | sx += dx 90 | 91 | x[x > 1] = 1. 92 | return x 93 | 94 | 95 | -------------------------------------------------------------------------------- /images/epoch0_2500steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holmdk/Video-Prediction-using-PyTorch/edab22a5514be45e77a5bec793911a89d533d501/images/epoch0_2500steps.png -------------------------------------------------------------------------------- /images/epoch0_500steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holmdk/Video-Prediction-using-PyTorch/edab22a5514be45e77a5bec793911a89d533d501/images/epoch0_500steps.png -------------------------------------------------------------------------------- /images/mnist_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holmdk/Video-Prediction-using-PyTorch/edab22a5514be45e77a5bec793911a89d533d501/images/mnist_gif.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # import libraries 2 | import os 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import torchvision 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | import pytorch_lightning as pl 9 | from pytorch_lightning import Trainer 10 | from multiprocessing import Process 11 | 12 | from utils.start_tensorboard import run_tensorboard 13 | from models.seq2seq_ConvLSTM import EncoderDecoderConvLSTM 14 | from data.MovingMNIST import MovingMNIST 15 | 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 20 | parser.add_argument('--beta_1', type=float, default=0.9, help='decay rate 1') 21 | parser.add_argument('--beta_2', type=float, default=0.98, help='decay rate 2') 22 | parser.add_argument('--batch_size', default=12, type=int, help='batch size') 23 | parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for') 24 | parser.add_argument('--use_amp', default=False, type=bool, help='mixed-precision training') 25 | parser.add_argument('--n_gpus', type=int, default=1, help='number of GPUs') 26 | parser.add_argument('--n_hidden_dim', type=int, default=64, help='number of hidden dim for ConvLSTM layers') 27 | 28 | opt = parser.parse_args() 29 | 30 | 31 | ########################## 32 | ######### MODEL ########## 33 | ########################## 34 | 35 | class MovingMNISTLightning(pl.LightningModule): 36 | 37 | def __init__(self, hparams=None, model=None): 38 | super(MovingMNISTLightning, self).__init__() 39 | 40 | # default config 41 | self.path = os.getcwd() + '/data' 42 | self.model = model 43 | 44 | # logging config 45 | self.log_images = True 46 | 47 | # Training config 48 | self.criterion = torch.nn.MSELoss() 49 | self.batch_size = opt.batch_size 50 | self.n_steps_past = 10 51 | self.n_steps_ahead = 10 # 4 52 | 53 | def create_video(self, x, y_hat, y): 54 | # predictions with input for illustration purposes 55 | preds = torch.cat([x.cpu(), y_hat.unsqueeze(2).cpu()], dim=1)[0] 56 | 57 | # entire input and ground truth 58 | y_plot = torch.cat([x.cpu(), y.unsqueeze(2).cpu()], dim=1)[0] 59 | 60 | # error (l2 norm) plot between pred and ground truth 61 | difference = (torch.pow(y_hat[0] - y[0], 2)).detach().cpu() 62 | zeros = torch.zeros(difference.shape) 63 | difference_plot = torch.cat([zeros.cpu().unsqueeze(0), difference.unsqueeze(0).cpu()], dim=1)[ 64 | 0].unsqueeze(1) 65 | 66 | # concat all images 67 | final_image = torch.cat([preds, y_plot, difference_plot], dim=0) 68 | 69 | # make them into a single grid image file 70 | grid = torchvision.utils.make_grid(final_image, nrow=self.n_steps_past + self.n_steps_ahead) 71 | 72 | return grid 73 | 74 | def forward(self, x): 75 | x = x.to(device='cuda') 76 | 77 | output = self.model(x, future_seq=self.n_steps_ahead) 78 | 79 | return output 80 | 81 | def training_step(self, batch, batch_idx): 82 | x, y = batch[:, 0:self.n_steps_past, :, :, :], batch[:, self.n_steps_past:, :, :, :] 83 | x = x.permute(0, 1, 4, 2, 3) 84 | y = y.squeeze() 85 | 86 | y_hat = self.forward(x).squeeze() # is squeeze neccessary? 87 | 88 | loss = self.criterion(y_hat, y) 89 | 90 | # save learning_rate 91 | lr_saved = self.trainer.optimizers[0].param_groups[-1]['lr'] 92 | lr_saved = torch.scalar_tensor(lr_saved).cuda() 93 | 94 | # save predicted images every 250 global_step 95 | if self.log_images: 96 | if self.global_step % 250 == 0: 97 | final_image = self.create_video(x, y_hat, y) 98 | 99 | self.logger.experiment.add_image( 100 | 'epoch_' + str(self.current_epoch) + '_step' + str(self.global_step) + '_generated_images', 101 | final_image, 0) 102 | plt.close() 103 | 104 | tensorboard_logs = {'train_mse_loss': loss, 105 | 'learning_rate': lr_saved} 106 | 107 | return {'loss': loss, 'log': tensorboard_logs} 108 | 109 | 110 | def test_step(self, batch, batch_idx): 111 | # OPTIONAL 112 | x, y = batch 113 | y_hat = self.forward(x) 114 | return {'test_loss': self.criterion(y_hat, y)} 115 | 116 | 117 | def test_end(self, outputs): 118 | # OPTIONAL 119 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 120 | tensorboard_logs = {'test_loss': avg_loss} 121 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 122 | 123 | 124 | def configure_optimizers(self): 125 | return torch.optim.Adam(self.parameters(), lr=opt.lr, betas=(opt.beta_1, opt.beta_2)) 126 | 127 | @pl.data_loader 128 | def train_dataloader(self): 129 | train_data = MovingMNIST( 130 | train=True, 131 | data_root=self.path, 132 | seq_len=self.n_steps_past + self.n_steps_ahead, 133 | image_size=64, 134 | deterministic=True, 135 | num_digits=2) 136 | 137 | train_loader = torch.utils.data.DataLoader( 138 | dataset=train_data, 139 | batch_size=self.batch_size, 140 | shuffle=True) 141 | 142 | return train_loader 143 | 144 | @pl.data_loader 145 | def test_dataloader(self): 146 | test_data = MovingMNIST( 147 | train=False, 148 | data_root=self.path, 149 | seq_len=self.n_steps_past + self.n_steps_ahead, 150 | image_size=64, 151 | deterministic=True, 152 | num_digits=2) 153 | 154 | test_loader = torch.utils.data.DataLoader( 155 | dataset=test_data, 156 | batch_size=self.batch_size, 157 | shuffle=True) 158 | 159 | return test_loader 160 | 161 | 162 | 163 | def run_trainer(): 164 | conv_lstm_model = EncoderDecoderConvLSTM(nf=opt.n_hidden_dim, in_chan=1) 165 | 166 | model = MovingMNISTLightning(model=conv_lstm_model) 167 | 168 | trainer = Trainer(max_epochs=opt.epochs, 169 | gpus=opt.n_gpus, 170 | distributed_backend='dp', 171 | early_stop_callback=False, 172 | use_amp=opt.use_amp 173 | ) 174 | 175 | trainer.fit(model) 176 | 177 | 178 | if __name__ == '__main__': 179 | p1 = Process(target=run_trainer) # start trainer 180 | p1.start() 181 | p2 = Process(target=run_tensorboard(new_run=True)) # start tensorboard 182 | p2.start() 183 | p1.join() 184 | p2.join() 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /models/ConvLSTMCell.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ConvLSTMCell(nn.Module): 6 | 7 | def __init__(self, input_dim, hidden_dim, kernel_size, bias): 8 | """ 9 | Initialize ConvLSTM cell. 10 | 11 | Parameters 12 | ---------- 13 | input_dim: int 14 | Number of channels of input tensor. 15 | hidden_dim: int 16 | Number of channels of hidden state. 17 | kernel_size: (int, int) 18 | Size of the convolutional kernel. 19 | bias: bool 20 | Whether or not to add the bias. 21 | """ 22 | 23 | super(ConvLSTMCell, self).__init__() 24 | 25 | self.input_dim = input_dim 26 | self.hidden_dim = hidden_dim 27 | 28 | self.kernel_size = kernel_size 29 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 30 | self.bias = bias 31 | 32 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 33 | out_channels=4 * self.hidden_dim, 34 | kernel_size=self.kernel_size, 35 | padding=self.padding, 36 | bias=self.bias) 37 | 38 | def forward(self, input_tensor, cur_state): 39 | h_cur, c_cur = cur_state 40 | 41 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 42 | 43 | combined_conv = self.conv(combined) 44 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 45 | i = torch.sigmoid(cc_i) 46 | f = torch.sigmoid(cc_f) 47 | o = torch.sigmoid(cc_o) 48 | g = torch.tanh(cc_g) 49 | 50 | c_next = f * c_cur + i * g 51 | h_next = o * torch.tanh(c_next) 52 | 53 | return h_next, c_next 54 | 55 | def init_hidden(self, batch_size, image_size): 56 | height, width = image_size 57 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 58 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 59 | -------------------------------------------------------------------------------- /models/seq2seq_ConvLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.ConvLSTMCell import ConvLSTMCell 5 | 6 | class EncoderDecoderConvLSTM(nn.Module): 7 | def __init__(self, nf, in_chan): 8 | super(EncoderDecoderConvLSTM, self).__init__() 9 | 10 | """ ARCHITECTURE 11 | 12 | # Encoder (ConvLSTM) 13 | # Encoder Vector (final hidden state of encoder) 14 | # Decoder (ConvLSTM) - takes Encoder Vector as input 15 | # Decoder (3D CNN) - produces regression predictions for our model 16 | 17 | """ 18 | self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan, 19 | hidden_dim=nf, 20 | kernel_size=(3, 3), 21 | bias=True) 22 | 23 | self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf, 24 | hidden_dim=nf, 25 | kernel_size=(3, 3), 26 | bias=True) 27 | 28 | self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf, # nf + 1 29 | hidden_dim=nf, 30 | kernel_size=(3, 3), 31 | bias=True) 32 | 33 | self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf, 34 | hidden_dim=nf, 35 | kernel_size=(3, 3), 36 | bias=True) 37 | 38 | self.decoder_CNN = nn.Conv3d(in_channels=nf, 39 | out_channels=1, 40 | kernel_size=(1, 3, 3), 41 | padding=(0, 1, 1)) 42 | 43 | 44 | def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4): 45 | 46 | outputs = [] 47 | 48 | # encoder 49 | for t in range(seq_len): 50 | h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :], 51 | cur_state=[h_t, c_t]) # we could concat to provide skip conn here 52 | h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t, 53 | cur_state=[h_t2, c_t2]) # we could concat to provide skip conn here 54 | 55 | # encoder_vector 56 | encoder_vector = h_t2 57 | 58 | # decoder 59 | for t in range(future_step): 60 | h_t3, c_t3 = self.decoder_1_convlstm(input_tensor=encoder_vector, 61 | cur_state=[h_t3, c_t3]) # we could concat to provide skip conn here 62 | h_t4, c_t4 = self.decoder_2_convlstm(input_tensor=h_t3, 63 | cur_state=[h_t4, c_t4]) # we could concat to provide skip conn here 64 | encoder_vector = h_t4 65 | outputs += [h_t4] # predictions 66 | 67 | outputs = torch.stack(outputs, 1) 68 | outputs = outputs.permute(0, 2, 1, 3, 4) 69 | outputs = self.decoder_CNN(outputs) 70 | outputs = torch.nn.Sigmoid()(outputs) 71 | 72 | return outputs 73 | 74 | def forward(self, x, future_seq=0, hidden_state=None): 75 | 76 | """ 77 | Parameters 78 | ---------- 79 | input_tensor: 80 | 5-D Tensor of shape (b, t, c, h, w) # batch, time, channel, height, width 81 | """ 82 | 83 | # find size of different input dimensions 84 | b, seq_len, _, h, w = x.size() 85 | 86 | # initialize hidden states 87 | h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 88 | h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 89 | h_t3, c_t3 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 90 | h_t4, c_t4 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 91 | 92 | # autoencoder forward 93 | outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4) 94 | 95 | return outputs 96 | -------------------------------------------------------------------------------- /utils/start_tensorboard.py: -------------------------------------------------------------------------------- 1 | 2 | " START TENSORBOARD " 3 | import os, glob 4 | from tensorboard import program 5 | 6 | import time 7 | program.logger.setLevel('INFO') 8 | 9 | def run_tensorboard(new_run): 10 | 11 | path = os.getcwd() + '/lightning_logs/' 12 | 13 | try: 14 | newest_folder = max(glob.glob(os.path.join(path, '*/')), key=os.path.getmtime) 15 | version_number = newest_folder.split('\\')[-2].split('_')[1] 16 | 17 | if new_run: 18 | new_version_number = str(int(version_number) + 1) 19 | newest_folder = newest_folder.replace(version_number, new_version_number) 20 | 21 | version_number = new_version_number # for print purposes 22 | 23 | except ValueError: 24 | version_number = 0 25 | newest_folder = path + 'version_0' 26 | 27 | while not os.path.exists(newest_folder): 28 | time.sleep(1) 29 | 30 | tb = program.TensorBoard() 31 | tb.configure(argv=[None, '--logdir', newest_folder]) 32 | url = tb.launch() 33 | 34 | print('\n') 35 | print('-' * 20) 36 | print('Starting tensorboard at URL %s version %s' % (url, version_number)) 37 | print('-' * 20) 38 | print('\n') 39 | 40 | while True: # to keep script alive 41 | time.sleep(2) 42 | 43 | if __name__ == '__main__': 44 | run_tensorboard(new_run=False) --------------------------------------------------------------------------------