├── LICENSE ├── README.md ├── dataset.py ├── generate.py ├── get-moving-mnist.sh ├── imgs ├── 7000_cap.png ├── 7001.png ├── 7002.png ├── 7003.png └── 7004.png ├── network.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvLSTM 2 | 3 | Convolutional LSTM implemented with chainer 4 | 5 | `python 3.5.2` + `chainer 3.0.0` 6 | 7 | ## Getting dataset 8 | 9 | ``` 10 | $ ./get-moving-mnist.sh 11 | ``` 12 | 13 | ## Training 14 | 15 | ``` 16 | $ python3 ./train.py -g 0 --epoch 10 --inf 3 --outf 3 --batch 16 17 | ``` 18 | 19 | ## Generating 20 | 21 | ``` 22 | $ mkdir img 23 | $ python3 ./generate.py --model results/model --id 7000 --inf 3 --outf 3 24 | ``` 25 | 26 | Then, the images are generated in `img/`. 27 | 28 | ``` 29 | $ mkdir img 30 | $ wget -O model https://www.dropbox.com/s/pthwoljp8qq3o30/model?dl=0 31 | $ python3 ./generate.py --model model --id 7000 --inf 3 --outf 3 32 | ``` 33 | 34 | Or, you can try ConvLSTM with the pre-trained model. 35 | 36 | ## Gallery 37 | 38 | ![7000_cap](https://github.com/joisino/ConvLSTM/blob/master/imgs/7000_cap.png) 39 | 40 | ![7001](https://github.com/joisino/ConvLSTM/blob/master/imgs/7001.png) 41 | 42 | ![7002](https://github.com/joisino/ConvLSTM/blob/master/imgs/7002.png) 43 | 44 | ![7003](https://github.com/joisino/ConvLSTM/blob/master/imgs/7003.png) 45 | 46 | ![7004](https://github.com/joisino/ConvLSTM/blob/master/imgs/7004.png) 47 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | 4 | class MovingMnistDataset(chainer.dataset.DatasetMixin): 5 | def __init__(self, l, r, inn, outn, path="./mnist_test_seq.npy"): 6 | self.l = l 7 | self.r = r 8 | self.inn = inn 9 | self.outn = outn 10 | self.data = np.load(path) 11 | self.data[self.data < 128] = 0 12 | self.data[self.data >= 128] = 1 13 | 14 | def __len__(self): 15 | return self.r - self.l 16 | 17 | def get_example(self, i): 18 | ind = self.l + i 19 | return self.data[:self.inn, ind, :, :].astype(np.int32), self.data[self.inn:self.inn+self.outn, ind, :, :].astype(np.int32) 20 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import numpy as np 5 | import argparse 6 | import chainer 7 | from chainer import serializers 8 | from chainer import Variable 9 | from chainer import cuda 10 | import dataset 11 | import network 12 | from PIL import Image 13 | 14 | def generate(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu', '-g', type=int, default=-1) 17 | parser.add_argument('--model', '-m', type=str, default=None) 18 | parser.add_argument('--id', '-i', type=int, default=0) 19 | parser.add_argument('--inf', type=int, default=10) 20 | parser.add_argument('--outf', type=int, default=10) 21 | args = parser.parse_args() 22 | 23 | test = dataset.MovingMnistDataset(0, 10000, args.inf, args.outf) 24 | 25 | model = network.MovingMnistNetwork(sz=[128, 64, 64], n=2, directory="img/") 26 | 27 | if args.model != None: 28 | print( "loading model from " + args.model ) 29 | serializers.load_npz(args.model, model) 30 | 31 | x, t = test[args.id] 32 | 33 | x = np.expand_dims(x, 0) 34 | t = np.expand_dims(t, 0) 35 | 36 | if args.gpu >= 0: 37 | cuda.get_device_from_id(0).use() 38 | model.to_gpu() 39 | x = cuda.cupy.array(x) 40 | t = cuda.cupy.array(t) 41 | 42 | res = model(Variable(x), Variable(t)) 43 | 44 | 45 | if __name__ == '__main__': 46 | generate() 47 | -------------------------------------------------------------------------------- /get-moving-mnist.sh: -------------------------------------------------------------------------------- 1 | curl -o mnist_test_seq.npy http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy 2 | -------------------------------------------------------------------------------- /imgs/7000_cap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joisino/ConvLSTM/dc8dcac4c9d44b3e93566bc0ee7a85ffb3a17b7f/imgs/7000_cap.png -------------------------------------------------------------------------------- /imgs/7001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joisino/ConvLSTM/dc8dcac4c9d44b3e93566bc0ee7a85ffb3a17b7f/imgs/7001.png -------------------------------------------------------------------------------- /imgs/7002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joisino/ConvLSTM/dc8dcac4c9d44b3e93566bc0ee7a85ffb3a17b7f/imgs/7002.png -------------------------------------------------------------------------------- /imgs/7003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joisino/ConvLSTM/dc8dcac4c9d44b3e93566bc0ee7a85ffb3a17b7f/imgs/7003.png -------------------------------------------------------------------------------- /imgs/7004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joisino/ConvLSTM/dc8dcac4c9d44b3e93566bc0ee7a85ffb3a17b7f/imgs/7004.png -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | from chainer import Variable 4 | from chainer import variable 5 | from chainer import reporter 6 | from chainer import initializers 7 | from chainer import Link, Chain 8 | import chainer.functions as F 9 | import chainer.links as L 10 | from PIL import Image 11 | 12 | class ConvLSTM(Chain): 13 | def __init__(self, inp = 256, mid = 128, sz = 3): 14 | super(ConvLSTM, self).__init__( 15 | Wxi = L.Convolution2D(inp, mid, sz, pad = sz//2), 16 | Whi = L.Convolution2D(mid, mid, sz, pad = sz//2, nobias = True), 17 | Wxf = L.Convolution2D(inp, mid, sz, pad = sz//2), 18 | Whf = L.Convolution2D(mid, mid, sz, pad = sz//2, nobias = True), 19 | Wxc = L.Convolution2D(inp, mid, sz, pad = sz//2), 20 | Whc = L.Convolution2D(mid, mid, sz, pad = sz//2, nobias = True), 21 | Wxo = L.Convolution2D(inp, mid, sz, pad = sz//2), 22 | Who = L.Convolution2D(mid, mid, sz, pad = sz//2, nobias = True) 23 | ) 24 | 25 | self.inp = inp 26 | self.mid = mid 27 | 28 | self.pc = None 29 | self.ph = None 30 | 31 | with self.init_scope(): 32 | Wci_initializer = initializers.Zero() 33 | self.Wci = variable.Parameter(Wci_initializer) 34 | Wcf_initializer = initializers.Zero() 35 | self.Wcf = variable.Parameter(Wcf_initializer) 36 | Wco_initializer = initializers.Zero() 37 | self.Wco = variable.Parameter(Wco_initializer) 38 | 39 | def reset_state(self, pc = None, ph = None): 40 | self.pc = pc 41 | self.ph = ph 42 | 43 | def initialize_params(self, shape): 44 | self.Wci.initialize((self.mid, shape[2], shape[3])) 45 | self.Wcf.initialize((self.mid, shape[2], shape[3])) 46 | self.Wco.initialize((self.mid, shape[2], shape[3])) 47 | 48 | def initialize_state(self, shape): 49 | self.pc = Variable(self.xp.zeros((shape[0], self.mid, shape[2], shape[3]), dtype = self.xp.float32)) 50 | self.ph = Variable(self.xp.zeros((shape[0], self.mid, shape[2], shape[3]), dtype = self.xp.float32)) 51 | 52 | def __call__(self, x): 53 | if self.Wci.data is None: 54 | self.initialize_params(x.data.shape) 55 | 56 | if self.pc is None: 57 | self.initialize_state(x.data.shape) 58 | 59 | ci = F.sigmoid(self.Wxi(x) + self.Whi(self.ph) + F.scale(self.pc, self.Wci, 1)) 60 | cf = F.sigmoid(self.Wxf(x) + self.Whf(self.ph) + F.scale(self.pc, self.Wcf, 1)) 61 | cc = cf * self.pc + ci * F.tanh(self.Wxc(x) + self.Whc(self.ph)) 62 | co = F.sigmoid(self.Wxo(x) + self.Who(self.ph) + F.scale(cc, self.Wco, 1)) 63 | ch = co * F.tanh(cc) 64 | 65 | self.pc = cc 66 | self.ph = ch 67 | 68 | return ch 69 | 70 | 71 | class MovingMnistNetwork(Chain): 72 | def __init__(self, sz=[256, 128, 128], n=256, directory=None): 73 | super(MovingMnistNetwork, self).__init__( 74 | e1 = ConvLSTM(n, sz[0], 5), 75 | e2 = ConvLSTM(sz[0], sz[1], 5), 76 | e3 = ConvLSTM(sz[1], sz[2], 5), 77 | p1 = ConvLSTM(n, sz[0], 5), 78 | p2 = ConvLSTM(sz[0], sz[1], 5), 79 | p3 = ConvLSTM(sz[1], sz[2], 5), 80 | last = L.Convolution2D(sum(sz), n, 1) 81 | ) 82 | 83 | self.n = n 84 | self.directory = directory 85 | 86 | def save_image(self, arr, filename): 87 | img = chainer.cuda.to_cpu(arr) 88 | img = img * 255 89 | img = Image.fromarray(img) 90 | if img.mode != 'RGB': 91 | img = img.convert('RGB') 92 | img.save(filename) 93 | 94 | def __call__(self, x, t): 95 | self.e1.reset_state() 96 | self.e2.reset_state() 97 | self.e3.reset_state() 98 | 99 | We = self.xp.array([[i == j for i in range(self.n)] for j in range(self.n)], dtype=self.xp.float32) 100 | for i in range(x.shape[1]): 101 | 102 | # save input images 103 | if self.directory is not None: 104 | for j in range(x.shape[0]): 105 | filename = self.directory + "input" + str(j) + "-" + str(i) + ".png" 106 | self.save_image(x[j, i, :, :].data, filename) 107 | 108 | xi = F.embed_id(x[:, i, :, :], We) 109 | xi = F.transpose(xi, (0, 3, 1, 2)) 110 | 111 | h1 = self.e1(xi) 112 | h2 = self.e2(h1) 113 | self.e3(h2) 114 | 115 | self.p1.reset_state(self.e1.pc, self.e1.ph) 116 | self.p2.reset_state(self.e2.pc, self.e2.ph) 117 | self.p3.reset_state(self.e3.pc, self.e3.ph) 118 | 119 | loss = None 120 | 121 | for i in range(t.shape[1]): 122 | xs = x.shape 123 | 124 | h1 = self.p1(Variable(self.xp.zeros((xs[0], self.n, xs[2], xs[3]), dtype=self.xp.float32))) 125 | h2 = self.p2(h1) 126 | h3 = self.p3(h2) 127 | 128 | h = F.concat((h1, h2, h3)) 129 | ans = self.last(h) 130 | 131 | # save output and teacher images 132 | if self.directory is not None: 133 | for j in range(t.shape[0]): 134 | filename = self.directory + "truth" + str(j) + "-" + str(i) + ".png" 135 | self.save_image(t[j, i, :, :].data, filename) 136 | filename = self.directory + "output" + str(j) + "-" + str(i) + ".png" 137 | self.save_image(self.xp.argmax(ans[j, :, :, :].data, 0).astype(np.int32), filename) 138 | 139 | cur_loss = F.softmax_cross_entropy(ans, t[:, i, :, :]) 140 | loss = cur_loss if loss is None else loss + cur_loss 141 | 142 | reporter.report({'loss': loss}, self) 143 | 144 | return loss 145 | 146 | 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import numpy as np 5 | import argparse 6 | import chainer 7 | from chainer import training 8 | from chainer import iterators, optimizers, serializers 9 | from chainer import cuda 10 | from chainer.training import extensions 11 | import dataset 12 | import network 13 | 14 | def train(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu', '-g', type=int, default=-1) 17 | parser.add_argument('--model', '-m', type=str, default=None) 18 | parser.add_argument('--opt', type=str, default=None) 19 | parser.add_argument('--epoch', '-e', type=int, default=3) 20 | parser.add_argument('--lr', '-l', type=float, default=0.001) 21 | parser.add_argument('--inf', type=int, default=10) 22 | parser.add_argument('--outf', type=int, default=10) 23 | parser.add_argument('--batch', '-b', type=int, default=8) 24 | args = parser.parse_args() 25 | 26 | train = dataset.MovingMnistDataset(0, 7000, args.inf, args.outf) 27 | train_iter = iterators.SerialIterator(train, batch_size=args.batch, shuffle=True) 28 | test = dataset.MovingMnistDataset(7000, 10000, args.inf, args.outf) 29 | test_iter = iterators.SerialIterator(test, batch_size=args.batch, repeat=False, shuffle=False) 30 | 31 | model = network.MovingMnistNetwork(sz=[128,64,64], n=2) 32 | 33 | if args.model != None: 34 | print( "loading model from " + args.model ) 35 | serializers.load_npz(args.model, model) 36 | 37 | if args.gpu >= 0: 38 | cuda.get_device_from_id(0).use() 39 | model.to_gpu() 40 | 41 | opt = optimizers.Adam(alpha=args.lr) 42 | opt.setup(model) 43 | 44 | if args.opt != None: 45 | print( "loading opt from " + args.opt ) 46 | serializers.load_npz(args.opt, opt) 47 | 48 | updater = training.StandardUpdater(train_iter, opt, device=args.gpu) 49 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results') 50 | 51 | trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) 52 | trainer.extend(extensions.LogReport(trigger=(10, 'iteration'))) 53 | trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss'])) 54 | trainer.extend(extensions.ProgressBar(update_interval=1)) 55 | 56 | trainer.run() 57 | 58 | modelname = "./results/model" 59 | print( "saving model to " + modelname ) 60 | serializers.save_npz(modelname, model) 61 | 62 | optname = "./results/opt" 63 | print( "saving opt to " + optname ) 64 | serializers.save_npz(optname, opt) 65 | 66 | if __name__ == '__main__': 67 | train() 68 | --------------------------------------------------------------------------------