├── example.png ├── LICENSE ├── .gitignore ├── README.md ├── bouncing_mnist.py ├── test.py ├── generate_test_set.py ├── cnn.py ├── convgru.py ├── train_gru_predictor.py ├── convlstm.py └── train_lstm_predictor.py /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aserdega/convlstmgru/HEAD/example.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Andriy Serdega 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | #my stuff 7 | /outputs 8 | /logs 9 | /logs* 10 | /outputs* 11 | /~ 12 | *.jpg 13 | *.ipynb 14 | .DS_Store 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Codacy Badge](https://app.codacy.com/project/badge/Grade/8dd11ea5f7ea4e19b66f2c6072e97252)](https://www.codacy.com/manual/andriyserdega/convlstmgru?utm_source=github.com&utm_medium=referral&utm_content=aserdega/convlstmgru&utm_campaign=Badge_Grade) 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 3 | # ConvLSTM and ConvGRU | Pytorch 4 | **Implementation of ConvolutionalLSTM and ConvolutonalGRU in PyTorch** 5 | 6 | Inspired by [this](https://github.com/ndrplz/ConvLSTM_pytorch) repository but has been refactored and got new features such as peephole option and usage examples in implementations of video predicton seq-to-seq models on moving MNIST dataset. 7 | 8 | ## How to Use 9 | The `ConvLSTM` and `ConvGRU` modules are inherited from `torch.nn.Module`. 10 | 11 | ConvLSTM and ConvGRU wrappers allow using arbitrary number of layers. You can specify size of hidden dimension (number of channels) and the kernel size for each layer. In case of multiple layers and single kernel size it is being replicated for all the layers. For example, in the following snippet each of the three layers has same kernel size for each layer. 12 | 13 | Short usage example: 14 | ``` 15 | conv_lstm_encoder = ConvLSTM( 16 | input_size=(hidden_spt,hidden_spt), 17 | input_dim=hidden_dim, 18 | hidden_dim=lstm_dims, 19 | kernel_size=(3,3), 20 | num_layers=3, 21 | peephole=True, 22 | batchnorm=False, 23 | batch_first=True, 24 | activation=F.tanh 25 | ) 26 | 27 | hidden = conv_lstm_encoder.get_init_states(batch_size) 28 | output, encoder_state = conv_lstm_encoder(input, hidden) 29 | ``` 30 | 31 | # Project Structure 32 | ## Main Files 33 | - convlstm.py: contains main classes for ConvLSTMCell(represents one "layer") and ConvLSTM modules 34 | - convgru.py : same as for convlstm 35 | ## Other 36 | - train_gru_predictor.py and train_lstm_predictor.py: train video prediction models based on ConvGru and ConvLSTM respectively 37 | - cnn.py: file that contains simple convolutional networks for encoding and decoding frames representations 38 | - bouncing_mnist.py: contains dataloader that generates moving MNIST dataset from plain MNIST on a fly, use [this](https://www.dropbox.com/s/xt93tn9cstf85w1/mnist.h5?dl=0) raw MNIST dataset for reproducing the experiments. 39 | - generate_test_set.py: used to generate testing data for trained models 40 | - test.py: contains tester for trained models 41 | 42 | ## Prediction examples 43 | For every 3 rows, 1st represent previous frames that are fed to the model, 2nd represent predicted frames and 3rd represent GT future frames: 44 | 45 | ![Predictions](example.png) 46 | -------------------------------------------------------------------------------- /bouncing_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | import h5py 5 | import numpy as np 6 | 7 | 8 | class BouncingMnist(): 9 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 10 | def __init__(self, mnistDataset='mnist.h5', mode='standard', transform=None, background='zeros', num_frames=20, batch_size=1, image_size=64, num_digits=2, step_length=0.1): 11 | self.mode_ = mode 12 | self.background_ = background 13 | self.seq_length_ = num_frames 14 | self.batch_size_ = batch_size 15 | self.image_size_ = image_size 16 | self.num_digits_ = num_digits 17 | self.step_length_ = step_length 18 | self.dataset_size_ = 20000 # The dataset is really infinite. This is just for validation. 19 | self.digit_size_ = 28 20 | self.frame_size_ = self.image_size_ ** 2 21 | self.num_channels_ = 1 22 | self.transform_ = transform 23 | 24 | try: 25 | f = h5py.File(mnistDataset) 26 | except: 27 | print('Please set the correct path to MNIST dataset') 28 | sys.exit() 29 | 30 | self.data_ = f['train'][()].reshape(-1, 28, 28) 31 | #self.test = f['test'][()].reshape(-1, 28, 28) 32 | 33 | f.close() 34 | self.indices_ = np.arange(self.data_.shape[0]) 35 | self.row_ = 0 36 | np.random.shuffle(self.indices_) 37 | 38 | def __len__(self): 39 | return self.dataset_size_ 40 | 41 | def __getitem__(self, idx): 42 | item = self.get_batch()[0,:,:,:,:] 43 | item_t = self.transform_(item) 44 | return item_t 45 | 46 | def GetRandomTrajectory(self, batch_size): 47 | length = self.seq_length_ 48 | canvas_size = self.image_size_ - self.digit_size_ 49 | 50 | # Initial position uniform random inside the box. 51 | y = np.random.rand(batch_size) 52 | x = np.random.rand(batch_size) 53 | 54 | # Choose a random velocity. 55 | theta = np.random.rand(batch_size) * 2 * np.pi 56 | v_y = np.sin(theta) 57 | v_x = np.cos(theta) 58 | 59 | start_y = np.zeros((length, batch_size)) 60 | start_x = np.zeros((length, batch_size)) 61 | for i in range(length): 62 | # Take a step along velocity. 63 | y += v_y * self.step_length_ 64 | x += v_x * self.step_length_ 65 | 66 | # Bounce off edges. 67 | for j in range(batch_size): 68 | if x[j] <= 0: 69 | x[j] = 0 70 | v_x[j] = -v_x[j] 71 | if x[j] >= 1.0: 72 | x[j] = 1.0 73 | v_x[j] = -v_x[j] 74 | if y[j] <= 0: 75 | y[j] = 0 76 | v_y[j] = -v_y[j] 77 | if y[j] >= 1.0: 78 | y[j] = 1.0 79 | v_y[j] = -v_y[j] 80 | start_y[i, :] = y 81 | start_x[i, :] = x 82 | 83 | # Scale to the size of the canvas. 84 | start_y = (canvas_size * start_y).astype(np.int32) 85 | start_x = (canvas_size * start_x).astype(np.int32) 86 | return start_y, start_x 87 | 88 | def Overlap(self, a, b): 89 | return np.maximum(a, b) 90 | 91 | def get_batch(self, verbose=False): 92 | start_y, start_x = self.GetRandomTrajectory(self.batch_size_ * self.num_digits_) 93 | 94 | # minibatch data 95 | if self.background_ == 'zeros': 96 | data = np.zeros((self.batch_size_, self.num_channels_, self.image_size_, self.image_size_, self.seq_length_), dtype=np.float32) 97 | elif self.background_ == 'rand': 98 | data = np.random.rand(self.batch_size_, self.num_channels_, self.image_size_, self.image_size_, self.seq_length_) 99 | 100 | for j in range(self.batch_size_): 101 | for n in range(self.num_digits_): 102 | 103 | # get random digit from dataset 104 | ind = self.indices_[self.row_] 105 | self.row_ += 1 106 | if self.row_ == self.data_.shape[0]: 107 | self.row_ = 0 108 | np.random.shuffle(self.indices_) 109 | digit_image = self.data_[ind, :, :] 110 | digit_size = self.digit_size_ 111 | 112 | if self.mode_ == 'squares': 113 | digit_size = np.random.randint(5,20) 114 | digit_image = np.ones((digit_size, digit_size), dtype=np.float32) 115 | 116 | # generate video 117 | for i in range(self.seq_length_): 118 | top = start_y[i, j * self.num_digits_ + n] 119 | left = start_x[i, j * self.num_digits_ + n] 120 | bottom = top + digit_size 121 | right = left + digit_size 122 | data[j, :, top:bottom, left:right, i] = self.Overlap(data[j, :, top:bottom, left:right, i], digit_image) 123 | 124 | dum = np.moveaxis(data, -1, 1) 125 | return dum 126 | 127 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | import torchvision.utils 8 | import torchvision.datasets as dset 9 | import torchvision.transforms as transforms 10 | 11 | import numpy as np 12 | 13 | from convlstm import ConvLSTM 14 | from cnn import ConvEncoder, ConvDecoder 15 | 16 | from bouncing_mnist import BouncingMnist 17 | 18 | DATA_PATH = "./mnist_test/mnist_test_500.npy" 19 | b_size = 32 20 | 21 | CUDA1 = "cuda:1" 22 | 23 | class TesterMnist: 24 | def __init__(self, data_path="./mnist_test/mnist_test_500.npy", out_dir="./test_out"): 25 | self.dataset = np.load(data_path) 26 | self.size = self.dataset.shape[0] 27 | self.iter = 0 28 | self.out_dir = out_dir 29 | 30 | def run_test(self, cnn_encoder, cnn_decoder, lstm_encoder, lstm_decoder): 31 | """ 32 | cnn_encoder = cnn_encoder.eval() 33 | cnn_decoder = cnn_decoder.eval() 34 | lstm_encoder = lstm_encoder.eval() 35 | lstm_decoder = lstm_decoder.eval() 36 | """ 37 | 38 | hidden_dim = lstm_encoder.input_dim 39 | hidden_spt = lstm_encoder.height 40 | 41 | mse_hist = [] 42 | bce_hist = [] 43 | pnsr_hist = torch.zeros(10).cuda()#, device=CUDA1)#divide by number of batches!!! 44 | fr_mse_hist = torch.zeros(10).cuda()#, device=CUDA1)#divide by number of batches!!! 45 | 46 | i = 0 47 | batch_n = 1 48 | for sl in range(0,self.size - b_size,b_size): 49 | seqs = torch.from_numpy(self.dataset[sl:sl+b_size,:,:,:,:]) 50 | nextf_raw = seqs[:,10:,:,:,:].cuda() 51 | #----cnn encoder---- 52 | 53 | prevf_raw = seqs[:,:10,:,:,:].contiguous().view(-1,1,64,64).cuda() 54 | prevf_enc = cnn_encoder(prevf_raw).view(b_size,10,hidden_dim,hidden_spt,hidden_spt) 55 | 56 | #----lstm encoder--- 57 | 58 | hidden = lstm_encoder.get_init_states(b_size) 59 | _, encoder_state = lstm_encoder(prevf_enc, hidden) 60 | 61 | #----lstm decoder--- 62 | 63 | decoder_output_list = [] 64 | 65 | for s in range(10): 66 | if s == 0: 67 | decoder_out, decoder_state = lstm_decoder(prevf_enc[:,-1:,:,:,:], encoder_state) 68 | else: 69 | decoder_out, decoder_state = lstm_decoder(decoder_out, decoder_state) 70 | decoder_output_list.append(decoder_out) 71 | 72 | final_decoder_out = torch.cat(decoder_output_list, 1) 73 | 74 | #----cnn decoder---- 75 | 76 | decoder_out_rs = final_decoder_out.view(-1,hidden_dim,hidden_spt,hidden_spt) 77 | 78 | cnn_decoder_out_raw = F.sigmoid(cnn_decoder(decoder_out_rs).detach()) 79 | cnn_decoder_out = cnn_decoder_out_raw.view(b_size,10,1,64,64) 80 | 81 | #-----calculate mse and bce---------- 82 | slice_mse = F.mse_loss(cnn_decoder_out, nextf_raw).item() 83 | slice_bce = F.binary_cross_entropy(cnn_decoder_out, nextf_raw).item() 84 | 85 | mse_hist.append(slice_mse) 86 | bce_hist.append(slice_bce) 87 | 88 | cur_pnsr, cur_fr_mse = self.calculate_pnsr(cnn_decoder_out, nextf_raw) 89 | pnsr_hist += cur_pnsr 90 | fr_mse_hist += cur_fr_mse 91 | 92 | batch_n += 1 93 | 94 | pnsr_hist = pnsr_hist / batch_n 95 | fr_mse_hist = fr_mse_hist / batch_n 96 | 97 | mse_total = np.mean(mse_hist) 98 | bce_total = np.mean(bce_hist) 99 | 100 | l = [] 101 | for j in range(3): 102 | l.append(seqs[j,:10,:,:,:]) 103 | l.append(cnn_decoder_out[j,:,:,:,:].data.cpu()) 104 | l.append(seqs[j,10:,:,:,:]) 105 | samples = torch.cat(l) 106 | torchvision.utils.save_image(samples, 107 | self.out_dir + "/{0:0>5}iter.png".format(self.iter), nrow=10) 108 | self.iter += 1 109 | 110 | """ 111 | cnn_encoder = cnn_encoder.train() 112 | cnn_decoder = cnn_decoder.train() 113 | lstm_encoder = lstm_encoder.train() 114 | lstm_decoder = lstm_decoder.train() 115 | """ 116 | 117 | return mse_total, bce_total, pnsr_hist, fr_mse_hist 118 | 119 | #sum, not mean for mse, wrong 120 | def calculate_pnsr(self, pred, target): 121 | elw_mse = F.mse_loss(pred,target,size_average=False,reduce=False).squeeze() 122 | fr_mse = elw_mse.sum(-1).sum(-1).mean(0) 123 | pnsr = 10 * torch.log10(1 / elw_mse.mean(-1).mean(-1).mean(0)) 124 | return pnsr, fr_mse 125 | 126 | -------------------------------------------------------------------------------- /generate_test_set.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | import h5py 5 | import numpy as np 6 | 7 | 8 | class MovingMnistTestSetGenerator(): 9 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 10 | def __init__(self, mnistDataset='mnist.h5', mode='standard', background='zeros', num_frames=20, save_dir="./mnist_test", batch_size=1000, image_size=64, num_digits=2, step_length=0.1): 11 | self.mode_ = mode 12 | self.background_ = background 13 | self.seq_length_ = num_frames 14 | self.batch_size_ = batch_size 15 | self.image_size_ = image_size 16 | self.num_digits_ = num_digits 17 | self.step_length_ = step_length 18 | self.dataset_size_ = 20000 19 | self.digit_size_ = 28 20 | self.frame_size_ = self.image_size_ ** 2 21 | self.num_channels_ = 1 22 | self.save_dir = save_dir 23 | 24 | try: 25 | f = h5py.File(mnistDataset) 26 | except: 27 | print('Please set the correct path to MNIST dataset') 28 | sys.exit() 29 | 30 | self.data_ = f['test'][()].reshape(-1, 28, 28) 31 | 32 | f.close() 33 | 34 | self.indices_ = np.arange(self.data_.shape[0]) 35 | self.row_ = 0 36 | np.random.shuffle(self.indices_) 37 | 38 | def __len__(self): 39 | return self.dataset_size_ 40 | 41 | def GenerateTestSet(self): 42 | dataset = self.get_batch() 43 | 44 | np.save(self.save_dir + "/mnist_test_500.npy", dataset) 45 | """ 46 | np.save(self.save_dir + "/mnist_test1.npy", dataset[:2500,:,:,:,:]) 47 | np.save(self.save_dir + "/mnist_test2.npy", dataset[2500:5000,:,:,:,:]) 48 | np.save(self.save_dir + "/mnist_test3.npy", dataset[5000:7500,:,:,:,:]) 49 | np.save(self.save_dir + "/mnist_test4.npy", dataset[7500:,:,:,:,:]) 50 | """ 51 | return True 52 | 53 | def GetRandomTrajectory(self, batch_size): 54 | length = self.seq_length_ 55 | canvas_size = self.image_size_ - self.digit_size_ 56 | 57 | # Initial position uniform random inside the box. 58 | y = np.random.rand(batch_size) 59 | x = np.random.rand(batch_size) 60 | 61 | # Choose a random velocity. 62 | theta = np.random.rand(batch_size) * 2 * np.pi 63 | v_y = np.sin(theta) 64 | v_x = np.cos(theta) 65 | 66 | start_y = np.zeros((length, batch_size)) 67 | start_x = np.zeros((length, batch_size)) 68 | for i in range(length): 69 | # Take a step along velocity. 70 | y += v_y * self.step_length_ 71 | x += v_x * self.step_length_ 72 | 73 | # Bounce off edges. 74 | for j in range(batch_size): 75 | if x[j] <= 0: 76 | x[j] = 0 77 | v_x[j] = -v_x[j] 78 | if x[j] >= 1.0: 79 | x[j] = 1.0 80 | v_x[j] = -v_x[j] 81 | if y[j] <= 0: 82 | y[j] = 0 83 | v_y[j] = -v_y[j] 84 | if y[j] >= 1.0: 85 | y[j] = 1.0 86 | v_y[j] = -v_y[j] 87 | start_y[i, :] = y 88 | start_x[i, :] = x 89 | 90 | # Scale to the size of the canvas. 91 | start_y = (canvas_size * start_y).astype(np.int32) 92 | start_x = (canvas_size * start_x).astype(np.int32) 93 | return start_y, start_x 94 | 95 | def Overlap(self, a, b): 96 | return np.maximum(a, b) 97 | 98 | def get_batch(self, verbose=False): 99 | start_y, start_x = self.GetRandomTrajectory(self.batch_size_ * self.num_digits_) 100 | 101 | # minibatch data 102 | if self.background_ == 'zeros': 103 | data = np.zeros((self.batch_size_, self.num_channels_, self.image_size_, self.image_size_, self.seq_length_), dtype=np.float32) 104 | elif self.background_ == 'rand': 105 | data = np.random.rand(self.batch_size_, self.num_channels_, self.image_size_, self.image_size_, self.seq_length_) 106 | 107 | for j in range(self.batch_size_): 108 | for n in range(self.num_digits_): 109 | 110 | # get random digit from dataset 111 | ind = self.indices_[self.row_] 112 | self.row_ += 1 113 | if self.row_ == self.data_.shape[0]: 114 | self.row_ = 0 115 | np.random.shuffle(self.indices_) 116 | digit_image = self.data_[ind, :, :] 117 | digit_size = self.digit_size_ 118 | 119 | if self.mode_ == 'squares': 120 | digit_size = np.random.randint(5,20) 121 | digit_image = np.ones((digit_size, digit_size), dtype=np.float32) 122 | 123 | # generate video 124 | for i in range(self.seq_length_): 125 | top = start_y[i, j * self.num_digits_ + n] 126 | left = start_x[i, j * self.num_digits_ + n] 127 | bottom = top + digit_size 128 | right = left + digit_size 129 | data[j, :, top:bottom, left:right, i] = self.Overlap(data[j, :, top:bottom, left:right, i], digit_image) 130 | 131 | dum = np.moveaxis(data, -1, 1) 132 | return dum 133 | 134 | 135 | generator = MovingMnistTestSetGenerator() 136 | generator.GenerateTestSet() 137 | print('Done') 138 | 139 | -------------------------------------------------------------------------------- /cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ''' 6 | class ConvEncoder(nn.Module): 7 | def __init__(self, out_dim=64): 8 | super(ConvEncoder, self).__init__() 9 | self.cnn_encoder = nn.Sequential( 10 | nn.Conv2d(1,32,3,stride=1,padding=1), 11 | nn.BatchNorm2d(32), 12 | nn.ELU(), 13 | nn.Conv2d(32,64,3,stride=2,padding=1), 14 | nn.BatchNorm2d(64), 15 | nn.ELU(), 16 | nn.Conv2d(64,128,3,stride=2,padding=1), 17 | nn.BatchNorm2d(128), 18 | nn.ELU(), 19 | nn.Conv2d(128,out_dim,3,stride=2,padding=1), 20 | nn.Tanh() 21 | ) 22 | 23 | self.out_dim = out_dim 24 | 25 | def forward(self, input): 26 | output = self.cnn_encoder(input) 27 | return output 28 | 29 | 30 | class ConvDecoder(nn.Module): 31 | def __init__(self, b_size=32, inp_dim=64): 32 | super(ConvDecoder, self).__init__() 33 | self.dconv1 = nn.ConvTranspose2d(inp_dim,128,3,stride=2, padding=1) 34 | self.bn1 = nn.BatchNorm2d(128) 35 | self.dconv2 = nn.ConvTranspose2d(128,64,3,stride=2, padding=1) 36 | self.bn2 = nn.BatchNorm2d(64) 37 | self.dconv3 = nn.ConvTranspose2d(64,32,3,stride=2, padding=1) 38 | self.bn3 = nn.BatchNorm2d(32) 39 | self.dconv4 = nn.ConvTranspose2d(32,1,3,stride=1, padding=1) 40 | 41 | self.size1 = torch.Size([b_size * 10, 128, 16, 16]) 42 | self.size2 = torch.Size([b_size * 10, 64, 32, 32]) 43 | self.size3 = torch.Size([b_size * 10, 32, 64, 64]) 44 | 45 | self.inp_dim = inp_dim 46 | 47 | def forward(self, input): 48 | h1 = self.bn1(self.dconv1(input, self.size1)) 49 | a1 = F.elu(h1) 50 | h2 = self.bn2(self.dconv2(a1, self.size2)) 51 | a2 = F.elu(h2) 52 | h3 = self.bn3(self.dconv3(a2, self.size3)) 53 | a3 = F.elu(h3) 54 | h4 = self.dconv4(a3) 55 | return h4 56 | ''' 57 | 58 | class ConvEncoder(nn.Module): 59 | def __init__(self, out_dim=64): 60 | super(ConvEncoder, self).__init__() 61 | self.cnn_encoder = nn.Sequential( 62 | nn.Conv2d(1,16,3,stride=1,padding=1), 63 | nn.BatchNorm2d(16), 64 | nn.ELU(), 65 | nn.Conv2d(16,32,3,stride=2,padding=1), 66 | nn.BatchNorm2d(32), 67 | nn.ELU(), 68 | nn.Conv2d(32,64,3,stride=1,padding=1), 69 | nn.BatchNorm2d(64), 70 | nn.ELU(), 71 | nn.Conv2d(64,out_dim,3,stride=2,padding=1), 72 | nn.Tanh() 73 | ) 74 | 75 | self.out_dim = out_dim 76 | self.device = 'cpu' 77 | 78 | def cuda(self, device='cuda'): 79 | super(ConvEncoder, self).cuda(device) 80 | self.device = device 81 | 82 | def forward(self, input): 83 | output = self.cnn_encoder(input) 84 | return output 85 | 86 | 87 | class ConvDecoder(nn.Module): 88 | def __init__(self, b_size=32, inp_dim=64): 89 | super(ConvDecoder, self).__init__() 90 | self.dconv1 = nn.ConvTranspose2d(inp_dim,64,3,stride=2, padding=1) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.dconv2 = nn.ConvTranspose2d(64,32,3,stride=1, padding=1) 93 | self.bn2 = nn.BatchNorm2d(32) 94 | self.dconv3 = nn.ConvTranspose2d(32,16,3,stride=2, padding=1) 95 | self.bn3 = nn.BatchNorm2d(16) 96 | self.dconv4 = nn.ConvTranspose2d(16,1,3,stride=1, padding=1) 97 | 98 | self.size1 = torch.Size([b_size * 10, 64, 32, 32]) 99 | self.size2 = torch.Size([b_size * 10, 16, 64, 64]) 100 | 101 | self.inp_dim = inp_dim 102 | 103 | def forward(self, input): 104 | h1 = self.bn1(self.dconv1(input, self.size1)) 105 | a1 = F.elu(h1) 106 | h2 = self.bn2(self.dconv2(a1)) 107 | a2 = F.elu(h2) 108 | h3 = self.bn3(self.dconv3(a2, self.size2)) 109 | a3 = F.elu(h3) 110 | h4 = self.dconv4(a3) 111 | return h4 112 | 113 | class MCConvEncoder(nn.Module): 114 | def __init__(self, out_dim=256): 115 | super(MCConvEncoder, self).__init__() 116 | self.encoder = nn.Sequential( 117 | nn.Conv2d(1,64,5,stride=1,padding=2), 118 | nn.BatchNorm2d(64), 119 | nn.ELU(), 120 | nn.MaxPool2d(2,stride=2), 121 | nn.Conv2d(64,128,5,stride=1,padding=2), 122 | nn.BatchNorm2d(128), 123 | nn.ELU(), 124 | nn.MaxPool2d(2,stride=2), 125 | nn.Conv2d(128,out_dim,7,stride=1,padding=3), 126 | nn.Tanh(), 127 | nn.MaxPool2d(2,stride=2), 128 | ) 129 | 130 | self.out_dim = out_dim 131 | 132 | def forward(self, input): 133 | output = self.encoder(input) 134 | return output 135 | 136 | 137 | class MCConvDecoder(nn.Module): 138 | def __init__(self, b_size=32, inp_dim=256): 139 | super(MCConvDecoder, self).__init__() 140 | self.decoder = nn.Sequential( 141 | nn.Upsample(scale_factor=2), 142 | nn.ConvTranspose2d(inp_dim,128,7,stride=1,padding=3), 143 | nn.BatchNorm2d(128), 144 | nn.ELU(), 145 | nn.Upsample(scale_factor=2), 146 | nn.ConvTranspose2d(128,64,5,stride=1,padding=2), 147 | nn.BatchNorm2d(64), 148 | nn.ELU(), 149 | nn.Upsample(scale_factor=2), 150 | nn.ConvTranspose2d(64,1,5,stride=1,padding=2), 151 | ) 152 | 153 | self.inp_dim = inp_dim 154 | 155 | def forward(self, input): 156 | output = self.decoder(input) 157 | return output 158 | 159 | -------------------------------------------------------------------------------- /convgru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | class ConvGRUCell(nn.Module): 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True, activation=F.tanh, batchnorm=False): 9 | """ 10 | Initialize ConvGRU cell. 11 | Parameters 12 | ---------- 13 | input_size: (int, int) 14 | Height and width of input tensor as (height, width). 15 | input_dim: int 16 | Number of channels of input tensor. 17 | hidden_dim: int 18 | Number of channels of hidden state. 19 | kernel_size: (int, int) 20 | Size of the convolutional kernel. 21 | bias: bool 22 | Whether or not to add the bias. 23 | """ 24 | super(ConvGRUCell, self).__init__() 25 | 26 | self.height, self.width = input_size 27 | self.input_dim = input_dim 28 | self.hidden_dim = hidden_dim 29 | 30 | self.kernel_size = kernel_size 31 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 32 | self.bias = bias 33 | self.activation = activation 34 | self.batchnorm = batchnorm 35 | 36 | 37 | self.conv_zr = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 38 | out_channels=2 * self.hidden_dim, 39 | kernel_size=self.kernel_size, 40 | padding=self.padding, 41 | bias=self.bias) 42 | 43 | self.conv_h1 = nn.Conv2d(in_channels=self.input_dim, 44 | out_channels=self.hidden_dim, 45 | kernel_size=self.kernel_size, 46 | padding=self.padding, 47 | bias=self.bias) 48 | 49 | self.conv_h2 = nn.Conv2d(in_channels=self.hidden_dim, 50 | out_channels=self.hidden_dim, 51 | kernel_size=self.kernel_size, 52 | padding=self.padding, 53 | bias=self.bias) 54 | 55 | self.reset_parameters() 56 | 57 | def forward(self, input, h_prev): 58 | combined = torch.cat((input, h_prev), dim=1) # concatenate along channel axis 59 | 60 | combined_conv = F.sigmoid(self.conv_zr(combined)) 61 | 62 | z, r = torch.split(combined_conv, self.hidden_dim, dim=1) 63 | 64 | h_ = self.activation(self.conv_h1(input) + r * self.conv_h2(h_prev)) 65 | 66 | h_cur = (1 - z) * h_ + z * h_prev 67 | 68 | return h_cur 69 | 70 | def init_hidden(self, batch_size, cuda=True): 71 | state = torch.zeros(batch_size, self.hidden_dim, self.height, self.width) 72 | if cuda: 73 | state = state.cuda() 74 | return state 75 | 76 | def reset_parameters(self): 77 | #self.conv.reset_parameters() 78 | nn.init.xavier_uniform_(self.conv_zr.weight, gain=nn.init.calculate_gain('tanh')) 79 | self.conv_zr.bias.data.zero_() 80 | nn.init.xavier_uniform_(self.conv_h1.weight, gain=nn.init.calculate_gain('tanh')) 81 | self.conv_h1.bias.data.zero_() 82 | nn.init.xavier_uniform_(self.conv_h2.weight, gain=nn.init.calculate_gain('tanh')) 83 | self.conv_h2.bias.data.zero_() 84 | 85 | if self.batchnorm: 86 | self.bn1.reset_parameters() 87 | self.bn2.reset_parameters() 88 | 89 | 90 | class ConvGRU(nn.Module): 91 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True, bias=True, activation=F.tanh, batchnorm=False): 92 | super(ConvGRU, self).__init__() 93 | 94 | self._check_kernel_size_consistency(kernel_size) 95 | 96 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 97 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 98 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 99 | activation = self._extend_for_multilayer(activation, num_layers) 100 | 101 | if not len(kernel_size) == len(hidden_dim) == len(activation) == num_layers: 102 | raise ValueError('Inconsistent list length.') 103 | 104 | self.height, self.width = input_size 105 | 106 | self.input_dim = input_dim 107 | self.hidden_dim = hidden_dim 108 | self.kernel_size = kernel_size 109 | self.num_layers = num_layers 110 | self.batch_first = batch_first 111 | self.bias = bias 112 | 113 | cell_list = [] 114 | for i in range(0, self.num_layers): 115 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 116 | 117 | cell_list.append(ConvGRUCell(input_size=(self.height, self.width), 118 | input_dim=cur_input_dim, 119 | hidden_dim=self.hidden_dim[i], 120 | kernel_size=self.kernel_size[i], 121 | bias=self.bias, 122 | activation=activation[i], 123 | batchnorm=batchnorm)) 124 | 125 | self.cell_list = nn.ModuleList(cell_list) 126 | 127 | self.reset_parameters() 128 | 129 | def forward(self, input, hidden_state): 130 | """ 131 | 132 | Parameters 133 | ---------- 134 | input_tensor: 135 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 136 | hidden_state: 137 | Returns 138 | ------- 139 | last_state_list, layer_output 140 | """ 141 | cur_layer_input = torch.unbind(input, dim=int(self.batch_first)) 142 | 143 | if not hidden_state: 144 | hidden_state = self.get_init_states(cur_layer_input[0].size(0)) 145 | 146 | seq_len = len(cur_layer_input) 147 | 148 | layer_output_list = [] 149 | last_state_list = [] 150 | 151 | for layer_idx in range(self.num_layers): 152 | h = hidden_state[layer_idx] 153 | output_inner = [] 154 | for t in range(seq_len): 155 | h = self.cell_list[layer_idx](input=cur_layer_input[t], 156 | h_prev=h) 157 | output_inner.append(h) 158 | 159 | cur_layer_input = output_inner 160 | last_state_list.append(h) 161 | 162 | layer_output = torch.stack(output_inner, dim=int(self.batch_first)) 163 | 164 | return layer_output, last_state_list 165 | 166 | def reset_parameters(self): 167 | for c in self.cell_list: 168 | c.reset_parameters() 169 | 170 | def get_init_states(self, batch_size, cuda=True): 171 | init_states = [] 172 | for i in range(self.num_layers): 173 | init_states.append(self.cell_list[i].init_hidden(batch_size, cuda)) 174 | return init_states 175 | 176 | @staticmethod 177 | def _check_kernel_size_consistency(kernel_size): 178 | if not (isinstance(kernel_size, tuple) or (isinstance(kernel_size, list) 179 | and all([isinstance(elem, tuple) for elem in kernel_size]))): 180 | raise ValueError('`kernel_size` must be tuple or list of tuples') 181 | 182 | @staticmethod 183 | def _extend_for_multilayer(param, num_layers): 184 | if not isinstance(param, list): 185 | param = [param] * num_layers 186 | return param 187 | 188 | -------------------------------------------------------------------------------- /train_gru_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | import torchvision.utils 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | 14 | 15 | from convgru import ConvGRU 16 | from cnn import ConvEncoder, ConvDecoder 17 | 18 | import test 19 | from bouncing_mnist import BouncingMnist 20 | 21 | 22 | #-------parameters---- 23 | 24 | b_size = 32 25 | hidden_dim = 128 26 | 27 | lstm_dims = [128,128,128] 28 | 29 | #----------create some dirs--------- 30 | 31 | root_log_dir = './results_gru_' + str(lstm_dims)[1:-1].replace(', ','x') 32 | 33 | exists = True 34 | while exists: 35 | exists = os.path.exists(root_log_dir) 36 | if not exists: 37 | os.makedirs(root_log_dir) 38 | else: 39 | root_log_dir += '_' 40 | 41 | train_out_dir = root_log_dir + '/train' 42 | test_out_dir = root_log_dir + '/test' 43 | 44 | chkpnt_dir = root_log_dir + '/checkpoint' 45 | 46 | os.makedirs(train_out_dir) 47 | os.makedirs(test_out_dir) 48 | os.makedirs(chkpnt_dir) 49 | 50 | log_file = root_log_dir + '/out.log' 51 | f = open(log_file,"w+") 52 | f.close() 53 | 54 | #-------some methods--------- 55 | 56 | def save_model(step, train_bce, test_bce): 57 | new_dir = chkpnt_dir + "/{0:0>5}iter | test bce:{1:.4f} | train bce:{2:.4f}".format(step, test_bce, train_bce) 58 | os.makedirs(new_dir) 59 | 60 | torch.save(lstm_encoder.state_dict(), new_dir + '/lstm_encoder.msd') 61 | torch.save(lstm_decoder.state_dict(), new_dir + '/lstm_decoder.msd') 62 | torch.save(cnn_encoder.state_dict(), new_dir + '/cnn_encoder.msd') 63 | torch.save(cnn_decoder.state_dict(), new_dir + '/cnn_decoder.msd') 64 | 65 | def get_sample_prob(step): 66 | alpha = 2450#1150 67 | beta = 8000 68 | return alpha / (alpha + np.exp((step + beta) / alpha)) 69 | 70 | #----------data--------- 71 | 72 | transform = transforms.Compose([ 73 | transforms.Lambda(lambda x: torch.from_numpy(x)) 74 | ]) 75 | 76 | raw_data = BouncingMnist(transform=transform) 77 | dloader = torch.utils.data.DataLoader(raw_data, batch_size=b_size, 78 | shuffle=True, drop_last=True, num_workers=4) 79 | 80 | tester = test.TesterMnist(out_dir=test_out_dir) 81 | 82 | #--------model--------- 83 | 84 | cnn_encoder = ConvEncoder(hidden_dim) 85 | cnn_decoder = ConvDecoder(b_size=b_size,inp_dim=hidden_dim) 86 | 87 | cnn_encoder.cuda() 88 | cnn_decoder.cuda() 89 | 90 | lstm_encoder = ConvGRU( 91 | input_size=(16,16), 92 | input_dim=hidden_dim, 93 | hidden_dim=lstm_dims, 94 | kernel_size=(3,3), 95 | num_layers=3, 96 | batchnorm=False, 97 | batch_first=True, 98 | activation=F.tanh 99 | ) 100 | 101 | lstm_decoder = ConvGRU( 102 | input_size=(16,16), 103 | input_dim=hidden_dim, 104 | hidden_dim=lstm_dims, 105 | kernel_size=(3,3), 106 | num_layers=3, 107 | batchnorm=False, 108 | batch_first=True, 109 | activation=F.tanh 110 | ) 111 | 112 | lstm_encoder.cuda() 113 | lstm_decoder.cuda() 114 | 115 | 116 | sigmoid = nn.Sigmoid() 117 | crit = nn.BCELoss() 118 | crit.cuda() 119 | 120 | 121 | params = list(cnn_encoder.parameters()) + list(cnn_decoder.parameters()) + \ 122 | list(lstm_encoder.parameters()) + list(lstm_decoder.parameters()) 123 | 124 | p_optimizer = optim.Adam(params) 125 | 126 | 127 | 128 | #--------train--------- 129 | 130 | i = 0 131 | 132 | for e in range(100): 133 | for _, batch in enumerate(dloader): 134 | p_optimizer.zero_grad() 135 | 136 | seqs = batch 137 | nextf_raw = seqs[:,10:,:,:,:].cuda() 138 | 139 | #----cnn encoder---- 140 | 141 | prevf_raw = seqs[:,:10,:,:,:].contiguous().view(-1,1,64,64).cuda() 142 | prevf_enc = cnn_encoder(prevf_raw).view(b_size,10,hidden_dim,16,16) 143 | 144 | nextf_enc = cnn_encoder(seqs[:,10:,:,:,:].contiguous().view(-1,1,64,64).cuda()).view(b_size,10,hidden_dim,16,16) 145 | 146 | #----lstm encoder--- 147 | 148 | hidden = lstm_encoder.get_init_states(b_size) 149 | _, encoder_state = lstm_encoder(prevf_enc, hidden) 150 | 151 | #----lstm decoder--- 152 | 153 | sample_prob = get_sample_prob(i) 154 | decoder_output_list = [] 155 | r_hist = [] 156 | for s in range(10): 157 | if s == 0: 158 | decoder_out, decoder_state = lstm_decoder(prevf_enc[:,-1:,:,:,:], encoder_state) 159 | else: 160 | r = np.random.rand() 161 | r_hist.append(int(r > sample_prob)) #debug 162 | if r > sample_prob: 163 | decoder_out, decoder_state = lstm_decoder(decoder_out, decoder_state) 164 | else: 165 | decoder_out, decoder_state = lstm_decoder(nextf_enc[:,s-1:s,:,:,:], decoder_state) 166 | decoder_output_list.append(decoder_out) 167 | 168 | final_decoder_out = torch.cat(decoder_output_list, 1) 169 | 170 | #----cnn decoder---- 171 | 172 | decoder_out_rs = final_decoder_out.view(-1,hidden_dim,16,16) 173 | 174 | cnn_decoder_out_raw = F.sigmoid(cnn_decoder(decoder_out_rs)) 175 | cnn_decoder_out = cnn_decoder_out_raw.view(b_size,10,1,64,64) 176 | 177 | #-----predictor loss---------- 178 | 179 | pred_loss = crit(cnn_decoder_out, nextf_raw) 180 | 181 | total_loss = pred_loss 182 | total_loss.backward() 183 | p_optimizer.step() 184 | 185 | #----ouputs--------- 186 | 187 | if i % 100 == 0: 188 | output_str = " | Epoch: {0} | Iter: {1} |\n".format(e, i) 189 | output_str += " TotalLoss: {0:.6f}\n".format(total_loss.item()) 190 | output_str += "-------------------------------------------\n" 191 | output_str += " Sampling prob: {0:.3f}\n".format(sample_prob) 192 | output_str += " Sampling:\n" 193 | output_str += " " + str(r_hist) + "\n" 194 | 195 | l = [] 196 | for j in range(3): 197 | l.append(seqs[j,:10,:,:,:]) 198 | l.append(cnn_decoder_out[j,:,:,:,:].data.cpu()) 199 | l.append(seqs[j,10:,:,:,:]) 200 | samples = torch.cat(l) 201 | torchvision.utils.save_image(samples, 202 | train_out_dir + "/{0:0>5}iter.png".format(i), nrow=10) 203 | 204 | if i % 200 == 0: 205 | #evaluate on the test set 206 | test_mse, test_bce, pnsr, fr_mse = tester.run_test(cnn_encoder,cnn_decoder, lstm_encoder, lstm_decoder) 207 | 208 | output_str += "-------------------------------------------\n" 209 | output_str += " Test BCE: {0:.6f}\n".format(test_bce) 210 | output_str += " Test MSE: {0:.6f}\n".format(test_mse) 211 | output_str += "-------------------------------------------\n" 212 | output_str += " Test PNSR:\n" 213 | output_str += " " + str(pnsr.data.cpu()) + "\n" 214 | output_str += " Frame mse:\n" 215 | output_str += " " + str(fr_mse.data.cpu()) + "\n" 216 | 217 | 218 | if i % 100 == 0: 219 | output_str += "\n=================================================\n" 220 | print(output_str) 221 | with open(log_file, "a") as lf: 222 | lf.write(output_str) 223 | 224 | if i % 400 == 0 and i > 17000: 225 | save_model(i, total_loss.item(), test_bce) 226 | 227 | 228 | #-------------------- 229 | 230 | i += 1 231 | if i >= 35000: 232 | exit() 233 | 234 | -------------------------------------------------------------------------------- /convlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | class ConvLSTMCell(nn.Module): 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True, activation=F.tanh, peephole=False, batchnorm=False): 9 | """ 10 | Initialize ConvLSTM cell. 11 | Parameters 12 | ---------- 13 | input_size: (int, int) 14 | Height and width of input tensor as (height, width). 15 | input_dim: int 16 | Number of channels of input tensor. 17 | hidden_dim: int 18 | Number of channels of hidden state. 19 | kernel_size: (int, int) 20 | Size of the convolutional kernel. 21 | bias: bool 22 | Whether or not to add the bias. 23 | """ 24 | super(ConvLSTMCell, self).__init__() 25 | 26 | self.height, self.width = input_size 27 | self.input_dim = input_dim 28 | self.hidden_dim = hidden_dim 29 | 30 | self.kernel_size = kernel_size 31 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 32 | self.bias = bias 33 | self.activation = activation 34 | self.peephole = peephole 35 | self.batchnorm = batchnorm 36 | 37 | if peephole: 38 | self.Wci = nn.Parameter(torch.FloatTensor(hidden_dim, self.height, self.width)) 39 | self.Wcf = nn.Parameter(torch.FloatTensor(hidden_dim, self.height, self.width)) 40 | self.Wco = nn.Parameter(torch.FloatTensor(hidden_dim, self.height, self.width)) 41 | 42 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 43 | out_channels=4 * self.hidden_dim, 44 | kernel_size=self.kernel_size, 45 | padding=self.padding, 46 | bias=self.bias) 47 | 48 | self.reset_parameters() 49 | 50 | def forward(self, input, prev_state): 51 | h_prev, c_prev = prev_state 52 | 53 | combined = torch.cat((input, h_prev), dim=1) # concatenate along channel axis 54 | combined_conv = self.conv(combined) 55 | 56 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 57 | 58 | if self.peephole: 59 | i = F.sigmoid(cc_i + self.Wci * c_prev) 60 | f = F.sigmoid(cc_f + self.Wcf * c_prev) 61 | else: 62 | i = F.sigmoid(cc_i) 63 | f = F.sigmoid(cc_f) 64 | 65 | g = self.activation(cc_g) 66 | c_cur = f * c_prev + i * g 67 | 68 | if self.peephole: 69 | o = F.sigmoid(cc_o + self.Wco * c_cur) 70 | else: 71 | o = F.sigmoid(cc_o) 72 | 73 | h_cur = o * self.activation(c_cur) 74 | 75 | return h_cur, c_cur 76 | 77 | def init_hidden(self, batch_size, cuda=True, device='cuda'): 78 | state = (torch.zeros(batch_size, self.hidden_dim, self.height, self.width), 79 | torch.zeros(batch_size, self.hidden_dim, self.height, self.width)) 80 | if cuda: 81 | state = (state[0].to(device), state[1].to(device)) 82 | return state 83 | 84 | def reset_parameters(self): 85 | #self.conv.reset_parameters() 86 | nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh')) 87 | self.conv.bias.data.zero_() 88 | 89 | if self.batchnorm: 90 | self.bn1.reset_parameters() 91 | self.bn2.reset_parameters() 92 | if self.peephole: 93 | std = 1. / math.sqrt(self.hidden_dim) 94 | self.Wci.data.uniform_(0,1)#(std=std) 95 | self.Wcf.data.uniform_(0,1)#(std=std) 96 | self.Wco.data.uniform_(0,1)#(std=std) 97 | 98 | 99 | class ConvLSTM(nn.Module): 100 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 101 | batch_first=False, bias=True, activation=F.tanh, peephole=False, batchnorm=False): 102 | super(ConvLSTM, self).__init__() 103 | 104 | self._check_kernel_size_consistency(kernel_size) 105 | 106 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 107 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 108 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 109 | activation = self._extend_for_multilayer(activation, num_layers) 110 | 111 | if not len(kernel_size) == len(hidden_dim) == len(activation) == num_layers: 112 | raise ValueError('Inconsistent list length.') 113 | 114 | self.height, self.width = input_size 115 | 116 | self.input_dim = input_dim 117 | self.hidden_dim = hidden_dim 118 | self.kernel_size = kernel_size 119 | self.num_layers = num_layers 120 | self.batch_first = batch_first 121 | self.bias = bias 122 | 123 | cell_list = [] 124 | for i in range(0, self.num_layers): 125 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 126 | 127 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 128 | input_dim=cur_input_dim, 129 | hidden_dim=self.hidden_dim[i], 130 | kernel_size=self.kernel_size[i], 131 | bias=self.bias, 132 | activation=activation[i], 133 | peephole=peephole, 134 | batchnorm=batchnorm)) 135 | 136 | self.cell_list = nn.ModuleList(cell_list) 137 | 138 | self.device = 'cpu' 139 | self.reset_parameters() 140 | 141 | def forward(self, input, hidden_state): 142 | """ 143 | 144 | Parameters 145 | ---------- 146 | input_tensor: 147 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 148 | hidden_state: 149 | Returns 150 | ------- 151 | last_state_list, layer_output 152 | """ 153 | cur_layer_input = torch.unbind(input, dim=int(self.batch_first)) 154 | 155 | if not hidden_state: 156 | hidden_state = self.get_init_states(cur_layer_input[0].size(int(not self.batch_first))) 157 | 158 | seq_len = len(cur_layer_input) 159 | 160 | layer_output_list = [] 161 | last_state_list = [] 162 | 163 | for layer_idx in range(self.num_layers): 164 | h, c = hidden_state[layer_idx] 165 | output_inner = [] 166 | for t in range(seq_len): 167 | h, c = self.cell_list[layer_idx](input=cur_layer_input[t], 168 | prev_state=[h, c]) 169 | output_inner.append(h) 170 | 171 | cur_layer_input = output_inner 172 | last_state_list.append((h, c)) 173 | 174 | layer_output = torch.stack(output_inner, dim=int(self.batch_first)) 175 | 176 | return layer_output, last_state_list 177 | 178 | def reset_parameters(self): 179 | for c in self.cell_list: 180 | c.reset_parameters() 181 | 182 | def get_init_states(self, batch_size, cuda=True, device='cuda'): 183 | init_states = [] 184 | for i in range(self.num_layers): 185 | init_states.append(self.cell_list[i].init_hidden(batch_size, cuda, device)) 186 | return init_states 187 | 188 | @staticmethod 189 | def _check_kernel_size_consistency(kernel_size): 190 | if not (isinstance(kernel_size, tuple) or (isinstance(kernel_size, list) 191 | and all([isinstance(elem, tuple) for elem in kernel_size]))): 192 | raise ValueError('`Kernel_size` must be tuple or list of tuples') 193 | 194 | @staticmethod 195 | def _extend_for_multilayer(param, num_layers): 196 | if not isinstance(param, list): 197 | param = [param] * num_layers 198 | return param 199 | -------------------------------------------------------------------------------- /train_lstm_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | import torchvision.utils 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | 14 | from convlstm import ConvLSTM 15 | from cnn import ConvEncoder, ConvDecoder 16 | 17 | import test 18 | from bouncing_mnist import BouncingMnist 19 | 20 | 21 | #-------parameters---- 22 | 23 | b_size = 32 24 | 25 | hidden_dim = 64 26 | hidden_spt = 16 27 | 28 | lstm_dims = [64,64,64] 29 | 30 | teacher_forcing = True 31 | 32 | torch.manual_seed(2) 33 | np.random.seed(2) 34 | torch.backends.cudnn.deterministic = True 35 | 36 | #----------create some dirs--------- 37 | 38 | root_log_dir = './results_verylong_' + ('' if teacher_forcing else 'ntf_') + \ 39 | str(lstm_dims)[1:-1].replace(', ','x') 40 | 41 | exists = True 42 | while exists: 43 | exists = os.path.exists(root_log_dir) 44 | if not exists: 45 | os.makedirs(root_log_dir) 46 | else: 47 | root_log_dir += '_' 48 | 49 | train_out_dir = root_log_dir + '/train' 50 | test_out_dir = root_log_dir + '/test' 51 | 52 | chkpnt_dir = root_log_dir + '/checkpoint' 53 | 54 | os.makedirs(train_out_dir) 55 | os.makedirs(test_out_dir) 56 | os.makedirs(chkpnt_dir) 57 | 58 | log_file = root_log_dir + '/out.log' 59 | f = open(log_file,"w+") 60 | f.close() 61 | 62 | #-------some methods--------- 63 | 64 | def save_model(step, train_bce, test_bce): 65 | new_dir = chkpnt_dir + "/{0:0>5}_iter|test_bce:{1:.4f}|train_bce:{2:.4f}".format(step, test_bce, train_bce) 66 | os.makedirs(new_dir) 67 | 68 | torch.save(lstm_encoder.state_dict(), new_dir + '/lstm_encoder.msd') 69 | torch.save(lstm_decoder.state_dict(), new_dir + '/lstm_decoder.msd') 70 | torch.save(cnn_encoder.state_dict(), new_dir + '/cnn_encoder.msd') 71 | torch.save(cnn_decoder.state_dict(), new_dir + '/cnn_decoder.msd') 72 | 73 | def get_sample_prob(step): 74 | alpha = 2450#1150 75 | beta = 8000 76 | return alpha / (alpha + np.exp((step + beta) / alpha)) 77 | 78 | #----------data--------- 79 | 80 | transform = transforms.Compose([ 81 | transforms.Lambda(lambda x: torch.from_numpy(x)) 82 | ]) 83 | 84 | raw_data = BouncingMnist(transform=transform) 85 | dloader = torch.utils.data.DataLoader(raw_data, batch_size=b_size, 86 | shuffle=True, drop_last=True, num_workers=4) 87 | 88 | tester = test.TesterMnist(out_dir=test_out_dir) 89 | 90 | #--------model--------- 91 | 92 | cnn_encoder = ConvEncoder(hidden_dim) 93 | cnn_decoder = ConvDecoder(b_size=b_size,inp_dim=hidden_dim) 94 | 95 | cnn_encoder.cuda() 96 | cnn_decoder.cuda() 97 | 98 | lstm_encoder = ConvLSTM( 99 | input_size=(hidden_spt,hidden_spt), 100 | input_dim=hidden_dim, 101 | hidden_dim=lstm_dims, 102 | kernel_size=(3,3), 103 | num_layers=3, 104 | peephole=True, 105 | batchnorm=False, 106 | batch_first=True, 107 | activation=F.tanh 108 | ) 109 | 110 | lstm_decoder = ConvLSTM( 111 | input_size=(hidden_spt,hidden_spt), 112 | input_dim=hidden_dim, 113 | hidden_dim=lstm_dims, 114 | kernel_size=(3,3), 115 | num_layers=3, 116 | peephole=True, 117 | batchnorm=False, 118 | batch_first=True, 119 | activation=F.tanh 120 | ) 121 | 122 | lstm_encoder.cuda() 123 | lstm_decoder.cuda() 124 | 125 | 126 | sigmoid = nn.Sigmoid() 127 | crit = nn.BCELoss() 128 | crit.cuda() 129 | 130 | 131 | params = list(cnn_encoder.parameters()) + list(cnn_decoder.parameters()) + \ 132 | list(lstm_encoder.parameters()) + list(lstm_decoder.parameters()) 133 | 134 | p_optimizer = optim.Adam(params) 135 | 136 | #--------train--------- 137 | 138 | i = 0 139 | 140 | for e in range(100): 141 | for _, batch in enumerate(dloader): 142 | p_optimizer.zero_grad() 143 | 144 | seqs = batch 145 | nextf_raw = seqs[:,10:,:,:,:].cuda() 146 | 147 | #----cnn encoder---- 148 | 149 | prevf_raw = seqs[:,:10,:,:,:].contiguous().view(-1,1,64,64).cuda() 150 | prevf_enc = cnn_encoder(prevf_raw).view(b_size,10,hidden_dim,hidden_spt,hidden_spt) 151 | 152 | if teacher_forcing: 153 | cnn_encoder_out = cnn_encoder(seqs[:,10:,:,:,:].contiguous().view(-1,1,64,64).cuda()) 154 | nextf_enc = cnn_encoder_out.view(b_size,10,hidden_dim,hidden_spt,hidden_spt) 155 | 156 | #----lstm encoder--- 157 | 158 | hidden = lstm_encoder.get_init_states(b_size) 159 | _, encoder_state = lstm_encoder(prevf_enc, hidden) 160 | 161 | #----lstm decoder--- 162 | 163 | sample_prob = get_sample_prob(i) if teacher_forcing else 0 164 | decoder_output_list = [] 165 | r_hist = [] 166 | 167 | for s in range(10): 168 | if s == 0: 169 | decoder_out, decoder_state = lstm_decoder(prevf_enc[:,-1:,:,:,:], encoder_state) 170 | else: 171 | r = np.random.rand() 172 | r_hist.append(int(r > sample_prob)) #debug 173 | 174 | if r > sample_prob: 175 | decoder_out, decoder_state = lstm_decoder(decoder_out, decoder_state) 176 | else: 177 | decoder_out, decoder_state = lstm_decoder(nextf_enc[:,s-1:s,:,:,:], decoder_state) 178 | 179 | decoder_output_list.append(decoder_out) 180 | 181 | final_decoder_out = torch.cat(decoder_output_list, 1) 182 | 183 | #----cnn decoder---- 184 | 185 | decoder_out_rs = final_decoder_out.view(-1,hidden_dim,hidden_spt,hidden_spt) 186 | 187 | cnn_decoder_out_raw = F.sigmoid(cnn_decoder(decoder_out_rs)) 188 | cnn_decoder_out = cnn_decoder_out_raw.view(b_size,10,1,64,64) 189 | 190 | #-----predictor loss---------- 191 | 192 | pred_loss = crit(cnn_decoder_out, nextf_raw) 193 | 194 | total_loss = pred_loss 195 | total_loss.backward() 196 | p_optimizer.step() 197 | 198 | #----ouputs--------- 199 | 200 | if i % 100 == 0: 201 | output_str = " | Epoch: {0} | Iter: {1} |\n".format(e, i) 202 | output_str += " TotalLoss: {0:.6f}\n".format(total_loss.item()) 203 | output_str += "-------------------------------------------\n" 204 | output_str += " Sampling prob: {0:.3f}\n".format(sample_prob) 205 | output_str += " Sampling:\n" 206 | output_str += " " + str(r_hist) + "\n" 207 | 208 | l = [] 209 | for j in range(3): 210 | l.append(seqs[j,:10,:,:,:]) 211 | l.append(cnn_decoder_out[j,:,:,:,:].data.cpu()) 212 | l.append(seqs[j,10:,:,:,:]) 213 | samples = torch.cat(l) 214 | torchvision.utils.save_image(samples, 215 | train_out_dir + "/{0:0>5}iter.png".format(i), nrow=10) 216 | 217 | if i % 200 == 0: 218 | #evaluate on the test set 219 | test_mse, test_bce, pnsr, fr_mse = tester.run_test(cnn_encoder,cnn_decoder, lstm_encoder, lstm_decoder) 220 | if i == 0: 221 | min_test_bce = test_bce 222 | 223 | output_str += "-------------------------------------------\n" 224 | output_str += " Test BCE: {0:.6f}\n".format(test_bce) 225 | output_str += " Test MSE: {0:.6f}\n".format(test_mse) 226 | output_str += "-------------------------------------------\n" 227 | output_str += " Test PNSR:\n" 228 | output_str += " " + str(pnsr.data.cpu()) + "\n" 229 | output_str += " Frame mse:\n" 230 | output_str += " " + str(fr_mse.data.cpu()) + "\n" 231 | 232 | if i % 100 == 0: 233 | output_str += "\n=================================================\n\n" 234 | print(output_str) 235 | with open(log_file, "a") as lf: 236 | lf.write(output_str) 237 | 238 | if i > 35000 and (i % 400 == 0 or test_bce < min_test_bce): 239 | min_test_bce = min(min_test_bce, test_bce) 240 | save_model(i, total_loss.item(), test_bce) 241 | 242 | #-------------------- 243 | 244 | i += 1 245 | if i >= 60000: 246 | exit() 247 | 248 | --------------------------------------------------------------------------------