├── download_data.sh
├── download_models.sh
├── environment.yml
├── kitti_settings.py
├── License.txt
├── .gitignore
├── keras_utils.py
├── kitti_evaluate.py
├── kitti_extrap_finetune.py
├── kitti_train.py
├── data_utils.py
├── README.md
├── process_kitti.py
└── prednet.py
/download_data.sh:
--------------------------------------------------------------------------------
1 | savedir="kitti_data"
2 | mkdir -p -- "$savedir"
3 | wget https://www.dropbox.com/s/rpwlnn6j39jjme4/kitti_data.zip?dl=0 -O $savedir/prednet_kitti_data.zip
4 | unzip $savedir/prednet_kitti_data.zip -d $savedir
5 |
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | savedir="model_data_keras2"
2 | mkdir -p -- "$savedir"
3 | wget https://www.dropbox.com/s/iutxm0anhxqca0z/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip
4 | unzip $savedir/model_data_keras2.zip -d $savedir
5 | rm $savedir/model_data_keras2.zip
6 | mv $savedir/model_data_keras2/* $savedir
7 | rm -r $savedir/model_data_keras2
8 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # This file allows you to configure your python environment quickly for prednet.
2 | # Please refer to the conda documentation for more info:
3 | # https://conda.io/docs/user-guide/tasks/manage-environments.html
4 |
5 | name: prednet
6 |
7 | dependencies:
8 | - python=2.7
9 |
10 | - pip:
11 | - keras==2.0.6
12 | - theano==0.9.0
13 | - tensorflow-gpu==1.2.1
14 | - hickle==2.1.0
15 | - matplotlib
--------------------------------------------------------------------------------
/kitti_settings.py:
--------------------------------------------------------------------------------
1 | # Where KITTI data will be saved if you run process_kitti.py
2 | # If you directly download the processed data, change to the path of the data.
3 | DATA_DIR = './kitti_data/'
4 |
5 | # Where model weights and config will be saved if you run kitti_train.py
6 | # If you directly download the trained weights, change to appropriate path.
7 | WEIGHTS_DIR = './model_data_keras2/'
8 |
9 | # Where results (prediction plots and evaluation file) will be saved.
10 | RESULTS_SAVE_DIR = './kitti_results/'
11 |
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 coxlab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Video and data folders
10 | kitti_data/
11 | model_data_keras*/
12 |
13 | # Distribution / packaging
14 | .Python
15 | env/
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *,cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # IPython Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | venv/
87 | ENV/
88 |
89 | # Spyder project settings
90 | .spyderproject
91 |
92 | # Rope project settings
93 | .ropeproject
94 |
--------------------------------------------------------------------------------
/keras_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from keras import backend as K
5 | from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor
6 | from keras.models import model_from_json
7 |
8 | legacy_prednet_support = generate_legacy_interface(
9 | allowed_positional_args=['stack_sizes', 'R_stack_sizes',
10 | 'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'],
11 | conversions=[('dim_ordering', 'data_format'),
12 | ('consume_less', 'implementation')],
13 | value_conversions={'dim_ordering': {'tf': 'channels_last',
14 | 'th': 'channels_first',
15 | 'default': None},
16 | 'consume_less': {'cpu': 0,
17 | 'mem': 1,
18 | 'gpu': 2}},
19 | preprocessor=recurrent_args_preprocessor)
20 |
21 | # Convert old Keras (1.2) json models and weights to Keras 2.0
22 | def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file):
23 | from prednet import PredNet
24 | # If using tensorflow, it doesn't allow you to load the old weights.
25 | if K.backend() != 'theano':
26 | os.environ['KERAS_BACKEND'] = backend
27 | reload(K)
28 |
29 | f = open(old_json_file, 'r')
30 | json_string = f.read()
31 | f.close()
32 | model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
33 | model.load_weights(old_weights_file)
34 |
35 | weights = model.layers[1].get_weights()
36 | if weights[0].shape[0] == model.layers[1].stack_sizes[1]:
37 | for i, w in enumerate(weights):
38 | if w.ndim == 4:
39 | weights[i] = np.transpose(w, (2, 3, 1, 0))
40 | model.set_weights(weights)
41 |
42 | model.save_weights(new_weights_file)
43 | json_string = model.to_json()
44 | with open(new_json_file, "w") as f:
45 | f.write(json_string)
46 |
47 |
48 | if __name__ == '__main__':
49 | old_dir = './model_data/'
50 | new_dir = './model_data_keras2/'
51 | if not os.path.exists(new_dir):
52 | os.mkdir(new_dir)
53 | for w_tag in ['', '-Lall', '-extrapfinetuned']:
54 | m_tag = '' if w_tag == '-Lall' else w_tag
55 | convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json',
56 | old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5',
57 | new_dir + 'prednet_kitti_model' + m_tag + '.json',
58 | new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5')
59 |
--------------------------------------------------------------------------------
/kitti_evaluate.py:
--------------------------------------------------------------------------------
1 | '''
2 | Evaluate trained PredNet on KITTI sequences.
3 | Calculates mean-squared error and plots predictions.
4 | '''
5 |
6 | import os
7 | import numpy as np
8 | from six.moves import cPickle
9 | import matplotlib
10 | matplotlib.use('Agg')
11 | import matplotlib.pyplot as plt
12 | import matplotlib.gridspec as gridspec
13 |
14 | from keras import backend as K
15 | from keras.models import Model, model_from_json
16 | from keras.layers import Input, Dense, Flatten
17 |
18 | from prednet import PredNet
19 | from data_utils import SequenceGenerator
20 | from kitti_settings import *
21 |
22 |
23 | n_plot = 40
24 | batch_size = 10
25 | nt = 10
26 |
27 | weights_file = os.path.join(WEIGHTS_DIR, 'tensorflow_weights/prednet_kitti_weights.hdf5')
28 | json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
29 | test_file = os.path.join(DATA_DIR, 'X_test.hkl')
30 | test_sources = os.path.join(DATA_DIR, 'sources_test.hkl')
31 |
32 | # Load trained model
33 | f = open(json_file, 'r')
34 | json_string = f.read()
35 | f.close()
36 | train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
37 | train_model.load_weights(weights_file)
38 |
39 | # Create testing model (to output predictions)
40 | layer_config = train_model.layers[1].get_config()
41 | layer_config['output_mode'] = 'prediction'
42 | data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
43 | test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
44 | input_shape = list(train_model.layers[0].batch_input_shape[1:])
45 | input_shape[0] = nt
46 | inputs = Input(shape=tuple(input_shape))
47 | predictions = test_prednet(inputs)
48 | test_model = Model(inputs=inputs, outputs=predictions)
49 |
50 | test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
51 | X_test = test_generator.create_all()
52 | X_hat = test_model.predict(X_test, batch_size)
53 | if data_format == 'channels_first':
54 | X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
55 | X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))
56 |
57 | # Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
58 | mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first
59 | mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 )
60 | if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR)
61 | f = open(RESULTS_SAVE_DIR + 'prediction_scores.txt', 'w')
62 | f.write("Model MSE: %f\n" % mse_model)
63 | f.write("Previous Frame MSE: %f" % mse_prev)
64 | f.close()
65 |
66 | # Plot some predictions
67 | aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3]
68 | plt.figure(figsize = (nt, 2*aspect_ratio))
69 | gs = gridspec.GridSpec(2, nt)
70 | gs.update(wspace=0., hspace=0.)
71 | plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/')
72 | if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)
73 | plot_idx = np.random.permutation(X_test.shape[0])[:n_plot]
74 | for i in plot_idx:
75 | for t in range(nt):
76 | plt.subplot(gs[t])
77 | plt.imshow(X_test[i,t], interpolation='none')
78 | plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
79 | if t==0: plt.ylabel('Actual', fontsize=10)
80 |
81 | plt.subplot(gs[t + nt])
82 | plt.imshow(X_hat[i,t], interpolation='none')
83 | plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
84 | if t==0: plt.ylabel('Predicted', fontsize=10)
85 |
86 | plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png')
87 | plt.clf()
88 |
--------------------------------------------------------------------------------
/kitti_extrap_finetune.py:
--------------------------------------------------------------------------------
1 | '''
2 | Fine-tune PredNet model trained for t+1 prediction for up to t+5 prediction.
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | np.random.seed(123)
8 |
9 | from keras import backend as K
10 | from keras.models import Model, model_from_json
11 | from keras.layers import Input
12 | from keras.callbacks import LearningRateScheduler, ModelCheckpoint
13 |
14 | from prednet import PredNet
15 | from data_utils import SequenceGenerator
16 | from kitti_settings import *
17 |
18 | # Define loss as MAE of frame predictions after t=0
19 | # It doesn't make sense to compute loss on error representation, since the error isn't wrt ground truth when extrapolating.
20 | def extrap_loss(y_true, y_hat):
21 | y_true = y_true[:, 1:]
22 | y_hat = y_hat[:, 1:]
23 | return 0.5 * K.mean(K.abs(y_true - y_hat), axis=-1) # 0.5 to match scale of loss when trained in error mode (positive and negative errors split)
24 |
25 | nt = 15
26 | extrap_start_time = 10 # starting at this time step, the prediction from the previous time step will be treated as the actual input
27 | orig_weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5') # original t+1 weights
28 | orig_json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
29 |
30 | save_model = True
31 | extrap_weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights-extrapfinetuned.hdf5') # where new weights will be saved
32 | extrap_json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model-extrapfinetuned.json')
33 |
34 | # Data files
35 | train_file = os.path.join(DATA_DIR, 'X_train.hkl')
36 | train_sources = os.path.join(DATA_DIR, 'sources_train.hkl')
37 | val_file = os.path.join(DATA_DIR, 'X_val.hkl')
38 | val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')
39 |
40 | # Training parameters
41 | nb_epoch = 150
42 | batch_size = 4
43 | samples_per_epoch = 500
44 | N_seq_val = 100 # number of sequences to use for validation
45 |
46 | # Load t+1 model
47 | f = open(orig_json_file, 'r')
48 | json_string = f.read()
49 | f.close()
50 | orig_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
51 | orig_model.load_weights(orig_weights_file)
52 |
53 | layer_config = orig_model.layers[1].get_config()
54 | layer_config['output_mode'] = 'prediction'
55 | layer_config['extrap_start_time'] = extrap_start_time
56 | data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
57 | prednet = PredNet(weights=orig_model.layers[1].get_weights(), **layer_config)
58 |
59 | input_shape = list(orig_model.layers[0].batch_input_shape[1:])
60 | input_shape[0] = nt
61 |
62 | inputs = Input(input_shape)
63 | predictions = prednet(inputs)
64 | model = Model(inputs=inputs, outputs=predictions)
65 | model.compile(loss=extrap_loss, optimizer='adam')
66 |
67 | train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True, output_mode='prediction')
68 | val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val, output_mode='prediction')
69 |
70 | lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
71 | callbacks = [LearningRateScheduler(lr_schedule)]
72 | if save_model:
73 | if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
74 | callbacks.append(ModelCheckpoint(filepath=extrap_weights_file, monitor='val_loss', save_best_only=True))
75 | history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
76 | validation_data=val_generator, validation_steps=N_seq_val / batch_size)
77 |
78 | if save_model:
79 | json_string = model.to_json()
80 | with open(extrap_json_file, "w") as f:
81 | f.write(json_string)
82 |
--------------------------------------------------------------------------------
/kitti_train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Train PredNet on KITTI sequences. (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/)
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | np.random.seed(123)
8 | from six.moves import cPickle
9 |
10 | from keras import backend as K
11 | from keras.models import Model
12 | from keras.layers import Input, Dense, Flatten
13 | from keras.layers import LSTM
14 | from keras.layers import TimeDistributed
15 | from keras.callbacks import LearningRateScheduler, ModelCheckpoint
16 | from keras.optimizers import Adam
17 |
18 | from prednet import PredNet
19 | from data_utils import SequenceGenerator
20 | from kitti_settings import *
21 |
22 |
23 | save_model = True # if weights will be saved
24 | weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5') # where weights will be saved
25 | json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
26 |
27 | # Data files
28 | train_file = os.path.join(DATA_DIR, 'X_train.hkl')
29 | train_sources = os.path.join(DATA_DIR, 'sources_train.hkl')
30 | val_file = os.path.join(DATA_DIR, 'X_val.hkl')
31 | val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')
32 |
33 | # Training parameters
34 | nb_epoch = 150
35 | batch_size = 4
36 | samples_per_epoch = 500
37 | N_seq_val = 100 # number of sequences to use for validation
38 |
39 | # Model parameters
40 | n_channels, im_height, im_width = (3, 128, 160)
41 | input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
42 | stack_sizes = (n_channels, 48, 96, 192)
43 | R_stack_sizes = stack_sizes
44 | A_filt_sizes = (3, 3, 3)
45 | Ahat_filt_sizes = (3, 3, 3, 3)
46 | R_filt_sizes = (3, 3, 3, 3)
47 | layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
48 | layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
49 | nt = 10 # number of timesteps used for sequences in training
50 | time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first
51 | time_loss_weights[0] = 0
52 |
53 |
54 | prednet = PredNet(stack_sizes, R_stack_sizes,
55 | A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
56 | output_mode='error', return_sequences=True)
57 |
58 | inputs = Input(shape=(nt,) + input_shape)
59 | errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers)
60 | errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer
61 | errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
62 | final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
63 | model = Model(inputs=inputs, outputs=final_errors)
64 | model.compile(loss='mean_absolute_error', optimizer='adam')
65 |
66 | train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
67 | val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)
68 |
69 | lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
70 | callbacks = [LearningRateScheduler(lr_schedule)]
71 | if save_model:
72 | if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
73 | callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))
74 |
75 | history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
76 | validation_data=val_generator, validation_steps=N_seq_val / batch_size)
77 |
78 | if save_model:
79 | json_string = model.to_json()
80 | with open(json_file, "w") as f:
81 | f.write(json_string)
82 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import hickle as hkl
2 | import numpy as np
3 | from keras import backend as K
4 | from keras.preprocessing.image import Iterator
5 |
6 | # Data generator that creates sequences for input into PredNet.
7 | class SequenceGenerator(Iterator):
8 | def __init__(self, data_file, source_file, nt,
9 | batch_size=8, shuffle=False, seed=None,
10 | output_mode='error', sequence_start_mode='all', N_seq=None,
11 | data_format=K.image_data_format()):
12 | self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels)
13 | self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video
14 | self.nt = nt
15 | self.batch_size = batch_size
16 | self.data_format = data_format
17 | assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}'
18 | self.sequence_start_mode = sequence_start_mode
19 | assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}'
20 | self.output_mode = output_mode
21 |
22 | if self.data_format == 'channels_first':
23 | self.X = np.transpose(self.X, (0, 3, 1, 2))
24 | self.im_shape = self.X[0].shape
25 |
26 | if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame
27 | self.possible_starts = np.array([i for i in range(self.X.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]])
28 | elif self.sequence_start_mode == 'unique': #create sequences where each unique frame is in at most one sequence
29 | curr_location = 0
30 | possible_starts = []
31 | while curr_location < self.X.shape[0] - self.nt + 1:
32 | if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]:
33 | possible_starts.append(curr_location)
34 | curr_location += self.nt
35 | else:
36 | curr_location += 1
37 | self.possible_starts = possible_starts
38 |
39 | if shuffle:
40 | self.possible_starts = np.random.permutation(self.possible_starts)
41 | if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to
42 | self.possible_starts = self.possible_starts[:N_seq]
43 | self.N_sequences = len(self.possible_starts)
44 | super(SequenceGenerator, self).__init__(len(self.possible_starts), batch_size, shuffle, seed)
45 |
46 | def __getitem__(self, null):
47 | return self.next()
48 |
49 | def next(self):
50 | with self.lock:
51 | current_index = (self.batch_index * self.batch_size) % self.n
52 | index_array, current_batch_size = next(self.index_generator), self.batch_size
53 | batch_x = np.zeros((current_batch_size, self.nt) + self.im_shape, np.float32)
54 | for i, idx in enumerate(index_array):
55 | idx = self.possible_starts[idx]
56 | batch_x[i] = self.preprocess(self.X[idx:idx+self.nt])
57 | if self.output_mode == 'error': # model outputs errors, so y should be zeros
58 | batch_y = np.zeros(current_batch_size, np.float32)
59 | elif self.output_mode == 'prediction': # output actual pixels
60 | batch_y = batch_x
61 | return batch_x, batch_y
62 |
63 | def preprocess(self, X):
64 | return X.astype(np.float32) / 255
65 |
66 | def create_all(self):
67 | X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32)
68 | for i, idx in enumerate(self.possible_starts):
69 | X_all[i] = self.preprocess(self.X[idx:idx+self.nt])
70 | return X_all
71 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # prednet
2 |
3 | Code and models accompanying [Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning](https://arxiv.org/abs/1605.08104) by Bill Lotter, Gabriel Kreiman, and David Cox.
4 |
5 | The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005).
6 | **Check out example prediction videos [here](https://coxlab.github.io/prednet/).**
7 |
8 | The architecture is implemented as a custom layer1 in [Keras](http://keras.io/).
9 | Code and model data is compatible with Keras 2.0 and Python 2.7 and 3.6.
10 | The latest version has been tested on Keras 2.2.4 with Tensorflow 1.6.
11 | For previous versions of the code compatible with Keras 1.2.1, use fbcdc18.
12 | To convert old PredNet model files and weights for Keras 2.0 compatibility, see ```convert_model_to_keras2``` in `keras_utils.py`.
13 |
14 |
15 | ## KITTI Demo
16 |
17 | Code is included for training the PredNet on the raw [KITTI](http://www.cvlibs.net/datasets/kitti/) dataset.
18 | We include code for downloading and processing the data, as well as training and evaluating the model.
19 | The preprocessed data and can also be downloaded directly using `download_data.sh` and the **trained weights** by running `download_models.sh`.
20 | The model download will include the original weights trained for t+1 prediction, the fine-tuned weights trained to extrapolate predictions for multiple timesteps, and the "Lall" weights trained with an 0.1 loss weight on upper layers (see paper for details).
21 |
22 | ### Steps
23 | 1. **Download/process data**
24 | ```bash
25 | python process_kitti.py
26 | ```
27 | This will scrape the KITTI website to download the raw data from the city, residential, and road categories (~165 GB) and then process the images (cropping, downsampling).
28 | Alternatively, the processed data (~3 GB) can be directly downloaded by executing `download_data.sh`
29 |
30 |
31 |
32 | 2. **Train model**
33 | ```bash
34 | python kitti_train.py
35 | ```
36 | This will train a PredNet model for t+1 prediction.
37 | See [Keras FAQ](http://keras.io/getting-started/faq/#how-can-i-run-keras-on-gpu) on how to run using a GPU.
38 | **To download pre-trained weights**, run `download_models.sh`
39 |
40 |
41 |
42 | 3. **Evaluate model**
43 | ```bash
44 | python kitti_evaluate.py
45 | ```
46 | This will output the mean-squared error for predictions as well as make plots comparing predictions to ground-truth.
47 |
48 | ### Feature Extraction
49 | Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate ```output_mode``` argument. For example, to extract the hidden state of the LSTM (the "Representation" units) in the lowest layer, use ```output_mode = 'R0'```. More details can be found in the PredNet docstring.
50 |
51 | ### Multi-Step Prediction
52 | The PredNet argument ```extrap_start_time``` can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with ```extrap_start_time = 10```, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights (see `kitti_extrap_finetune.py`), and the resulting fine-tuned weights are included in `download_models.sh`. Note that when training with extrapolation, the "errors" are no longer tied to ground truth, so the loss should be calculated on the pixel predictions themselves. This can be done by using ```output_mode = 'prediction'```, as illustrated in `kitti_extrap_finetune.py`.
53 |
54 | ### Additional Notes
55 | When training on a new dataset, the image size has to be divisible by 2^(nb of layers - 1) because of the cyclical 2x2 max-pooling and upsampling operations.
56 |
57 |
58 |
59 | 1 Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a 'super' layer where all layers in the PredNet are in one PredNet 'layer'. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation.
60 |
--------------------------------------------------------------------------------
/process_kitti.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code for downloading and processing KITTI data (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/)
3 | '''
4 |
5 | import os
6 | import requests
7 | from bs4 import BeautifulSoup
8 | import urllib.request
9 | import numpy as np
10 | from imageio import imread
11 | from scipy.misc import imresize
12 | import hickle as hkl
13 | from kitti_settings import *
14 |
15 |
16 | desired_im_sz = (128, 160)
17 | categories = ['city', 'residential', 'road']
18 |
19 | # Recordings used for validation and testing.
20 | # Were initially chosen randomly such that one of the city recordings was used for validation and one of each category was used for testing.
21 | val_recordings = [('city', '2011_09_26_drive_0005_sync')]
22 | test_recordings = [('city', '2011_09_26_drive_0104_sync'), ('residential', '2011_09_26_drive_0079_sync'), ('road', '2011_09_26_drive_0070_sync')]
23 |
24 | if not os.path.exists(DATA_DIR): os.mkdir(DATA_DIR)
25 |
26 | # Download raw zip files by scraping KITTI website
27 | def download_data():
28 | base_dir = os.path.join(DATA_DIR, 'raw/')
29 | if not os.path.exists(base_dir): os.mkdir(base_dir)
30 | for c in categories:
31 | url = "http://www.cvlibs.net/datasets/kitti/raw_data.php?type=" + c
32 | r = requests.get(url)
33 | soup = BeautifulSoup(r.content)
34 | drive_list = soup.find_all("h3")
35 | drive_list = [d.text[:d.text.find(' ')] for d in drive_list]
36 | print( "Downloading set: " + c)
37 | c_dir = base_dir + c + '/'
38 | if not os.path.exists(c_dir): os.mkdir(c_dir)
39 | for i, d in enumerate(drive_list):
40 | print( str(i+1) + '/' + str(len(drive_list)) + ": " + d)
41 | url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/" + d + "/" + d + "_sync.zip"
42 | urllib.request.urlretrieve(url, filename=c_dir + d + "_sync.zip")
43 |
44 |
45 | # unzip images
46 | def extract_data():
47 | for c in categories:
48 | c_dir = os.path.join(DATA_DIR, 'raw/', c + '/')
49 | zip_files = list(os.walk(c_dir, topdown=False))[-1][-1]#.next()
50 | for f in zip_files:
51 | print( 'unpacking: ' + f)
52 | spec_folder = f[:10] + '/' + f[:-4] + '/image_03/data*'
53 | command = 'unzip -qq ' + c_dir + f + ' ' + spec_folder + ' -d ' + c_dir + f[:-4]
54 | os.system(command)
55 |
56 |
57 | # Create image datasets.
58 | # Processes images and saves them in train, val, test splits.
59 | def process_data():
60 | splits = {s: [] for s in ['train', 'test', 'val']}
61 | splits['val'] = val_recordings
62 | splits['test'] = test_recordings
63 | not_train = splits['val'] + splits['test']
64 | for c in categories: # Randomly assign recordings to training and testing. Cross-validation done across entire recordings.
65 | c_dir = os.path.join(DATA_DIR, 'raw', c + '/')
66 | folders= list(os.walk(c_dir, topdown=False))[-1][-2]
67 | splits['train'] += [(c, f) for f in folders if (c, f) not in not_train]
68 |
69 | for split in splits:
70 | im_list = []
71 | source_list = [] # corresponds to recording that image came from
72 | for category, folder in splits[split]:
73 | im_dir = os.path.join(DATA_DIR, 'raw/', category, folder, folder[:10], folder, 'image_03/data/')
74 | files = list(os.walk(im_dir, topdown=False))[-1][-1]
75 | im_list += [im_dir + f for f in sorted(files)]
76 | source_list += [category + '-' + folder] * len(files)
77 |
78 | print( 'Creating ' + split + ' data: ' + str(len(im_list)) + ' images')
79 | X = np.zeros((len(im_list),) + desired_im_sz + (3,), np.uint8)
80 | for i, im_file in enumerate(im_list):
81 | im = imread(im_file)
82 | X[i] = process_im(im, desired_im_sz)
83 |
84 | hkl.dump(X, os.path.join(DATA_DIR, 'X_' + split + '.hkl'))
85 | hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_' + split + '.hkl'))
86 |
87 |
88 | # resize and crop image
89 | def process_im(im, desired_sz):
90 | target_ds = float(desired_sz[0])/im.shape[0]
91 | im = imresize(im, (desired_sz[0], int(np.round(target_ds * im.shape[1]))))
92 | d = int((im.shape[1] - desired_sz[1]) / 2)
93 | im = im[:, d:d+desired_sz[1]]
94 | return im
95 |
96 |
97 | if __name__ == '__main__':
98 | download_data()
99 | extract_data()
100 | process_data()
101 |
--------------------------------------------------------------------------------
/prednet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from keras import backend as K
4 | from keras import activations
5 | from keras.layers import Recurrent
6 | from keras.layers import Conv2D, UpSampling2D, MaxPooling2D
7 | from keras.engine import InputSpec
8 | from keras_utils import legacy_prednet_support
9 |
10 | class PredNet(Recurrent):
11 | '''PredNet architecture - Lotter 2016.
12 | Stacked convolutional LSTM inspired by predictive coding principles.
13 |
14 | # Arguments
15 | stack_sizes: number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.
16 | Length is the number of layers in the architecture.
17 | First element is the number of channels in the input.
18 | Ex. (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and has 16 and 32
19 | channels in the second and third layers, respectively.
20 | R_stack_sizes: number of channels in the representation (R) modules.
21 | Length must equal length of stack_sizes, but the number of channels per layer can be different.
22 | A_filt_sizes: filter sizes for the target (A) modules.
23 | Has length of 1 - len(stack_sizes).
24 | Ex. (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of the errors (E)
25 | from the layer below (followed by max-pooling)
26 | Ahat_filt_sizes: filter sizes for the prediction (Ahat) modules.
27 | Has length equal to length of stack_sizes.
28 | Ex. (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution of the
29 | representation (R) modules at each layer.
30 | R_filt_sizes: filter sizes for the representation (R) modules.
31 | Has length equal to length of stack_sizes.
32 | Corresponds to the filter sizes for all convolutions in the LSTM.
33 | pixel_max: the maximum pixel value.
34 | Used to clip the pixel-layer prediction.
35 | error_activation: activation function for the error (E) units.
36 | A_activation: activation function for the target (A) and prediction (A_hat) units.
37 | LSTM_activation: activation function for the cell and hidden states of the LSTM.
38 | LSTM_inner_activation: activation function for the gates in the LSTM.
39 | output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below).
40 | Controls what is outputted by the PredNet.
41 | If 'error', the mean response of the error (E) units of each layer will be outputted.
42 | That is, the output shape will be (batch_size, nb_layers).
43 | If 'prediction', the frame prediction will be outputted.
44 | If 'all', the output will be the frame prediction concatenated with the mean layer errors.
45 | The frame prediction is flattened before concatenation.
46 | Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model
47 | For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number.
48 | For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'.
49 | The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively.
50 | extrap_start_time: time step for which model will start extrapolating.
51 | Starting at this time step, the prediction from the previous time step will be treated as the "actual"
52 | data_format: 'channels_first' or 'channels_last'.
53 | It defaults to the `image_data_format` value found in your
54 | Keras config file at `~/.keras/keras.json`.
55 |
56 | # References
57 | - [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104)
58 | - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf)
59 | - [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214)
60 | - [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf)
61 | '''
62 | @legacy_prednet_support
63 | def __init__(self, stack_sizes, R_stack_sizes,
64 | A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
65 | pixel_max=1., error_activation='relu', A_activation='relu',
66 | LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid',
67 | output_mode='error', extrap_start_time=None,
68 | data_format=K.image_data_format(), **kwargs):
69 | self.stack_sizes = stack_sizes
70 | self.nb_layers = len(stack_sizes)
71 | assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)'
72 | self.R_stack_sizes = R_stack_sizes
73 | assert len(A_filt_sizes) == (self.nb_layers - 1), 'len(A_filt_sizes) must equal len(stack_sizes) - 1'
74 | self.A_filt_sizes = A_filt_sizes
75 | assert len(Ahat_filt_sizes) == self.nb_layers, 'len(Ahat_filt_sizes) must equal len(stack_sizes)'
76 | self.Ahat_filt_sizes = Ahat_filt_sizes
77 | assert len(R_filt_sizes) == (self.nb_layers), 'len(R_filt_sizes) must equal len(stack_sizes)'
78 | self.R_filt_sizes = R_filt_sizes
79 |
80 | self.pixel_max = pixel_max
81 | self.error_activation = activations.get(error_activation)
82 | self.A_activation = activations.get(A_activation)
83 | self.LSTM_activation = activations.get(LSTM_activation)
84 | self.LSTM_inner_activation = activations.get(LSTM_inner_activation)
85 |
86 | default_output_modes = ['prediction', 'error', 'all']
87 | layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']]
88 | assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode)
89 | self.output_mode = output_mode
90 | if self.output_mode in layer_output_modes:
91 | self.output_layer_type = self.output_mode[:-1]
92 | self.output_layer_num = int(self.output_mode[-1])
93 | else:
94 | self.output_layer_type = None
95 | self.output_layer_num = None
96 | self.extrap_start_time = extrap_start_time
97 |
98 | assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}'
99 | self.data_format = data_format
100 | self.channel_axis = -3 if data_format == 'channels_first' else -1
101 | self.row_axis = -2 if data_format == 'channels_first' else -3
102 | self.column_axis = -1 if data_format == 'channels_first' else -2
103 | super(PredNet, self).__init__(**kwargs)
104 | self.input_spec = [InputSpec(ndim=5)]
105 |
106 | def compute_output_shape(self, input_shape):
107 | if self.output_mode == 'prediction':
108 | out_shape = input_shape[2:]
109 | elif self.output_mode == 'error':
110 | out_shape = (self.nb_layers,)
111 | elif self.output_mode == 'all':
112 | out_shape = (np.prod(input_shape[2:]) + self.nb_layers,)
113 | else:
114 | stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes'
115 | stack_mult = 2 if self.output_layer_type == 'E' else 1
116 | out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num]
117 | out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num
118 | out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num
119 | if self.data_format == 'channels_first':
120 | out_shape = (out_stack_size, out_nb_row, out_nb_col)
121 | else:
122 | out_shape = (out_nb_row, out_nb_col, out_stack_size)
123 |
124 | if self.return_sequences:
125 | return (input_shape[0], input_shape[1]) + out_shape
126 | else:
127 | return (input_shape[0],) + out_shape
128 |
129 | def get_initial_state(self, x):
130 | input_shape = self.input_spec[0].shape
131 | init_nb_row = input_shape[self.row_axis]
132 | init_nb_col = input_shape[self.column_axis]
133 |
134 | base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape
135 | non_channel_axis = -1 if self.data_format == 'channels_first' else -2
136 | for _ in range(2):
137 | base_initial_state = K.sum(base_initial_state, axis=non_channel_axis)
138 | base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels)
139 |
140 | initial_states = []
141 | states_to_pass = ['r', 'c', 'e']
142 | nlayers_to_pass = {u: self.nb_layers for u in states_to_pass}
143 | if self.extrap_start_time is not None:
144 | states_to_pass.append('ahat') # pass prediction in states so can use as actual for t+1 when extrapolating
145 | nlayers_to_pass['ahat'] = 1
146 | for u in states_to_pass:
147 | for l in range(nlayers_to_pass[u]):
148 | ds_factor = 2 ** l
149 | nb_row = init_nb_row // ds_factor
150 | nb_col = init_nb_col // ds_factor
151 | if u in ['r', 'c']:
152 | stack_size = self.R_stack_sizes[l]
153 | elif u == 'e':
154 | stack_size = 2 * self.stack_sizes[l]
155 | elif u == 'ahat':
156 | stack_size = self.stack_sizes[l]
157 | output_size = stack_size * nb_row * nb_col # flattened size
158 |
159 | reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size)
160 | initial_state = K.dot(base_initial_state, reducer) # (samples, output_size)
161 | if self.data_format == 'channels_first':
162 | output_shp = (-1, stack_size, nb_row, nb_col)
163 | else:
164 | output_shp = (-1, nb_row, nb_col, stack_size)
165 | initial_state = K.reshape(initial_state, output_shp)
166 | initial_states += [initial_state]
167 |
168 | if K._BACKEND == 'theano':
169 | from theano import tensor as T
170 | # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension.
171 | # In our case, this is a problem when training on grayscale images, and the below line fixes it.
172 | initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states]
173 |
174 | if self.extrap_start_time is not None:
175 | initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep
176 | return initial_states
177 |
178 | def build(self, input_shape):
179 | self.input_spec = [InputSpec(shape=input_shape)]
180 | self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']}
181 |
182 | for l in range(self.nb_layers):
183 | for c in ['i', 'f', 'c', 'o']:
184 | act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation
185 | self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
186 |
187 | act = 'relu' if l == 0 else self.A_activation
188 | self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
189 |
190 | if l < self.nb_layers - 1:
191 | self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format))
192 |
193 | self.upsample = UpSampling2D(data_format=self.data_format)
194 | self.pool = MaxPooling2D(data_format=self.data_format)
195 |
196 | self.trainable_weights = []
197 | nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2])
198 | for c in sorted(self.conv_layers.keys()):
199 | for l in range(len(self.conv_layers[c])):
200 | ds_factor = 2 ** l
201 | if c == 'ahat':
202 | nb_channels = self.R_stack_sizes[l]
203 | elif c == 'a':
204 | nb_channels = 2 * self.stack_sizes[l]
205 | else:
206 | nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[l]
207 | if l < self.nb_layers - 1:
208 | nb_channels += self.R_stack_sizes[l+1]
209 | in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor)
210 | if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1])
211 | with K.name_scope('layer_' + c + '_' + str(l)):
212 | self.conv_layers[c][l].build(in_shape)
213 | self.trainable_weights += self.conv_layers[c][l].trainable_weights
214 |
215 | self.states = [None] * self.nb_layers*3
216 |
217 | if self.extrap_start_time is not None:
218 | self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32')
219 | self.states += [None] * 2 # [previous frame prediction, timestep]
220 |
221 | def step(self, a, states):
222 | r_tm1 = states[:self.nb_layers]
223 | c_tm1 = states[self.nb_layers:2*self.nb_layers]
224 | e_tm1 = states[2*self.nb_layers:3*self.nb_layers]
225 |
226 | if self.extrap_start_time is not None:
227 | t = states[-1]
228 | a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual
229 |
230 | c = []
231 | r = []
232 | e = []
233 |
234 | # Update R units starting from the top
235 | for l in reversed(range(self.nb_layers)):
236 | inputs = [r_tm1[l], e_tm1[l]]
237 | if l < self.nb_layers - 1:
238 | inputs.append(r_up)
239 |
240 | inputs = K.concatenate(inputs, axis=self.channel_axis)
241 | i = self.conv_layers['i'][l].call(inputs)
242 | f = self.conv_layers['f'][l].call(inputs)
243 | o = self.conv_layers['o'][l].call(inputs)
244 | _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs)
245 | _r = o * self.LSTM_activation(_c)
246 | c.insert(0, _c)
247 | r.insert(0, _r)
248 |
249 | if l > 0:
250 | r_up = self.upsample.call(_r)
251 |
252 | # Update feedforward path starting from the bottom
253 | for l in range(self.nb_layers):
254 | ahat = self.conv_layers['ahat'][l].call(r[l])
255 | if l == 0:
256 | ahat = K.minimum(ahat, self.pixel_max)
257 | frame_prediction = ahat
258 |
259 | # compute errors
260 | e_up = self.error_activation(ahat - a)
261 | e_down = self.error_activation(a - ahat)
262 |
263 | e.append(K.concatenate((e_up, e_down), axis=self.channel_axis))
264 |
265 | if self.output_layer_num == l:
266 | if self.output_layer_type == 'A':
267 | output = a
268 | elif self.output_layer_type == 'Ahat':
269 | output = ahat
270 | elif self.output_layer_type == 'R':
271 | output = r[l]
272 | elif self.output_layer_type == 'E':
273 | output = e[l]
274 |
275 | if l < self.nb_layers - 1:
276 | a = self.conv_layers['a'][l].call(e[l])
277 | a = self.pool.call(a) # target for next layer
278 |
279 | if self.output_layer_type is None:
280 | if self.output_mode == 'prediction':
281 | output = frame_prediction
282 | else:
283 | for l in range(self.nb_layers):
284 | layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True)
285 | all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1)
286 | if self.output_mode == 'error':
287 | output = all_error
288 | else:
289 | output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1)
290 |
291 | states = r + c + e
292 | if self.extrap_start_time is not None:
293 | states += [frame_prediction, t + 1]
294 | return output, states
295 |
296 | def get_config(self):
297 | config = {'stack_sizes': self.stack_sizes,
298 | 'R_stack_sizes': self.R_stack_sizes,
299 | 'A_filt_sizes': self.A_filt_sizes,
300 | 'Ahat_filt_sizes': self.Ahat_filt_sizes,
301 | 'R_filt_sizes': self.R_filt_sizes,
302 | 'pixel_max': self.pixel_max,
303 | 'error_activation': self.error_activation.__name__,
304 | 'A_activation': self.A_activation.__name__,
305 | 'LSTM_activation': self.LSTM_activation.__name__,
306 | 'LSTM_inner_activation': self.LSTM_inner_activation.__name__,
307 | 'data_format': self.data_format,
308 | 'extrap_start_time': self.extrap_start_time,
309 | 'output_mode': self.output_mode}
310 | base_config = super(PredNet, self).get_config()
311 | return dict(list(base_config.items()) + list(config.items()))
312 |
--------------------------------------------------------------------------------