├── src ├── __init__.py ├── layers │ ├── __init__.py │ ├── TensorLayerNorm.py │ ├── MIMN.py │ ├── SpatioTemporalLSTMCellv2.py │ └── MIMBlock.py ├── models │ ├── __init__.py │ ├── mim.py │ └── model_factory.py ├── utils │ ├── __init__.py │ ├── metrics.py │ ├── optimizer.py │ └── preprocess.py ├── data_provider │ ├── __init__.py │ ├── datasets_factory.py │ ├── mnist.py │ ├── human.py │ └── taxibj.py └── trainer.py ├── data └── human36m.sh ├── README.md └── run.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/human36m.sh: -------------------------------------------------------------------------------- 1 | # Download H36M images 2 | mkdir human 3 | cd human 4 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar 5 | tar -xf S1.tar 6 | rm S1.tar 7 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar 8 | tar -xf S5.tar 9 | rm S5.tar 10 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar 11 | tar -xf S6.tar 12 | rm S6.tar 13 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar 14 | tar -xf S7.tar 15 | rm S7.tar 16 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar 17 | tar -xf S8.tar 18 | rm S8.tar 19 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar 20 | tar -xf S9.tar 21 | rm S9.tar 22 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar 23 | tar -xf S11.tar 24 | rm S11.tar 25 | cd .. 26 | -------------------------------------------------------------------------------- /src/layers/TensorLayerNorm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | EPSILON = 0.00001 4 | 5 | 6 | def tensor_layer_norm(x, state_name): 7 | x_shape = x.get_shape() 8 | dims = x_shape.ndims 9 | params_shape = x_shape[-1:] 10 | if dims == 4: 11 | m, v = tf.nn.moments(x, [1,2,3], keep_dims=True) 12 | elif dims == 5: 13 | m, v = tf.nn.moments(x, [1,2,3,4], keep_dims=True) 14 | elif dims == 2: 15 | m, v = tf.nn.moments(x, [1], keep_dims=True) 16 | else: 17 | raise ValueError('input tensor for layer normalization must be rank 4 or 5.') 18 | b = tf.get_variable(state_name+'b',initializer=tf.zeros(params_shape)) 19 | s = tf.get_variable(state_name+'s',initializer=tf.ones(params_shape)) 20 | x_tln = tf.nn.batch_normalization(x, m, v, b, s, EPSILON) 21 | return x_tln 22 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | from scipy.signal import convolve2d 5 | 6 | 7 | def batch_mae_frame_float(gen_frames, gt_frames): 8 | # [batch, width, height] or [batch, width, height, channel] 9 | if gen_frames.ndim == 3: 10 | axis = (1, 2) 11 | elif gen_frames.ndim == 4: 12 | axis = (1, 2, 3) 13 | x = np.float32(gen_frames) 14 | y = np.float32(gt_frames) 15 | mae = np.sum(np.absolute(x - y), axis=axis, dtype=np.float32) 16 | return np.mean(mae) 17 | 18 | 19 | def batch_psnr(gen_frames, gt_frames): 20 | # [batch, width, height] or [batch, width, height, channel] 21 | if gen_frames.ndim == 3: 22 | axis = (1, 2) 23 | elif gen_frames.ndim == 4: 24 | axis = (1, 2, 3) 25 | x = np.int32(gen_frames) 26 | y = np.int32(gt_frames) 27 | num_pixels = float(np.size(gen_frames[0])) 28 | mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels 29 | psnr = 20 * np.log10(255) - 10 * np.log10(mse) 30 | return np.mean(psnr) -------------------------------------------------------------------------------- /src/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): 5 | updates = [] 6 | if type(cost_or_grads) is not list: 7 | grads = tf.gradients(cost_or_grads, params) 8 | else: 9 | grads = cost_or_grads 10 | t = tf.Variable(1., 'adam_t') 11 | for p, g in zip(params, grads): 12 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') 13 | if mom1 > 0: 14 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') 15 | v_t = mom1 * v + (1. - mom1) * g 16 | v_hat = v_t / (1. - tf.pow(mom1, t)) 17 | updates.append(v.assign(v_t)) 18 | else: 19 | v_hat = g 20 | mg_t = mom2 * mg + (1. - mom2) * tf.square(g) 21 | mg_hat = mg_t / (1. - tf.pow(mom2, t)) 22 | g_t = v_hat / tf.sqrt(mg_hat + 1e-8) 23 | p_t = p - lr * g_t 24 | updates.append(mg.assign(mg_t)) 25 | updates.append(p.assign(p_t)) 26 | updates.append(t.assign_add(1)) 27 | return tf.group(*updates) 28 | 29 | -------------------------------------------------------------------------------- /src/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | 5 | 6 | def reshape_patch(img_tensor, patch_size): 7 | assert 5 == img_tensor.ndim 8 | batch_size = np.shape(img_tensor)[0] 9 | seq_length = np.shape(img_tensor)[1] 10 | img_height = np.shape(img_tensor)[2] 11 | img_width = np.shape(img_tensor)[3] 12 | num_channels = np.shape(img_tensor)[4] 13 | a = np.reshape(img_tensor, [batch_size, seq_length, 14 | img_height//patch_size, patch_size, 15 | img_width//patch_size, patch_size, 16 | num_channels]) 17 | b = np.transpose(a, [0,1,2,4,3,5,6]) 18 | patch_tensor = np.reshape(b, [batch_size, seq_length, 19 | img_height//patch_size, 20 | img_width//patch_size, 21 | patch_size*patch_size*num_channels]) 22 | return patch_tensor 23 | 24 | 25 | def reshape_patch_back(patch_tensor, patch_size): 26 | assert 5 == patch_tensor.ndim 27 | batch_size = np.shape(patch_tensor)[0] 28 | seq_length = np.shape(patch_tensor)[1] 29 | patch_height = np.shape(patch_tensor)[2] 30 | patch_width = np.shape(patch_tensor)[3] 31 | channels = np.shape(patch_tensor)[4] 32 | img_channels = channels // (patch_size*patch_size) 33 | a = np.reshape(patch_tensor, [batch_size, seq_length, 34 | patch_height, patch_width, 35 | patch_size, patch_size, 36 | img_channels]) 37 | b = np.transpose(a, [0,1,2,4,3,5,6]) 38 | img_tensor = np.reshape(b, [batch_size, seq_length, 39 | patch_height * patch_size, 40 | patch_width * patch_size, 41 | img_channels]) 42 | return img_tensor 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Memory In Memory Networks 2 | 3 | MIM is a neural network for video prediction and spatiotemporal modeling. It is based on the paper [Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics](https://arxiv.org/pdf/1811.07490.pdf) to be presented at CVPR 2019. 4 | 5 | ## Abstract 6 | 7 | Natural spatiotemporal processes can be highly non-stationary in many ways, e.g. the low-level non-stationarity such as spatial correlations or temporal dependencies of local pixel values; and the high-level non-stationarity such as the accumulation, deformation or dissipation of radar echoes in precipitation forecasting. 8 | 9 | We try to stationalize and approximate the non-stationary processes by modeling the differential signals with the MIM recurrent blocks. By stacking multiple MIM blocks, we could potentially handle higher-order non-stationarity. Our model achieves the state-of-the-art results on three spatiotemporal prediction tasks across both synthetic and real-world data. 10 | 11 | ![model](https://github.com/ZJianjin/mim_images/blob/master/readme_structure.png) 12 | 13 | ## Pre-trained Models and Datasets 14 | 15 | All pre-trained MIM models have been uploaded to [DROPBOX](https://www.dropbox.com/s/7kd82ijezk4lkmp/mim-lib.zip?dl=0) and [BAIDU YUN](https://pan.baidu.com/s/1O07H7l1NTWmAkx3UCDVMLA) (password: srhv). 16 | 17 | It also includes our pre-processed training/testing data for Moving MNIST, Color-Changing Moving MNIST, and TaxiBJ. 18 | 19 | For Human3.6M, you may download it using data/human36m.sh. 20 | 21 | ## Generation Results 22 | 23 | #### Moving MNIST 24 | 25 | ![mnist1](https://github.com/ZJianjin/mim_images/blob/master/mnist1.gif) 26 | 27 | ![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist4.gif) 28 | 29 | ![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist5.gif) 30 | 31 | #### Color-Changing Moving MNIST 32 | 33 | ![mnistc1](https://github.com/ZJianjin/mim_images/blob/master/mnistc2.gif) 34 | 35 | ![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc3.gif) 36 | 37 | ![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc4.gif) 38 | 39 | #### Radar Echos 40 | 41 | ![radar1](https://github.com/ZJianjin/mim_images/blob/master/radar9.gif) 42 | 43 | ![radar2](https://github.com/ZJianjin/mim_images/blob/master/radar3.gif) 44 | 45 | ![radar3](https://github.com/ZJianjin/mim_images/blob/master/radar7.gif) 46 | 47 | #### Human3.6M 48 | 49 | ![human1](https://github.com/ZJianjin/mim_images/blob/master/human3.gif) 50 | 51 | ![human2](https://github.com/ZJianjin/mim_images/blob/master/human5.gif) 52 | 53 | ![human3](https://github.com/ZJianjin/mim_images/blob/master/human10.gif) 54 | 55 | ## BibTeX 56 | ``` 57 | @article{wang2018memory, 58 | title={Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics}, 59 | author={Wang, Yunbo and Zhang, Jianjin and Zhu, Hongyu and Long, Mingsheng and Wang, Jianmin and Yu, Philip S}, 60 | journal={arXiv preprint arXiv:1811.07490}, 61 | year={2019} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /src/layers/MIMN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.layers.TensorLayerNorm import tensor_layer_norm 3 | 4 | class MIMN(): 5 | def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln=True, initializer=0.001): 6 | """Initialize the basic Conv LSTM cell. 7 | Args: 8 | layer_name: layer names for different convlstm layers. 9 | filter_size: int tuple thats the height and width of the filter. 10 | num_hidden: number of units in output tensor. 11 | tln: whether to apply tensor layer normalization. 12 | """ 13 | self.layer_name = layer_name 14 | self.filter_size = filter_size 15 | self.num_hidden = num_hidden 16 | self.layer_norm = tln 17 | self.batch = seq_shape[0] 18 | self.height = seq_shape[2] 19 | self.width = seq_shape[3] 20 | self._forget_bias = 1.0 21 | if initializer == -1: 22 | self.initializer = None 23 | else: 24 | self.initializer = tf.random_uniform_initializer(-initializer,initializer) 25 | 26 | def init_state(self): 27 | shape = [self.batch, self.height, self.width, self.num_hidden] 28 | return tf.zeros(shape, dtype=tf.float32) 29 | 30 | def __call__(self, x, h_t, c_t): 31 | if h_t is None: 32 | h_t = self.init_state() 33 | if c_t is None: 34 | c_t = self.init_state() 35 | with tf.variable_scope(self.layer_name): 36 | h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4, 37 | self.filter_size, 1, padding='same', 38 | kernel_initializer=self.initializer, 39 | name='state_to_state') 40 | if self.layer_norm: 41 | h_concat = tensor_layer_norm(h_concat, 'state_to_state') 42 | i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3) 43 | 44 | ct_weight = tf.get_variable( 45 | 'c_t_weight', [self.height,self.width,self.num_hidden*2]) 46 | ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight) 47 | i_c, f_c = tf.split(ct_activation, 2, 3) 48 | 49 | i_ = i_h + i_c 50 | f_ = f_h + f_c 51 | g_ = g_h 52 | o_ = o_h 53 | 54 | if x != None: 55 | x_concat = tf.layers.conv2d(x, self.num_hidden * 4, 56 | self.filter_size, 1, 57 | padding='same', 58 | kernel_initializer=self.initializer, 59 | name='input_to_state') 60 | if self.layer_norm: 61 | x_concat = tensor_layer_norm(x_concat, 'input_to_state') 62 | i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3) 63 | 64 | i_ += i_x 65 | f_ += f_x 66 | g_ += g_x 67 | o_ += o_x 68 | 69 | i_ = tf.nn.sigmoid(i_) 70 | f_ = tf.nn.sigmoid(f_ + self._forget_bias) 71 | c_new = f_ * c_t + i_ * tf.nn.tanh(g_) 72 | 73 | oc_weight = tf.get_variable( 74 | 'oc_weight', [self.height,self.width,self.num_hidden]) 75 | o_c = tf.multiply(c_new, oc_weight) 76 | 77 | h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new) 78 | 79 | return h_new, c_new 80 | 81 | -------------------------------------------------------------------------------- /src/layers/SpatioTemporalLSTMCellv2.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | from src.layers.TensorLayerNorm import tensor_layer_norm 5 | 6 | class SpatioTemporalLSTMCell(): 7 | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden, 8 | seq_shape, tln=False, initializer=None): 9 | """Initialize the basic Conv LSTM cell. 10 | Args: 11 | layer_name: layer names for different convlstm layers. 12 | filter_size: int tuple thats the height and width of the filter. 13 | num_hidden: number of units in output tensor. 14 | forget_bias: float, The bias added to forget gates (see above). 15 | tln: whether to apply tensor layer normalization 16 | """ 17 | self.layer_name = layer_name 18 | self.filter_size = filter_size 19 | self.num_hidden_in = num_hidden_in 20 | self.num_hidden = num_hidden 21 | self.batch = seq_shape[0] 22 | self.height = seq_shape[2] 23 | self.width = seq_shape[3] 24 | self.layer_norm = tln 25 | self._forget_bias = 1.0 26 | 27 | def w_initializer(dim_in, dim_out): 28 | random_range = math.sqrt(6.0 / (dim_in + dim_out)) 29 | return tf.random_uniform_initializer(-random_range, random_range) 30 | if initializer is None or initializer == -1: 31 | self.initializer = w_initializer 32 | else: 33 | self.initializer = tf.random_uniform_initializer(-initializer, initializer) 34 | 35 | def init_state(self): 36 | return tf.zeros([self.batch, self.height, self.width, self.num_hidden], 37 | dtype=tf.float32) 38 | 39 | def __call__(self, x, h, c, m): 40 | if h is None: 41 | h = self.init_state() 42 | if c is None: 43 | c = self.init_state() 44 | if m is None: 45 | m = self.init_state() 46 | 47 | with tf.variable_scope(self.layer_name): 48 | t_cc = tf.layers.conv2d( 49 | h, self.num_hidden*4, 50 | self.filter_size, 1, padding='same', 51 | kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4), 52 | name='time_state_to_state') 53 | s_cc = tf.layers.conv2d( 54 | m, self.num_hidden*4, 55 | self.filter_size, 1, padding='same', 56 | kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4), 57 | name='spatio_state_to_state') 58 | x_shape_in = x.get_shape().as_list()[-1] 59 | x_cc = tf.layers.conv2d( 60 | x, self.num_hidden*4, 61 | self.filter_size, 1, padding='same', 62 | kernel_initializer=self.initializer(x_shape_in, self.num_hidden*4), 63 | name='input_to_state') 64 | if self.layer_norm: 65 | t_cc = tensor_layer_norm(t_cc, 'time_state_to_state') 66 | s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state') 67 | x_cc = tensor_layer_norm(x_cc, 'input_to_state') 68 | 69 | i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3) 70 | i_t, g_t, f_t, o_t = tf.split(t_cc, 4, 3) 71 | i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3) 72 | 73 | i = tf.nn.sigmoid(i_x + i_t) 74 | i_ = tf.nn.sigmoid(i_x + i_s) 75 | g = tf.nn.tanh(g_x + g_t) 76 | g_ = tf.nn.tanh(g_x + g_s) 77 | f = tf.nn.sigmoid(f_x + f_t + self._forget_bias) 78 | f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias) 79 | o = tf.nn.sigmoid(o_x + o_t + o_s) 80 | new_m = f_ * m + i_ * g_ 81 | new_c = f * c + i * g 82 | cell = tf.concat([new_c, new_m],3) 83 | cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same', 84 | kernel_initializer=self.initializer(self.num_hidden*2, self.num_hidden), 85 | name='cell_reduce') 86 | new_h = o * tf.nn.tanh(cell) 87 | 88 | return new_h, new_c, new_m 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /src/data_provider/datasets_factory.py: -------------------------------------------------------------------------------- 1 | from src.data_provider import mnist, human, taxibj 2 | 3 | datasets_map = { 4 | 'mnist': mnist, 5 | 'taxibj': taxibj, 6 | 'human': human 7 | } 8 | 9 | 10 | def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size, 11 | img_width, seq_length=20, is_training=True): 12 | '''Given a dataset name and returns a Dataset. 13 | Args: 14 | dataset_name: String, the name of the dataset. 15 | train_data_paths: List, [train_data_path1, train_data_path2...] 16 | valid_data_paths: List, [val_data_path1, val_data_path2...] 17 | batch_size: Int 18 | img_width: Int 19 | is_training: Bool 20 | Returns: 21 | if is_training: 22 | Two dataset instances for both training and evaluation. 23 | else: 24 | One dataset instance for evaluation. 25 | Raises: 26 | ValueError: If `dataset_name` is unknown. 27 | ''' 28 | if dataset_name not in datasets_map: 29 | raise ValueError('Name of dataset unknown %s' % dataset_name) 30 | train_data_list = train_data_paths.split(',') 31 | valid_data_list = valid_data_paths.split(',') 32 | if dataset_name == 'mnist': 33 | test_input_param = {'paths': valid_data_list, 34 | 'minibatch_size': batch_size, 35 | 'input_data_type': 'float32', 36 | 'is_output_sequence': True, 37 | 'name': dataset_name + 'test iterator'} 38 | test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param) 39 | test_input_handle.begin(do_shuffle=False) 40 | if is_training: 41 | train_input_param = {'paths': train_data_list, 42 | 'minibatch_size': batch_size, 43 | 'input_data_type': 'float32', 44 | 'is_output_sequence': True, 45 | 'name': dataset_name + ' train iterator'} 46 | train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param) 47 | train_input_handle.begin(do_shuffle=True) 48 | return train_input_handle, test_input_handle 49 | else: 50 | return test_input_handle 51 | 52 | if dataset_name == 'human': 53 | input_param = {'paths': valid_data_list, 54 | 'image_width': img_width, 55 | 'minibatch_size': batch_size, 56 | 'seq_length': seq_length, 57 | 'channel': 3, 58 | 'input_data_type': 'float32', 59 | 'name': 'human'} 60 | input_handle = datasets_map[dataset_name].DataProcess(input_param) 61 | test_input_handle = input_handle.get_test_input_handle() 62 | test_input_handle.begin(do_shuffle=False) 63 | if is_training: 64 | train_input_handle = input_handle.get_train_input_handle() 65 | train_input_handle.begin(do_shuffle=True) 66 | return train_input_handle, test_input_handle 67 | else: 68 | return test_input_handle 69 | 70 | if dataset_name == 'taxibj': 71 | input_param = {'paths': valid_data_list, 72 | 'image_width': img_width, 73 | 'minibatch_size': batch_size, 74 | 'seq_length': seq_length, 75 | 'input_data_type': 'float32', 76 | 'name': dataset_name + ' iterator'} 77 | input_handle = datasets_map[dataset_name].DataProcess(input_param) 78 | if is_training: 79 | train_input_handle = input_handle.get_train_input_handle() 80 | train_input_handle.begin(do_shuffle=True) 81 | test_input_handle = input_handle.get_test_input_handle() 82 | test_input_handle.begin(do_shuffle=False) 83 | return train_input_handle, test_input_handle 84 | else: 85 | test_input_handle = input_handle.get_test_input_handle() 86 | test_input_handle.begin(do_shuffle=False) 87 | return test_input_handle 88 | -------------------------------------------------------------------------------- /src/models/mim.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jianjin' 2 | 3 | import tensorflow as tf 4 | from src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell as stlstm 5 | from src.layers.MIMBlock import MIMBlock as mimblock 6 | from src.layers.MIMN import MIMN as mimn 7 | import math 8 | 9 | 10 | def w_initializer(dim_in, dim_out): 11 | random_range = math.sqrt(6.0 / (dim_in + dim_out)) 12 | return tf.random_uniform_initializer(-random_range, random_range) 13 | 14 | 15 | def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, filter_size, 16 | stride=1, total_length=20, input_length=10, tln=True): 17 | gen_images = [] 18 | stlstm_layer = [] 19 | stlstm_layer_diff = [] 20 | cell_state = [] 21 | hidden_state = [] 22 | cell_state_diff = [] 23 | hidden_state_diff = [] 24 | shape = images.get_shape().as_list() 25 | output_channels = shape[-1] 26 | 27 | for i in range(num_layers): 28 | if i == 0: 29 | num_hidden_in = num_hidden[num_layers - 1] 30 | else: 31 | num_hidden_in = num_hidden[i - 1] 32 | if i < 1: 33 | new_stlstm_layer = stlstm('stlstm_' + str(i + 1), 34 | filter_size, 35 | num_hidden_in, 36 | num_hidden[i], 37 | shape, 38 | tln=tln) 39 | else: 40 | new_stlstm_layer = mimblock('stlstm_' + str(i + 1), 41 | filter_size, 42 | num_hidden_in, 43 | num_hidden[i], 44 | shape, 45 | tln=tln) 46 | stlstm_layer.append(new_stlstm_layer) 47 | cell_state.append(None) 48 | hidden_state.append(None) 49 | 50 | for i in range(num_layers - 1): 51 | new_stlstm_layer = mimn('stlstm_diff' + str(i + 1), 52 | filter_size, 53 | num_hidden[i + 1], 54 | shape, 55 | tln=tln) 56 | stlstm_layer_diff.append(new_stlstm_layer) 57 | cell_state_diff.append(None) 58 | hidden_state_diff.append(None) 59 | 60 | st_memory = None 61 | 62 | for time_step in range(total_length - 1): 63 | reuse = bool(gen_images) 64 | with tf.variable_scope('predrnn', reuse=reuse): 65 | if time_step < input_length: 66 | x_gen = images[:,time_step] 67 | else: 68 | x_gen = schedual_sampling_bool[:,time_step-input_length]*images[:,time_step] + \ 69 | (1-schedual_sampling_bool[:,time_step-input_length])*x_gen 70 | preh = hidden_state[0] 71 | hidden_state[0], cell_state[0], st_memory = stlstm_layer[0]( 72 | x_gen, hidden_state[0], cell_state[0], st_memory) 73 | for i in range(1, num_layers): 74 | if time_step > 0: 75 | if i == 1: 76 | hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1]( 77 | hidden_state[i - 1] - preh, hidden_state_diff[i - 1], cell_state_diff[i - 1]) 78 | else: 79 | hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1]( 80 | hidden_state_diff[i - 2], hidden_state_diff[i - 1], cell_state_diff[i - 1]) 81 | else: 82 | stlstm_layer_diff[i - 1](tf.zeros_like(hidden_state[i - 1]), None, None) 83 | preh = hidden_state[i] 84 | hidden_state[i], cell_state[i], st_memory = stlstm_layer[i]( 85 | hidden_state[i - 1], hidden_state_diff[i - 1], hidden_state[i], cell_state[i], st_memory) 86 | x_gen = tf.layers.conv2d(hidden_state[num_layers - 1], 87 | filters=output_channels, 88 | kernel_size=1, 89 | strides=1, 90 | padding='same', 91 | kernel_initializer=w_initializer(num_hidden[num_layers - 1], output_channels), 92 | name="back_to_pixel") 93 | gen_images.append(x_gen) 94 | 95 | gen_images = tf.stack(gen_images, axis=1) 96 | loss = tf.nn.l2_loss(gen_images - images[:, 1:]) 97 | 98 | return [gen_images, loss] 99 | -------------------------------------------------------------------------------- /src/layers/MIMBlock.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.layers.TensorLayerNorm import tensor_layer_norm 3 | import math 4 | 5 | 6 | class MIMBlock(): 7 | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden, 8 | seq_shape, tln=False, initializer=None): 9 | """Initialize the basic Conv LSTM cell. 10 | Args: 11 | layer_name: layer names for different convlstm layers. 12 | filter_size: int tuple thats the height and width of the filter. 13 | num_hidden: number of units in output tensor. 14 | forget_bias: float, The bias added to forget gates (see above). 15 | tln: whether to apply tensor layer normalization 16 | """ 17 | self.layer_name = layer_name 18 | self.filter_size = filter_size 19 | self.num_hidden_in = num_hidden_in 20 | self.num_hidden = num_hidden 21 | self.convlstm_c = None 22 | self.batch = seq_shape[0] 23 | self.height = seq_shape[2] 24 | self.width = seq_shape[3] 25 | self.layer_norm = tln 26 | self._forget_bias = 1.0 27 | 28 | def w_initializer(dim_in, dim_out): 29 | random_range = math.sqrt(6.0 / (dim_in + dim_out)) 30 | return tf.random_uniform_initializer(-random_range, random_range) 31 | if initializer is None or initializer == -1: 32 | self.initializer = w_initializer 33 | else: 34 | self.initializer = tf.random_uniform_initializer(-initializer, initializer) 35 | 36 | def init_state(self): 37 | return tf.zeros([self.batch, self.height, self.width, self.num_hidden], 38 | dtype=tf.float32) 39 | 40 | def MIMS(self, x, h_t, c_t): 41 | if h_t is None: 42 | h_t = self.init_state() 43 | if c_t is None: 44 | c_t = self.init_state() 45 | with tf.variable_scope(self.layer_name): 46 | h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4, 47 | self.filter_size, 1, padding='same', 48 | kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), 49 | name='state_to_state') 50 | if self.layer_norm: 51 | h_concat = tensor_layer_norm(h_concat, 'state_to_state') 52 | i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3) 53 | 54 | ct_weight = tf.get_variable( 55 | 'c_t_weight', [self.height,self.width,self.num_hidden*2]) 56 | ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight) 57 | i_c, f_c = tf.split(ct_activation, 2, 3) 58 | 59 | i_ = i_h + i_c 60 | f_ = f_h + f_c 61 | g_ = g_h 62 | o_ = o_h 63 | 64 | if x != None: 65 | x_concat = tf.layers.conv2d(x, self.num_hidden * 4, 66 | self.filter_size, 1, 67 | padding='same', 68 | kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), 69 | name='input_to_state') 70 | if self.layer_norm: 71 | x_concat = tensor_layer_norm(x_concat, 'input_to_state') 72 | i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3) 73 | 74 | i_ += i_x 75 | f_ += f_x 76 | g_ += g_x 77 | o_ += o_x 78 | 79 | i_ = tf.nn.sigmoid(i_) 80 | f_ = tf.nn.sigmoid(f_ + self._forget_bias) 81 | c_new = f_ * c_t + i_ * tf.nn.tanh(g_) 82 | 83 | oc_weight = tf.get_variable( 84 | 'oc_weight', [self.height,self.width,self.num_hidden]) 85 | o_c = tf.multiply(c_new, oc_weight) 86 | 87 | h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new) 88 | 89 | return h_new, c_new 90 | 91 | def __call__(self, x, diff_h, h, c, m): 92 | if h is None: 93 | h = self.init_state() 94 | if c is None: 95 | c = self.init_state() 96 | if m is None: 97 | m = self.init_state() 98 | if diff_h is None: 99 | diff_h = tf.zeros_like(h) 100 | 101 | with tf.variable_scope(self.layer_name): 102 | t_cc = tf.layers.conv2d( 103 | h, self.num_hidden * 3, 104 | self.filter_size, 1, padding='same', 105 | kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 3), 106 | name='time_state_to_state') 107 | s_cc = tf.layers.conv2d( 108 | m, self.num_hidden * 4, 109 | self.filter_size, 1, padding='same', 110 | kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), 111 | name='spatio_state_to_state') 112 | x_shape_in = x.get_shape().as_list()[-1] 113 | x_cc = tf.layers.conv2d( 114 | x, self.num_hidden * 4, 115 | self.filter_size, 1, padding='same', 116 | kernel_initializer=self.initializer(x_shape_in, self.num_hidden * 4), 117 | name='input_to_state') 118 | if self.layer_norm: 119 | t_cc = tensor_layer_norm(t_cc, 'time_state_to_state') 120 | s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state') 121 | x_cc = tensor_layer_norm(x_cc, 'input_to_state') 122 | 123 | i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3) 124 | i_t, g_t, o_t = tf.split(t_cc, 3, 3) 125 | i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3) 126 | 127 | i = tf.nn.sigmoid(i_x + i_t) 128 | i_ = tf.nn.sigmoid(i_x + i_s) 129 | g = tf.nn.tanh(g_x + g_t) 130 | g_ = tf.nn.tanh(g_x + g_s) 131 | f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias) 132 | o = tf.nn.sigmoid(o_x + o_t + o_s) 133 | new_m = f_ * m + i_ * g_ 134 | c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c) 135 | new_c = c + i * g 136 | cell = tf.concat([new_c, new_m], 3) 137 | cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, 138 | padding='same', name='cell_reduce') 139 | new_h = o * tf.nn.tanh(cell) 140 | 141 | return new_h, new_c, new_m 142 | -------------------------------------------------------------------------------- /src/data_provider/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class InputHandle: 5 | def __init__(self, input_param): 6 | self.paths = input_param['paths'] 7 | self.num_paths = len(input_param['paths']) 8 | self.name = input_param['name'] 9 | self.input_data_type = input_param.get('input_data_type', 'float32') 10 | self.output_data_type = input_param.get('output_data_type', 'float32') 11 | self.minibatch_size = input_param['minibatch_size'] 12 | self.is_output_sequence = input_param['is_output_sequence'] 13 | self.data = {} 14 | self.indices = {} 15 | self.current_position = 0 16 | self.current_batch_size = 0 17 | self.current_batch_indices = [] 18 | self.current_input_length = 0 19 | self.current_output_length = 0 20 | self.load() 21 | 22 | def load(self): 23 | dat_1 = np.load(self.paths[0]) 24 | for key in dat_1.keys(): 25 | self.data[key] = dat_1[key] 26 | if self.num_paths == 2: 27 | dat_2 = np.load(self.paths[1]) 28 | num_clips_1 = dat_1['clips'].shape[1] 29 | dat_2['clips'][:,:,0] += num_clips_1 30 | self.data['clips'] = np.concatenate( 31 | (dat_1['clips'], dat_2['clips']), axis=1) 32 | self.data['input_raw_data'] = np.concatenate( 33 | (dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0) 34 | self.data['output_raw_data'] = np.concatenate( 35 | (dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0) 36 | for key in self.data.keys(): 37 | print(key) 38 | print(self.data[key].shape) 39 | 40 | def total(self): 41 | return self.data['clips'].shape[1] 42 | 43 | def begin(self, do_shuffle = True): 44 | self.indices = np.arange(self.total(),dtype="int32") 45 | if do_shuffle: 46 | random.shuffle(self.indices) 47 | self.current_position = 0 48 | if self.current_position + self.minibatch_size <= self.total(): 49 | self.current_batch_size = self.minibatch_size 50 | else: 51 | self.current_batch_size = self.total() - self.current_position 52 | self.current_batch_indices = self.indices[ 53 | self.current_position:self.current_position + self.current_batch_size] 54 | self.current_input_length = max(self.data['clips'][0, ind, 1] for ind 55 | in self.current_batch_indices) 56 | self.current_output_length = max(self.data['clips'][1, ind, 1] for ind 57 | in self.current_batch_indices) 58 | 59 | def next(self): 60 | self.current_position += self.current_batch_size 61 | if self.no_batch_left(): 62 | return None 63 | if self.current_position + self.minibatch_size <= self.total(): 64 | self.current_batch_size = self.minibatch_size 65 | else: 66 | self.current_batch_size = self.total() - self.current_position 67 | self.current_batch_indices = self.indices[ 68 | self.current_position:self.current_position + self.current_batch_size] 69 | self.current_input_length = max(self.data['clips'][0, ind, 1] for ind 70 | in self.current_batch_indices) 71 | self.current_output_length = max(self.data['clips'][1, ind, 1] for ind 72 | in self.current_batch_indices) 73 | 74 | def no_batch_left(self): 75 | if self.current_position >= self.total() - self.current_batch_size: 76 | return True 77 | else: 78 | return False 79 | 80 | def input_batch(self): 81 | if self.no_batch_left(): 82 | return None 83 | input_batch = np.zeros( 84 | (self.current_batch_size, self.current_input_length) + 85 | tuple(self.data['dims'][0])).astype(self.input_data_type) 86 | input_batch = np.transpose(input_batch,(0,1,3,4,2)) 87 | for i in range(self.current_batch_size): 88 | batch_ind = self.current_batch_indices[i] 89 | begin = self.data['clips'][0, batch_ind, 0] 90 | end = self.data['clips'][0, batch_ind, 0] + \ 91 | self.data['clips'][0, batch_ind, 1] 92 | data_slice = self.data['input_raw_data'][begin:end, :, :, :] 93 | data_slice = np.transpose(data_slice,(0,2,3,1)) 94 | input_batch[i, :self.current_input_length, :, :, :] = data_slice 95 | input_batch = input_batch.astype(self.input_data_type) 96 | return input_batch 97 | 98 | def output_batch(self): 99 | if self.no_batch_left(): 100 | return None 101 | if(2 ,3) == self.data['dims'].shape: 102 | raw_dat = self.data['output_raw_data'] 103 | else: 104 | raw_dat = self.data['input_raw_data'] 105 | if self.is_output_sequence: 106 | if (1, 3) == self.data['dims'].shape: 107 | output_dim = self.data['dims'][0] 108 | else: 109 | output_dim = self.data['dims'][1] 110 | output_batch = np.zeros( 111 | (self.current_batch_size,self.current_output_length) + 112 | tuple(output_dim)) 113 | else: 114 | output_batch = np.zeros((self.current_batch_size, ) + 115 | tuple(self.data['dims'][1])) 116 | for i in range(self.current_batch_size): 117 | batch_ind = self.current_batch_indices[i] 118 | begin = self.data['clips'][1, batch_ind, 0] 119 | end = self.data['clips'][1, batch_ind, 0] + \ 120 | self.data['clips'][1, batch_ind, 1] 121 | if self.is_output_sequence: 122 | data_slice = raw_dat[begin:end, :, :, :] 123 | output_batch[i, : data_slice.shape[0], :, :, :] = data_slice 124 | else: 125 | data_slice = raw_dat[begin, :, :, :] 126 | output_batch[i,:, :, :] = data_slice 127 | output_batch = output_batch.astype(self.output_data_type) 128 | output_batch = np.transpose(output_batch, [0,1,3,4,2]) 129 | return output_batch 130 | 131 | def get_batch(self): 132 | input_seq = self.input_batch() 133 | output_seq = self.output_batch() 134 | batch = np.concatenate((input_seq, output_seq), axis=1) 135 | return batch 136 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import datetime 3 | import cv2 4 | import numpy as np 5 | from skimage.measure import compare_ssim 6 | from src.utils import metrics 7 | from src.utils import preprocess 8 | 9 | 10 | def train(model, ims, real_input_flag, configs, itr, ims_reverse=None): 11 | ims = ims[:, :configs.total_length] 12 | ims_list = np.split(ims, configs.n_gpu) 13 | cost = model.train(ims_list, configs.lr, real_input_flag) 14 | 15 | flag = 1 16 | 17 | if configs.reverse_img: 18 | ims_rev = np.split(ims_reverse, configs.n_gpu) 19 | cost += model.train(ims_rev, configs.lr, real_input_flag) 20 | flag += 1 21 | 22 | if configs.reverse_input: 23 | ims_rev = np.split(ims[:, ::-1], configs.n_gpu) 24 | cost += model.train(ims_rev, configs.lr, real_input_flag) 25 | flag += 1 26 | if configs.reverse_img: 27 | ims_rev = np.split(ims_reverse[:, ::-1], configs.n_gpu) 28 | cost += model.train(ims_rev, configs.lr, real_input_flag) 29 | flag += 1 30 | 31 | cost = cost / flag 32 | 33 | if itr % configs.display_interval == 0: 34 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr)) 35 | print('training loss: ' + str(cost)) 36 | 37 | 38 | def test(model, test_input_handle, configs, save_name): 39 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') 40 | test_input_handle.begin(do_shuffle=False) 41 | res_path = os.path.join(configs.gen_frm_dir, str(save_name)) 42 | os.mkdir(res_path) 43 | avg_mse = 0 44 | batch_id = 0 45 | img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] 46 | 47 | for i in range(configs.total_length - configs.input_length): 48 | img_mse.append(0) 49 | ssim.append(0) 50 | psnr.append(0) 51 | fmae.append(0) 52 | sharp.append(0) 53 | 54 | if configs.img_height > 0: 55 | height = configs.img_height 56 | else: 57 | height = configs.img_width 58 | 59 | real_input_flag = np.zeros( 60 | (configs.batch_size, 61 | configs.total_length - configs.input_length - 1, 62 | configs.img_width // configs.patch_size, 63 | height // configs.patch_size, 64 | configs.patch_size ** 2 * configs.img_channel)) 65 | 66 | while not test_input_handle.no_batch_left(): 67 | batch_id = batch_id + 1 68 | if save_name != 'test_result': 69 | if batch_id > 100: break 70 | test_ims = test_input_handle.get_batch() 71 | test_ims = test_ims[:, :configs.total_length] 72 | if len(test_ims.shape) > 3: 73 | test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) 74 | else: 75 | test_dat = test_ims 76 | test_dat = np.split(test_dat, configs.n_gpu) 77 | img_gen, debug = model.test(test_dat, real_input_flag) 78 | 79 | # concat outputs of different gpus along batch 80 | img_gen = np.concatenate(img_gen) 81 | if len(img_gen.shape) > 3: 82 | img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) 83 | # MSE per frame 84 | for i in range(configs.total_length - configs.input_length): 85 | x = test_ims[:, i + configs.input_length, :, :, :] 86 | x = x[:configs.batch_size * configs.n_gpu] 87 | x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x)) 88 | gx = img_gen[:, i, :, :, :] 89 | fmae[i] += metrics.batch_mae_frame_float(gx, x) 90 | gx = np.maximum(gx, 0) 91 | gx = np.minimum(gx, 1) 92 | mse = np.square(x - gx).sum() 93 | img_mse[i] += mse 94 | avg_mse += mse 95 | real_frm = np.uint8(x * 255) 96 | pred_frm = np.uint8(gx * 255) 97 | psnr[i] += metrics.batch_psnr(pred_frm, real_frm) 98 | for b in range(configs.batch_size): 99 | sharp[i] += np.max( 100 | cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) 101 | 102 | score, _ = compare_ssim(gx[b], x[b], full=True, multichannel=True) 103 | ssim[i] += score 104 | 105 | # save prediction examples 106 | if batch_id <= configs.num_save_samples: 107 | path = os.path.join(res_path, str(batch_id)) 108 | os.mkdir(path) 109 | if len(debug) != 0: 110 | np.save(os.path.join(path, "f.npy"), debug) 111 | for i in range(configs.total_length): 112 | name = 'gt' + str(i + 1) + '.png' 113 | file_name = os.path.join(path, name) 114 | img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) 115 | if configs.img_channel == 2: 116 | img_gt = img_gt[:, :, :1] 117 | cv2.imwrite(file_name, img_gt) 118 | for i in range(configs.total_length - configs.input_length): 119 | name = 'pd' + str(i + 1 + configs.input_length) + '.png' 120 | file_name = os.path.join(path, name) 121 | img_pd = img_gen[0, i, :, :, :] 122 | if configs.img_channel == 2: 123 | img_pd = img_pd[:, :, :1] 124 | img_pd = np.maximum(img_pd, 0) 125 | img_pd = np.minimum(img_pd, 1) 126 | img_pd = np.uint8(img_pd * 255) 127 | cv2.imwrite(file_name, img_pd) 128 | test_input_handle.next() 129 | 130 | avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu) 131 | print('mse per seq: ' + str(avg_mse)) 132 | for i in range(configs.total_length - configs.input_length): 133 | print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu)) 134 | 135 | psnr = np.asarray(psnr, dtype=np.float32) / batch_id 136 | fmae = np.asarray(fmae, dtype=np.float32) / batch_id 137 | ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) 138 | sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id) 139 | 140 | print('psnr per frame: ' + str(np.mean(psnr))) 141 | for i in range(configs.total_length - configs.input_length): 142 | print(psnr[i]) 143 | print('fmae per frame: ' + str(np.mean(fmae))) 144 | for i in range(configs.total_length - configs.input_length): 145 | print(fmae[i]) 146 | print('ssim per frame: ' + str(np.mean(ssim))) 147 | for i in range(configs.total_length - configs.input_length): 148 | print(ssim[i]) 149 | print('sharpness per frame: ' + str(np.mean(sharp))) 150 | for i in range(configs.total_length - configs.input_length): 151 | print(sharp[i]) 152 | -------------------------------------------------------------------------------- /src/models/model_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from src.utils import optimizer 6 | from src.models import mim 7 | 8 | 9 | class Model(object): 10 | def __init__(self, configs): 11 | self.configs = configs 12 | # inputs 13 | if configs.img_height > 0: 14 | height = configs.img_height 15 | else: 16 | height = configs.img_width 17 | self.x = [tf.placeholder(tf.float32, 18 | [self.configs.batch_size, 19 | self.configs.total_length, 20 | self.configs.img_width // self.configs.patch_size, 21 | height // self.configs.patch_size, 22 | self.configs.patch_size * self.configs.patch_size * self.configs.img_channel]) 23 | for i in range(self.configs.n_gpu)] 24 | 25 | self.real_input_flag = tf.placeholder(tf.float32, 26 | [self.configs.batch_size, 27 | self.configs.total_length - self.configs.input_length - 1, 28 | self.configs.img_width // self.configs.patch_size, 29 | height // self.configs.patch_size, 30 | self.configs.patch_size * self.configs.patch_size * self.configs.img_channel]) 31 | 32 | grads = [] 33 | loss_train = [] 34 | self.pred_seq = [] 35 | self.tf_lr = tf.placeholder(tf.float32, shape=[]) 36 | self.params = dict() 37 | self.params.update(self.configs.__dict__['__flags']) 38 | num_hidden = [int(x) for x in self.configs.num_hidden.split(',')] 39 | num_layers = len(num_hidden) 40 | for i in range(self.configs.n_gpu): 41 | with tf.device('/gpu:%d' % i): 42 | with tf.variable_scope(tf.get_variable_scope(), 43 | reuse=True if i > 0 else None): 44 | # define a model 45 | output_list = self.construct_model( 46 | self.configs.model_name, 47 | self.x[i], 48 | self.params, 49 | self.real_input_flag, 50 | num_layers, 51 | num_hidden, 52 | self.configs.filter_size, 53 | self.configs.stride, 54 | self.configs.total_length, 55 | self.configs.input_length, 56 | self.configs.layer_norm) 57 | 58 | gen_ims = output_list[0] 59 | loss = output_list[1] 60 | if len(output_list) > 2: 61 | self.debug = output_list[2] 62 | else: 63 | self.debug = [] 64 | pred_ims = gen_ims[:, self.configs.input_length - self.configs.total_length:] 65 | loss_train.append(loss / self.configs.batch_size) 66 | # gradients 67 | all_params = tf.trainable_variables() 68 | grads.append(tf.gradients(loss, all_params)) 69 | self.pred_seq.append(pred_ims) 70 | 71 | # add losses and gradients together and get training updates 72 | with tf.device('/gpu:0'): 73 | for i in range(1, self.configs.n_gpu): 74 | loss_train[0] += loss_train[i] 75 | for j in range(len(grads[0])): 76 | grads[0][j] += grads[i][j] 77 | # keep track of moving average 78 | ema = tf.train.ExponentialMovingAverage(decay=0.9995) 79 | maintain_averages_op = tf.group(ema.apply(all_params)) 80 | self.train_op = tf.group(optimizer.adam_updates( 81 | all_params, grads[0], lr=self.tf_lr, mom1=0.95, mom2=0.9995), 82 | maintain_averages_op) 83 | 84 | self.loss_train = loss_train[0] / self.configs.n_gpu 85 | 86 | # session 87 | variables = tf.global_variables() 88 | self.saver = tf.train.Saver(variables) 89 | init = tf.global_variables_initializer() 90 | configProt = tf.ConfigProto() 91 | configProt.gpu_options.allow_growth = configs.allow_gpu_growth 92 | configProt.allow_soft_placement = True 93 | self.sess = tf.Session(config=configProt) 94 | self.sess.run(init) 95 | if self.configs.pretrained_model: 96 | self.saver.restore(self.sess, self.configs.pretrained_model) 97 | 98 | def train(self, inputs, lr, real_input_flag): 99 | feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)} 100 | feed_dict.update({self.tf_lr: lr}) 101 | feed_dict.update({self.real_input_flag: real_input_flag}) 102 | loss, _, debug = self.sess.run((self.loss_train, self.train_op, self.debug), feed_dict) 103 | return loss 104 | 105 | def test(self, inputs, real_input_flag): 106 | feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)} 107 | feed_dict.update({self.real_input_flag: real_input_flag}) 108 | gen_ims, debug = self.sess.run((self.pred_seq, self.debug), feed_dict) 109 | return gen_ims, debug 110 | 111 | def save(self, itr): 112 | checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt') 113 | self.saver.save(self.sess, checkpoint_path, global_step=itr) 114 | print('saved to ' + self.configs.save_dir) 115 | 116 | def load(self, checkpoint_path): 117 | print('load model:', checkpoint_path) 118 | self.saver.restore(self.sess, checkpoint_path) 119 | 120 | def construct_model(self, name, images, model_params, real_input_flag, num_layers, num_hidden, 121 | filter_size, stride, total_length, input_length, tln): 122 | '''Returns a sequence of generated frames 123 | Args: 124 | name: [predrnn_pp] 125 | params: dict for extra parameters of some models 126 | real_input_flag: for schedualed sampling. 127 | num_hidden: number of units in a lstm layer. 128 | filter_size: for convolutions inside lstm. 129 | stride: for convolutions inside lstm. 130 | total_length: including ins and outs. 131 | input_length: for inputs. 132 | tln: whether to apply tensor layer normalization. 133 | Returns: 134 | gen_images: a seq of frames. 135 | loss: [l2 / l1+l2]. 136 | Raises: 137 | ValueError: If network `name` is not recognized. 138 | ''' 139 | 140 | networks_map = { 141 | 'mim': mim.mim, 142 | } 143 | 144 | params = dict(mask=real_input_flag, num_layers=num_layers, num_hidden=num_hidden, filter_size=filter_size, 145 | stride=stride, total_length=total_length, input_length=input_length, is_training=True) 146 | params.update(model_params) 147 | if name in networks_map: 148 | func = networks_map[name] 149 | return func(images, params, real_input_flag, num_layers, num_hidden, filter_size, 150 | stride, total_length, input_length, tln) 151 | else: 152 | raise ValueError('Name of network unknown %s' % name) 153 | -------------------------------------------------------------------------------- /src/data_provider/human.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jianjin' 2 | import numpy as np 3 | import os 4 | import cv2 5 | from PIL import Image 6 | import logging 7 | import random 8 | import tensorflow as tf 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class InputHandle: 13 | def __init__(self, datas, indices, input_param): 14 | self.name = input_param['name'] 15 | self.input_data_type = input_param.get('input_data_type', 'float32') 16 | self.minibatch_size = input_param['minibatch_size'] 17 | self.image_width = input_param['image_width'] 18 | self.channel = input_param['channel'] 19 | self.datas = datas 20 | self.indices = indices 21 | self.current_position = 0 22 | self.current_batch_indices = [] 23 | self.current_input_length = input_param['seq_length'] 24 | self.interval = 2 25 | 26 | def total(self): 27 | return len(self.indices) 28 | 29 | def begin(self, do_shuffle=True): 30 | logger.info("Initialization for read data ") 31 | if do_shuffle: 32 | random.shuffle(self.indices) 33 | self.current_position = 0 34 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 35 | 36 | def next(self): 37 | self.current_position += self.minibatch_size 38 | if self.no_batch_left(): 39 | return None 40 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 41 | 42 | def no_batch_left(self): 43 | if self.current_position + self.minibatch_size > self.total(): 44 | return True 45 | else: 46 | return False 47 | 48 | def get_batch(self): 49 | if self.no_batch_left(): 50 | logger.error( 51 | "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") 52 | return None 53 | input_batch = np.zeros( 54 | (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype( 55 | self.input_data_type) 56 | for i in range(self.minibatch_size): 57 | batch_ind = self.current_batch_indices[i] 58 | begin = batch_ind 59 | end = begin + self.current_input_length * self.interval 60 | data_slice = self.datas[begin:end:self.interval] 61 | input_batch[i, :self.current_input_length, :, :, :] = data_slice 62 | # logger.info('data_slice shape') 63 | # logger.info(data_slice.shape) 64 | # logger.info(input_batch.shape) 65 | input_batch = input_batch.astype(self.input_data_type) 66 | return input_batch 67 | 68 | def print_stat(self): 69 | logger.info("Iterator Name: " + self.name) 70 | logger.info(" current_position: " + str(self.current_position)) 71 | logger.info(" Minibatch Size: " + str(self.minibatch_size)) 72 | logger.info(" total Size: " + str(self.total())) 73 | logger.info(" current_input_length: " + str(self.current_input_length)) 74 | logger.info(" Input Data Type: " + str(self.input_data_type)) 75 | 76 | class DataProcess: 77 | def __init__(self, input_param): 78 | self.input_param = input_param 79 | self.paths = input_param['paths'] 80 | self.image_width = input_param['image_width'] 81 | self.seq_len = input_param['seq_length'] 82 | 83 | def load_data(self, paths, mode='train'): 84 | data_dir = paths[0] 85 | intervel = 2 86 | 87 | frames_np = [] 88 | scenarios = ['Walking'] 89 | if mode == 'train': 90 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 91 | elif mode == 'test': 92 | subjects = ['S9', 'S11'] 93 | else: 94 | print ("MODE ERROR") 95 | _path = data_dir 96 | print ('load data...', _path) 97 | filenames = os.listdir(_path) 98 | filenames.sort() 99 | print ('data size ', len(filenames)) 100 | frames_file_name = [] 101 | for filename in filenames: 102 | fix = filename.split('.') 103 | fix = fix[0] 104 | subject = fix.split('_') 105 | scenario = subject[1] 106 | subject = subject[0] 107 | if subject not in subjects or scenario not in scenarios: 108 | continue 109 | file_path = os.path.join(_path, filename) 110 | image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB) 111 | #[1000,1000,3] 112 | image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :] 113 | if self.image_width != image.shape[0]: 114 | image = cv2.resize(image, (self.image_width, self.image_width)) 115 | #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width), 116 | # interpolation=cv2.INTER_LINEAR) 117 | frames_np.append(np.array(image, dtype=np.float32) / 255.0) 118 | frames_file_name.append(filename) 119 | # if len(frames_np) % 100 == 0: print len(frames_np) 120 | #if len(frames_np) % 1000 == 0: break 121 | # is it a begin index of sequence 122 | indices = [] 123 | index = 0 124 | print ('gen index') 125 | while index + intervel * self.seq_len - 1 < len(frames_file_name): 126 | # 'S11_Discussion_1.54138969_000471.jpg' 127 | # ['S11_Discussion_1', '54138969_000471', 'jpg'] 128 | start_infos = frames_file_name[index].split('.') 129 | end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.') 130 | if start_infos[0] != end_infos[0]: 131 | index += 1 132 | continue 133 | start_video_id, start_frame_id = start_infos[1].split('_') 134 | end_video_id, end_frame_id = end_infos[1].split('_') 135 | if start_video_id != end_video_id: 136 | index += 1 137 | continue 138 | if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel: 139 | indices.append(index) 140 | if mode == 'train': 141 | index += 10 142 | elif mode == 'test': 143 | index += 5 144 | print("there are " + str(len(indices)) + " sequences") 145 | # data = np.asarray(frames_np) 146 | data = frames_np 147 | print("there are " + str(len(data)) + " pictures") 148 | return data, indices 149 | 150 | def get_train_input_handle(self): 151 | train_data, train_indices = self.load_data(self.paths, mode='train') 152 | return InputHandle(train_data, train_indices, self.input_param) 153 | 154 | def get_test_input_handle(self): 155 | test_data, test_indices = self.load_data(self.paths, mode='test') 156 | return InputHandle(test_data, test_indices, self.input_param) 157 | 158 | 159 | def main(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("input_dir", type=str) 162 | parser.add_argument("output_dir", type=str) 163 | args = parser.parse_args() 164 | 165 | partition_names = ['train', 'test'] 166 | partition_fnames = partition_data(args.input_dir) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from time import time 8 | 9 | from src.data_provider import datasets_factory 10 | from src.models.model_factory import Model 11 | from src.utils import preprocess 12 | import src.trainer as trainer 13 | 14 | # ----------------------------------------------------------------------------- 15 | FLAGS = tf.app.flags.FLAGS 16 | 17 | # os.environ["CUDA_VISIBLE_DEVICES"] = "2" 18 | 19 | # mode 20 | tf.app.flags.DEFINE_boolean('is_training', True, 'training or testing') 21 | 22 | # data I/O 23 | tf.app.flags.DEFINE_string('dataset_name', 'mnist', 24 | 'The name of dataset.') 25 | tf.app.flags.DEFINE_string('train_data_paths', 26 | 'data/moving-mnist-example/moving-mnist-train.npz', 27 | 'train data paths.') 28 | tf.app.flags.DEFINE_string('valid_data_paths', 29 | 'data/moving-mnist-example/moving-mnist-valid.npz', 30 | 'validation data paths.') 31 | tf.app.flags.DEFINE_string('save_dir', 'checkpoints/mnist_predrnn_pp', 32 | 'dir to store trained net.') 33 | tf.app.flags.DEFINE_string('gen_frm_dir', 'results/mnist_predrnn_pp', 34 | 'dir to store result.') 35 | tf.app.flags.DEFINE_integer('input_length', 10, 36 | 'encoder hidden states.') 37 | tf.app.flags.DEFINE_integer('total_length', 20, 38 | 'total input and output length.') 39 | tf.app.flags.DEFINE_integer('img_width', 64, 40 | 'input image width.') 41 | tf.app.flags.DEFINE_integer('img_channel', 1, 42 | 'number of image channel.') 43 | # model[convlstm, predcnn, predrnn, predrnn_pp] 44 | tf.app.flags.DEFINE_string('model_name', 'convlstm_net', 45 | 'The name of the architecture.') 46 | tf.app.flags.DEFINE_string('pretrained_model', '', 47 | 'file of a pretrained model to initialize from.') 48 | tf.app.flags.DEFINE_string('num_hidden', '64,64,64,64', 49 | 'COMMA separated number of units in a convlstm layer.') 50 | tf.app.flags.DEFINE_integer('filter_size', 5, 51 | 'filter of a convlstm layer.') 52 | tf.app.flags.DEFINE_integer('stride', 1, 53 | 'stride of a convlstm layer.') 54 | tf.app.flags.DEFINE_integer('patch_size', 1, 55 | 'patch size on one dimension.') 56 | tf.app.flags.DEFINE_boolean('layer_norm', True, 57 | 'whether to apply tensor layer norm.') 58 | # scheduled sampling 59 | tf.app.flags.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling') 60 | tf.app.flags.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.') 61 | tf.app.flags.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.') 62 | tf.app.flags.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.') 63 | # optimization 64 | tf.app.flags.DEFINE_float('lr', 0.001, 65 | 'base learning rate.') 66 | tf.app.flags.DEFINE_boolean('reverse_input', True, 67 | 'whether to reverse the input frames while training.') 68 | tf.app.flags.DEFINE_boolean('reverse_img', False, 69 | 'whether to reverse the input images while training.') 70 | tf.app.flags.DEFINE_integer('batch_size', 8, 71 | 'batch size for training.') 72 | tf.app.flags.DEFINE_integer('max_iterations', 80000, 73 | 'max num of steps.') 74 | tf.app.flags.DEFINE_integer('display_interval', 1, 75 | 'number of iters showing training loss.') 76 | tf.app.flags.DEFINE_integer('test_interval', 1000, 77 | 'number of iters for test.') 78 | tf.app.flags.DEFINE_integer('snapshot_interval', 1000, 79 | 'number of iters saving models.') 80 | tf.app.flags.DEFINE_integer('num_save_samples', 10, 81 | 'number of sequences to be saved.') 82 | tf.app.flags.DEFINE_integer('n_gpu', 1, 83 | 'how many GPUs to distribute the training across.') 84 | # gpu 85 | tf.app.flags.DEFINE_boolean('allow_gpu_growth', False, 86 | 'allow gpu growth') 87 | 88 | tf.app.flags.DEFINE_integer('img_height', 0, 89 | 'input image height.') 90 | 91 | 92 | def main(argv=None): 93 | if tf.gfile.Exists(FLAGS.save_dir): 94 | tf.gfile.DeleteRecursively(FLAGS.save_dir) 95 | tf.gfile.MakeDirs(FLAGS.save_dir) 96 | if tf.gfile.Exists(FLAGS.gen_frm_dir): 97 | tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) 98 | tf.gfile.MakeDirs(FLAGS.gen_frm_dir) 99 | 100 | gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(',') ,dtype=np.int32) 101 | FLAGS.n_gpu = len(gpu_list) 102 | print('Initializing models') 103 | 104 | model = Model(FLAGS) 105 | 106 | if FLAGS.is_training: 107 | train_wrapper(model) 108 | else: 109 | start = time() 110 | test_wrapper(model) 111 | stop = time() 112 | print("Time used: " + str(stop - start) + "s") 113 | 114 | 115 | def schedule_sampling(eta, itr): 116 | if FLAGS.img_height > 0: 117 | height = FLAGS.img_height 118 | else: 119 | height = FLAGS.img_width 120 | zeros = np.zeros((FLAGS.batch_size, 121 | FLAGS.total_length - FLAGS.input_length - 1, 122 | FLAGS.img_width // FLAGS.patch_size, 123 | height // FLAGS.patch_size, 124 | FLAGS.patch_size ** 2 * FLAGS.img_channel)) 125 | if not FLAGS.scheduled_sampling: 126 | return 0.0, zeros 127 | 128 | if itr < FLAGS.sampling_stop_iter: 129 | eta -= FLAGS.sampling_changing_rate 130 | else: 131 | eta = 0.0 132 | random_flip = np.random.random_sample( 133 | (FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1)) 134 | true_token = (random_flip < eta) 135 | ones = np.ones((FLAGS.img_width // FLAGS.patch_size, 136 | height // FLAGS.patch_size, 137 | FLAGS.patch_size ** 2 * FLAGS.img_channel)) 138 | zeros = np.zeros((FLAGS.img_width // FLAGS.patch_size, 139 | height // FLAGS.patch_size, 140 | FLAGS.patch_size ** 2 * FLAGS.img_channel)) 141 | real_input_flag = [] 142 | for i in range(FLAGS.batch_size): 143 | for j in range(FLAGS.total_length - FLAGS.input_length - 1): 144 | if true_token[i, j]: 145 | real_input_flag.append(ones) 146 | else: 147 | real_input_flag.append(zeros) 148 | real_input_flag = np.array(real_input_flag) 149 | real_input_flag = np.reshape(real_input_flag, 150 | (FLAGS.batch_size, 151 | FLAGS.total_length - FLAGS.input_length - 1, 152 | FLAGS.img_width // FLAGS.patch_size, 153 | height // FLAGS.patch_size, 154 | FLAGS.patch_size ** 2 * FLAGS.img_channel)) 155 | return eta, real_input_flag 156 | 157 | 158 | def train_wrapper(model): 159 | if FLAGS.pretrained_model: 160 | model.load(FLAGS.pretrained_model) 161 | # load data 162 | train_input_handle, test_input_handle = datasets_factory.data_provider( 163 | FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, 164 | FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True) 165 | 166 | eta = FLAGS.sampling_start_value 167 | 168 | for itr in range(1, FLAGS.max_iterations + 1): 169 | if train_input_handle.no_batch_left(): 170 | train_input_handle.begin(do_shuffle=True) 171 | ims = train_input_handle.get_batch() 172 | ims_reverse = None 173 | if FLAGS.reverse_img: 174 | ims_reverse = ims[:, :, :, ::-1] 175 | ims_reverse = preprocess.reshape_patch(ims_reverse, FLAGS.patch_size) 176 | ims = preprocess.reshape_patch(ims, FLAGS.patch_size) 177 | 178 | eta, real_input_flag = schedule_sampling(eta, itr) 179 | 180 | trainer.train(model, ims, real_input_flag, FLAGS, itr, ims_reverse) 181 | 182 | if itr % FLAGS.snapshot_interval == 0: 183 | model.save(itr) 184 | 185 | if itr % FLAGS.test_interval == 0: 186 | trainer.test(model, test_input_handle, FLAGS, itr) 187 | 188 | train_input_handle.next() 189 | 190 | 191 | def test_wrapper(model): 192 | model.load(FLAGS.pretrained_model) 193 | test_input_handle = datasets_factory.data_provider( 194 | FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, 195 | FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=False) 196 | trainer.test(model, test_input_handle, FLAGS, 'test_result') 197 | 198 | 199 | if __name__ == '__main__': 200 | tf.app.run() 201 | 202 | -------------------------------------------------------------------------------- /src/data_provider/taxibj.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jianjin' 2 | 3 | import random 4 | import os.path 5 | import logging 6 | import os 7 | from copy import copy 8 | import numpy as np 9 | import h5py 10 | import pandas as pd 11 | from datetime import datetime 12 | import time 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def string2timestamp(strings, T=48): 18 | timestamps = [] 19 | 20 | time_per_slot = 24.0 / T 21 | num_per_T = T // 24 22 | for t in strings: 23 | year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1 24 | timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot), 25 | minute=(slot % num_per_T) * int(60.0 * time_per_slot)))) 26 | 27 | return timestamps 28 | 29 | 30 | class STMatrix(object): 31 | """docstring for STMatrix""" 32 | 33 | def __init__(self, data, timestamps, T=48, CheckComplete=True): 34 | super(STMatrix, self).__init__() 35 | assert len(data) == len(timestamps) 36 | self.data = data 37 | self.timestamps = timestamps 38 | self.T = T 39 | self.pd_timestamps = string2timestamp(timestamps, T=self.T) 40 | if CheckComplete: 41 | self.check_complete() 42 | # index 43 | self.make_index() 44 | 45 | def make_index(self): 46 | self.get_index = dict() 47 | for i, ts in enumerate(self.pd_timestamps): 48 | self.get_index[ts] = i 49 | 50 | def check_complete(self): 51 | missing_timestamps = [] 52 | offset = pd.DateOffset(minutes=24 * 60 // self.T) 53 | pd_timestamps = self.pd_timestamps 54 | i = 1 55 | while i < len(pd_timestamps): 56 | if pd_timestamps[i-1] + offset != pd_timestamps[i]: 57 | missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i])) 58 | i += 1 59 | for v in missing_timestamps: 60 | print(v) 61 | assert len(missing_timestamps) == 0 62 | 63 | def get_matrix(self, timestamp): 64 | return self.data[self.get_index[timestamp]] 65 | 66 | def save(self, fname): 67 | pass 68 | 69 | def check_it(self, depends): 70 | for d in depends: 71 | if d not in self.get_index.keys(): 72 | return False 73 | return True 74 | 75 | def create_dataset(self, len_closeness=20): 76 | """current version 77 | """ 78 | # offset_week = pd.DateOffset(days=7) 79 | offset_frame = pd.DateOffset(minutes=24 * 60 // self.T) 80 | XC = [] 81 | timestamps_Y = [] 82 | depends = [range(1, len_closeness+1)] 83 | 84 | i = len_closeness 85 | while i < len(self.pd_timestamps): 86 | Flag = True 87 | for depend in depends: 88 | if Flag is False: 89 | break 90 | Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend]) 91 | 92 | if Flag is False: 93 | i += 1 94 | continue 95 | x_c = [np.transpose(self.get_matrix(self.pd_timestamps[i] - j * offset_frame), [1, 2, 0]) for j in depends[0]] 96 | if len_closeness > 0: 97 | XC.append(np.stack(x_c, axis=0)) 98 | timestamps_Y.append(self.timestamps[i]) 99 | i += 1 100 | XC = np.stack(XC, axis=0) 101 | return XC, timestamps_Y 102 | 103 | 104 | def load_stdata(fname): 105 | f = h5py.File(fname, 'r') 106 | data = f['data'].value 107 | timestamps = f['date'].value 108 | f.close() 109 | return data, timestamps 110 | 111 | 112 | def stat(fname): 113 | def get_nb_timeslot(f): 114 | s = f['date'][0] 115 | e = f['date'][-1] 116 | year, month, day = map(int, [s[:4], s[4:6], s[6:8]]) 117 | ts = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d") 118 | year, month, day = map(int, [e[:4], e[4:6], e[6:8]]) 119 | te = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d") 120 | nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48 121 | ts_str, te_str = time.strftime("%Y-%m-%d", ts), time.strftime("%Y-%m-%d", te) 122 | return nb_timeslot, ts_str, te_str 123 | 124 | with h5py.File(fname, 'r') as f: 125 | nb_timeslot, ts_str, te_str = get_nb_timeslot(f) 126 | nb_day = int(nb_timeslot / 48) 127 | mmax = f['data'].value.max() 128 | mmin = f['data'].value.min() 129 | stat = '=' * 5 + 'stat' + '=' * 5 + '\n' + \ 130 | 'data shape: %s\n' % str(f['data'].shape) + \ 131 | '# of days: %i, from %s to %s\n' % (nb_day, ts_str, te_str) + \ 132 | '# of timeslots: %i\n' % int(nb_timeslot) + \ 133 | '# of timeslots (available): %i\n' % f['date'].shape[0] + \ 134 | 'missing ratio of timeslots: %.1f%%\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \ 135 | 'max: %.3f, min: %.3f\n' % (mmax, mmin) + \ 136 | '=' * 5 + 'stat' + '=' * 5 137 | print(stat) 138 | 139 | 140 | class MinMaxNormalization(object): 141 | '''MinMax Normalization --> [-1, 1] 142 | x = (x - min) / (max - min). 143 | x = x * 2 - 1 144 | ''' 145 | 146 | def __init__(self): 147 | pass 148 | 149 | def fit(self, X): 150 | self._min = X.min() 151 | self._max = X.max() 152 | print("min:", self._min, "max:", self._max) 153 | 154 | def transform(self, X): 155 | X = 1. * (X - self._min) / (self._max - self._min) 156 | # X = X * 2. - 1. 157 | return X 158 | 159 | def fit_transform(self, X): 160 | self.fit(X) 161 | return self.transform(X) 162 | 163 | def inverse_transform(self, X): 164 | X = (X + 1.) / 2. 165 | X = 1. * X * (self._max - self._min) + self._min 166 | return X 167 | 168 | 169 | def timestamp2vec(timestamps): 170 | # tm_wday range [0, 6], Monday is 0 171 | # vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps] # python3 172 | vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps] # python2 173 | ret = [] 174 | for i in vec: 175 | v = [0 for _ in range(7)] 176 | v[i] = 1 177 | if i >= 5: 178 | v.append(0) # weekend 179 | else: 180 | v.append(1) # weekday 181 | ret.append(v) 182 | return np.asarray(ret) 183 | 184 | 185 | def remove_incomplete_days(data, timestamps, T=48): 186 | # remove a certain day which has not 48 timestamps 187 | days = [] # available days: some day only contain some seqs 188 | days_incomplete = [] 189 | i = 0 190 | while i < len(timestamps): 191 | if int(timestamps[i][8:]) != 1: 192 | i += 1 193 | elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T: 194 | days.append(timestamps[i][:8]) 195 | i += T 196 | else: 197 | days_incomplete.append(timestamps[i][:8]) 198 | i += 1 199 | print("incomplete days: ", days_incomplete) 200 | days = set(days) 201 | idx = [] 202 | for i, t in enumerate(timestamps): 203 | if t[:8] in days: 204 | idx.append(i) 205 | 206 | data = data[idx] 207 | timestamps = [timestamps[i] for i in idx] 208 | return data, timestamps 209 | 210 | 211 | class InputHandle: 212 | def __init__(self, datas, indices, input_param): 213 | self.name = input_param['name'] 214 | self.input_data_type = input_param.get('input_data_type', 'float32') 215 | self.minibatch_size = input_param['minibatch_size'] 216 | self.image_width = input_param['image_width'] 217 | self.datas = datas 218 | self.indices = indices 219 | self.current_position = 0 220 | self.current_batch_indices = [] 221 | self.current_input_length = input_param['seq_length'] 222 | 223 | def total(self): 224 | return len(self.indices) 225 | 226 | def begin(self, do_shuffle=True): 227 | logger.info("Initialization for read data ") 228 | if do_shuffle: 229 | random.shuffle(self.indices) 230 | self.current_position = 0 231 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 232 | 233 | def next(self): 234 | self.current_position += self.minibatch_size 235 | if self.no_batch_left(): 236 | return None 237 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 238 | 239 | def no_batch_left(self): 240 | if self.current_position + self.minibatch_size >= self.total(): 241 | return True 242 | else: 243 | return False 244 | 245 | def get_batch(self): 246 | if self.no_batch_left(): 247 | logger.error( 248 | "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") 249 | return None 250 | input_batch = self.datas[self.current_batch_indices, :, :, :] 251 | input_batch = input_batch.astype(self.input_data_type) 252 | return input_batch 253 | 254 | def print_stat(self): 255 | logger.info("Iterator Name: " + self.name) 256 | logger.info(" current_position: " + str(self.current_position)) 257 | logger.info(" Minibatch Size: " + str(self.minibatch_size)) 258 | logger.info(" total Size: " + str(self.total())) 259 | logger.info(" current_input_length: " + str(self.current_input_length)) 260 | logger.info(" Input Data Type: " + str(self.input_data_type)) 261 | 262 | 263 | class DataProcess: 264 | def __init__(self, input_param): 265 | self.paths = input_param['paths'] 266 | self.image_width = input_param['image_width'] 267 | 268 | self.input_param = input_param 269 | self.seq_len = input_param['seq_length'] 270 | self.train_data, self.test_data, _, _, _ = self.load_data(self.paths, len_closeness=input_param['seq_length']) 271 | self.train_indices = list(range(self.train_data.shape[0])) 272 | self.test_indices = list(range(self.test_data.shape[0])) 273 | 274 | def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len_test=48 * 7 * 4): 275 | """ 276 | """ 277 | assert (len_closeness > 0) 278 | # load data 279 | # 13 - 16 280 | data_all = [] 281 | timestamps_all = list() 282 | for year in range(13, 17): 283 | fname = os.path.join( 284 | datapath[0], 'BJ{}_M32x32_T30_InOut.h5'.format(year)) 285 | print("file name: ", fname) 286 | stat(fname) 287 | data, timestamps = load_stdata(fname) 288 | # print(timestamps) 289 | # remove a certain day which does not have 48 timestamps 290 | data, timestamps = remove_incomplete_days(data, timestamps, T) 291 | data = data[:, :nb_flow] 292 | data[data < 0] = 0. 293 | data_all.append(data) 294 | timestamps_all.append(timestamps) 295 | print("\n") 296 | 297 | # minmax_scale 298 | data_train = np.vstack(copy(data_all))[:-len_test] 299 | print('train_data shape: ', data_train.shape) 300 | mmn = MinMaxNormalization() 301 | mmn.fit(data_train) 302 | data_all_mmn = [mmn.transform(d) for d in data_all] 303 | 304 | XC = [] 305 | timestamps_Y = [] 306 | for data, timestamps in zip(data_all_mmn, timestamps_all): 307 | # instance-based dataset --> sequences with format as (X, Y) where X is 308 | # a sequence of images and Y is an image. 309 | st = STMatrix(data, timestamps, T, CheckComplete=False) 310 | _XC, _timestamps_Y = st.create_dataset(len_closeness=len_closeness) 311 | XC.append(_XC) 312 | timestamps_Y += _timestamps_Y 313 | XC = np.concatenate(XC, axis=0) 314 | print("XC shape: ", XC.shape) 315 | 316 | XC_train = XC[:-len_test] 317 | XC_test = XC[-len_test:] 318 | timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:] 319 | 320 | X_train = XC_train 321 | X_test = XC_test 322 | print('train shape:', XC_train.shape, 323 | 'test shape: ', XC_test.shape) 324 | 325 | return X_train, X_test, mmn, timestamp_train, timestamp_test 326 | 327 | def get_train_input_handle(self): 328 | return InputHandle(self.train_data, self.train_indices, self.input_param) 329 | 330 | def get_test_input_handle(self): 331 | return InputHandle(self.test_data, self.test_indices, self.input_param) 332 | --------------------------------------------------------------------------------