├── download_data.sh ├── download_models.sh ├── environment.yml ├── kitti_settings.py ├── License.txt ├── .gitignore ├── keras_utils.py ├── kitti_evaluate.py ├── kitti_extrap_finetune.py ├── kitti_train.py ├── data_utils.py ├── README.md ├── process_kitti.py └── prednet.py /download_data.sh: -------------------------------------------------------------------------------- 1 | savedir="kitti_data" 2 | mkdir -p -- "$savedir" 3 | wget https://www.dropbox.com/s/rpwlnn6j39jjme4/kitti_data.zip?dl=0 -O $savedir/prednet_kitti_data.zip 4 | unzip $savedir/prednet_kitti_data.zip -d $savedir 5 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | savedir="model_data_keras2" 2 | mkdir -p -- "$savedir" 3 | wget https://www.dropbox.com/s/iutxm0anhxqca0z/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip 4 | unzip $savedir/model_data_keras2.zip -d $savedir 5 | rm $savedir/model_data_keras2.zip 6 | mv $savedir/model_data_keras2/* $savedir 7 | rm -r $savedir/model_data_keras2 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # This file allows you to configure your python environment quickly for prednet. 2 | # Please refer to the conda documentation for more info: 3 | # https://conda.io/docs/user-guide/tasks/manage-environments.html 4 | 5 | name: prednet 6 | 7 | dependencies: 8 | - python=2.7 9 | 10 | - pip: 11 | - keras==2.0.6 12 | - theano==0.9.0 13 | - tensorflow-gpu==1.2.1 14 | - hickle==2.1.0 15 | - matplotlib -------------------------------------------------------------------------------- /kitti_settings.py: -------------------------------------------------------------------------------- 1 | # Where KITTI data will be saved if you run process_kitti.py 2 | # If you directly download the processed data, change to the path of the data. 3 | DATA_DIR = './kitti_data/' 4 | 5 | # Where model weights and config will be saved if you run kitti_train.py 6 | # If you directly download the trained weights, change to appropriate path. 7 | WEIGHTS_DIR = './model_data_keras2/' 8 | 9 | # Where results (prediction plots and evaluation file) will be saved. 10 | RESULTS_SAVE_DIR = './kitti_results/' 11 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 coxlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Video and data folders 10 | kitti_data/ 11 | model_data_keras*/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # IPython Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | -------------------------------------------------------------------------------- /keras_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from keras import backend as K 5 | from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor 6 | from keras.models import model_from_json 7 | 8 | legacy_prednet_support = generate_legacy_interface( 9 | allowed_positional_args=['stack_sizes', 'R_stack_sizes', 10 | 'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'], 11 | conversions=[('dim_ordering', 'data_format'), 12 | ('consume_less', 'implementation')], 13 | value_conversions={'dim_ordering': {'tf': 'channels_last', 14 | 'th': 'channels_first', 15 | 'default': None}, 16 | 'consume_less': {'cpu': 0, 17 | 'mem': 1, 18 | 'gpu': 2}}, 19 | preprocessor=recurrent_args_preprocessor) 20 | 21 | # Convert old Keras (1.2) json models and weights to Keras 2.0 22 | def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file): 23 | from prednet import PredNet 24 | # If using tensorflow, it doesn't allow you to load the old weights. 25 | if K.backend() != 'theano': 26 | os.environ['KERAS_BACKEND'] = backend 27 | reload(K) 28 | 29 | f = open(old_json_file, 'r') 30 | json_string = f.read() 31 | f.close() 32 | model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) 33 | model.load_weights(old_weights_file) 34 | 35 | weights = model.layers[1].get_weights() 36 | if weights[0].shape[0] == model.layers[1].stack_sizes[1]: 37 | for i, w in enumerate(weights): 38 | if w.ndim == 4: 39 | weights[i] = np.transpose(w, (2, 3, 1, 0)) 40 | model.set_weights(weights) 41 | 42 | model.save_weights(new_weights_file) 43 | json_string = model.to_json() 44 | with open(new_json_file, "w") as f: 45 | f.write(json_string) 46 | 47 | 48 | if __name__ == '__main__': 49 | old_dir = './model_data/' 50 | new_dir = './model_data_keras2/' 51 | if not os.path.exists(new_dir): 52 | os.mkdir(new_dir) 53 | for w_tag in ['', '-Lall', '-extrapfinetuned']: 54 | m_tag = '' if w_tag == '-Lall' else w_tag 55 | convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json', 56 | old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5', 57 | new_dir + 'prednet_kitti_model' + m_tag + '.json', 58 | new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5') 59 | -------------------------------------------------------------------------------- /kitti_evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Evaluate trained PredNet on KITTI sequences. 3 | Calculates mean-squared error and plots predictions. 4 | ''' 5 | 6 | import os 7 | import numpy as np 8 | from six.moves import cPickle 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | import matplotlib.gridspec as gridspec 13 | 14 | from keras import backend as K 15 | from keras.models import Model, model_from_json 16 | from keras.layers import Input, Dense, Flatten 17 | 18 | from prednet import PredNet 19 | from data_utils import SequenceGenerator 20 | from kitti_settings import * 21 | 22 | 23 | n_plot = 40 24 | batch_size = 10 25 | nt = 10 26 | 27 | weights_file = os.path.join(WEIGHTS_DIR, 'tensorflow_weights/prednet_kitti_weights.hdf5') 28 | json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json') 29 | test_file = os.path.join(DATA_DIR, 'X_test.hkl') 30 | test_sources = os.path.join(DATA_DIR, 'sources_test.hkl') 31 | 32 | # Load trained model 33 | f = open(json_file, 'r') 34 | json_string = f.read() 35 | f.close() 36 | train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) 37 | train_model.load_weights(weights_file) 38 | 39 | # Create testing model (to output predictions) 40 | layer_config = train_model.layers[1].get_config() 41 | layer_config['output_mode'] = 'prediction' 42 | data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering'] 43 | test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config) 44 | input_shape = list(train_model.layers[0].batch_input_shape[1:]) 45 | input_shape[0] = nt 46 | inputs = Input(shape=tuple(input_shape)) 47 | predictions = test_prednet(inputs) 48 | test_model = Model(inputs=inputs, outputs=predictions) 49 | 50 | test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format) 51 | X_test = test_generator.create_all() 52 | X_hat = test_model.predict(X_test, batch_size) 53 | if data_format == 'channels_first': 54 | X_test = np.transpose(X_test, (0, 1, 3, 4, 2)) 55 | X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2)) 56 | 57 | # Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt 58 | mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first 59 | mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 ) 60 | if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR) 61 | f = open(RESULTS_SAVE_DIR + 'prediction_scores.txt', 'w') 62 | f.write("Model MSE: %f\n" % mse_model) 63 | f.write("Previous Frame MSE: %f" % mse_prev) 64 | f.close() 65 | 66 | # Plot some predictions 67 | aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3] 68 | plt.figure(figsize = (nt, 2*aspect_ratio)) 69 | gs = gridspec.GridSpec(2, nt) 70 | gs.update(wspace=0., hspace=0.) 71 | plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/') 72 | if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir) 73 | plot_idx = np.random.permutation(X_test.shape[0])[:n_plot] 74 | for i in plot_idx: 75 | for t in range(nt): 76 | plt.subplot(gs[t]) 77 | plt.imshow(X_test[i,t], interpolation='none') 78 | plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') 79 | if t==0: plt.ylabel('Actual', fontsize=10) 80 | 81 | plt.subplot(gs[t + nt]) 82 | plt.imshow(X_hat[i,t], interpolation='none') 83 | plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') 84 | if t==0: plt.ylabel('Predicted', fontsize=10) 85 | 86 | plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png') 87 | plt.clf() 88 | -------------------------------------------------------------------------------- /kitti_extrap_finetune.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Fine-tune PredNet model trained for t+1 prediction for up to t+5 prediction. 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | np.random.seed(123) 8 | 9 | from keras import backend as K 10 | from keras.models import Model, model_from_json 11 | from keras.layers import Input 12 | from keras.callbacks import LearningRateScheduler, ModelCheckpoint 13 | 14 | from prednet import PredNet 15 | from data_utils import SequenceGenerator 16 | from kitti_settings import * 17 | 18 | # Define loss as MAE of frame predictions after t=0 19 | # It doesn't make sense to compute loss on error representation, since the error isn't wrt ground truth when extrapolating. 20 | def extrap_loss(y_true, y_hat): 21 | y_true = y_true[:, 1:] 22 | y_hat = y_hat[:, 1:] 23 | return 0.5 * K.mean(K.abs(y_true - y_hat), axis=-1) # 0.5 to match scale of loss when trained in error mode (positive and negative errors split) 24 | 25 | nt = 15 26 | extrap_start_time = 10 # starting at this time step, the prediction from the previous time step will be treated as the actual input 27 | orig_weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5') # original t+1 weights 28 | orig_json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json') 29 | 30 | save_model = True 31 | extrap_weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights-extrapfinetuned.hdf5') # where new weights will be saved 32 | extrap_json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model-extrapfinetuned.json') 33 | 34 | # Data files 35 | train_file = os.path.join(DATA_DIR, 'X_train.hkl') 36 | train_sources = os.path.join(DATA_DIR, 'sources_train.hkl') 37 | val_file = os.path.join(DATA_DIR, 'X_val.hkl') 38 | val_sources = os.path.join(DATA_DIR, 'sources_val.hkl') 39 | 40 | # Training parameters 41 | nb_epoch = 150 42 | batch_size = 4 43 | samples_per_epoch = 500 44 | N_seq_val = 100 # number of sequences to use for validation 45 | 46 | # Load t+1 model 47 | f = open(orig_json_file, 'r') 48 | json_string = f.read() 49 | f.close() 50 | orig_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) 51 | orig_model.load_weights(orig_weights_file) 52 | 53 | layer_config = orig_model.layers[1].get_config() 54 | layer_config['output_mode'] = 'prediction' 55 | layer_config['extrap_start_time'] = extrap_start_time 56 | data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering'] 57 | prednet = PredNet(weights=orig_model.layers[1].get_weights(), **layer_config) 58 | 59 | input_shape = list(orig_model.layers[0].batch_input_shape[1:]) 60 | input_shape[0] = nt 61 | 62 | inputs = Input(input_shape) 63 | predictions = prednet(inputs) 64 | model = Model(inputs=inputs, outputs=predictions) 65 | model.compile(loss=extrap_loss, optimizer='adam') 66 | 67 | train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True, output_mode='prediction') 68 | val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val, output_mode='prediction') 69 | 70 | lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs 71 | callbacks = [LearningRateScheduler(lr_schedule)] 72 | if save_model: 73 | if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR) 74 | callbacks.append(ModelCheckpoint(filepath=extrap_weights_file, monitor='val_loss', save_best_only=True)) 75 | history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks, 76 | validation_data=val_generator, validation_steps=N_seq_val / batch_size) 77 | 78 | if save_model: 79 | json_string = model.to_json() 80 | with open(extrap_json_file, "w") as f: 81 | f.write(json_string) 82 | -------------------------------------------------------------------------------- /kitti_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Train PredNet on KITTI sequences. (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/) 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | np.random.seed(123) 8 | from six.moves import cPickle 9 | 10 | from keras import backend as K 11 | from keras.models import Model 12 | from keras.layers import Input, Dense, Flatten 13 | from keras.layers import LSTM 14 | from keras.layers import TimeDistributed 15 | from keras.callbacks import LearningRateScheduler, ModelCheckpoint 16 | from keras.optimizers import Adam 17 | 18 | from prednet import PredNet 19 | from data_utils import SequenceGenerator 20 | from kitti_settings import * 21 | 22 | 23 | save_model = True # if weights will be saved 24 | weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5') # where weights will be saved 25 | json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json') 26 | 27 | # Data files 28 | train_file = os.path.join(DATA_DIR, 'X_train.hkl') 29 | train_sources = os.path.join(DATA_DIR, 'sources_train.hkl') 30 | val_file = os.path.join(DATA_DIR, 'X_val.hkl') 31 | val_sources = os.path.join(DATA_DIR, 'sources_val.hkl') 32 | 33 | # Training parameters 34 | nb_epoch = 150 35 | batch_size = 4 36 | samples_per_epoch = 500 37 | N_seq_val = 100 # number of sequences to use for validation 38 | 39 | # Model parameters 40 | n_channels, im_height, im_width = (3, 128, 160) 41 | input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels) 42 | stack_sizes = (n_channels, 48, 96, 192) 43 | R_stack_sizes = stack_sizes 44 | A_filt_sizes = (3, 3, 3) 45 | Ahat_filt_sizes = (3, 3, 3, 3) 46 | R_filt_sizes = (3, 3, 3, 3) 47 | layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1] 48 | layer_loss_weights = np.expand_dims(layer_loss_weights, 1) 49 | nt = 10 # number of timesteps used for sequences in training 50 | time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first 51 | time_loss_weights[0] = 0 52 | 53 | 54 | prednet = PredNet(stack_sizes, R_stack_sizes, 55 | A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, 56 | output_mode='error', return_sequences=True) 57 | 58 | inputs = Input(shape=(nt,) + input_shape) 59 | errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers) 60 | errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer 61 | errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt) 62 | final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time 63 | model = Model(inputs=inputs, outputs=final_errors) 64 | model.compile(loss='mean_absolute_error', optimizer='adam') 65 | 66 | train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True) 67 | val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val) 68 | 69 | lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs 70 | callbacks = [LearningRateScheduler(lr_schedule)] 71 | if save_model: 72 | if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR) 73 | callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True)) 74 | 75 | history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks, 76 | validation_data=val_generator, validation_steps=N_seq_val / batch_size) 77 | 78 | if save_model: 79 | json_string = model.to_json() 80 | with open(json_file, "w") as f: 81 | f.write(json_string) 82 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import hickle as hkl 2 | import numpy as np 3 | from keras import backend as K 4 | from keras.preprocessing.image import Iterator 5 | 6 | # Data generator that creates sequences for input into PredNet. 7 | class SequenceGenerator(Iterator): 8 | def __init__(self, data_file, source_file, nt, 9 | batch_size=8, shuffle=False, seed=None, 10 | output_mode='error', sequence_start_mode='all', N_seq=None, 11 | data_format=K.image_data_format()): 12 | self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels) 13 | self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video 14 | self.nt = nt 15 | self.batch_size = batch_size 16 | self.data_format = data_format 17 | assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}' 18 | self.sequence_start_mode = sequence_start_mode 19 | assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}' 20 | self.output_mode = output_mode 21 | 22 | if self.data_format == 'channels_first': 23 | self.X = np.transpose(self.X, (0, 3, 1, 2)) 24 | self.im_shape = self.X[0].shape 25 | 26 | if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame 27 | self.possible_starts = np.array([i for i in range(self.X.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]]) 28 | elif self.sequence_start_mode == 'unique': #create sequences where each unique frame is in at most one sequence 29 | curr_location = 0 30 | possible_starts = [] 31 | while curr_location < self.X.shape[0] - self.nt + 1: 32 | if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]: 33 | possible_starts.append(curr_location) 34 | curr_location += self.nt 35 | else: 36 | curr_location += 1 37 | self.possible_starts = possible_starts 38 | 39 | if shuffle: 40 | self.possible_starts = np.random.permutation(self.possible_starts) 41 | if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to 42 | self.possible_starts = self.possible_starts[:N_seq] 43 | self.N_sequences = len(self.possible_starts) 44 | super(SequenceGenerator, self).__init__(len(self.possible_starts), batch_size, shuffle, seed) 45 | 46 | def __getitem__(self, null): 47 | return self.next() 48 | 49 | def next(self): 50 | with self.lock: 51 | current_index = (self.batch_index * self.batch_size) % self.n 52 | index_array, current_batch_size = next(self.index_generator), self.batch_size 53 | batch_x = np.zeros((current_batch_size, self.nt) + self.im_shape, np.float32) 54 | for i, idx in enumerate(index_array): 55 | idx = self.possible_starts[idx] 56 | batch_x[i] = self.preprocess(self.X[idx:idx+self.nt]) 57 | if self.output_mode == 'error': # model outputs errors, so y should be zeros 58 | batch_y = np.zeros(current_batch_size, np.float32) 59 | elif self.output_mode == 'prediction': # output actual pixels 60 | batch_y = batch_x 61 | return batch_x, batch_y 62 | 63 | def preprocess(self, X): 64 | return X.astype(np.float32) / 255 65 | 66 | def create_all(self): 67 | X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32) 68 | for i, idx in enumerate(self.possible_starts): 69 | X_all[i] = self.preprocess(self.X[idx:idx+self.nt]) 70 | return X_all 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prednet 2 | 3 | Code and models accompanying [Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning](https://arxiv.org/abs/1605.08104) by Bill Lotter, Gabriel Kreiman, and David Cox. 4 | 5 | The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005). 6 | **Check out example prediction videos [here](https://coxlab.github.io/prednet/).** 7 | 8 | The architecture is implemented as a custom layer1 in [Keras](http://keras.io/). 9 | Code and model data is compatible with Keras 2.0 and Python 2.7 and 3.6. 10 | The latest version has been tested on Keras 2.2.4 with Tensorflow 1.6. 11 | For previous versions of the code compatible with Keras 1.2.1, use fbcdc18. 12 | To convert old PredNet model files and weights for Keras 2.0 compatibility, see ```convert_model_to_keras2``` in `keras_utils.py`. 13 |
14 | 15 | ## KITTI Demo 16 | 17 | Code is included for training the PredNet on the raw [KITTI](http://www.cvlibs.net/datasets/kitti/) dataset. 18 | We include code for downloading and processing the data, as well as training and evaluating the model. 19 | The preprocessed data and can also be downloaded directly using `download_data.sh` and the **trained weights** by running `download_models.sh`. 20 | The model download will include the original weights trained for t+1 prediction, the fine-tuned weights trained to extrapolate predictions for multiple timesteps, and the "Lall" weights trained with an 0.1 loss weight on upper layers (see paper for details). 21 | 22 | ### Steps 23 | 1. **Download/process data** 24 | ```bash 25 | python process_kitti.py 26 | ``` 27 | This will scrape the KITTI website to download the raw data from the city, residential, and road categories (~165 GB) and then process the images (cropping, downsampling). 28 | Alternatively, the processed data (~3 GB) can be directly downloaded by executing `download_data.sh` 29 |
30 |
31 | 32 | 2. **Train model** 33 | ```bash 34 | python kitti_train.py 35 | ``` 36 | This will train a PredNet model for t+1 prediction. 37 | See [Keras FAQ](http://keras.io/getting-started/faq/#how-can-i-run-keras-on-gpu) on how to run using a GPU. 38 | **To download pre-trained weights**, run `download_models.sh` 39 |
40 |
41 | 42 | 3. **Evaluate model** 43 | ```bash 44 | python kitti_evaluate.py 45 | ``` 46 | This will output the mean-squared error for predictions as well as make plots comparing predictions to ground-truth. 47 | 48 | ### Feature Extraction 49 | Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate ```output_mode``` argument. For example, to extract the hidden state of the LSTM (the "Representation" units) in the lowest layer, use ```output_mode = 'R0'```. More details can be found in the PredNet docstring. 50 | 51 | ### Multi-Step Prediction 52 | The PredNet argument ```extrap_start_time``` can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with ```extrap_start_time = 10```, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights (see `kitti_extrap_finetune.py`), and the resulting fine-tuned weights are included in `download_models.sh`. Note that when training with extrapolation, the "errors" are no longer tied to ground truth, so the loss should be calculated on the pixel predictions themselves. This can be done by using ```output_mode = 'prediction'```, as illustrated in `kitti_extrap_finetune.py`. 53 | 54 | ### Additional Notes 55 | When training on a new dataset, the image size has to be divisible by 2^(nb of layers - 1) because of the cyclical 2x2 max-pooling and upsampling operations. 56 | 57 |
58 | 59 | 1 Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a 'super' layer where all layers in the PredNet are in one PredNet 'layer'. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation. 60 | -------------------------------------------------------------------------------- /process_kitti.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for downloading and processing KITTI data (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/) 3 | ''' 4 | 5 | import os 6 | import requests 7 | from bs4 import BeautifulSoup 8 | import urllib.request 9 | import numpy as np 10 | from imageio import imread 11 | from scipy.misc import imresize 12 | import hickle as hkl 13 | from kitti_settings import * 14 | 15 | 16 | desired_im_sz = (128, 160) 17 | categories = ['city', 'residential', 'road'] 18 | 19 | # Recordings used for validation and testing. 20 | # Were initially chosen randomly such that one of the city recordings was used for validation and one of each category was used for testing. 21 | val_recordings = [('city', '2011_09_26_drive_0005_sync')] 22 | test_recordings = [('city', '2011_09_26_drive_0104_sync'), ('residential', '2011_09_26_drive_0079_sync'), ('road', '2011_09_26_drive_0070_sync')] 23 | 24 | if not os.path.exists(DATA_DIR): os.mkdir(DATA_DIR) 25 | 26 | # Download raw zip files by scraping KITTI website 27 | def download_data(): 28 | base_dir = os.path.join(DATA_DIR, 'raw/') 29 | if not os.path.exists(base_dir): os.mkdir(base_dir) 30 | for c in categories: 31 | url = "http://www.cvlibs.net/datasets/kitti/raw_data.php?type=" + c 32 | r = requests.get(url) 33 | soup = BeautifulSoup(r.content) 34 | drive_list = soup.find_all("h3") 35 | drive_list = [d.text[:d.text.find(' ')] for d in drive_list] 36 | print( "Downloading set: " + c) 37 | c_dir = base_dir + c + '/' 38 | if not os.path.exists(c_dir): os.mkdir(c_dir) 39 | for i, d in enumerate(drive_list): 40 | print( str(i+1) + '/' + str(len(drive_list)) + ": " + d) 41 | url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/" + d + "/" + d + "_sync.zip" 42 | urllib.request.urlretrieve(url, filename=c_dir + d + "_sync.zip") 43 | 44 | 45 | # unzip images 46 | def extract_data(): 47 | for c in categories: 48 | c_dir = os.path.join(DATA_DIR, 'raw/', c + '/') 49 | zip_files = list(os.walk(c_dir, topdown=False))[-1][-1]#.next() 50 | for f in zip_files: 51 | print( 'unpacking: ' + f) 52 | spec_folder = f[:10] + '/' + f[:-4] + '/image_03/data*' 53 | command = 'unzip -qq ' + c_dir + f + ' ' + spec_folder + ' -d ' + c_dir + f[:-4] 54 | os.system(command) 55 | 56 | 57 | # Create image datasets. 58 | # Processes images and saves them in train, val, test splits. 59 | def process_data(): 60 | splits = {s: [] for s in ['train', 'test', 'val']} 61 | splits['val'] = val_recordings 62 | splits['test'] = test_recordings 63 | not_train = splits['val'] + splits['test'] 64 | for c in categories: # Randomly assign recordings to training and testing. Cross-validation done across entire recordings. 65 | c_dir = os.path.join(DATA_DIR, 'raw', c + '/') 66 | folders= list(os.walk(c_dir, topdown=False))[-1][-2] 67 | splits['train'] += [(c, f) for f in folders if (c, f) not in not_train] 68 | 69 | for split in splits: 70 | im_list = [] 71 | source_list = [] # corresponds to recording that image came from 72 | for category, folder in splits[split]: 73 | im_dir = os.path.join(DATA_DIR, 'raw/', category, folder, folder[:10], folder, 'image_03/data/') 74 | files = list(os.walk(im_dir, topdown=False))[-1][-1] 75 | im_list += [im_dir + f for f in sorted(files)] 76 | source_list += [category + '-' + folder] * len(files) 77 | 78 | print( 'Creating ' + split + ' data: ' + str(len(im_list)) + ' images') 79 | X = np.zeros((len(im_list),) + desired_im_sz + (3,), np.uint8) 80 | for i, im_file in enumerate(im_list): 81 | im = imread(im_file) 82 | X[i] = process_im(im, desired_im_sz) 83 | 84 | hkl.dump(X, os.path.join(DATA_DIR, 'X_' + split + '.hkl')) 85 | hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_' + split + '.hkl')) 86 | 87 | 88 | # resize and crop image 89 | def process_im(im, desired_sz): 90 | target_ds = float(desired_sz[0])/im.shape[0] 91 | im = imresize(im, (desired_sz[0], int(np.round(target_ds * im.shape[1])))) 92 | d = int((im.shape[1] - desired_sz[1]) / 2) 93 | im = im[:, d:d+desired_sz[1]] 94 | return im 95 | 96 | 97 | if __name__ == '__main__': 98 | download_data() 99 | extract_data() 100 | process_data() 101 | -------------------------------------------------------------------------------- /prednet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from keras import backend as K 4 | from keras import activations 5 | from keras.layers import Recurrent 6 | from keras.layers import Conv2D, UpSampling2D, MaxPooling2D 7 | from keras.engine import InputSpec 8 | from keras_utils import legacy_prednet_support 9 | 10 | class PredNet(Recurrent): 11 | '''PredNet architecture - Lotter 2016. 12 | Stacked convolutional LSTM inspired by predictive coding principles. 13 | 14 | # Arguments 15 | stack_sizes: number of channels in targets (A) and predictions (Ahat) in each layer of the architecture. 16 | Length is the number of layers in the architecture. 17 | First element is the number of channels in the input. 18 | Ex. (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and has 16 and 32 19 | channels in the second and third layers, respectively. 20 | R_stack_sizes: number of channels in the representation (R) modules. 21 | Length must equal length of stack_sizes, but the number of channels per layer can be different. 22 | A_filt_sizes: filter sizes for the target (A) modules. 23 | Has length of 1 - len(stack_sizes). 24 | Ex. (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of the errors (E) 25 | from the layer below (followed by max-pooling) 26 | Ahat_filt_sizes: filter sizes for the prediction (Ahat) modules. 27 | Has length equal to length of stack_sizes. 28 | Ex. (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution of the 29 | representation (R) modules at each layer. 30 | R_filt_sizes: filter sizes for the representation (R) modules. 31 | Has length equal to length of stack_sizes. 32 | Corresponds to the filter sizes for all convolutions in the LSTM. 33 | pixel_max: the maximum pixel value. 34 | Used to clip the pixel-layer prediction. 35 | error_activation: activation function for the error (E) units. 36 | A_activation: activation function for the target (A) and prediction (A_hat) units. 37 | LSTM_activation: activation function for the cell and hidden states of the LSTM. 38 | LSTM_inner_activation: activation function for the gates in the LSTM. 39 | output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below). 40 | Controls what is outputted by the PredNet. 41 | If 'error', the mean response of the error (E) units of each layer will be outputted. 42 | That is, the output shape will be (batch_size, nb_layers). 43 | If 'prediction', the frame prediction will be outputted. 44 | If 'all', the output will be the frame prediction concatenated with the mean layer errors. 45 | The frame prediction is flattened before concatenation. 46 | Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model 47 | For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number. 48 | For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'. 49 | The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively. 50 | extrap_start_time: time step for which model will start extrapolating. 51 | Starting at this time step, the prediction from the previous time step will be treated as the "actual" 52 | data_format: 'channels_first' or 'channels_last'. 53 | It defaults to the `image_data_format` value found in your 54 | Keras config file at `~/.keras/keras.json`. 55 | 56 | # References 57 | - [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104) 58 | - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) 59 | - [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214) 60 | - [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf) 61 | ''' 62 | @legacy_prednet_support 63 | def __init__(self, stack_sizes, R_stack_sizes, 64 | A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, 65 | pixel_max=1., error_activation='relu', A_activation='relu', 66 | LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid', 67 | output_mode='error', extrap_start_time=None, 68 | data_format=K.image_data_format(), **kwargs): 69 | self.stack_sizes = stack_sizes 70 | self.nb_layers = len(stack_sizes) 71 | assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)' 72 | self.R_stack_sizes = R_stack_sizes 73 | assert len(A_filt_sizes) == (self.nb_layers - 1), 'len(A_filt_sizes) must equal len(stack_sizes) - 1' 74 | self.A_filt_sizes = A_filt_sizes 75 | assert len(Ahat_filt_sizes) == self.nb_layers, 'len(Ahat_filt_sizes) must equal len(stack_sizes)' 76 | self.Ahat_filt_sizes = Ahat_filt_sizes 77 | assert len(R_filt_sizes) == (self.nb_layers), 'len(R_filt_sizes) must equal len(stack_sizes)' 78 | self.R_filt_sizes = R_filt_sizes 79 | 80 | self.pixel_max = pixel_max 81 | self.error_activation = activations.get(error_activation) 82 | self.A_activation = activations.get(A_activation) 83 | self.LSTM_activation = activations.get(LSTM_activation) 84 | self.LSTM_inner_activation = activations.get(LSTM_inner_activation) 85 | 86 | default_output_modes = ['prediction', 'error', 'all'] 87 | layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']] 88 | assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) 89 | self.output_mode = output_mode 90 | if self.output_mode in layer_output_modes: 91 | self.output_layer_type = self.output_mode[:-1] 92 | self.output_layer_num = int(self.output_mode[-1]) 93 | else: 94 | self.output_layer_type = None 95 | self.output_layer_num = None 96 | self.extrap_start_time = extrap_start_time 97 | 98 | assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}' 99 | self.data_format = data_format 100 | self.channel_axis = -3 if data_format == 'channels_first' else -1 101 | self.row_axis = -2 if data_format == 'channels_first' else -3 102 | self.column_axis = -1 if data_format == 'channels_first' else -2 103 | super(PredNet, self).__init__(**kwargs) 104 | self.input_spec = [InputSpec(ndim=5)] 105 | 106 | def compute_output_shape(self, input_shape): 107 | if self.output_mode == 'prediction': 108 | out_shape = input_shape[2:] 109 | elif self.output_mode == 'error': 110 | out_shape = (self.nb_layers,) 111 | elif self.output_mode == 'all': 112 | out_shape = (np.prod(input_shape[2:]) + self.nb_layers,) 113 | else: 114 | stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes' 115 | stack_mult = 2 if self.output_layer_type == 'E' else 1 116 | out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num] 117 | out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num 118 | out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num 119 | if self.data_format == 'channels_first': 120 | out_shape = (out_stack_size, out_nb_row, out_nb_col) 121 | else: 122 | out_shape = (out_nb_row, out_nb_col, out_stack_size) 123 | 124 | if self.return_sequences: 125 | return (input_shape[0], input_shape[1]) + out_shape 126 | else: 127 | return (input_shape[0],) + out_shape 128 | 129 | def get_initial_state(self, x): 130 | input_shape = self.input_spec[0].shape 131 | init_nb_row = input_shape[self.row_axis] 132 | init_nb_col = input_shape[self.column_axis] 133 | 134 | base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape 135 | non_channel_axis = -1 if self.data_format == 'channels_first' else -2 136 | for _ in range(2): 137 | base_initial_state = K.sum(base_initial_state, axis=non_channel_axis) 138 | base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels) 139 | 140 | initial_states = [] 141 | states_to_pass = ['r', 'c', 'e'] 142 | nlayers_to_pass = {u: self.nb_layers for u in states_to_pass} 143 | if self.extrap_start_time is not None: 144 | states_to_pass.append('ahat') # pass prediction in states so can use as actual for t+1 when extrapolating 145 | nlayers_to_pass['ahat'] = 1 146 | for u in states_to_pass: 147 | for l in range(nlayers_to_pass[u]): 148 | ds_factor = 2 ** l 149 | nb_row = init_nb_row // ds_factor 150 | nb_col = init_nb_col // ds_factor 151 | if u in ['r', 'c']: 152 | stack_size = self.R_stack_sizes[l] 153 | elif u == 'e': 154 | stack_size = 2 * self.stack_sizes[l] 155 | elif u == 'ahat': 156 | stack_size = self.stack_sizes[l] 157 | output_size = stack_size * nb_row * nb_col # flattened size 158 | 159 | reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size) 160 | initial_state = K.dot(base_initial_state, reducer) # (samples, output_size) 161 | if self.data_format == 'channels_first': 162 | output_shp = (-1, stack_size, nb_row, nb_col) 163 | else: 164 | output_shp = (-1, nb_row, nb_col, stack_size) 165 | initial_state = K.reshape(initial_state, output_shp) 166 | initial_states += [initial_state] 167 | 168 | if K._BACKEND == 'theano': 169 | from theano import tensor as T 170 | # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension. 171 | # In our case, this is a problem when training on grayscale images, and the below line fixes it. 172 | initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states] 173 | 174 | if self.extrap_start_time is not None: 175 | initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep 176 | return initial_states 177 | 178 | def build(self, input_shape): 179 | self.input_spec = [InputSpec(shape=input_shape)] 180 | self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']} 181 | 182 | for l in range(self.nb_layers): 183 | for c in ['i', 'f', 'c', 'o']: 184 | act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation 185 | self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) 186 | 187 | act = 'relu' if l == 0 else self.A_activation 188 | self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) 189 | 190 | if l < self.nb_layers - 1: 191 | self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format)) 192 | 193 | self.upsample = UpSampling2D(data_format=self.data_format) 194 | self.pool = MaxPooling2D(data_format=self.data_format) 195 | 196 | self.trainable_weights = [] 197 | nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2]) 198 | for c in sorted(self.conv_layers.keys()): 199 | for l in range(len(self.conv_layers[c])): 200 | ds_factor = 2 ** l 201 | if c == 'ahat': 202 | nb_channels = self.R_stack_sizes[l] 203 | elif c == 'a': 204 | nb_channels = 2 * self.stack_sizes[l] 205 | else: 206 | nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[l] 207 | if l < self.nb_layers - 1: 208 | nb_channels += self.R_stack_sizes[l+1] 209 | in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor) 210 | if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1]) 211 | with K.name_scope('layer_' + c + '_' + str(l)): 212 | self.conv_layers[c][l].build(in_shape) 213 | self.trainable_weights += self.conv_layers[c][l].trainable_weights 214 | 215 | self.states = [None] * self.nb_layers*3 216 | 217 | if self.extrap_start_time is not None: 218 | self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32') 219 | self.states += [None] * 2 # [previous frame prediction, timestep] 220 | 221 | def step(self, a, states): 222 | r_tm1 = states[:self.nb_layers] 223 | c_tm1 = states[self.nb_layers:2*self.nb_layers] 224 | e_tm1 = states[2*self.nb_layers:3*self.nb_layers] 225 | 226 | if self.extrap_start_time is not None: 227 | t = states[-1] 228 | a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual 229 | 230 | c = [] 231 | r = [] 232 | e = [] 233 | 234 | # Update R units starting from the top 235 | for l in reversed(range(self.nb_layers)): 236 | inputs = [r_tm1[l], e_tm1[l]] 237 | if l < self.nb_layers - 1: 238 | inputs.append(r_up) 239 | 240 | inputs = K.concatenate(inputs, axis=self.channel_axis) 241 | i = self.conv_layers['i'][l].call(inputs) 242 | f = self.conv_layers['f'][l].call(inputs) 243 | o = self.conv_layers['o'][l].call(inputs) 244 | _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs) 245 | _r = o * self.LSTM_activation(_c) 246 | c.insert(0, _c) 247 | r.insert(0, _r) 248 | 249 | if l > 0: 250 | r_up = self.upsample.call(_r) 251 | 252 | # Update feedforward path starting from the bottom 253 | for l in range(self.nb_layers): 254 | ahat = self.conv_layers['ahat'][l].call(r[l]) 255 | if l == 0: 256 | ahat = K.minimum(ahat, self.pixel_max) 257 | frame_prediction = ahat 258 | 259 | # compute errors 260 | e_up = self.error_activation(ahat - a) 261 | e_down = self.error_activation(a - ahat) 262 | 263 | e.append(K.concatenate((e_up, e_down), axis=self.channel_axis)) 264 | 265 | if self.output_layer_num == l: 266 | if self.output_layer_type == 'A': 267 | output = a 268 | elif self.output_layer_type == 'Ahat': 269 | output = ahat 270 | elif self.output_layer_type == 'R': 271 | output = r[l] 272 | elif self.output_layer_type == 'E': 273 | output = e[l] 274 | 275 | if l < self.nb_layers - 1: 276 | a = self.conv_layers['a'][l].call(e[l]) 277 | a = self.pool.call(a) # target for next layer 278 | 279 | if self.output_layer_type is None: 280 | if self.output_mode == 'prediction': 281 | output = frame_prediction 282 | else: 283 | for l in range(self.nb_layers): 284 | layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) 285 | all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1) 286 | if self.output_mode == 'error': 287 | output = all_error 288 | else: 289 | output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1) 290 | 291 | states = r + c + e 292 | if self.extrap_start_time is not None: 293 | states += [frame_prediction, t + 1] 294 | return output, states 295 | 296 | def get_config(self): 297 | config = {'stack_sizes': self.stack_sizes, 298 | 'R_stack_sizes': self.R_stack_sizes, 299 | 'A_filt_sizes': self.A_filt_sizes, 300 | 'Ahat_filt_sizes': self.Ahat_filt_sizes, 301 | 'R_filt_sizes': self.R_filt_sizes, 302 | 'pixel_max': self.pixel_max, 303 | 'error_activation': self.error_activation.__name__, 304 | 'A_activation': self.A_activation.__name__, 305 | 'LSTM_activation': self.LSTM_activation.__name__, 306 | 'LSTM_inner_activation': self.LSTM_inner_activation.__name__, 307 | 'data_format': self.data_format, 308 | 'extrap_start_time': self.extrap_start_time, 309 | 'output_mode': self.output_mode} 310 | base_config = super(PredNet, self).get_config() 311 | return dict(list(base_config.items()) + list(config.items())) 312 | --------------------------------------------------------------------------------