├── .gitattributes ├── LICENSE ├── README.md ├── data_utils.py ├── debug.py ├── debug ├── data_utils.py ├── debug.py ├── keras_utils.py ├── prednet.py ├── reproduce_bug.py └── test_LSTM │ └── convLSTM.py ├── evaluate.py ├── evaluate.sh ├── kitti_results └── prediction_plots │ ├── 0 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ ├── 1 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ ├── 2 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ ├── 3 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ ├── 4 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ ├── 2.5 │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png │ └── use_pretrained_weights │ ├── plot_0.png │ ├── plot_1.png │ ├── plot_2.png │ ├── plot_3.png │ ├── plot_4.png │ ├── plot_5.png │ ├── plot_6.png │ └── plot_7.png ├── load_weights.py ├── model_data_keras2 └── prednet_kitti_weights.hdf5 ├── prednet.py ├── train.py ├── train.sh └── visualization.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chenrui Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PredNet_pytorch 2 | 3 | An implement of PredNet in pytorch. See the paper [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104) in ICLR 2017 for more details. 4 | 5 | The [offical code](https://github.com/coxlab/prednet) is implemented via Keras, and the project website can be found at [https://coxlab.github.io/prednet/](https://coxlab.github.io/prednet/). 6 | 7 | ## Dataset 8 | The preprocessed KITTI data can be obtained using `downlaod_data.sh` in [offical code](https://github.com/coxlab/prednet). 9 | 10 | ## How to run 11 | ### Train model 12 | ``` 13 | sh train.sh 14 | ``` 15 | ### Evaluate model 16 | ``` 17 | sh evaluate.sh 18 | ``` 19 | 20 | ## Some example results 21 | ![example](./kitti_results/prediction_plots/use_pretrained_weights/plot_5.png) 22 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data as data 6 | import torchvision 7 | from torchvision import datasets, transforms 8 | 9 | import numpy as np 10 | import h5py 11 | import re 12 | 13 | 14 | class SequenceGenerator(data.Dataset): 15 | """ 16 | Sequence Generator 17 | 18 | the role of SequenceGenerator is equal to ImageFolder class in pytorch. 19 | 20 | the X_train.h5 contains 41396 images for 57 videos. 21 | the X_test.h5 contains 832 images for 3 videos. 22 | the X_val.h5 contains 154 images for 1 videos. 23 | 24 | Args: 25 | - data_file: 26 | data path, e.g., '/media/sdb1/chenrui/kitti_data/h5/X_train.h5' 27 | - source_file: 28 | e.g., '/media/sdb1/chenrui/kitti_data/h5/sources_train.h5' 29 | source for each image so when creating sequences can assure that consecutive frames are from same video. 30 | the content is like: 'road-2011_10_03_drive_0047_sync' 31 | - num_timeSteps: 32 | number of timesteps to predict 33 | - seed: 34 | Random seeding for data shuffling. 35 | - shuffle: 36 | shuffle or not 37 | - output_mode: 38 | `error` or `prediction` 39 | - sequence_start_mode: 40 | `all` or `unique`. 41 | `all`: allow for any possible sequence, starting from any frame. 42 | `unique`: create sequences where each unique frame is in at most one sequence 43 | - N_seq: 44 | TODO 45 | """ 46 | def __init__(self, data_file, source_file, num_timeSteps, shuffle = False, seed = None, 47 | output_mode = 'error', sequence_start_mode = 'all', N_seq = None, data_format = 'channels_first'): 48 | super(SequenceGenerator, self).__init__() 49 | pattern = re.compile(r'.*?h5/(.+?)\.h5') 50 | resList = re.findall(pattern, data_file) 51 | varName = resList[0] 52 | h5f = h5py.File(data_file, 'r') 53 | self.X = h5f[varName][:] # X will be like (n_images, cols, rows, channels) 54 | 55 | resList = re.findall(pattern, source_file) 56 | varName = resList[0] 57 | source_h5f = h5py.File(source_file, 'r') 58 | self.sources = source_h5f[varName][:] # list 59 | 60 | self.num_timeSteps = num_timeSteps 61 | self.shuffle = shuffle 62 | self.seed = seed 63 | assert output_mode in {'error', 'prediction'} 64 | self.output_mode = output_mode 65 | assert sequence_start_mode in {'all', 'unique'} 66 | self.sequence_start_mode = sequence_start_mode 67 | self.N_seq = N_seq 68 | self.data_format = data_format 69 | if self.data_format == 'channels_first': 70 | self.X = np.transpose(self.X, (0, 3, 1, 2)) 71 | self.img_shape = self.X[0].shape 72 | self.num_samples = self.X.shape[0] 73 | 74 | if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame (如果视频中任意一帧都可以作为起点,只需要确定加上序列长度后的小片段终点是否还属于同一个视频即可) 75 | self.possible_starts = np.array([i for i in range(self.num_samples - self.num_timeSteps) if self.sources[i] == self.sources[i + self.num_timeSteps - 1]]) 76 | elif self.sequence_start_mode == 'unique': # create sequences where each unique frame is in at most one sequence 77 | curr_location = 0 78 | possible_starts = [] 79 | while curr_location < self.num_samples - self.num_timeSteps + 1: 80 | if self.sources[curr_location] == self.sources[curr_location + self.num_timeSteps - 1]: 81 | possible_starts.append(curr_location) 82 | curr_location += self.num_timeSteps 83 | else: 84 | curr_location += 1 85 | self.possible_starts = possible_starts 86 | 87 | if shuffle: 88 | self.possible_starts = np.random.permutation(self.possible_starts) 89 | 90 | if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to 91 | self.possible_starts = self.possible_starts[:N_seq] 92 | self.N_sequences = len(self.possible_starts) # 所有可能的训练片段数 93 | 94 | def __getitem__(self, index): 95 | ''' 96 | Args: 97 | index (int): Index 98 | 99 | Returns: 100 | tuple: (stacked images, target) where target is NOT class_index of the target class 101 | BUT the order of frames in sorting task. 102 | ''' 103 | idx = self.possible_starts[index] 104 | image_group = self.preprocess(self.X[idx : (idx + self.num_timeSteps)]) 105 | 106 | if self.output_mode == 'error': 107 | target = 0. # model outputs errors, so y should be zeros 108 | elif self.output_mode == 'prediction': 109 | target = image_group # output actual pixels 110 | 111 | return image_group, target 112 | 113 | def preprocess(self, X): 114 | return X.astype(np.float32) / 255. 115 | 116 | def __len__(self): 117 | return self.N_sequences 118 | 119 | def create_all(self): 120 | '''等价于原代码中的create_all. 为evaluate模式服务, 返回全部的测试数据.''' 121 | X_all = np.zeros((self.N_sequences, self.num_timeSteps) + self.img_shape, np.float32) 122 | for i, idx in enumerate(self.possible_starts): 123 | X_all[i] = self.preprocess(self.X[idx : (idx + self.num_timeSteps)]) 124 | return X_all 125 | 126 | 127 | class ZcrDataLoader(object): 128 | '''[DataLoader for video frame predictation]''' 129 | def __init__(self, data_file, source_file, output_mode, sequence_start_mode, N_seq, args): 130 | super(ZcrDataLoader, self).__init__() 131 | self.data_file = data_file 132 | self.source_file = source_file 133 | self.output_mode = output_mode 134 | self.sequence_start_mode = sequence_start_mode 135 | self.N_seq = N_seq 136 | self.args = args 137 | 138 | def dataLoader(self): 139 | image_dataset = SequenceGenerator(self.data_file, self.source_file, self.args.num_timeSteps, self.args.shuffle, None, self.output_mode, self.sequence_start_mode, self.N_seq, self.args.data_format) 140 | # NOTE: 将drop_last设置为True, 可以删除最后一个不完整的batch(e.g.,当数据集大小不能被batch_size整除时, 最后一个batch的样本数是不够一个batch_size的, 这可能会导致某些要用到上一次结果的代码因为旧size和新size不匹配而报错(PredNet就有这个问题, 故这里将drop_last设置为True)) 141 | dataloader = data.DataLoader(image_dataset, batch_size = self.args.batch_size, shuffle = False, num_workers = self.args.workers, drop_last = True) 142 | return dataloader 143 | 144 | 145 | if __name__ == '__main__': 146 | pass -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | import h5py 6 | 7 | # dataDir = '../coxlab-prednet-cc76248/kitti_data/' 8 | # trainSet_path = os.path.join(dataDir, 'X_train.hkl') 9 | # train_sources = os.path.join(dataDir, 'sources_train.hkl') 10 | # testSet_path = os.path.join(dataDir, 'X_test.hkl') 11 | # test_sources = os.path.join(dataDir, 'sources_test.hkl') 12 | 13 | # @200.121 14 | dataDir = '/media/sdb1/chenrui/kitti_data/h5/' 15 | trainSet_path = os.path.join(dataDir, 'X_train.h5') 16 | train_sources = os.path.join(dataDir, 'sources_train.h5') 17 | testSet_path = os.path.join(dataDir, 'X_test.h5') 18 | test_sources = os.path.join(dataDir, 'sources_test.h5') 19 | 20 | 21 | h5f = h5py.File(testSet_path,'r') 22 | testSet = h5f['X_test'][:] 23 | 24 | # print(testSet) 25 | # print(type(testSet)) # 26 | # print(testSet.shape) # (832, 128, 160, 3) -------------------------------------------------------------------------------- /debug/data_utils.py: -------------------------------------------------------------------------------- 1 | # import hickle as hkl 2 | import h5py 3 | import numpy as np 4 | from keras import backend as K 5 | from keras.preprocessing.image import Iterator 6 | import re 7 | 8 | # Data generator that creates sequences for input into PredNet. 9 | class SequenceGenerator(Iterator): 10 | def __init__(self, data_file, source_file, nt, 11 | batch_size=8, shuffle=False, seed=None, 12 | output_mode='error', sequence_start_mode='all', N_seq=None, 13 | data_format=K.image_data_format()): 14 | # self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels) 15 | pattern = re.compile(r'.*?h5/(.+?)\.h5') 16 | resList = re.findall(pattern, data_file) 17 | varName = resList[0] 18 | h5f = h5py.File(data_file, 'r') 19 | self.X = h5f[varName][:] 20 | resList = re.findall(pattern, source_file) 21 | varName = resList[0] 22 | source_h5f = h5py.File(source_file, 'r') 23 | self.sources = source_h5f[varName][:] 24 | # self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video 25 | self.nt = nt 26 | self.batch_size = batch_size 27 | self.data_format = data_format 28 | assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}' 29 | self.sequence_start_mode = sequence_start_mode 30 | assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}' 31 | self.output_mode = output_mode 32 | 33 | if self.data_format == 'channels_first': 34 | self.X = np.transpose(self.X, (0, 3, 1, 2)) 35 | self.im_shape = self.X[0].shape 36 | 37 | if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame 38 | self.possible_starts = np.array([i for i in range(self.X.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]]) 39 | elif self.sequence_start_mode == 'unique': #create sequences where each unique frame is in at most one sequence 40 | curr_location = 0 41 | possible_starts = [] 42 | while curr_location < self.X.shape[0] - self.nt + 1: 43 | if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]: 44 | possible_starts.append(curr_location) 45 | curr_location += self.nt 46 | else: 47 | curr_location += 1 48 | self.possible_starts = possible_starts 49 | 50 | if shuffle: 51 | self.possible_starts = np.random.permutation(self.possible_starts) 52 | if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to 53 | self.possible_starts = self.possible_starts[:N_seq] 54 | self.N_sequences = len(self.possible_starts) 55 | super(SequenceGenerator, self).__init__(len(self.possible_starts), batch_size, shuffle, seed) 56 | 57 | def next(self): 58 | with self.lock: 59 | index_array, current_index, current_batch_size = next(self.index_generator) 60 | batch_x = np.zeros((current_batch_size, self.nt) + self.im_shape, np.float32) 61 | for i, idx in enumerate(index_array): 62 | idx = self.possible_starts[idx] 63 | batch_x[i] = self.preprocess(self.X[idx:idx+self.nt]) 64 | if self.output_mode == 'error': # model outputs errors, so y should be zeros 65 | batch_y = np.zeros(current_batch_size, np.float32) 66 | elif self.output_mode == 'prediction': # output actual pixels 67 | batch_y = batch_x 68 | return batch_x, batch_y 69 | 70 | def preprocess(self, X): 71 | return X.astype(np.float32) / 255 72 | 73 | def create_all(self): 74 | X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32) 75 | for i, idx in enumerate(self.possible_starts): 76 | X_all[i] = self.preprocess(self.X[idx:idx+self.nt]) 77 | return X_all 78 | -------------------------------------------------------------------------------- /debug/debug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | import h5py 6 | 7 | from prednet import PredNet 8 | 9 | # dataDir = '../coxlab-prednet-cc76248/kitti_data/' 10 | # trainSet_path = os.path.join(dataDir, 'X_train.hkl') 11 | # train_sources = os.path.join(dataDir, 'sources_train.hkl') 12 | # testSet_path = os.path.join(dataDir, 'X_test.hkl') 13 | # test_sources = os.path.join(dataDir, 'sources_test.hkl') 14 | 15 | # @200.121 16 | # dataDir = '/media/sdb1/chenrui/kitti_data/h5/' 17 | # trainSet_path = os.path.join(dataDir, 'X_train.h5') 18 | # train_sources = os.path.join(dataDir, 'sources_train.h5') 19 | # testSet_path = os.path.join(dataDir, 'X_test.h5') 20 | # test_sources = os.path.join(dataDir, 'sources_test.h5') 21 | 22 | 23 | # h5f = h5py.File(testSet_path,'r') 24 | # testSet = h5f['X_test'][:] 25 | 26 | # print(testSet) 27 | # print(type(testSet)) # 28 | # print(testSet.shape) # (832, 128, 160, 3) 29 | 30 | from data_utils import SequenceGenerator 31 | 32 | data_file = '/media/sdb1/chenrui/kitti_data/h5/X_test.h5' 33 | source_file = '/media/sdb1/chenrui/kitti_data/h5/sources_test.h5' 34 | nt = 10 35 | 36 | # sg = SequenceGenerator(data_file, source_file, nt) 37 | 38 | # print(next(sg)) 39 | 40 | n_channels = 3 41 | stack_sizes = (n_channels, 48, 96, 192) 42 | R_stack_sizes = stack_sizes 43 | A_filt_sizes = (3, 3, 3) 44 | Ahat_filt_sizes = (3, 3, 3, 3) 45 | R_filt_sizes = (3, 3, 3, 3) 46 | prednet = PredNet(stack_sizes, R_stack_sizes, A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, output_mode='error', data_format = 'channels_first', return_sequences=True) 47 | 48 | input_shape = (8, 3, 128, 160) 49 | prednet.build(input_shape) 50 | print('\n'.join(['%s:%s' % item for item in prednet.__dict__.items()])) 51 | print('+' * 30) 52 | print(prednet.conv_layers['ahat'][1].strides) -------------------------------------------------------------------------------- /debug/keras_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from keras import backend as K 5 | from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor 6 | from keras.models import model_from_json 7 | 8 | legacy_prednet_support = generate_legacy_interface( 9 | allowed_positional_args=['stack_sizes', 'R_stack_sizes', 10 | 'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'], 11 | conversions=[('dim_ordering', 'data_format'), 12 | ('consume_less', 'implementation')], 13 | value_conversions={'dim_ordering': {'tf': 'channels_last', 14 | 'th': 'channels_first', 15 | 'default': None}, 16 | 'consume_less': {'cpu': 0, 17 | 'mem': 1, 18 | 'gpu': 2}}, 19 | preprocessor=recurrent_args_preprocessor) 20 | 21 | # Convert old Keras (1.2) json models and weights to Keras 2.0 22 | def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file): 23 | from prednet import PredNet 24 | # If using tensorflow, it doesn't allow you to load the old weights. 25 | if K.backend() != 'theano': 26 | os.environ['KERAS_BACKEND'] = backend 27 | reload(K) 28 | 29 | f = open(old_json_file, 'r') 30 | json_string = f.read() 31 | f.close() 32 | model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) 33 | model.load_weights(old_weights_file) 34 | 35 | weights = model.layers[1].get_weights() 36 | if weights[0].shape[0] == model.layers[1].stack_sizes[1]: 37 | for i, w in enumerate(weights): 38 | if w.ndim == 4: 39 | weights[i] = np.transpose(w, (2, 3, 1, 0)) 40 | model.set_weights(weights) 41 | 42 | model.save_weights(new_weights_file) 43 | json_string = model.to_json() 44 | with open(new_json_file, "w") as f: 45 | f.write(json_string) 46 | 47 | 48 | if __name__ == '__main__': 49 | old_dir = './model_data/' 50 | new_dir = './model_data_keras2/' 51 | if not os.path.exists(new_dir): 52 | os.mkdir(new_dir) 53 | for w_tag in ['', '-Lall', '-extrapfinetuned']: 54 | m_tag = '' if w_tag == '-Lall' else w_tag 55 | convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json', 56 | old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5', 57 | new_dir + 'prednet_kitti_model' + m_tag + '.json', 58 | new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5') 59 | -------------------------------------------------------------------------------- /debug/prednet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from keras import backend as K 4 | from keras import activations 5 | from keras.layers import Recurrent 6 | from keras.layers import Conv2D, UpSampling2D, MaxPooling2D 7 | from keras.engine import InputSpec 8 | from keras_utils import legacy_prednet_support 9 | 10 | class PredNet(Recurrent): 11 | '''PredNet architecture - Lotter 2016. 12 | Stacked convolutional LSTM inspired by predictive coding principles. 13 | 14 | # Arguments 15 | stack_sizes: number of channels in targets (A) and predictions (Ahat) in each layer of the architecture. 16 | Length is the number of layers in the architecture. 17 | First element is the number of channels in the input. 18 | Ex. (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and has 16 and 32 19 | channels in the second and third layers, respectively. 20 | R_stack_sizes: number of channels in the representation (R) modules. 21 | Length must equal length of stack_sizes, but the number of channels per layer can be different. 22 | A_filt_sizes: filter sizes for the target (A) modules. 23 | Has length of len(stack_sizes) - 1. 24 | Ex. (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of the errors (E) 25 | from the layer below (followed by max-pooling) 26 | Ahat_filt_sizes: filter sizes for the prediction (Ahat) modules. 27 | Has length equal to length of stack_sizes. 28 | Ex. (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution of the 29 | representation (R) modules at each layer. 30 | R_filt_sizes: filter sizes for the representation (R) modules. 31 | Has length equal to length of stack_sizes. 32 | Corresponds to the filter sizes for all convolutions in the LSTM. 33 | pixel_max: the maximum pixel value. 34 | Used to clip the pixel-layer prediction. 35 | error_activation: activation function for the error (E) units. 36 | A_activation: activation function for the target (A) and prediction (A_hat) units. 37 | LSTM_activation: activation function for the cell and hidden states of the LSTM. 38 | LSTM_inner_activation: activation function for the gates in the LSTM. 39 | output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below). 40 | Controls what is outputted by the PredNet. 41 | If 'error', the mean response of the error (E) units of each layer will be outputted. 42 | That is, the output shape will be (batch_size, nb_layers). 43 | If 'prediction', the frame prediction will be outputted. 44 | If 'all', the output will be the frame prediction concatenated with the mean layer errors. 45 | The frame prediction is flattened before concatenation. 46 | Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model 47 | For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number. 48 | For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'. 49 | The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively. 50 | extrap_start_time: time step for which model will start extrapolating. 51 | Starting at this time step, the prediction from the previous time step will be treated as the "actual" 52 | data_format: 'channels_first' or 'channels_last'. 53 | It defaults to the `image_data_format` value found in your 54 | Keras config file at `~/.keras/keras.json`. 55 | 56 | # References 57 | - [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104) 58 | - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) 59 | - [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214) 60 | - [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf) 61 | ''' 62 | @legacy_prednet_support 63 | def __init__(self, stack_sizes, R_stack_sizes, 64 | A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, 65 | pixel_max=1., error_activation='relu', A_activation='relu', 66 | LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid', 67 | output_mode='error', extrap_start_time=None, 68 | data_format=K.image_data_format(), **kwargs): 69 | self.stack_sizes = stack_sizes 70 | self.nb_layers = len(stack_sizes) 71 | assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)' 72 | self.R_stack_sizes = R_stack_sizes 73 | assert len(A_filt_sizes) == (self.nb_layers - 1), 'len(A_filt_sizes) must equal len(stack_sizes) - 1' 74 | self.A_filt_sizes = A_filt_sizes 75 | assert len(Ahat_filt_sizes) == self.nb_layers, 'len(Ahat_filt_sizes) must equal len(stack_sizes)' 76 | self.Ahat_filt_sizes = Ahat_filt_sizes 77 | assert len(R_filt_sizes) == (self.nb_layers), 'len(R_filt_sizes) must equal len(stack_sizes)' 78 | self.R_filt_sizes = R_filt_sizes 79 | 80 | self.pixel_max = pixel_max 81 | self.error_activation = activations.get(error_activation) 82 | self.A_activation = activations.get(A_activation) 83 | self.LSTM_activation = activations.get(LSTM_activation) 84 | self.LSTM_inner_activation = activations.get(LSTM_inner_activation) 85 | 86 | default_output_modes = ['prediction', 'error', 'all'] 87 | layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']] 88 | assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) 89 | self.output_mode = output_mode 90 | if self.output_mode in layer_output_modes: 91 | self.output_layer_type = self.output_mode[:-1] 92 | self.output_layer_num = int(self.output_mode[-1]) 93 | else: 94 | self.output_layer_type = None 95 | self.output_layer_num = None 96 | self.extrap_start_time = extrap_start_time 97 | 98 | assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}' 99 | self.data_format = data_format 100 | self.channel_axis = -3 if data_format == 'channels_first' else -1 101 | self.row_axis = -2 if data_format == 'channels_first' else -3 102 | self.column_axis = -1 if data_format == 'channels_first' else -2 103 | super(PredNet, self).__init__(**kwargs) 104 | self.input_spec = [InputSpec(ndim=5)] 105 | 106 | def compute_output_shape(self, input_shape): 107 | if self.output_mode == 'prediction': 108 | out_shape = input_shape[2:] 109 | elif self.output_mode == 'error': 110 | out_shape = (self.nb_layers,) 111 | elif self.output_mode == 'all': 112 | out_shape = (np.prod(input_shape[2:]) + self.nb_layers,) 113 | else: 114 | stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes' 115 | stack_mult = 2 if self.output_layer_type == 'E' else 1 116 | out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num] 117 | out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num 118 | out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num 119 | if self.data_format == 'channels_first': 120 | out_shape = (out_stack_size, out_nb_row, out_nb_col) 121 | else: 122 | out_shape = (out_nb_row, out_nb_col, out_stack_size) 123 | 124 | if self.return_sequences: 125 | return (input_shape[0], input_shape[1]) + out_shape # zcr: input_shape[1] is the timesteps 126 | else: 127 | return (input_shape[0],) + out_shape 128 | 129 | def get_initial_state(self, x): 130 | input_shape = self.input_spec[0].shape 131 | init_nb_row = input_shape[self.row_axis] 132 | init_nb_col = input_shape[self.column_axis] 133 | 134 | base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape 135 | non_channel_axis = -1 if self.data_format == 'channels_first' else -2 136 | for _ in range(2): 137 | base_initial_state = K.sum(base_initial_state, axis=non_channel_axis) 138 | base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels) 139 | 140 | initial_states = [] 141 | states_to_pass = ['r', 'c', 'e'] 142 | nlayers_to_pass = {u: self.nb_layers for u in states_to_pass} 143 | if self.extrap_start_time is not None: 144 | states_to_pass.append('ahat') # pass prediction in states so can use as actual for t+1 when extrapolating 145 | nlayers_to_pass['ahat'] = 1 146 | for u in states_to_pass: 147 | for l in range(nlayers_to_pass[u]): 148 | ds_factor = 2 ** l 149 | nb_row = init_nb_row // ds_factor 150 | nb_col = init_nb_col // ds_factor 151 | if u in ['r', 'c']: 152 | stack_size = self.R_stack_sizes[l] 153 | elif u == 'e': 154 | stack_size = 2 * self.stack_sizes[l] 155 | elif u == 'ahat': 156 | stack_size = self.stack_sizes[l] 157 | output_size = stack_size * nb_row * nb_col # flattened size 158 | 159 | reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size) 160 | initial_state = K.dot(base_initial_state, reducer) # (samples, output_size) 161 | if self.data_format == 'channels_first': 162 | output_shp = (-1, stack_size, nb_row, nb_col) 163 | else: 164 | output_shp = (-1, nb_row, nb_col, stack_size) 165 | initial_state = K.reshape(initial_state, output_shp) 166 | initial_states += [initial_state] 167 | 168 | if K._BACKEND == 'theano': 169 | from theano import tensor as T 170 | # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension. 171 | # In our case, this is a problem when training on grayscale images, and the below line fixes it. 172 | initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states] 173 | 174 | if self.extrap_start_time is not None: 175 | initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep 176 | return initial_states 177 | 178 | def build(self, input_shape): 179 | self.input_spec = [InputSpec(shape=input_shape)] 180 | self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']} 181 | 182 | for l in range(self.nb_layers): 183 | for c in ['i', 'f', 'c', 'o']: 184 | act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation 185 | self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) 186 | 187 | act = 'relu' if l == 0 else self.A_activation 188 | self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) 189 | 190 | if l < self.nb_layers - 1: 191 | self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format)) 192 | 193 | self.upsample = UpSampling2D(data_format=self.data_format) 194 | self.pool = MaxPooling2D(data_format=self.data_format) 195 | 196 | self.trainable_weights = [] 197 | nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2]) 198 | for c in sorted(self.conv_layers.keys()): 199 | for l in range(len(self.conv_layers[c])): 200 | ds_factor = 2 ** l 201 | if c == 'ahat': 202 | nb_channels = self.R_stack_sizes[l] 203 | elif c == 'a': 204 | nb_channels = 2 * self.R_stack_sizes[l] 205 | else: 206 | nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[l] 207 | if l < self.nb_layers - 1: 208 | nb_channels += self.R_stack_sizes[l+1] 209 | in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor) 210 | if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1]) 211 | with K.name_scope('layer_' + c + '_' + str(l)): 212 | self.conv_layers[c][l].build(in_shape) 213 | self.trainable_weights += self.conv_layers[c][l].trainable_weights 214 | 215 | self.states = [None] * self.nb_layers*3 216 | 217 | if self.extrap_start_time is not None: 218 | self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32') 219 | self.states += [None] * 2 # [previous frame prediction, timestep] 220 | 221 | def step(self, a, states): 222 | r_tm1 = states[:self.nb_layers] 223 | c_tm1 = states[self.nb_layers:2*self.nb_layers] 224 | e_tm1 = states[2*self.nb_layers:3*self.nb_layers] 225 | 226 | if self.extrap_start_time is not None: 227 | t = states[-1] 228 | a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual 229 | 230 | c = [] 231 | r = [] 232 | e = [] 233 | 234 | # Update R units starting from the top 235 | for l in reversed(range(self.nb_layers)): 236 | inputs = [r_tm1[l], e_tm1[l]] 237 | if l < self.nb_layers - 1: # zcr: 即不是最高一层 238 | inputs.append(r_up) 239 | 240 | inputs = K.concatenate(inputs, axis=self.channel_axis) 241 | i = self.conv_layers['i'][l].call(inputs) 242 | f = self.conv_layers['f'][l].call(inputs) 243 | o = self.conv_layers['o'][l].call(inputs) 244 | _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs) 245 | _r = o * self.LSTM_activation(_c) 246 | c.insert(0, _c) 247 | r.insert(0, _r) 248 | 249 | if l > 0: 250 | r_up = self.upsample.call(_r) 251 | 252 | # Update feedforward path starting from the bottom 253 | for l in range(self.nb_layers): 254 | ahat = self.conv_layers['ahat'][l].call(r[l]) 255 | if l == 0: 256 | ahat = K.minimum(ahat, self.pixel_max) 257 | frame_prediction = ahat 258 | 259 | # compute errors 260 | e_up = self.error_activation(ahat - a) 261 | e_down = self.error_activation(a - ahat) 262 | 263 | e.append(K.concatenate((e_up, e_down), axis=self.channel_axis)) 264 | 265 | if self.output_layer_num == l: 266 | if self.output_layer_type == 'A': 267 | output = a 268 | elif self.output_layer_type == 'Ahat': 269 | output = ahat 270 | elif self.output_layer_type == 'R': 271 | output = r[l] 272 | elif self.output_layer_type == 'E': 273 | output = e[l] 274 | 275 | if l < self.nb_layers - 1: 276 | a = self.conv_layers['a'][l].call(e[l]) 277 | a = self.pool.call(a) # target for next layer 278 | 279 | if self.output_layer_type is None: 280 | if self.output_mode == 'prediction': 281 | output = frame_prediction 282 | else: 283 | for l in range(self.nb_layers): 284 | layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) 285 | all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1) 286 | if self.output_mode == 'error': 287 | output = all_error 288 | else: 289 | output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1) 290 | 291 | states = r + c + e 292 | if self.extrap_start_time is not None: 293 | states += [frame_prediction, t + 1] 294 | return output, states 295 | 296 | def get_config(self): 297 | config = {'stack_sizes': self.stack_sizes, 298 | 'R_stack_sizes': self.R_stack_sizes, 299 | 'A_filt_sizes': self.A_filt_sizes, 300 | 'Ahat_filt_sizes': self.Ahat_filt_sizes, 301 | 'R_filt_sizes': self.R_filt_sizes, 302 | 'pixel_max': self.pixel_max, 303 | 'error_activation': self.error_activation.__name__, 304 | 'A_activation': self.A_activation.__name__, 305 | 'LSTM_activation': self.LSTM_activation.__name__, 306 | 'LSTM_inner_activation': self.LSTM_inner_activation.__name__, 307 | 'data_format': self.data_format, 308 | 'extrap_start_time': self.extrap_start_time, 309 | 'output_mode': self.output_mode} 310 | base_config = super(PredNet, self).get_config() 311 | return dict(list(base_config.items()) + list(config.items())) 312 | -------------------------------------------------------------------------------- /debug/reproduce_bug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | # a = Variable(torch.from_numpy(np.arange(12).reshape(3, 4)).float(), requires_grad=True) 8 | L = [Variable(torch.from_numpy(np.arange(12).reshape(3, 4)).float(), requires_grad=True), 9 | Variable(torch.from_numpy(np.arange(1, 13).reshape(3, 4)).float(), requires_grad=True), 10 | Variable(torch.from_numpy(np.arange(2, 14).reshape(3, 4)).float(), requires_grad=True), 11 | Variable(torch.from_numpy(np.arange(3, 15).reshape(3, 4)).float(), requires_grad=True), 12 | ] 13 | w = Variable(torch.from_numpy(np.array([1., 0.1, 0.1, 0.1])).float()) 14 | # L = [a * i for i in range(1, 4)] 15 | error_list = [i * w for i in L] 16 | error_list = [e.sum() for e in L] 17 | ''' 18 | # print(L) 19 | >>> print(L) 20 | [Variable containing: 21 | 0 1 2 3 22 | 4 5 6 7 23 | 8 9 10 11 24 | [torch.LongTensor of size 3x4] 25 | , Variable containing: 26 | 0 2 4 6 27 | 8 10 12 14 28 | 16 18 20 22 29 | [torch.LongTensor of size 3x4] 30 | , Variable containing: 31 | 0 3 6 9 32 | 12 15 18 21 33 | 24 27 30 33 34 | [torch.LongTensor of size 3x4] 35 | ] 36 | 37 | >>> error_list 38 | [Variable containing: 39 | 66 40 | [torch.LongTensor of size 1] 41 | , Variable containing: 42 | 132 43 | [torch.LongTensor of size 1] 44 | , Variable containing: 45 | 198 46 | [torch.LongTensor of size 1] 47 | ] 48 | 49 | ''' 50 | total = error_list[0] 51 | for e in error_list[1:]: 52 | total = total + e 53 | 54 | total.backward() 55 | 56 | -------------------------------------------------------------------------------- /debug/test_LSTM/convLSTM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | def weights_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | m.weight.data.normal_(0.0, 0.02) 9 | elif classname.find('BatchNorm') != -1: 10 | m.weight.data.normal_(1.0, 0.02) 11 | m.bias.data.fill_(0) 12 | 13 | class CLSTM_cell(nn.Module): 14 | """Initialize a basic Conv LSTM cell. 15 | Args: 16 | shape: int tuple thats the height and width of the hidden states h and c() 17 | filter_size: int that is the height and width of the filters 18 | num_features: int thats the num of channels of the states, like hidden_size 19 | 20 | """ 21 | def __init__(self, shape, input_chans, filter_size, num_features): 22 | super(CLSTM_cell, self).__init__() 23 | 24 | self.shape = shape#H,W 25 | self.input_chans=input_chans 26 | self.filter_size=filter_size 27 | self.num_features = num_features # num_features也就是卷积层的out_channels数 28 | #self.batch_size=batch_size 29 | self.padding = int((filter_size - 1) / 2) #in this way the output has the same size 30 | self.conv = nn.Conv2d(self.input_chans + self.num_features, 4*self.num_features, self.filter_size, 1, self.padding) 31 | 32 | 33 | def forward(self, input, hidden_state): 34 | hidden,c=hidden_state#hidden and c are images with several channels 35 | print('*' * 20) 36 | print('hidden ',hidden.size()) 37 | print('input ',input.size()) 38 | combined = torch.cat((input, hidden), 1)#oncatenate in the channels 39 | print('combined',combined.size()) 40 | print('*' * 30) 41 | A=self.conv(combined) 42 | (ai,af,ao,ag)=torch.split(A,self.num_features,dim=1)#it should return 4 tensors 43 | i=torch.sigmoid(ai) 44 | f=torch.sigmoid(af) 45 | o=torch.sigmoid(ao) 46 | g=torch.tanh(ag) 47 | 48 | next_c=f*c+i*g 49 | next_h=o*torch.tanh(next_c) 50 | return next_h, next_c 51 | 52 | def init_hidden(self,batch_size): 53 | return (Variable(torch.zeros(batch_size,self.num_features,self.shape[0],self.shape[1])).cuda(),Variable(torch.zeros(batch_size,self.num_features,self.shape[0],self.shape[1])).cuda()) 54 | 55 | 56 | class CLSTM(nn.Module): 57 | """Initialize a basic Conv LSTM cell. 58 | Args: 59 | shape: int tuple thats the height and width of the hidden states h and c() 60 | filter_size: int that is the height and width of the filters 61 | num_features: int thats the num of channels of the states, like hidden_size 62 | 63 | """ 64 | def __init__(self, shape, input_chans, filter_size, num_features,num_layers): 65 | super(CLSTM, self).__init__() 66 | 67 | self.shape = shape#H,W 68 | self.input_chans=input_chans 69 | self.filter_size=filter_size 70 | self.num_features = num_features 71 | self.num_layers=num_layers 72 | cell_list=[] 73 | cell_list.append(CLSTM_cell(self.shape, self.input_chans, self.filter_size, self.num_features).cuda())# the first one has a different number of input channels 74 | 75 | for idcell in range(1,self.num_layers): 76 | cell_list.append(CLSTM_cell(self.shape, self.num_features, self.filter_size, self.num_features).cuda()) 77 | self.cell_list=nn.ModuleList(cell_list) 78 | 79 | 80 | def forward(self, input, hidden_state): 81 | """ 82 | args: 83 | hidden_state:list of tuples, one for every layer, each tuple should be hidden_layer_i,c_layer_i 84 | input is the tensor of shape seq_len,Batch,Chans,H,W 85 | 86 | """ 87 | 88 | current_input = input.transpose(0, 1)#now is seq_len,B,C,H,W 89 | #current_input=input 90 | next_hidden=[]#hidden states(h and c) 91 | seq_len=current_input.size(0) 92 | 93 | 94 | for idlayer in range(self.num_layers):#loop for every layer 95 | 96 | hidden_c=hidden_state[idlayer]#hidden and c are images with several channels. zcr: 这里的hidden_c包括(hidden,c), 这个hidden_c其实就是特定层初始化的(hidden, c) 97 | all_output = [] 98 | output_inner = [] 99 | for t in range(seq_len):#loop for every step 100 | hidden_c=self.cell_list[idlayer](current_input[t,...],hidden_c)#cell_list is a list with different conv_lstms 1 for every layer 101 | 102 | output_inner.append(hidden_c[0]) # hidden_c是一个(next_hidden, next_c)的元组 103 | # print('&' * 10, hidden_c[0].size()) # torch.Size([16, 10, 25, 25] 104 | 105 | next_hidden.append(hidden_c) 106 | current_input = torch.cat(output_inner, 0).view(current_input.size(0), *output_inner[0].size())#seq_len,B,chans,H,W 107 | 108 | 109 | return next_hidden, current_input 110 | 111 | def init_hidden(self,batch_size): 112 | init_states=[]#this is a list of tuples 113 | for i in range(self.num_layers): 114 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 115 | return init_states 116 | 117 | 118 | if __name__ == '__main__': 119 | ###########Usage####################################### 120 | print('hey') 121 | num_features=10 122 | filter_size=5 123 | batch_size=16 124 | shape=(25,25)#H,W 125 | inp_chans=3 126 | nlayers=2 127 | seq_len=4 128 | 129 | #If using this format, then we need to transpose in CLSTM 130 | input = Variable(torch.rand(batch_size,seq_len,inp_chans,shape[0],shape[1])).cuda() 131 | 132 | conv_lstm=CLSTM(shape, inp_chans, filter_size, num_features,nlayers) 133 | conv_lstm.apply(weights_init) 134 | conv_lstm.cuda() 135 | 136 | print('convlstm module:',conv_lstm) 137 | 138 | 139 | # print('params:') 140 | # params=conv_lstm.parameters() 141 | # for p in params: 142 | # print('param ',p.size()) 143 | # print('mean ',torch.mean(p)) 144 | 145 | 146 | hidden_state=conv_lstm.init_hidden(batch_size) 147 | # print('hidden states: ', hidden_state) 148 | print('hidden states length: ',len(hidden_state)) # 2 149 | # for i in range(len(hidden_state)): 150 | # print(i, len(hidden_state[i])) # 都是2 151 | # print('hidden_h shape ',hidden_state[0][0].size()) # torch.Size([16, 10, 25, 25]) 152 | # print('hidden_h shape ',hidden_state[0][1].size()) # torch.Size([16, 10, 25, 25]) 153 | # print('hidden_h shape ',hidden_state[1][0].size()) # torch.Size([16, 10, 25, 25]) 154 | # print('hidden_h shape ',hidden_state[1][1].size()) # torch.Size([16, 10, 25, 25]) 155 | out=conv_lstm(input,hidden_state) 156 | print('out shape',out[1].size()) # ([4, 16, 10, 25, 25]) 157 | print('len hidden ', len(out[0])) # 2 158 | print('next hidden',out[0][0][0].size()) # torch.Size([16, 10, 25, 25]) 159 | print('convlstm dict',conv_lstm.state_dict().keys()) 160 | 161 | 162 | L=torch.sum(out[1]) 163 | L.backward() -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | import numpy as np 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | import matplotlib.gridspec as gridspec 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | # zcr lib 16 | from prednet import PredNet 17 | from data_utils import ZcrDataLoader 18 | 19 | def arg_parse(): 20 | desc = "Video Frames Predicting Task via PredNet." 21 | parser = argparse.ArgumentParser(description = desc) 22 | 23 | parser.add_argument('--mode', default = 'train', type = str, 24 | help = 'train or evaluate (default: train)') 25 | parser.add_argument('--dataPath', default = '', type = str, metavar = 'PATH', 26 | help = 'path to video dataset (default: none)') 27 | parser.add_argument('--resultsPath', default = '', type = str, metavar = 'PATH', 28 | help = 'saving path to results of PredNet (default: none)') 29 | parser.add_argument('--checkpoint_file', default = '', type = str, 30 | help = 'checkpoint file for evaluating. (default: none)') 31 | parser.add_argument('--batch_size', default = 32, type = int, metavar = 'N', 32 | help = 'The size of batch') 33 | parser.add_argument('--num_plot', default = 40, type = int, metavar = 'N', 34 | help = 'how many images to plot') 35 | parser.add_argument('--num_timeSteps', default = 10, type = int, metavar = 'N', 36 | help = 'number of timesteps used for sequences in training (default: 10)') 37 | parser.add_argument('--workers', default = 4, type = int, metavar = 'N', 38 | help = 'number of data loading workers (default: 4)') 39 | parser.add_argument('--shuffle', default = True, type = bool, 40 | help = 'shuffle or not') 41 | parser.add_argument('--data_format', default = 'channels_last', type = str, 42 | help = '(c, h, w) or (h, w, c)?') 43 | parser.add_argument('--n_channels', default = 3, type = int, metavar = 'N', 44 | help = 'The number of input channels (default: 3)') 45 | parser.add_argument('--img_height', default = 128, type = int, metavar = 'N', 46 | help = 'The height of input frame (default: 128)') 47 | parser.add_argument('--img_width', default = 160, type = int, metavar = 'N', 48 | help = 'The width of input frame (default: 160)') 49 | # parser.add_argument('--stack_sizes', default = '', type = str, 50 | # help = 'Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.') 51 | # parser.add_argument('--R_stack_sizes', default = '', type = str, 52 | # help = 'Number of channels in the representation (R) modules.') 53 | # parser.add_argument('--A_filter_sizes', default = '', type = str, 54 | # help = 'Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image))') 55 | # parser.add_argument('--Ahat_filter_sizes', default = '', type = str, 56 | # help = 'Filter sizes for the prediction (Ahat) modules.') 57 | # parser.add_argument('--R_filter_sizes', default = '', type = str, 58 | # help = 'Filter sizes for the representation (R) modules.') 59 | 60 | args = parser.parse_args() 61 | return args 62 | 63 | def print_args(args): 64 | print('-' * 50) 65 | for arg, content in args.__dict__.items(): 66 | print("{}: {}".format(arg, content)) 67 | print('-' * 50) 68 | 69 | 70 | def evaluate(model, args): 71 | '''Evaluate PredNet on KITTI sequences''' 72 | prednet = model # Now prednet is the testing model (to output predictions) 73 | 74 | DATA_DIR = args.dataPath 75 | RESULTS_SAVE_DIR = args.resultsPath 76 | test_file = os.path.join(DATA_DIR, 'X_test.h5') 77 | test_sources = os.path.join(DATA_DIR, 'sources_test.h5') 78 | 79 | output_mode = 'prediction' 80 | sequence_start_mode = 'unique' 81 | N_seq = None 82 | dataLoader = ZcrDataLoader(test_file, test_sources, output_mode, sequence_start_mode, N_seq, args).dataLoader() 83 | X_test = dataLoader.dataset.create_all() 84 | # print('X_test.shape', X_test.shape) # (83, 10, 3, 128, 160) 85 | X_test = X_test[:8, ...] # to overcome `cuda runtime error: out of memory` 86 | batch_size = X_test.shape[0] 87 | X_groundTruth = np.transpose(X_test, (1, 0, 2, 3, 4)) # (timesteps, batch_size, 3, 128, 160) 88 | X_groundTruth_list = [] 89 | for t in range(X_groundTruth.shape[0]): 90 | X_groundTruth_list.append(np.squeeze(X_groundTruth[t, ...])) # (batch_size, 3, 128, 160) 91 | 92 | X_test = Variable(torch.from_numpy(X_test).float().cuda()) 93 | 94 | if prednet.data_format == 'channels_first': 95 | input_shape = (batch_size, args.num_timeSteps, n_channels, img_height, img_width) 96 | else: 97 | input_shape = (batch_size, args.num_timeSteps, img_height, img_width, n_channels) 98 | initial_states = prednet.get_initial_states(input_shape) 99 | predictions = prednet(X_test, initial_states) 100 | # print(predictions) 101 | # print(predictions[0].size()) # torch.Size([8, 3, 128, 160]) 102 | 103 | X_predict_list = [pred.data.cpu().numpy() for pred in predictions] # length of X_predict_list is timesteps. 每个元素shape是(batch_size, 3, H, W) 104 | 105 | # Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt 106 | # MSE_PredNet = np.mean((real_X[:, 1: ] - pred_X[:, 1:])**2) # look at all timesteps except the first 107 | # MSE_previous = np.mean((real_X[:, :-1] - real_X[:, 1:])**2) 108 | # if not os.path.exists(RESULTS_SAVE_DIR): 109 | # os.mkdir(RESULTS_SAVE_DIR) 110 | # score_file = os.path.join(RESULTS_SAVE_DIR, 'prediction_scores.txt') 111 | # with open(score_file, 'w') as f: 112 | # f.write("PredNet MSE: %f\n" % MSE_PredNet) 113 | # f.write("Previous Frame MSE: %f" % MSE_previous) 114 | 115 | # Plot some predictions 116 | if prednet.data_format == 'channels_first': 117 | X_groundTruth_list = [np.transpose(batch_img, (0, 2, 3, 1)) for batch_img in X_groundTruth_list] 118 | X_predict_list = [np.transpose(batch_img, (0, 2, 3, 1)) for batch_img in X_predict_list] 119 | assert len(X_groundTruth_list) == len(X_predict_list) == args.num_timeSteps 120 | timesteps = args.num_timeSteps 121 | total_num = X_groundTruth_list[0].shape[0] 122 | height = X_predict_list[0].shape[1] 123 | width = X_predict_list[0].shape[2] 124 | 125 | n_plot = args.num_plot 126 | if n_plot > total_num: 127 | n_plot = total_num 128 | aspect_ratio = float(height) / width 129 | plt.figure(figsize = (timesteps, (2 * aspect_ratio))) 130 | gs = gridspec.GridSpec(2, timesteps) 131 | gs.update(wspace = 0., hspace = 0.) 132 | plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/') 133 | if not os.path.exists(plot_save_dir): 134 | os.mkdir(plot_save_dir) 135 | plot_idx = np.random.permutation(total_num)[:n_plot] 136 | for i in plot_idx: 137 | for t in range(timesteps): 138 | ## plot the ground truth. 139 | plt.subplot(gs[t]) 140 | plt.imshow(X_groundTruth_list[t][i, ...], interpolation = 'none') 141 | plt.tick_params(axis = 'both', which = 'both', bottom = 'off', top = 'off', left = 'off', right = 'off', labelbottom = 'off', labelleft = 'off') 142 | if t == 0: 143 | plt.ylabel('Actual', fontsize = 10) 144 | 145 | ## plot the predictions. 146 | plt.subplot(gs[t + timesteps]) 147 | plt.imshow(X_predict_list[t][i, ...], interpolation = 'none') 148 | plt.tick_params(axis = 'both', which = 'both', bottom = 'off', top = 'off', left = 'off', right = 'off', labelbottom = 'off', labelleft = 'off') 149 | if t == 0: 150 | plt.ylabel('Predicted', fontsize = 10) 151 | 152 | plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png') 153 | plt.clf() 154 | print('The plots are saved in "%s"! Have a nice day!' % plot_save_dir) 155 | 156 | 157 | def checkpoint_loader(checkpoint_file): 158 | '''load the checkpoint for weights of PredNet.''' 159 | print('Loading...', end = '') 160 | checkpoint = torch.load(checkpoint_file) 161 | print('Done.') 162 | return checkpoint 163 | 164 | def load_pretrained_weights(model, state_dict_file): 165 | '''直接使用从原作者提供的Keras版本的预训练好的PredNet模型中拿过来的参数''' 166 | model = model.load_state_dict(torch.load(state_dict_file)) 167 | print('weights loaded!') 168 | return model 169 | 170 | if __name__ == '__main__': 171 | args = arg_parse() 172 | print_args(args) 173 | 174 | n_channels = args.n_channels 175 | img_height = args.img_height 176 | img_width = args.img_width 177 | 178 | # stack_sizes = eval(args.stack_sizes) 179 | # R_stack_sizes = eval(args.R_stack_sizes) 180 | # A_filter_sizes = eval(args.A_filter_sizes) 181 | # Ahat_filter_sizes = eval(args.Ahat_filter_sizes) 182 | # R_filter_sizes = eval(args.R_filter_sizes) 183 | 184 | stack_sizes = (n_channels, 48, 96, 192) 185 | R_stack_sizes = stack_sizes 186 | A_filter_sizes = (3, 3, 3) 187 | Ahat_filter_sizes = (3, 3, 3, 3) 188 | R_filter_sizes = (3, 3, 3, 3) 189 | 190 | prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes, 191 | output_mode = 'prediction', data_format = args.data_format, return_sequences = True) 192 | print(prednet) 193 | prednet.cuda() 194 | 195 | # print('\n'.join(['%s:%s' % item for item in prednet.__dict__.items()])) 196 | # print(type(prednet.state_dict())) # 197 | # for k, v in prednet.state_dict().items(): 198 | # print(k, v.size()) 199 | 200 | ## 使用自己训练的参数 201 | checkpoint_file = args.checkpoint_file 202 | try: 203 | checkpoint = checkpoint_loader(checkpoint_file) 204 | except Exception: 205 | raise(RuntimeError('Cannot load the checkpoint file named %s!' % checkpoint_file)) 206 | state_dict = checkpoint['state_dict'] 207 | prednet.load_state_dict(state_dict) 208 | 209 | ## 直接使用作者提供的预训练参数 210 | # state_dict_file = './model_data_keras2/preTrained_weights_forPyTorch.pkl' 211 | # # prednet = load_pretrained_weights(prednet, state_dict_file) # 这种不work... why? 212 | # prednet.load_state_dict(torch.load(state_dict_file)) 213 | 214 | assert args.mode == 'evaluate' 215 | evaluate(prednet, args) 216 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # usage: 4 | # ./evaluate.sh 5 | 6 | echo "Evaluate..." 7 | mode='evaluate' 8 | 9 | # @200.121 10 | DATA_DIR='/media/sdb1/chenrui/kitti_data/h5/' 11 | # Where results (prediction plots and evaluation file) will be saved. 12 | RESULTS_SAVE_DIR='./kitti_results/' 13 | checkpoint_file='./checkpoint/checkpoint_epoch1_trLoss1342.3278.pkl' # load weights from checkpoint file for evaluating. 14 | 15 | batch_size=10 16 | num_plot=40 # how many images to plot. 17 | 18 | # number of timesteps used for sequences in evaluating 19 | num_timeSteps=10 20 | 21 | workers=4 22 | shuffle=false 23 | 24 | data_format='channels_first' 25 | n_channels=3 26 | img_height=128 27 | img_width=160 28 | 29 | CUDA_VISIBLE_DEVICES=2 python evaluate.py \ 30 | --mode ${mode} \ 31 | --dataPath ${DATA_DIR} \ 32 | --resultsPath ${RESULTS_SAVE_DIR} \ 33 | --checkpoint_file ${checkpoint_file} \ 34 | --batch_size ${batch_size} \ 35 | --num_plot ${num_plot} \ 36 | --num_timeSteps ${num_timeSteps} \ 37 | --workers ${workers} \ 38 | --shuffle ${shuffle} \ 39 | --data_format ${data_format} \ 40 | --n_channels ${n_channels} \ 41 | --img_height ${img_height} \ 42 | --img_width ${img_width} 43 | # --stack_sizes ${stack_sizes} \ 44 | # --R_stack_sizes ${R_stack_sizes} \ 45 | # --A_filter_sizes ${A_filter_sizes} \ 46 | # --Ahat_filter_sizes ${Ahat_filter_sizes} \ 47 | # --R_filter_sizes ${R_filter_sizes} \ -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/0/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/0/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/1/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/1/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2.5/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2.5/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/2/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/2/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/3/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/3/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/4/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/4/plot_7.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_0.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_1.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_2.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_3.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_4.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_5.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_6.png -------------------------------------------------------------------------------- /kitti_results/prediction_plots/use_pretrained_weights/plot_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/kitti_results/prediction_plots/use_pretrained_weights/plot_7.png -------------------------------------------------------------------------------- /load_weights.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | '''将以hdf5形式保存的原Keras版本的PredNet模型的参数加载到zcr复现的pytorch版本的模型中.''' 4 | 5 | import os 6 | import numpy as np 7 | import h5py 8 | 9 | import torch 10 | # from torch.autograd import Variable 11 | 12 | 13 | weights_file = './model_data_keras2/prednet_kitti_weights.hdf5' 14 | weights_f = h5py.File(weights_file, 'r') 15 | 16 | pred_weights = weights_f['model_weights']['pred_net_1']['pred_net_1'] # contains 23 item: 4x4(i,f,c,o for 4 layers) + 4(Ahat for 4 layers) + 3(A for 4 layers) 17 | 18 | keras_items = ['bias:0', 'kernel:0'] 19 | pytorch_items = ['weight', 'bias'] 20 | 21 | keras_modules = ['a', 'ahat', 'c', 'f', 'i', 'o'] 22 | keras_modules = ['layer_' + m + '_' + str(i) for m in keras_modules for i in range(4)] 23 | keras_modules.remove('layer_a_3') 24 | assert len(keras_modules) == 4 * 4 + 4 + 3 25 | 26 | pytorch_modules_1 = ['A', 'Ahat'] 27 | pytorch_modules_2 = ['c', 'f', 'i', 'o'] 28 | pytorch_modules_1 = [m + '.' + str(2 * i) + '.' + item for m in pytorch_modules_1 for i in range(4) for item in pytorch_items] 29 | pytorch_modules_1.remove('A.6.weight') 30 | pytorch_modules_1.remove('A.6.bias') 31 | pytorch_modules_2 = [m + '.' + str(i) + '.' + item for m in pytorch_modules_2 for i in range(4) for item in pytorch_items] 32 | pytorch_modules = pytorch_modules_1 + pytorch_modules_2 33 | assert len(pytorch_modules) == (4 * 4 + 4 + 3) * 2 34 | 35 | weight_dict = dict() 36 | 37 | # 从h5文件加载过来的是类型的权重, 需要将其转换为cuda.Tensor 38 | for i in range(len(keras_modules)): 39 | weight_dict[pytorch_modules[i * 2 + 1]] = pred_weights[keras_modules[i]]['bias:0'][:] 40 | # weight_dict[pytorch_modules[i * 2 + 1]] = pred_weights[keras_modules[i]]['bias:0'] 41 | weight_dict[pytorch_modules[i * 2]] = np.transpose(pred_weights[keras_modules[i]]['kernel:0'][:], (3, 2, 1, 0)) 42 | # weight_dict[pytorch_modules[i * 2]] = pred_weights[keras_modules[i]]['kernel:0'] 43 | 44 | for k, v in weight_dict.items(): 45 | # print(k, v) 46 | # weight_dict[k] = Variable(torch.from_numpy(v).float().cuda()) 47 | weight_dict[k] = torch.from_numpy(v).float().cuda() 48 | 49 | fileName = './model_data_keras2/preTrained_weights_forPyTorch.pkl' 50 | weights_gift_from_keras = torch.save(weight_dict, fileName) -------------------------------------------------------------------------------- /model_data_keras2/prednet_kitti_weights.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zcrwind/PredNet_pytorch/ce1fd5d32035b1fcee5574bbab55fc2ea41cc4d6/model_data_keras2/prednet_kitti_weights.hdf5 -------------------------------------------------------------------------------- /prednet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | PredNet in PyTorch. 5 | ''' 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | 15 | def hard_sigmoid(x): 16 | ''' 17 | - hard sigmoid function by zcr. 18 | - Computes element-wise hard sigmoid of x. 19 | - what is hard sigmoid? 20 | Segment-wise linear approximation of sigmoid. Faster than sigmoid. 21 | Returns 0. if x < -2.5, 1. if x > 2.5. In -2.5 <= x <= 2.5, returns 0.2 * x + 0.5. 22 | - See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 23 | ''' 24 | slope = 0.2 25 | shift = 0.5 26 | x = (slope * x) + shift 27 | x = F.threshold(-x, -1, -1) 28 | x = F.threshold(-x, 0, 0) 29 | return x 30 | 31 | def get_activationFunc(act_str): 32 | act = act_str.lower() 33 | if act == 'relu': 34 | # return nn.ReLU(True) 35 | return nn.ReLU() 36 | elif act == 'tanh': 37 | # return F.tanh 38 | return nn.Tanh() 39 | # elif act == 'hard_sigmoid': 40 | # return hard_sigmoid 41 | else: 42 | raise(RuntimeError('cannot obtain the activation function named %s' % act_str)) 43 | 44 | def batch_flatten(x): 45 | ''' 46 | equal to the `batch_flatten` in keras. 47 | x is a Variable in pytorch 48 | ''' 49 | shape = [*x.size()] 50 | dim = np.prod(shape[1:]) 51 | dim = int(dim) # 不加这步的话, dim是类型, 不能在view中用. 加上这步转成类型. 52 | return x.view(-1, dim) 53 | 54 | 55 | 56 | class PredNet(nn.Module): 57 | """ 58 | PredNet realized by zcr. 59 | 60 | Args: 61 | stack_sizes: 62 | - Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture. 63 | - Length of stack_size (i.e. len(stack_size) and we use `num_layers` to denote it) is the number of layers in the architecture. 64 | - First element is the number of channels in the input. 65 | - e.g., (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and 66 | has 16 and 32 channels in the second and third layers, respectively. 67 | - 下标为(lay + 1)的值即为pytorch中第lay个卷积层的out_channels参数. 例如上述16对应到lay 0层(即输入层)的A和Ahat的out_channels是16. 68 | R_stack_sizes: 69 | - Number of channels in the representation (R) modules. 70 | - Length must equal length of stack_sizes, but the number of channels per layer can be different. 71 | - 即pytorch中卷积层的out_channels参数. 72 | A_filter_sizes: 73 | - Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image)) 74 | - Has length of len(stack_sizes) - 1. 75 | - e.g., (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of 76 | the errors (E) from the layer below (followed by max-pooling) 77 | - 即pytorch中卷积层的kernel_size. 78 | Ahat_filter_sizes: 79 | - Filter sizes for the prediction (Ahat) modules. 80 | - Has length equal to length of stack_sizes. 81 | - e.g., (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution 82 | of the representation (R) modules at each layer. 83 | - 即pytorch中卷积层的kernel_size. 84 | R_filter_sizes: 85 | - Filter sizes for the representation (R) modules. 86 | - Has length equal to length of stack_sizes. 87 | - Corresponds to the filter sizes for all convolutions in the LSTM. 88 | - 即pytorch中卷积层的kernel_size. 89 | pixel_max: 90 | - The maximum pixel value. 91 | - Used to clip the pixel-layer prediction. 92 | error_activation: 93 | - Activation function for the error (E) units. 94 | A_activation: 95 | - Activation function for the target (A) and prediction (A_hat) units. 96 | LSTM_activation: 97 | - Activation function for the cell and hidden states of the LSTM. 98 | LSTM_inner_activation: 99 | - Activation function for the gates in the LSTM. 100 | output_mode: 101 | - Either 'error', 'prediction', 'all' or layer specification (e.g., R2, see below). 102 | - Controls what is outputted by the PredNet. 103 | - if 'error': 104 | The mean response of the error (E) units of each layer will be outputted. 105 | That is, the output shape will be (batch_size, num_layers). 106 | - if 'prediction': 107 | The frame prediction will be outputted. 108 | - if 'all': 109 | The output will be the frame prediction concatenated with the mean layer errors. 110 | The frame prediction is flattened before concatenation. 111 | Note that nomenclature of 'all' means all TYPE of the output (i.e., `error` and `prediction`), but should not be confused with returning all of the layers of the model. 112 | - For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number. 113 | e.g., to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'. 114 | The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively. 115 | extrap_start_time: 116 | - Time step for which model will start extrapolating. 117 | - Starting at this time step, the prediction from the previous time step will be treated as the "actual" 118 | data_format: 119 | - 'channels_first': (channel, Height, Width) 120 | - 'channels_last' : (Height, Width, channel) 121 | 122 | """ 123 | def __init__(self, stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes, 124 | pixel_max = 1.0, error_activation = 'relu', A_activation = 'relu', LSTM_activation = 'tanh', 125 | LSTM_inner_activation = 'hard_sigmoid', output_mode = 'error', 126 | extrap_start_time = None, data_format = 'channels_last', return_sequences = False): 127 | super(PredNet, self).__init__() 128 | self.stack_sizes = stack_sizes 129 | self.num_layers = len(stack_sizes) 130 | assert len(R_stack_sizes) == self.num_layers 131 | self.R_stack_sizes = R_stack_sizes 132 | assert len(A_filter_sizes) == self.num_layers - 1 133 | self.A_filter_sizes = A_filter_sizes 134 | assert len(Ahat_filter_sizes) == self.num_layers 135 | self.Ahat_filter_sizes = Ahat_filter_sizes 136 | assert len(R_filter_sizes) == self.num_layers 137 | self.R_filter_sizes = R_filter_sizes 138 | 139 | self.pixel_max = pixel_max 140 | self.error_activation = error_activation 141 | self.A_activation = A_activation 142 | self.LSTM_activation = LSTM_activation 143 | self.LSTM_inner_activation = LSTM_inner_activation 144 | 145 | default_output_modes = ['prediction', 'error', 'all'] 146 | layer_output_modes = [layer + str(n) for n in range(self.num_layers) for layer in ['R', 'E', 'A', 'Ahat']] 147 | assert output_mode in default_output_modes + layer_output_modes 148 | self.output_mode = output_mode 149 | if self.output_mode in layer_output_modes: 150 | self.output_layer_type = self.output_mode[:-1] 151 | self.output_layer_NO = int(self.output_mode[-1]) # suppose the number of layers is < 10 152 | else: 153 | self.output_layer_type = None 154 | self.output_layer_NO = None 155 | 156 | self.extrap_start_time = extrap_start_time 157 | assert data_format in ['channels_first', 'channels_last'] 158 | self.data_format = data_format 159 | if self.data_format == 'channels_first': 160 | self.channel_axis = -3 161 | self.row_axis = -2 162 | self.col_axis = -1 163 | else: 164 | self.channel_axis = -1 165 | self.row_axis = -3 166 | self.col_axis = -2 167 | 168 | self.return_sequences = return_sequences 169 | 170 | self.make_layers() 171 | 172 | 173 | def get_initial_states(self, input_shape): 174 | ''' 175 | input_shape is like: (batch_size, timeSteps, Height, Width, 3) 176 | or: (batch_size, timeSteps, 3, Height, Width) 177 | ''' 178 | init_height = input_shape[self.row_axis] # equal to `init_nb_rows` in original version 179 | init_width = input_shape[self.col_axis] # equal to `init_nb_cols` in original version 180 | 181 | base_initial_state = np.zeros(input_shape) 182 | non_channel_axis = -1 if self.data_format == 'channels_first' else -2 183 | for _ in range(2): 184 | base_initial_state = np.sum(base_initial_state, axis = non_channel_axis) 185 | base_initial_state = np.sum(base_initial_state, axis = 1) # (batch_size, 3) 186 | 187 | initial_states = [] 188 | states_to_pass = ['R', 'c', 'E'] # R is `representation`, c is Cell state in LSTM, E is `error`. 189 | layerNum_to_pass = {sta: self.num_layers for sta in states_to_pass} 190 | if self.extrap_start_time is not None: 191 | states_to_pass.append('Ahat') # pass prediction in states so can use as actual for t+1 when extrapolating 192 | layerNum_to_pass['Ahat'] = 1 193 | 194 | for sta in states_to_pass: 195 | for lay in range(layerNum_to_pass[sta]): 196 | downSample_factor = 2 ** lay # 下采样缩放因子 197 | row = init_height // downSample_factor 198 | col = init_width // downSample_factor 199 | if sta in ['R', 'c']: 200 | stack_size = self.R_stack_sizes[lay] 201 | elif sta == 'E': 202 | stack_size = self.stack_sizes[lay] * 2 203 | elif sta == 'Ahat': 204 | stack_size = self.stack_sizes[lay] 205 | output_size = stack_size * row * col # flattened size 206 | reducer = np.zeros((input_shape[self.channel_axis], output_size)) # (3, output_size) 207 | initial_state = np.dot(base_initial_state, reducer) # (batch_size, output_size) 208 | 209 | if self.data_format == 'channels_first': 210 | output_shape = (-1, stack_size, row, col) 211 | else: 212 | output_shape = (-1, row, col, stack_size) 213 | # initial_state = torch.from_numpy(np.reshape(initial_state, output_shape)).float().cuda() 214 | initial_state = Variable(torch.from_numpy(np.reshape(initial_state, output_shape)).float().cuda(), requires_grad = True) 215 | initial_states += [initial_state] 216 | 217 | if self.extrap_start_time is not None: 218 | # initial_states += [torch.IntTensor(1).zero_().cuda()] # the last state will correspond to the current timestep 219 | initial_states += [Variable(torch.IntTensor(1).zero_().cuda())] # the last state will correspond to the current timestep 220 | return initial_states 221 | 222 | 223 | # def compute_output_shape(self, input_shape): 224 | # if self.output_mode == 'prediction': 225 | # out_shape = input_shape[2:] 226 | # elif self.output_mode == 'error': # error模式输出为各层误差,每层一个标量 227 | # out_shape = (self.num_layers,) 228 | # elif self.output_mode == 'all': 229 | # out_shape = (np.prod(input_shape[2:]) + self.num_layers,) # np.prod 元素逐个相乘 230 | # else: 231 | # if self.output_layer_type == 'R': 232 | # stack_str = 'R_stack_sizes' 233 | # else: 234 | # stack_str = 'stack_sizes' 235 | 236 | # if self.output_layer_type == 'E': 237 | # stack_multi = 2 238 | # else: 239 | # stack_multi = 1 240 | 241 | # out_stack_size = stack_multi * getattr(self, stack_str)[self.output_layer_NO] 242 | # layer_out_row = input_shape[self.row_axis] / (2 ** self.output_layer_NO) 243 | # layer_out_col = input_shape[self.col_axis] / (2 ** self.output_layer_NO) 244 | # if self.data_format == 'channels_first': 245 | # out_shape = (out_stack_size, layer_out_row, layer_out_col) 246 | # else: 247 | # out_shape = (layer_out_row, layer_out_col, out_stack_size) 248 | 249 | # if self.return_sequences: 250 | # return (input_shape[0], input_shape[1]) + out_shape # input_shape[1] is the timesteps 251 | # else: 252 | # return (input_shape[0],) + out_shape 253 | 254 | 255 | def isNotTopestLayer(self, layerIndex): 256 | '''judge if the layerIndex is not the topest layer.''' 257 | if layerIndex < self.num_layers - 1: 258 | return True 259 | else: 260 | return False 261 | 262 | 263 | def make_layers(self): 264 | ''' 265 | equal to the `build` method in original version. 266 | ''' 267 | # i: input, f: forget, c: cell, o: output 268 | self.conv_layers = {item: [] for item in ['i', 'f', 'c', 'o', 'A', 'Ahat']} 269 | lstm_list = ['i', 'f', 'c', 'o'] 270 | 271 | for item in sorted(self.conv_layers.keys()): 272 | for lay in range(self.num_layers): 273 | downSample_factor = 2 ** lay # 下采样缩放因子 274 | if item == 'Ahat': 275 | in_channels = self.R_stack_sizes[lay] # 因为Ahat是对R的输出进行卷积, 所以输入Ahat的channel数就是相同层中R的输出channel数. 276 | self.conv_layers['Ahat'].append(nn.Conv2d(in_channels = in_channels, 277 | out_channels = self.stack_sizes[lay], 278 | kernel_size = self.Ahat_filter_sizes[lay], 279 | stride = (1, 1), 280 | padding = int((self.Ahat_filter_sizes[lay] - 1) / 2) # the `SAME` mode (i.e.,(kernel_size - 1) / 2) 281 | )) 282 | act = 'relu' if lay == 0 else self.A_activation 283 | self.conv_layers['Ahat'].append(get_activationFunc(act)) 284 | 285 | elif item == 'A': 286 | if self.isNotTopestLayer(lay): # 这里只是控制一下层数(比其他如Ahat等少一层) 287 | # NOTE: 这里是从第二层(lay = 1)开始构建A的(因为整个网络的最低一层(layer0)的A就是原始图像(可以将layer0的A视为一个`恒等层`, 即输入图像, 输出原封不动的图像)) 288 | in_channels = self.R_stack_sizes[lay] * 2 # A卷积层输入特征数(in_channels)是对应层E的特征数,E包含(Ahat-A)和(A-Ahat)两部分,故x2. [从paper的Fig.1左图来看, E是Ahat的输出和A进行相减, 之后拼接.] 289 | self.conv_layers['A'].append(nn.Conv2d(in_channels = in_channels, 290 | out_channels = self.stack_sizes[lay + 1], 291 | kernel_size = self.A_filter_sizes[lay], 292 | stride = (1, 1), 293 | padding = int((self.A_filter_sizes[lay] - 1) / 2) # the `SAME` mode 294 | )) 295 | self.conv_layers['A'].append(get_activationFunc(self.A_activation)) 296 | 297 | elif item in lstm_list: # 构建R模块 298 | # R的输入特征数(in_channels): 同层的E、同层上一时刻的R(即R_t-1)、 同时刻上层的R(即R_l+1)这三者的特征数之和. 299 | # 如果该R模块位于顶层, 则没有来自上层的R. 其中: 300 | # - stack_sizes[lay] * 2 表示的是同层E的channel数 (因为E是将同层的A和Ahat在channel这一维度上拼接得到的, 故x2) 301 | # - R_stack_sizes[lay] 表示的是同层上一时刻的R的channel数 302 | # - R_stack_sizes[lay + 1] 表示的是同时刻上层的R的channel数 303 | in_channels = self.stack_sizes[lay] * 2 + self.R_stack_sizes[lay] 304 | if self.isNotTopestLayer(lay): 305 | in_channels += self.R_stack_sizes[lay + 1] 306 | # for j in lstm_list: # 严重的bug! 赶紧注释掉...下面的向前缩进4个空格... 307 | # LSTM中的i,f,c,o的非线性激活函数层放在forward中实现. (因为这里i,f,o要用hard_sigmoid函数, Keras中LSTM默认就是hard_sigmoid, 但是pytorch中需自己实现) 308 | # act = self.LSTM_activation if j == 'c' else self.LSTM_inner_activation 309 | # act = get_activationFunc(act) 310 | self.conv_layers[item].append(nn.Conv2d(in_channels = in_channels, 311 | out_channels = self.R_stack_sizes[lay], 312 | kernel_size = self.R_filter_sizes[lay], 313 | stride = (1, 1), 314 | padding = int((self.R_filter_sizes[lay] - 1) / 2) # the `SAME` mode 315 | )) 316 | 317 | for name, layerList in self.conv_layers.items(): 318 | self.conv_layers[name] = nn.ModuleList(layerList) 319 | setattr(self, name, self.conv_layers[name]) 320 | 321 | # see the source code in: 322 | # [PyTorch]: http://pytorch.org/docs/master/_modules/torch/nn/modules/upsampling.html 323 | # [Keras ]: keras-master/keras/layers/convolution.py/`class UpSampling2D(Layer)` 324 | # self.upSample = nn.Upsample(size = (2, 2), mode = 'nearest') # 是错误的! pytorch中的scale_factor参数对应到keras中的size参数. 325 | self.upSample = nn.Upsample(scale_factor = 2, mode = 'nearest') 326 | # see the source code in: 327 | # [PyTorch]: http://pytorch.org/docs/master/_modules/torch/nn/modules/pooling.html#MaxPool2d 328 | # [Keras ]: keras-master/keras/layers/pooling.py/`` 329 | # `pool_size` in Keras is equal to `kernel_size` in pytorch. 330 | # [TODO] padding here is not very clear. Is `0` here is the `SAME` mode in Keras? 331 | self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0) 332 | 333 | 334 | def step(self, A, states): 335 | ''' 336 | 这个step函数是和原代码中的`step`函数是等价的. 是PredNet的核心逻辑所在. 337 | 类比于标准LSTM的实现方式, 这个step函数的角色相当于LSTMCell, 而下面的forward函数相当于LSTM类. 338 | 339 | Args: 340 | A: 4D tensor with the shape of (batch_size, 3, Height, Width). 就是从A_withTimeStep按照时间步抽取出来的数据. 341 | states 和 `forward`函数的`initial_states`的形式完全相同, 只是后者是初始化的PredNet状态, 而这里的states是在timesteps内运算时的PredNet参数. 342 | ''' 343 | n = self.num_layers 344 | R_current = states[ : (n)] 345 | c_current = states[ (n):(2 * n)] 346 | E_current = states[(2 * n):(3 * n)] 347 | 348 | if self.extrap_start_time is not None: 349 | timestep = states[-1] 350 | if timestep >= self.t_extrap: # if past self.extrap_start_time, the previous prediction will be treated as the actual. 351 | A = states[-2] 352 | else: 353 | A = A 354 | 355 | R_list = [] 356 | c_list = [] 357 | E_list = [] 358 | 359 | # Update R units starting from the top. 360 | for lay in reversed(range(self.num_layers)): 361 | inputs = [R_current[lay], E_current[lay]] # 如果是顶层, R_l的输入只有两个: E_l^t, R_l^(t-1). 即没有高层的R模块的输入项. 362 | if self.isNotTopestLayer(lay): # 如果不是顶层,R_l的输入就有三个: E_l^t, R_l^(t-1), R_(l+1)^t. R_up即为R_(l+1)^t 363 | inputs.append(R_up) 364 | 365 | inputs = torch.cat(inputs, dim = self.channel_axis) 366 | if not isinstance(inputs, Variable): # 第一个时间步内inputs还是Tensor类型, 但是过一遍网络之后, 以后的时间步中就都是Variable类型了. 367 | inputs = Variable(inputs, requires_grad = True) 368 | 369 | # print(lay, type(inputs), inputs.size()) # 正确的情况下, 举例如下: 370 | # lay3: torch.Size([8, 576, 16, 20]) [576 = 384(E_l^t) + 192(R_l^(t-1))] 371 | # lay2: torch.Size([8, 480, 32, 40]) [480 = 192(E_l^t) + 96(R_l^(t-1)) + 192(R_(l+1)^t)] 372 | # lay1: torch.Size([8, 240, 64, 80]) [240 = 96(E_l^t) + 48(R_l^(t-1)) + 96(R_(l+1)^t)] 373 | # lay0: torch.Size([8, 57, 160, 128]) [ 57 = 6(E_l^t) + 3(R_l^(t-1)) + 48(R_(l+1)^t)] 374 | 375 | # see https://github.com/huggingface/torchMoji/blob/master/torchmoji/lstm.py 376 | in_gate = hard_sigmoid(self.conv_layers['i'][lay](inputs)) 377 | forget_gate = hard_sigmoid(self.conv_layers['f'][lay](inputs)) 378 | cell_gate = F.tanh(self.conv_layers['c'][lay](inputs)) 379 | out_gate = hard_sigmoid(self.conv_layers['o'][lay](inputs)) 380 | 381 | # print(forget_gate.size()) # torch.Size([8, 192, 16, 20]) 382 | # print(c_current[lay].size()) # torch.Size([8, 192, 16, 20]) 383 | # print(in_gate.size()) # torch.Size([8, 192, 16, 20]) 384 | # print(cell_gate.size()) # torch.Size([8, 192, 16, 20]) 385 | # print(type(forget_gate)) # 386 | # print(type(c_current[lay])) # 387 | # print(type(Variable(c_current[lay]))) # 388 | # print(type(in_gate)) # 389 | # print(type(cell_gate)) # 390 | if not isinstance(c_current[lay], Variable): 391 | c_current[lay] = Variable(c_current[lay], requires_grad = True) 392 | c_next = (forget_gate * c_current[lay]) + (in_gate * cell_gate) # 对应元素相乘 393 | R_next = out_gate * F.tanh(c_next) # `R_next` here相当于标准LSTM中的hidden state. 这个就是视频的表征. 394 | 395 | c_list.insert(0, c_next) 396 | R_list.insert(0, R_next) 397 | 398 | if lay > 0: 399 | # R_up = self.upSample(R_next).data # 注意: 这里出来的是Variable, 上面要append到inputs列表里的都是FloatTensor, 所以这里需要变成Tensor形式, 即加个`.data` 400 | R_up = self.upSample(R_next) # NOTE: 这个就是困扰好久, 导致loss.backward()报错的原因: torch.cat()中将Tensor和Variable混用导致的错误! 401 | # print(R_up.size()) # lay3: torch.Size([8, 192, 32, 40]) 402 | 403 | 404 | # Update feedforward path starting from the bottom. 405 | for lay in range(self.num_layers): 406 | Ahat = self.conv_layers['Ahat'][2 * lay](R_list[lay]) # Ahat是R的卷积, 故将同层同时刻的R输入. 这里千万注意: 每个`lay`其实对应的是两个组件: 卷积层+非线性激活层, 所以这里需要用(2 * lay)来索引`lay`对应的卷积层, 用(2 * lay + 1)来索引`lay`对应的非线性激活函数层. 下面对A的处理也是一样. 407 | Ahat = self.conv_layers['Ahat'][2 * lay + 1](Ahat) # 勿忘非线性激活.下面对A的处理也是一样. 408 | if lay == 0: 409 | # Ahat = torch.min(Ahat, self.pixel_max) # 错误(keras中的表示方式) 410 | Ahat[Ahat > self.pixel_max] = self.pixel_max # passed through a saturating non-linearity set at the maximum pixel value 411 | frame_prediction = Ahat # 最低一层的Ahat即为预测输出帧图像 412 | # if self.output_mode == 'prediction': 413 | # break 414 | 415 | # print('&' * 10, lay) 416 | # print('Ahat', Ahat.size()) # torch.Size([batch_size, 3, 128, 160]) 417 | # print('A', A.size()) # 原来A0直接用的是从dataloader中加载出来的数据, 所以打印的是torch.Size([batch_size, 10, 3, 128, 160]), 这就是问题所在: dataloader返回的数据是(batch_size, timesteps, (image_shape)), 而实际上在RNN中用的是将每个时间步分开的. 现在将核心逻辑解耦出来形成`step`函数, A0就变成torch.Size([batch_size, 3, 128, 160])这个维度了. 418 | # print('&' * 20) 419 | 420 | # compute errors 421 | if self.error_activation.lower() == 'relu': 422 | E_up = F.relu(Ahat - A) 423 | E_down = F.relu(A - Ahat) 424 | elif self.error_activation.lower() == 'tanh': 425 | E_up = F.tanh(Ahat - A) 426 | E_down = F.tanh(A - Ahat) 427 | else: 428 | raise(RuntimeError('cannot obtain the activation function named %s' % self.error_activation)) 429 | 430 | E_list.append(torch.cat((E_up, E_down), dim = self.channel_axis)) 431 | 432 | # 如果是想要获取特定的层中特定模块的输出: 433 | if self.output_layer_NO == lay: 434 | if self.output_layer_type == 'A': 435 | output = A 436 | elif self.output_layer_type == 'Ahat': 437 | output = Ahat 438 | elif self.output_layer_type == 'R': 439 | output = R_list[lay] 440 | elif self.output_layer_type == 'E': 441 | output = E_list[lay] 442 | 443 | if self.isNotTopestLayer(lay): 444 | A = self.conv_layers['A'][2 * lay](E_list[lay]) # 对E进行卷积+池化之后, 得到同时刻上一层的A, 如果该层已经是最顶层了, 就不用了 445 | A = self.conv_layers['A'][2 * lay + 1](A) # 勿忘非线性激活. 446 | A = self.pool(A) # target for next layer 447 | 448 | 449 | if self.output_layer_type is None: 450 | if self.output_mode == 'prediction': 451 | output = frame_prediction 452 | else: 453 | for lay in range(self.num_layers): 454 | layer_error = torch.mean(batch_flatten(E_list[lay]), dim = -1, keepdim = True) # batch_flatten函数是zcr依照Kears中同名函数实现的. 第0维是batch_size维度, 将除此维度之外的维度拉平 455 | all_error = layer_error if lay == 0 else torch.cat((all_error, layer_error), dim = -1) 456 | if self.output_mode == 'error': 457 | output = all_error 458 | else: 459 | output = torch.cat((batch_flatten(frame_prediction), all_error), dim = -1) 460 | 461 | states = R_list + c_list + E_list 462 | if self.extrap_start_time is not None: 463 | states += [frame_prediction, (timestep + 1)] 464 | return output, states 465 | 466 | 467 | def forward(self, A0_withTimeStep, initial_states): 468 | ''' 469 | A0_withTimeStep is the input from dataloader. Its shape is: (batch_size, timesteps, 3, Height, Width). 470 | 说白了, 这个A0_withTimeStep就是dataloader加载出来的原始图像, 即最底层(layer 0)的A, 只不过在batch_size和timestep两个维度扩展了. 471 | initial_states is a list of pytorch-tensors. 这个states参数其实就是初始状态, 因为这个forword函数本身是不被循环执行的. 472 | 473 | NOTE: 这个foward函数目的是为了实现原Keras版本的 `step` 函数, 但是和后者不太一样. 因为原代码的PredNet类是 474 | 继承了Keras中的`Recurrent`类, 所以貌似该父类就实现了将dataloader(即原代码中的SequenceGenerator)加载 475 | 的数据(batch_size, timesteps, 3, H, W)分解为(batch_size, 3, H, W), 然后循环timesteps次求解. 476 | 而这里的forward需要自己实现循环timesteps次. 这里的A的shape就是从dataloader中来的5D tensor (batch_size, timesteps, 3, Height, Width), 477 | 原代码中step函数的输入`x`的shape是4D tensor (batch_size, 3, Height, Width). 478 | ''' 479 | 480 | # 默认是batch_fist == True的, 即第一维是batch_size, 第二维是timesteps. 481 | A0_withTimeStep = A0_withTimeStep.transpose(0, 1) # (b, t, c, h, w) -> (t, b, c, h, w) 482 | 483 | num_timesteps = A0_withTimeStep.size()[0] 484 | 485 | hidden_states = initial_states # 赋值为hidden_states是为了在下面的循环中可以无痛使用 486 | output_list = [] # output需要保留下来: `error`模式下需要按照layer和timestep进行加权得到最终的loss; `prediction`模式下需要输出每个时间步的预测图像(如timestep为10的话, 输出10个图像) 487 | for t in range(num_timesteps): 488 | ''' 489 | 原本的LSTM(或普通RNN)是要两重循环的: 490 | for lay in range(num_layers): 491 | for t in range(num_timesteps): 492 | pass 493 | 但是正如原Keras版本的代码中脚注部分说的那样: PredNet虽然设定了层数, 但其实实现的时候是用 494 | 一个超级层(`super layer`)实现, 即本身就是一层. 所以这里就没有for lay循环了. 495 | ''' 496 | A0 = A0_withTimeStep[t, ...] 497 | output, hidden_states = self.step(A0, hidden_states) 498 | output_list.append(output) 499 | # hidden_states 不需要保留,只需让其在时间步内进行`长江后浪推前浪`式的迭代即可. 500 | 501 | if self.output_mode == 'error': 502 | '''进行按照layer和timestep的加权. 不同于原代码中加Dense layer的方式, 这里加权操作可以直接写在PredNet模型里(就这个if语句里), 也可以将所有时间步中每层的error返回, 在main函数中进行计算. zcr选择后者(和原代码保持一致)''' 503 | # print(len(output_list)) # 10, 即timestep数 504 | # print('output: ', output_list) # 每个时间步的`error`是(batch_size, num_layer)的矩阵, 类型是Variable. [torch.cuda.FloatTensor of size 8x4 (GPU 0)] 根据这个来进行按照layer和timestep的加权, 即可实现loss的计算! (按照layer进行两种加权, 即可得到所谓的`L_0`和`L_all`的两类loss) 505 | # print('Got the `error` list with the length of len(timeSteps) and shape of each element in this list is: (batch_size, num_layer).') 506 | return output_list 507 | elif self.output_mode == 'prediction': 508 | return output_list # 此时的output_list是timestep个预测帧图像 509 | elif self.output_mode == 'all': 510 | pass 511 | else: 512 | raise(RuntimeError('Kidding? Unknown output mode!')) 513 | 514 | 515 | if __name__ == '__main__': 516 | n_channels = 3 517 | img_height = 128 518 | img_width = 160 519 | 520 | stack_sizes = (n_channels, 48, 96, 192) 521 | R_stack_sizes = stack_sizes 522 | A_filter_sizes = (3, 3, 3) 523 | Ahat_filter_sizes = (3, 3, 3, 3) 524 | R_filter_sizes = (3, 3, 3, 3) 525 | 526 | prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes, 527 | output_mode = 'error', return_sequences = True) 528 | 529 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import numpy as np 6 | import argparse 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.optim import lr_scheduler 14 | 15 | # zcr lib 16 | from prednet import PredNet 17 | from data_utils import ZcrDataLoader 18 | 19 | # os.environ['CUDA_LAUNCH_BLOCKING'] = 1 20 | # torch.backends.cudnn.benchmark = True 21 | 22 | def arg_parse(): 23 | desc = "Video Frames Predicting Task via PredNet." 24 | parser = argparse.ArgumentParser(description = desc) 25 | 26 | parser.add_argument('--mode', default = 'train', type = str, 27 | help = 'train or evaluate (default: train)') 28 | parser.add_argument('--dataPath', default = '', type = str, metavar = 'PATH', 29 | help = 'path to video dataset (default: none)') 30 | parser.add_argument('--checkpoint_savePath', default = '', type = str, metavar = 'PATH', 31 | help = 'path for saving checkpoint file (default: none)') 32 | parser.add_argument('--epochs', default = 20, type = int, metavar='N', 33 | help = 'number of total epochs to run') 34 | parser.add_argument('--batch_size', default = 32, type = int, metavar = 'N', 35 | help = 'The size of batch') 36 | parser.add_argument('--optimizer', default = 'SGD', type = str, 37 | help = 'which optimizer to use') 38 | parser.add_argument('--lr', default = 0.01, type = float, 39 | metavar = 'LR', help = 'initial learning rate') 40 | parser.add_argument('--momentum', default = 0.9, type = float, 41 | help = 'momentum for SGD') 42 | parser.add_argument('--beta1', default = 0.9, type = float, 43 | help = 'beta1 in Adam optimizer') 44 | parser.add_argument('--beta2', default = 0.99, type = float, 45 | help = 'beta2 in Adam optimizer') 46 | parser.add_argument('--workers', default = 4, type = int, metavar = 'N', 47 | help = 'number of data loading workers (default: 4)') 48 | parser.add_argument('--checkpoint_file', default = '', type = str, 49 | help = 'path to checkpoint file for restrating (default: none)') 50 | parser.add_argument('--printCircle', default = 100, type = int, metavar = 'N', 51 | help = 'how many steps to print the loss information') 52 | parser.add_argument('--data_format', default = 'channels_last', type = str, 53 | help = '(c, h, w) or (h, w, c)?') 54 | parser.add_argument('--n_channels', default = 3, type = int, metavar = 'N', 55 | help = 'The number of input channels (default: 3)') 56 | parser.add_argument('--img_height', default = 128, type = int, metavar = 'N', 57 | help = 'The height of input frame (default: 128)') 58 | parser.add_argument('--img_width', default = 160, type = int, metavar = 'N', 59 | help = 'The width of input frame (default: 160)') 60 | # parser.add_argument('--stack_sizes', default = '', type = str, 61 | # help = 'Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.') 62 | # parser.add_argument('--R_stack_sizes', default = '', type = str, 63 | # help = 'Number of channels in the representation (R) modules.') 64 | # parser.add_argument('--A_filter_sizes', default = '', type = str, 65 | # help = 'Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image))') 66 | # parser.add_argument('--Ahat_filter_sizes', default = '', type = str, 67 | # help = 'Filter sizes for the prediction (Ahat) modules.') 68 | # parser.add_argument('--R_filter_sizes', default = '', type = str, 69 | # help = 'Filter sizes for the representation (R) modules.') 70 | parser.add_argument('--layer_loss_weightsMode', default = 'L_0', type = str, 71 | help = 'L_0 or L_all for loss weights in PredNet') 72 | parser.add_argument('--num_timeSteps', default = 10, type = int, metavar = 'N', 73 | help = 'number of timesteps used for sequences in training (default: 10)') 74 | parser.add_argument('--shuffle', default = True, type = bool, 75 | help = 'shuffle or not') 76 | 77 | args = parser.parse_args() 78 | return args 79 | 80 | def print_args(args): 81 | print('-' * 50) 82 | for arg, content in args.__dict__.items(): 83 | print("{}: {}".format(arg, content)) 84 | print('-' * 50) 85 | 86 | def train(model, args): 87 | '''Train PredNet on KITTI sequences''' 88 | 89 | # print('layer_loss_weightsMode: ', args.layer_loss_weightsMode) 90 | prednet = model 91 | # frame data files 92 | DATA_DIR = args.dataPath 93 | train_file = os.path.join(DATA_DIR, 'X_train.h5') 94 | train_sources = os.path.join(DATA_DIR, 'sources_train.h5') 95 | val_file = os.path.join(DATA_DIR, 'X_val.h5') 96 | val_sources = os.path.join(DATA_DIR, 'sources_val.h5') 97 | 98 | output_mode = 'error' 99 | sequence_start_mode = 'all' 100 | N_seq = None 101 | dataLoader = ZcrDataLoader(train_file, train_sources, output_mode, sequence_start_mode, N_seq, args).dataLoader() 102 | 103 | if prednet.data_format == 'channels_first': 104 | input_shape = (args.batch_size, args.num_timeSteps, n_channels, img_height, img_width) 105 | else: 106 | input_shape = (args.batch_size, args.num_timeSteps, img_height, img_width, n_channels) 107 | 108 | optimizer = torch.optim.Adam(prednet.parameters(), lr = args.lr) 109 | lr_maker = lr_scheduler.StepLR(optimizer = optimizer, step_size = 75, gamma = 0.1) # decay the lr every 50 epochs by a factor of 0.1 110 | 111 | printCircle = args.printCircle 112 | for e in range(args.epochs): 113 | tr_loss = 0.0 114 | sum_trainLoss_in_epoch = 0.0 115 | min_trainLoss_in_epoch = float('inf') 116 | startTime_epoch = time.time() 117 | lr_maker.step() 118 | 119 | initial_states = prednet.get_initial_states(input_shape) # 原网络貌似不是stateful的, 故这里再每个epoch开始时重新初始化(如果是stateful的, 则只在全部的epoch开始时初始化一次) 120 | states = initial_states 121 | for step, (frameGroup, target) in enumerate(dataLoader): 122 | # print(frameGroup) # [torch.FloatTensor of size 16x12x80x80] 123 | batch_frames = Variable(frameGroup.cuda()) 124 | batch_y = Variable(target.cuda()) 125 | output = prednet(batch_frames, states) 126 | 127 | # '''进行按照timestep和layer对error进行加权.''' 128 | ## 1. 按layer加权(巧妙利用广播. NOTE: 这里的error列表里的每个元素是Variable类型的矩阵, 需要转成numpy矩阵类型才可以用切片.) 129 | num_layer = len(stack_sizes) 130 | # weighting for each layer in final loss 131 | if args.layer_loss_weightsMode == 'L_0': # e.g., [1., 0., 0., 0.] 132 | layer_weights = np.array([0. for _ in range(num_layer)]) 133 | layer_weights[0] = 1. 134 | layer_weights = torch.from_numpy(layer_weights) 135 | # layer_weights = torch.from_numpy(np.array([1., 0., 0., 0.])) 136 | elif args.layer_loss_weightsMode == 'L_all': # e.g., [1., 1., 1., 1.] 137 | layer_weights = np.array([0.1 for _ in range(num_layer)]) 138 | layer_weights[0] = 1. 139 | layer_weights = torch.from_numpy(layer_weights) 140 | # layer_weights = torch.from_numpy(np.array([1., 0.1, 0.1, 0.1])) 141 | else: 142 | raise(RuntimeError('Unknown loss weighting mode! Please use `L_0` or `L_all`.')) 143 | # layer_weights = Variable(layer_weights.float().cuda(), requires_grad = False) # NOTE: layer_weights默认是DoubleTensor, 而下面的error是FloatTensor的Variable, 如果直接相乘会报错! 144 | layer_weights = Variable(layer_weights.float().cuda()) # NOTE: layer_weights默认是DoubleTensor, 而下面的error是FloatTensor的Variable, 如果直接相乘会报错! 145 | error_list = [batch_x_numLayer__error * layer_weights for batch_x_numLayer__error in output] # 利用广播实现加权 146 | 147 | ## 2. 按timestep进行加权. (paper: equally weight all timesteps except the first) 148 | num_timeSteps = args.num_timeSteps 149 | time_loss_weight = (1. / (num_timeSteps - 1)) 150 | time_loss_weight = Variable(torch.from_numpy(np.array([time_loss_weight])).float().cuda()) 151 | time_loss_weights = [time_loss_weight for _ in range(num_timeSteps - 1)] 152 | time_loss_weights.insert(0, Variable(torch.from_numpy(np.array([0.])).float().cuda())) 153 | 154 | error_list = [error_at_t.sum() for error_at_t in error_list] # 是一个Variable的列表 155 | total_error = error_list[0] * time_loss_weights[0] 156 | for err, time_weight in zip(error_list[1:], time_loss_weights[1:]): 157 | total_error = total_error + err * time_weight 158 | 159 | loss = total_error 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | # if (step + 1) == 2500: 165 | # zcr_state_dict = { 166 | # 'epoch' : (e + 1), 167 | # 'tr_loss' : 0, 168 | # 'state_dict': prednet.state_dict(), 169 | # 'optimizer' : optimizer.state_dict() 170 | # } 171 | # saveCheckpoint(zcr_state_dict) 172 | 173 | # print('epoch: [%3d/%3d] | step: [%4d/%4d] loss: %.4f' % ((e + 1), args.epochs, (step + 1), len(dataLoader), loss.data[0])) 174 | 175 | tr_loss += loss.data[0] 176 | sum_trainLoss_in_epoch += loss.data[0] 177 | if step % printCircle == (printCircle - 1): 178 | print('epoch: [%3d/%3d] | [%4d/%4d] loss: %.4f lr: %.5lf' % ((e + 1), args.epochs, (step + 1), len(dataLoader), tr_loss / printCircle, optimizer.param_groups[0]['lr'])) 179 | tr_loss = 0.0 180 | 181 | endTime_epoch = time.time() 182 | print('Time Consumed within an epoch: %.2f (s)' % (endTime_epoch - startTime_epoch)) 183 | 184 | if sum_trainLoss_in_epoch < min_trainLoss_in_epoch: 185 | min_trainLoss_in_epoch = sum_trainLoss_in_epoch 186 | zcr_state_dict = { 187 | 'epoch' : (e + 1), 188 | 'tr_loss' : min_trainLoss_in_epoch, 189 | 'state_dict': prednet.state_dict(), 190 | 'optimizer' : optimizer.state_dict() 191 | } 192 | saveCheckpoint(zcr_state_dict) 193 | 194 | 195 | def saveCheckpoint(zcr_state_dict, fileName = './checkpoint/checkpoint_newest.pkl'): 196 | '''save the checkpoint for both restarting and evaluating.''' 197 | tr_loss = '%.4f' % zcr_state_dict['tr_loss'] 198 | # val_loss = '%.4f' % zcr_state_dict['val_loss'] 199 | epoch = zcr_state_dict['epoch'] 200 | # fileName = './checkpoint/checkpoint_epoch' + str(epoch) + '_trLoss' + tr_loss + '_valLoss' + val_loss + '.pkl' 201 | fileName = '/media/sdb1/chenrui/checkpoint/PredNet/checkpoint_epoch' + str(epoch) + '_trLoss' + tr_loss + '.pkl' 202 | torch.save(zcr_state_dict, fileName) 203 | 204 | 205 | 206 | if __name__ == '__main__': 207 | args = arg_parse() 208 | print_args(args) 209 | 210 | # DATA_DIR = args.dataPath 211 | # data_file = os.path.join(DATA_DIR, 'X_test.h5') 212 | # source_file = os.path.join(DATA_DIR, 'sources_test.h5') 213 | # output_mode = 'error' 214 | # sequence_start_mode = 'all' 215 | # N_seq = None 216 | # dataLoader = ZcrDataLoader(data_file, source_file, output_mode, sequence_start_mode, N_seq, args).dataLoader() 217 | 218 | # images, target = next(iter(dataLoader)) 219 | # print(images) 220 | # print(target) 221 | 222 | n_channels = args.n_channels 223 | img_height = args.img_height 224 | img_width = args.img_width 225 | 226 | # stack_sizes = eval(args.stack_sizes) 227 | # R_stack_sizes = eval(args.R_stack_sizes) 228 | # A_filter_sizes = eval(args.A_filter_sizes) 229 | # Ahat_filter_sizes = eval(args.Ahat_filter_sizes) 230 | # R_filter_sizes = eval(args.R_filter_sizes) 231 | 232 | stack_sizes = (n_channels, 48, 96, 192) 233 | R_stack_sizes = stack_sizes 234 | A_filter_sizes = (3, 3, 3) 235 | Ahat_filter_sizes = (3, 3, 3, 3) 236 | R_filter_sizes = (3, 3, 3, 3) 237 | 238 | prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes, 239 | output_mode = 'error', data_format = args.data_format, return_sequences = True) 240 | print(prednet) 241 | prednet.cuda() 242 | 243 | assert args.mode == 'train' 244 | train(prednet, args) 245 | 246 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # usage: 4 | # ./train.sh 5 | 6 | echo "Train..." 7 | mode='train' 8 | 9 | # @200.121 10 | DATA_DIR='/media/sdb1/chenrui/kitti_data/h5/' 11 | checkpoint_savePath='./checkpoint/' 12 | checkpoint_file='./checkpoint/' # checkpoint file name for restarting. 13 | 14 | epochs=1 15 | batch_size=8 16 | optimizer='Adam' 17 | learning_rate=0.001 18 | momentum=0.9 19 | beta1=0.9 20 | beta2=0.99 21 | 22 | workers=4 23 | 24 | # it is vital for restarting 25 | checkpoint_file='./checkpoint/' 26 | printCircle=100 27 | 28 | data_format='channels_first' 29 | n_channels=3 30 | img_height=128 31 | img_width=160 32 | 33 | # stack_sizes="($n_channels, 48, 96, 192)" 34 | # R_stack_sizes=$stack_sizes 35 | # A_filter_sizes="(3, 3, 3)" 36 | # Ahat_filter_sizes="(3, 3, 3, 3)" 37 | # R_filter_sizes="(3, 3, 3, 3)" 38 | 39 | layer_loss_weightsMode='L_0' 40 | # layer_loss='L_all' 41 | 42 | # number of timesteps used for sequences in training 43 | num_timeSteps=10 44 | 45 | shuffle=true 46 | 47 | CUDA_VISIBLE_DEVICES=0 python train.py \ 48 | --mode ${mode} \ 49 | --dataPath ${DATA_DIR} \ 50 | --checkpoint_savePath ${checkpoint_savePath} \ 51 | --epochs ${epochs} \ 52 | --batch_size ${batch_size} \ 53 | --optimizer ${optimizer} \ 54 | --lr ${learning_rate} \ 55 | --momentum ${momentum} \ 56 | --beta1 ${beta1} \ 57 | --beta2 ${beta2} \ 58 | --workers ${workers} \ 59 | --checkpoint_file ${checkpoint_file} \ 60 | --printCircle ${printCircle} \ 61 | --data_format ${data_format} \ 62 | --n_channels ${n_channels} \ 63 | --img_height ${img_height} \ 64 | --img_width ${img_width} \ 65 | --layer_loss_weightsMode ${layer_loss_weightsMode} \ 66 | --num_timeSteps ${num_timeSteps} \ 67 | --shuffle ${shuffle} 68 | # --stack_sizes ${stack_sizes} \ 69 | # --R_stack_sizes ${R_stack_sizes} \ 70 | # --A_filter_sizes ${A_filter_sizes} \ 71 | # --Ahat_filter_sizes ${Ahat_filter_sizes} \ 72 | # --R_filter_sizes ${R_filter_sizes} \ 73 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | ''' 5 | Usage: 6 | python visualization.py 7 | ''' 8 | import sys 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | plt.switch_backend('agg') 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | 17 | def sortByVariance(filtersData): 18 | '''resort the filters by variance.''' 19 | sumedData = np.sum(filtersData, axis = 3) 20 | flat = sumedData.reshape(sumedData.shape[0], sumedData.shape[1] * sumedData.shape[2]) 21 | std = np.std(flat, axis = 1) 22 | order = np.argsort(std) 23 | filterNum = int(order.shape[0] - (order.shape[0] % 10)) # e.g., 57——>50 24 | sortedData = np.zeros((filterNum,) + filtersData.shape[1:]) 25 | for i in range(filterNum): 26 | sortedData[i, :, :, :] = filtersData[order[i], :, :, :] 27 | return sortedData 28 | 29 | def visualize(filtersData, output_figName): 30 | ''' 31 | visualize the conv1 filters 32 | filtersData: (filters_num, height, width, 3) 33 | ''' 34 | print(output_figName) 35 | filtersData = np.squeeze(filtersData) 36 | print('after squeeze: ', filtersData.shape) # (96, 11, 11, 3) 37 | 38 | # normalize filtersData for display 39 | filtersData = (filtersData - filtersData.min()) / (filtersData.max() - filtersData.min()) 40 | filtersData = sortByVariance(filtersData) 41 | print('after sorting: ', filtersData.shape) # (96, 11, 11, 3) 42 | 43 | filters_num = filtersData.shape[0] 44 | # force the number of filters to be square 45 | n = int(np.ceil(np.sqrt(filters_num))) 46 | # add some space between filters 47 | padding = (((0, 0), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3)) # don't pad the last dimension (if there is one) 48 | # padding = (((0, 64 - filters_num), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3)) # don't pad the last dimension (if there is one) 49 | print(padding) # ((0, 0), (0, 1), (0, 1), (0, 0)) 50 | filtersData = np.pad(filtersData, padding, mode = 'constant', constant_values = 1) # pad with ones (white) 51 | print('after padding: ', filtersData.shape) # (96, 12, 12, 3) 52 | # tile the filters into an image 53 | filtersData = filtersData.reshape((5, 10) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1))) 54 | print('after reshape1: ', filtersData.shape) # (6, 12, 16, 12, 3) 55 | # filtersData = filtersData.reshape((8, 8) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1))) 56 | filtersData = filtersData.reshape((5 * filtersData.shape[1], 10 * filtersData.shape[3]) + filtersData.shape[4:]) 57 | print('after reshape2: ', filtersData.shape) # (72, 192, 3) 58 | # filtersData = filtersData.reshape((8 * filtersData.shape[1], 8 * filtersData.shape[3]) + filtersData.shape[4:]) 59 | 60 | plt.imshow(filtersData) 61 | plt.axis('off') 62 | plt.savefig(output_figName, bbox_inches = 'tight') 63 | 64 | def get_filtersData(checkpoint_file): 65 | '''get the filters data from checkpoint file.''' 66 | checkpoint = torch.load(checkpoint_file) 67 | stateDict = checkpoint['state_dict'] 68 | ## debug 69 | # for k, v in stateDict.items(): 70 | # print(k) 71 | conv1_filters = stateDict['feature.0.weight'] 72 | conv1_filters = conv1_filters.cpu().numpy() # if no `.cpu()`: RuntimeError: can't convert CUDA tensor to numpy (it doesn't support GPU arrays). Use .cpu() to move the tensor to host memory first. 73 | conv1_filters = conv1_filters.transpose(0, 2, 3, 1) 74 | # print(conv1_filters.shape) # (96, 11, 11, 12) 75 | return conv1_filters 76 | 77 | def visualize_layer2(filtersData, output_figName): 78 | '''A.2.weight''' 79 | filtersData = np.squeeze(filtersData) 80 | print('after squeeze: ', filtersData.shape) 81 | 82 | # normalize filtersData for display 83 | filtersData = (filtersData - filtersData.min()) / (filtersData.max() - filtersData.min()) 84 | 85 | sumedData = np.sum(filtersData, axis = 3) 86 | flat = sumedData.reshape(sumedData.shape[0], sumedData.shape[1] * sumedData.shape[2]) 87 | std = np.std(flat, axis = 1) 88 | order = np.argsort(std) 89 | # filterNum = int(order.shape[0] - (order.shape[0] % 10)) 90 | sortedData = np.zeros(filtersData.shape) 91 | for i in range(filtersData.shape[0]): 92 | sortedData[i, :, :, :] = filtersData[order[i], :, :, :] 93 | filtersData = sortedData 94 | print('after sorting: ', filtersData.shape) 95 | 96 | filters_num = filtersData.shape[0] 97 | # force the number of filters to be square 98 | n = int(np.ceil(np.sqrt(filters_num))) 99 | # add some space between filters 100 | padding = (((0, 0), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3)) # don't pad the last dimension (if there is one) 101 | # padding = (((0, 64 - filters_num), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3)) # don't pad the last dimension (if there is one) 102 | print(padding) # ((0, 0), (0, 1), (0, 1), (0, 0)) 103 | filtersData = np.pad(filtersData, padding, mode = 'constant', constant_values = 1) # pad with ones (white) 104 | print('after padding: ', filtersData.shape) # (96, 12, 12, 3) 105 | # tile the filters into an image 106 | filtersData = filtersData.reshape((3, 16) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1))) 107 | print('after reshape1: ', filtersData.shape) # (6, 12, 16, 12, 3) 108 | # filtersData = filtersData.reshape((8, 8) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1))) 109 | filtersData = filtersData.reshape((3 * filtersData.shape[1], 16 * filtersData.shape[3]) + filtersData.shape[4:]) 110 | print('after reshape2: ', filtersData.shape) # (72, 192, 3) 111 | # filtersData = filtersData.reshape((8 * filtersData.shape[1], 8 * filtersData.shape[3]) + filtersData.shape[4:]) 112 | 113 | plt.imshow(filtersData) 114 | plt.axis('off') 115 | plt.savefig(output_figName, bbox_inches = 'tight') 116 | 117 | 118 | 119 | if __name__ == '__main__': 120 | state_dict_file = './model_data_keras2/preTrained_weights_forPyTorch.pkl' 121 | stateDict = torch.load(state_dict_file) 122 | modules = ['A', 'Ahat', 'c', 'f', 'i', 'o'] 123 | # for m in modules: 124 | # # kernel = stateDict[m + '.0.weight'].cpu().numpy() 125 | # kernel = stateDict[m + '.0.weight'].cpu() 126 | # # print(kernel.shape) 127 | # # A: (48, 6, 3, 3) 128 | # # Ahat: (3, 3, 3, 3) 129 | # # c、f、i、o: (3, 57, 3, 3) 130 | # # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'nearest') 131 | # # kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'nearest') 132 | # # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'bilinear') 133 | # # kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'bilinear') 134 | # # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'linear') # 不行, linear只接受3D输入 135 | # print(kernel.data.size()) 136 | # kernel = kernel.data.numpy() 137 | # kernel = np.transpose(kernel, (1, 2, 3, 0)) 138 | # if m in ['c', 'f', 'i', 'o']: 139 | # visualize(kernel, './conv1_filters/' + m + '.png') 140 | 141 | # kernel = stateDict['A.2.weight'].cpu() # (96, 96, 3, 3) 142 | kernel = stateDict['Ahat.2.weight'].cpu() # (48, 48, 3, 3) 143 | kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'bilinear') 144 | kernel = kernel.data.numpy() 145 | kernel = np.transpose(kernel, (1, 2, 3, 0))[..., :3] # orz...原来有96个'RGB通道', 无法显示成图像, 人为截取前三维 146 | print('before calling visualization func: ', kernel.shape) 147 | visualize_layer2(kernel, './conv1_filters/Ahat.2.kernel.png') 148 | --------------------------------------------------------------------------------