├── README.md ├── core ├── __init__.py ├── data_provider │ ├── __init__.py │ ├── datasets_factory.py │ ├── flip_rotate.py │ └── human.py ├── layers │ ├── MotionGRU.py │ ├── SpatioTemporalLSTMCell_Motion_Highway.py │ └── __init__.py ├── models │ ├── MotionRNN_PredRNN.py │ ├── __init__.py │ └── model_factory.py ├── trainer.py └── utils │ ├── __init__.py │ ├── metrics.py │ └── preprocess.py ├── human_script └── MotionRNN_PredRNN_human_train.sh ├── pic ├── architecture.png ├── motion_decomp.png └── vis.png └── run.py /README.md: -------------------------------------------------------------------------------- 1 | # MotionRNN (CVPR 2021) 2 | MotionRNN: A Flexible Model for Video Prediction with Spacetime-Varying Motions 3 | 4 | Different from previous models that focus on temporal state-transition modeling, MotionRNN attempts to model the patch-wise motion explicitly. Concretely, MotionRNN is featured by: 5 | 6 | - **Explicit motion modeling.** MotionRNN tries to learn the motion direction and distance for each patch, which can response to the complex spatiotemporal vairations quickly. 7 | - **Motion decomposition.** We present the **MotionGRU** unit to capture the transient variation and motion trend. 8 | - **Flexible framework.** MotionGRU can be embedded into previous models (e.g. ConvLSTM or PredRNN) with the help of **Motion Highway**, which can trade off moving and unchanged parts. 9 | 10 | ![motion_decomp](./pic/motion_decomp.png) 11 | 12 | ## MotionRNN 13 | 14 | Comparison between previous station-transition methods (left) and MotionRNN (right). 15 | 16 | To tackle the challenge of spacetime-varying motions modeling, the MotionRNN framework incorporates the MotionGRU unit between the stacked layers as an operator without changing the original state transition flow. 17 | 18 | ![architecture](./pic/architecture.png) 19 | 20 | ## Get Started 21 | 22 | 1. Install Python 3.6, PyTorch 1.9.0 for the main code. 23 | 2. Download data. You can download the Human dataset following the instruction from [here](https://github.com/Yunbo426/MIM). More datasets can be obtained from [here](https://github.com/thuml/predrnn-pytorch). 24 | 25 | 3. Train and evaluate the model. 26 | ``` 27 | cd human_script/ 28 | bash MotionRNN_PredRNN_human_train.sh 29 | ``` 30 | 31 | ## Learned Motion visualization 32 | 33 | The center arrows show a moving up and anticlockwise rotation. The bottom arrows indicate the downward-motion of a cyclone’s small tile. 34 | 35 | ![vis](./pic/vis.png) 36 | 37 | ## Citation 38 | 39 | If you find this repo useful, please cite our paper. 40 | 41 | ``` 42 | @inproceedings{wu2022MotionRNN, 43 | title={MotionRNN: A Flexible Model for Video Prediction with Spacetime-Varying Motions}, 44 | author={Haixu Wu and Zhiyu Yao and Jianmin Wang and Mingsheng Long}, 45 | booktitle={CVPR}, 46 | year={2021} 47 | } 48 | ``` 49 | 50 | ## Contact 51 | 52 | If you have any questions or want to use the code, please contact wuhx23@mails.tsinghua.edu.cn. 53 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/core/__init__.py -------------------------------------------------------------------------------- /core/data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/core/data_provider/__init__.py -------------------------------------------------------------------------------- /core/data_provider/datasets_factory.py: -------------------------------------------------------------------------------- 1 | from core.data_provider import human 2 | 3 | datasets_map = { 4 | 'human': human, 5 | } 6 | 7 | 8 | def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size, 9 | img_width, seq_length, is_training=True): 10 | if dataset_name not in datasets_map: 11 | raise ValueError('Name of dataset unknown %s' % dataset_name) 12 | train_data_list = train_data_paths.split(',') 13 | valid_data_list = valid_data_paths.split(',') 14 | 15 | if dataset_name == 'human': 16 | input_param = {'paths': valid_data_list, 17 | 'image_width': img_width, 18 | 'minibatch_size': batch_size, 19 | 'seq_length': seq_length, 20 | 'channel': 3, 21 | 'input_data_type': 'float32', 22 | 'name': 'human'} 23 | input_handle = datasets_map[dataset_name].DataProcess(input_param) 24 | test_input_handle = input_handle.get_test_input_handle() 25 | test_input_handle.begin(do_shuffle=False) 26 | if is_training: 27 | train_input_handle = input_handle.get_train_input_handle() 28 | train_input_handle.begin(do_shuffle=True) 29 | return train_input_handle, test_input_handle 30 | else: 31 | return test_input_handle 32 | -------------------------------------------------------------------------------- /core/data_provider/flip_rotate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : flip_rotate.py 3 | # @Author: GZF 4 | # @Date : 2017/10/29 5 | # @Desc : 6 | 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | import cv2 11 | import random 12 | 13 | 14 | def read_batch(img_path="../../input/1"): 15 | batch_size = 1 16 | seq_len = 10 17 | width = 100 18 | channel = 1 19 | input_batch = np.zeros((batch_size, seq_len, width, width, channel)).astype(np.float32) 20 | filename_list = os.listdir(img_path) 21 | for filename, index in zip(filename_list, range(seq_len)): 22 | file = Image.open(os.path.join(img_path, filename)) 23 | file = np.array(file, dtype=np.float32) 24 | file = cv2.resize(file, (width, width)) 25 | input_batch[0, index, :, :, 0] = file 26 | # plt.imshow(file, cmap="gray") 27 | # plt.show() 28 | print(file.shape) 29 | 30 | return input_batch 31 | 32 | 33 | def augment_data(batch): 34 | rand = random.random() 35 | 36 | if rand < 0.5: 37 | batch = np.flip(batch, 1) 38 | elif rand < 0.6: 39 | w = batch.shape[2] 40 | angle = 90 41 | rotate_img(batch, w, angle) 42 | elif rand < 0.7: 43 | w = batch.shape[2] 44 | angle = 270 45 | rotate_img(batch, w, angle) 46 | elif rand < 0.8: 47 | flip_img(batch, 0) 48 | elif rand < 0.9: 49 | flip_img(batch, 1) 50 | elif rand < 1.0: 51 | flip_img(batch, -1) 52 | return batch 53 | 54 | 55 | def flip_img(batch, flipCode): 56 | for batch_ind in range(batch.shape[0]): 57 | for seq_ind in range(batch.shape[1]): 58 | img_arr = batch[batch_ind, seq_ind, :, :, 0] 59 | batch[batch_ind, seq_ind, :, :, 0] = cv2.flip(img_arr, flipCode) 60 | 61 | 62 | def rotate_img(batch, w, angle): 63 | center = (w / 2, w / 2) 64 | scale = 1.0 65 | M = cv2.getRotationMatrix2D(center, angle, scale) 66 | 67 | for batch_ind in range(batch.shape[0]): 68 | for seq_ind in range(batch.shape[1]): 69 | img_arr = batch[batch_ind, seq_ind, :, :, 0] 70 | batch[batch_ind, seq_ind, :, :, 0] = cv2.warpAffine(img_arr, M, (w, w)) 71 | 72 | # if __name__ == '__main__': 73 | 74 | # batch = read_batch() 75 | # batch = augment_data(batch) 76 | 77 | # for batch_ind in range(batch.shape[0]): 78 | # for seq_ind in range(batch.shape[1]): 79 | # img_arr = batch[batch_ind, seq_ind, :, :, 0] 80 | # plt.imshow(img_arr, cmap="gray") 81 | # plt.show() 82 | -------------------------------------------------------------------------------- /core/data_provider/human.py: -------------------------------------------------------------------------------- 1 | __author__ = 'gaozhifeng' 2 | import numpy as np 3 | import os 4 | import cv2 5 | from PIL import Image 6 | import logging 7 | import random 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class InputHandle: 12 | def __init__(self, datas, indices, input_param): 13 | self.name = input_param['name'] 14 | self.input_data_type = input_param.get('input_data_type', 'float32') 15 | self.minibatch_size = input_param['minibatch_size'] 16 | self.image_width = input_param['image_width'] 17 | self.channel = input_param['channel'] 18 | self.datas = datas 19 | self.indices = indices 20 | self.current_position = 0 21 | self.current_batch_indices = [] 22 | self.current_input_length = input_param['seq_length'] 23 | self.interval = 2 24 | 25 | def total(self): 26 | return len(self.indices) 27 | 28 | def begin(self, do_shuffle=True): 29 | logger.info("Initialization for read data ") 30 | if do_shuffle: 31 | random.shuffle(self.indices) 32 | self.current_position = 0 33 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 34 | 35 | def next(self): 36 | self.current_position += self.minibatch_size 37 | if self.no_batch_left(): 38 | return None 39 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 40 | 41 | def no_batch_left(self): 42 | if self.current_position + self.minibatch_size > self.total(): 43 | return True 44 | else: 45 | return False 46 | 47 | def get_batch(self): 48 | if self.no_batch_left(): 49 | logger.error( 50 | "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") 51 | return None 52 | input_batch = np.zeros( 53 | (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype( 54 | self.input_data_type) 55 | for i in range(self.minibatch_size): 56 | batch_ind = self.current_batch_indices[i] 57 | begin = batch_ind 58 | end = begin + self.current_input_length * self.interval 59 | data_slice = self.datas[begin:end:self.interval] 60 | input_batch[i, :self.current_input_length, :, :, :] = data_slice 61 | # logger.info('data_slice shape') 62 | # logger.info(data_slice.shape) 63 | # logger.info(input_batch.shape) 64 | input_batch = input_batch.astype(self.input_data_type) 65 | return input_batch 66 | 67 | def print_stat(self): 68 | logger.info("Iterator Name: " + self.name) 69 | logger.info(" current_position: " + str(self.current_position)) 70 | logger.info(" Minibatch Size: " + str(self.minibatch_size)) 71 | logger.info(" total Size: " + str(self.total())) 72 | logger.info(" current_input_length: " + str(self.current_input_length)) 73 | logger.info(" Input Data Type: " + str(self.input_data_type)) 74 | 75 | class DataProcess: 76 | def __init__(self, input_param): 77 | self.input_param = input_param 78 | self.paths = input_param['paths'] 79 | self.image_width = input_param['image_width'] 80 | self.seq_len = input_param['seq_length'] 81 | 82 | def load_data(self, paths, mode='train'): 83 | data_dir = paths[0] 84 | intervel = 2 85 | 86 | frames_np = [] 87 | scenarios = ['Walking'] 88 | if mode == 'train': 89 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 90 | elif mode == 'test': 91 | subjects = ['S9', 'S11'] 92 | else: 93 | print ("MODE ERROR") 94 | _path = data_dir 95 | print ('load data...', _path) 96 | filenames = os.listdir(_path) 97 | filenames.sort() 98 | print ('data size ', len(filenames)) 99 | frames_file_name = [] 100 | for filename in filenames: 101 | fix = filename.split('.') 102 | fix = fix[0] 103 | subject = fix.split('_') 104 | scenario = subject[1] 105 | subject = subject[0] 106 | if subject not in subjects or scenario not in scenarios: 107 | continue 108 | file_path = os.path.join(_path, filename) 109 | image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB) 110 | #[1000,1000,3] 111 | image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :] 112 | if self.image_width != image.shape[0]: 113 | image = cv2.resize(image, (self.image_width, self.image_width)) 114 | #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width), 115 | # interpolation=cv2.INTER_LINEAR) 116 | frames_np.append(np.array(image, dtype=np.float32) / 255.0) 117 | frames_file_name.append(filename) 118 | # if len(frames_np) % 100 == 0: print len(frames_np) 119 | #if len(frames_np) % 1000 == 0: break 120 | # is it a begin index of sequence 121 | indices = [] 122 | index = 0 123 | print ('gen index') 124 | while index + intervel * self.seq_len - 1 < len(frames_file_name): 125 | # 'S11_Discussion_1.54138969_000471.jpg' 126 | # ['S11_Discussion_1', '54138969_000471', 'jpg'] 127 | start_infos = frames_file_name[index].split('.') 128 | end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.') 129 | if start_infos[0] != end_infos[0]: 130 | index += 1 131 | continue 132 | start_video_id, start_frame_id = start_infos[1].split('_') 133 | end_video_id, end_frame_id = end_infos[1].split('_') 134 | if start_video_id != end_video_id: 135 | index += 1 136 | continue 137 | 138 | if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel: 139 | indices.append(index) 140 | if mode == 'train': 141 | index += 10 142 | elif mode == 'test': 143 | index += 5 144 | print("there are " + str(len(indices)) + " sequences") 145 | # data = np.asarray(frames_np) 146 | data = frames_np 147 | print("there are " + str(len(data)) + " pictures") 148 | return data, indices 149 | 150 | def get_train_input_handle(self): 151 | train_data, train_indices = self.load_data(self.paths, mode='train') 152 | return InputHandle(train_data, train_indices, self.input_param) 153 | 154 | def get_test_input_handle(self): 155 | test_data, test_indices = self.load_data(self.paths, mode='test') 156 | return InputHandle(test_data, test_indices, self.input_param) 157 | 158 | 159 | def main(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("input_dir", type=str) 162 | parser.add_argument("output_dir", type=str) 163 | args = parser.parse_args() 164 | 165 | partition_names = ['train', 'test'] 166 | partition_fnames = partition_data(args.input_dir) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /core/layers/MotionGRU.py: -------------------------------------------------------------------------------- 1 | __author__ = 'haixu' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Warp(nn.Module): 8 | def __init__(self, inc, outc, neighbour=3): 9 | super(Warp, self).__init__() 10 | self.neighbour = neighbour 11 | self.zero_padding = nn.ZeroPad2d(1) 12 | self.conv = nn.Conv2d(inc, outc, kernel_size=neighbour, stride=neighbour, bias=None) 13 | self.warp_gate = nn.Conv2d(inc, neighbour * neighbour, kernel_size=3, padding=1, stride=1) 14 | nn.init.constant_(self.warp_gate.weight, 0) 15 | self.warp_gate.register_backward_hook(self._set_lr) 16 | 17 | @staticmethod 18 | def _set_lr(module, grad_input, grad_output): 19 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 20 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 21 | 22 | def forward(self, info): 23 | x = info[0] 24 | offset = info[1] 25 | 26 | dtype = offset.data.type() 27 | N = self.neighbour * self.neighbour 28 | 29 | m = torch.sigmoid(self.warp_gate(x)) 30 | x = self.zero_padding(x) 31 | ## Neighbourhood Warp Operation 32 | # (b, 2N, h, w) 33 | p = self._get_p(offset, dtype) 34 | # (b, h, w, 2N) 35 | p = p.contiguous().permute(0, 2, 3, 1) 36 | q_lt = p.detach().floor() 37 | q_rb = q_lt + 1 38 | 39 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)], 40 | dim=-1).long() 41 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)], 42 | dim=-1).long() 43 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 44 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 45 | 46 | # clip p 47 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) 48 | 49 | # bilinear kernel (b, h, w, N) 50 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 51 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 52 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 53 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 54 | 55 | # (b, c, h, w, N) 56 | x_q_lt = self._get_x_q(x, q_lt, N) 57 | x_q_rb = self._get_x_q(x, q_rb, N) 58 | x_q_lb = self._get_x_q(x, q_lb, N) 59 | x_q_rt = self._get_x_q(x, q_rt, N) 60 | 61 | # (b, c, h, w, N) 62 | x_warped = g_lt.unsqueeze(dim=1) * x_q_lt + \ 63 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 64 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 65 | g_rt.unsqueeze(dim=1) * x_q_rt 66 | 67 | ## Warp Gate 68 | m = m.contiguous().permute(0, 2, 3, 1) 69 | m = m.unsqueeze(dim=1) 70 | m = torch.cat([m for _ in range(x_warped.size(1))], dim=1) 71 | x_warped *= m 72 | 73 | x_warped = self._reshape_x_warped(x_warped, self.neighbour) 74 | out = self.conv(x_warped) 75 | return out 76 | 77 | def _get_p_n(self, N, dtype): 78 | p_n_x, p_n_y = torch.meshgrid( 79 | torch.arange(-(self.neighbour - 1) // 2, (self.neighbour - 1) // 2 + 1), 80 | torch.arange(-(self.neighbour - 1) // 2, (self.neighbour - 1) // 2 + 1)) 81 | # (2N, 1) 82 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 83 | p_n = p_n.view(1, 2 * N, 1, 1).type(dtype) 84 | return p_n 85 | 86 | def _get_p_0(self, h, w, N, dtype): 87 | p_0_x, p_0_y = torch.meshgrid(torch.arange(1, h + 1, 1), torch.arange(1, w + 1, 1)) 88 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 89 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 90 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 91 | return p_0 92 | 93 | def _get_p(self, offset, dtype): 94 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 95 | # (1, 2N, 1, 1) 96 | p_n = self._get_p_n(N, dtype) 97 | # (1, 2N, h, w) 98 | p_0 = self._get_p_0(h, w, N, dtype) 99 | p = p_0 + p_n + offset 100 | return p 101 | 102 | def _get_x_q(self, x, q, N): 103 | b, h, w, _ = q.size() 104 | padded_w = x.size(3) 105 | c = x.size(1) 106 | # (b, c, h*w) 107 | x = x.contiguous().view(b, c, -1) 108 | # (b, h, w, N) 109 | index = q[..., :N] * padded_w + q[..., N:] 110 | # (b, c, h*w*N) 111 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 112 | x_warped = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 113 | return x_warped 114 | 115 | @staticmethod 116 | def _reshape_x_warped(x_warped, neighbour): 117 | b, c, h, w, N = x_warped.size() 118 | x_warped = torch.cat( 119 | [x_warped[..., s:s + neighbour].contiguous().view(b, c, h, w * neighbour) for s in range(0, N, neighbour)], 120 | dim=-1) 121 | x_warped = x_warped.contiguous().view(b, c, h * neighbour, w * neighbour) 122 | return x_warped 123 | 124 | 125 | class MotionGRU(nn.Module): 126 | def __init__(self, in_channel, motion_hidden, neighbour): 127 | super(MotionGRU, self).__init__() 128 | self.update = nn.Conv2d(in_channel + motion_hidden, motion_hidden, kernel_size=3, stride=1, padding=1) 129 | nn.init.constant_(self.update.weight, 0) 130 | self.update.register_backward_hook(self._set_lr) 131 | 132 | self.reset = nn.Conv2d(in_channel + motion_hidden, motion_hidden, kernel_size=3, stride=1, padding=1) 133 | nn.init.constant_(self.reset.weight, 0) 134 | self.reset.register_backward_hook(self._set_lr) 135 | 136 | self.output = nn.Conv2d(in_channel + motion_hidden, motion_hidden, kernel_size=3, stride=1, padding=1) 137 | nn.init.constant_(self.output.weight, 0) 138 | self.output.register_backward_hook(self._set_lr) 139 | 140 | self.warp = Warp(in_channel, in_channel, neighbour) 141 | 142 | @staticmethod 143 | def _set_lr(module, grad_input, grad_output): 144 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 145 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 146 | 147 | def forward(self, x_t, pre_offset, mean): 148 | stacked_inputs = torch.cat([x_t, pre_offset], dim=1) 149 | update_gate = torch.sigmoid(self.update(stacked_inputs)) 150 | reset_gate = torch.sigmoid(self.reset(stacked_inputs)) 151 | offset = torch.tanh(self.output(torch.cat([x_t, pre_offset * reset_gate], dim=1))) 152 | offset = pre_offset * (1 - update_gate) + offset * update_gate 153 | mean = mean + 0.5 * (pre_offset - mean) 154 | offset = offset + mean 155 | 156 | x_t = self.warp([x_t, offset]) 157 | return x_t, offset, mean 158 | -------------------------------------------------------------------------------- /core/layers/SpatioTemporalLSTMCell_Motion_Highway.py: -------------------------------------------------------------------------------- 1 | __author__ = 'haixu' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SpatioTemporalLSTMCell(nn.Module): 8 | def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm): 9 | super(SpatioTemporalLSTMCell, self).__init__() 10 | 11 | self.num_hidden = num_hidden 12 | self._forget_bias = 1.0 13 | padding = filter_size // 2 14 | 15 | if layer_norm: 16 | self.conv_x = nn.Sequential( 17 | nn.Conv2d(in_channel, num_hidden * 7, filter_size, stride, padding, bias=False), 18 | nn.LayerNorm([num_hidden * 7, height, width]) 19 | ) 20 | self.conv_h = nn.Sequential( 21 | nn.Conv2d(num_hidden, num_hidden * 4, filter_size, stride, padding, bias=False), 22 | nn.LayerNorm([num_hidden * 4, height, width]) 23 | ) 24 | self.conv_m = nn.Sequential( 25 | nn.Conv2d(num_hidden, num_hidden * 3, filter_size, stride, padding, bias=False), 26 | nn.LayerNorm([num_hidden * 3, height, width]) 27 | ) 28 | self.conv_o = nn.Sequential( 29 | nn.Conv2d(num_hidden * 2, num_hidden, filter_size, stride, padding, bias=False), 30 | nn.LayerNorm([num_hidden, height, width]) 31 | ) 32 | else: 33 | self.conv_x = nn.Sequential( 34 | nn.Conv2d(in_channel, num_hidden * 7, filter_size, stride, padding, bias=False), 35 | ) 36 | self.conv_h = nn.Sequential( 37 | nn.Conv2d(num_hidden, num_hidden * 4, filter_size, stride, padding, bias=False), 38 | ) 39 | self.conv_m = nn.Sequential( 40 | nn.Conv2d(num_hidden, num_hidden * 3, filter_size, stride, padding, bias=False), 41 | ) 42 | self.conv_o = nn.Sequential( 43 | nn.Conv2d(num_hidden * 2, num_hidden, filter_size, stride, padding, bias=False), 44 | ) 45 | self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False) 46 | 47 | def forward(self, x_t, h_t, c_t, m_t, motion_highway): 48 | x_concat = self.conv_x(x_t) 49 | h_concat = self.conv_h(h_t) 50 | m_concat = self.conv_m(m_t) 51 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1) 52 | i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1) 53 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 54 | 55 | i_t = torch.sigmoid(i_x + i_h) 56 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 57 | g_t = torch.tanh(g_x + g_h) 58 | 59 | c_new = f_t * c_t + i_t * g_t 60 | 61 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 62 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 63 | g_t_prime = torch.tanh(g_x_prime + g_m) 64 | 65 | m_new = f_t_prime * m_t + i_t_prime * g_t_prime 66 | 67 | mem = torch.cat((c_new, m_new), 1) 68 | m_new_new = self.conv_last(mem) 69 | 70 | # Motion Highway 71 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 72 | h_new = o_t * torch.tanh(m_new_new) + (1 - o_t) * motion_highway 73 | motion_highway = h_new 74 | return h_new, c_new, m_new, motion_highway 75 | -------------------------------------------------------------------------------- /core/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/core/layers/__init__.py -------------------------------------------------------------------------------- /core/models/MotionRNN_PredRNN.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | from core.layers.SpatioTemporalLSTMCell_Motion_Highway import SpatioTemporalLSTMCell 6 | from core.layers.MotionGRU import MotionGRU 7 | 8 | 9 | class RNN(nn.Module): 10 | def __init__(self, num_layers, num_hidden, configs): 11 | super(RNN, self).__init__() 12 | self.configs = configs 13 | self.patch_height = configs.img_height // configs.patch_size 14 | self.patch_width = configs.img_width // configs.patch_size 15 | self.patch_ch = configs.img_channel * (configs.patch_size ** 2) 16 | self.num_layers = num_layers 17 | self.num_hidden = num_hidden 18 | self.neighbour = 3 19 | self.motion_hidden = 2 * self.neighbour * self.neighbour 20 | self.MSE_criterion = nn.MSELoss().to(self.configs.device) 21 | 22 | cell_list = [] 23 | for i in range(num_layers): 24 | in_channel = self.patch_ch if i == 0 else num_hidden[i - 1] 25 | cell_list.append( 26 | SpatioTemporalLSTMCell(in_channel, num_hidden[i], self.patch_height, self.patch_width, 27 | configs.filter_size, configs.stride, configs.layer_norm), 28 | ) 29 | enc_list = [] 30 | for i in range(num_layers - 1): 31 | enc_list.append( 32 | nn.Conv2d(num_hidden[i], num_hidden[i] // 4, kernel_size=configs.filter_size, stride=2, 33 | padding=configs.filter_size // 2), 34 | ) 35 | motion_list = [] 36 | for i in range(num_layers - 1): 37 | motion_list.append( 38 | MotionGRU(num_hidden[i] // 4, self.motion_hidden, self.neighbour) 39 | ) 40 | dec_list = [] 41 | for i in range(num_layers - 1): 42 | dec_list.append( 43 | nn.ConvTranspose2d(num_hidden[i] // 4, num_hidden[i], kernel_size=4, stride=2, 44 | padding=1), 45 | ) 46 | gate_list = [] 47 | for i in range(num_layers - 1): 48 | gate_list.append( 49 | nn.Conv2d(num_hidden[i] * 2, num_hidden[i], kernel_size=configs.filter_size, stride=1, 50 | padding=configs.filter_size // 2), 51 | ) 52 | self.gate_list = nn.ModuleList(gate_list) 53 | self.cell_list = nn.ModuleList(cell_list) 54 | self.motion_list = nn.ModuleList(motion_list) 55 | self.enc_list = nn.ModuleList(enc_list) 56 | self.dec_list = nn.ModuleList(dec_list) 57 | self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.patch_ch, 1, stride=1, padding=0, bias=False) 58 | self.conv_first_v = nn.Conv2d(self.patch_ch, num_hidden[0], 1, stride=1, padding=0, bias=False) 59 | 60 | def forward(self, all_frames, mask_true): 61 | # [batch, length, height, width, channel] -> [batch, length, channel, height, width] 62 | frames = all_frames.permute(0, 1, 4, 2, 3).contiguous() 63 | mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous() 64 | next_frames = [] 65 | h_t = [] 66 | c_t = [] 67 | h_t_conv = [] 68 | h_t_conv_offset = [] 69 | mean = [] 70 | 71 | for i in range(self.num_layers): 72 | zeros = torch.empty( 73 | [self.configs.batch_size, self.num_hidden[i], self.patch_height, self.patch_width]).to( 74 | self.configs.device) 75 | nn.init.xavier_normal_(zeros) 76 | h_t.append(zeros) 77 | c_t.append(zeros) 78 | 79 | for i in range(self.num_layers - 1): 80 | zeros = torch.empty( 81 | [self.configs.batch_size, self.num_hidden[i] // 4, self.patch_height // 2, 82 | self.patch_width // 2]).to( 83 | self.configs.device) 84 | nn.init.xavier_normal_(zeros) 85 | h_t_conv.append(zeros) 86 | zeros = torch.empty( 87 | [self.configs.batch_size, self.motion_hidden, self.patch_height // 2, self.patch_width // 2]).to( 88 | self.configs.device) 89 | nn.init.xavier_normal_(zeros) 90 | h_t_conv_offset.append(zeros) 91 | mean.append(zeros) 92 | 93 | mem = torch.empty([self.configs.batch_size, self.num_hidden[0], self.patch_height, self.patch_width]).to( 94 | self.configs.device) 95 | motion_highway = torch.empty( 96 | [self.configs.batch_size, self.num_hidden[0], self.patch_height, self.patch_width]).to( 97 | self.configs.device) 98 | nn.init.xavier_normal_(mem) 99 | nn.init.xavier_normal_(motion_highway) 100 | 101 | for t in range(self.configs.total_length - 1): 102 | if t < self.configs.input_length: 103 | net = frames[:, t] 104 | else: 105 | net = mask_true[:, t - self.configs.input_length] * frames[:, t] + \ 106 | (1 - mask_true[:, t - self.configs.input_length]) * x_gen 107 | 108 | motion_highway = self.conv_first_v(net) 109 | h_t[0], c_t[0], mem, motion_highway = self.cell_list[0](net, h_t[0], c_t[0], mem, motion_highway) 110 | net = self.enc_list[0](h_t[0]) 111 | h_t_conv[0], h_t_conv_offset[0], mean[0] = self.motion_list[0](net, h_t_conv_offset[0], mean[0]) 112 | h_t_tmp = self.dec_list[0](h_t_conv[0]) 113 | o_t = torch.sigmoid(self.gate_list[0](torch.cat([h_t_tmp, h_t[0]], dim=1))) 114 | h_t[0] = o_t * h_t_tmp + (1 - o_t) * h_t[0] 115 | 116 | for i in range(1, self.num_layers - 1): 117 | h_t[i], c_t[i], mem, motion_highway = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], mem, motion_highway) 118 | net = self.enc_list[i](h_t[i]) 119 | h_t_conv[i], h_t_conv_offset[i], mean[i] = self.motion_list[i](net, h_t_conv_offset[i], mean[i]) 120 | h_t_tmp = self.dec_list[i](h_t_conv[i]) 121 | o_t = torch.sigmoid(self.gate_list[i](torch.cat([h_t_tmp, h_t[i]], dim=1))) 122 | h_t[i] = o_t * h_t_tmp + (1 - o_t) * h_t[i] 123 | 124 | h_t[self.num_layers - 1], c_t[self.num_layers - 1], mem, motion_highway = self.cell_list[ 125 | self.num_layers - 1]( 126 | h_t[self.num_layers - 2], h_t[self.num_layers - 1], c_t[self.num_layers - 1], mem, motion_highway) 127 | x_gen = self.conv_last(h_t[self.num_layers - 1]) 128 | next_frames.append(x_gen) 129 | 130 | # [length, batch, channel, height, width] -> [batch, length, height, width, channel] 131 | next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous() 132 | loss = self.MSE_criterion(next_frames, all_frames[:, 1:]) 133 | return next_frames, loss 134 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/core/models/__init__.py -------------------------------------------------------------------------------- /core/models/model_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.optim import Adam 4 | from core.models import MotionRNN_PredRNN 5 | 6 | class Model(object): 7 | def __init__(self, configs): 8 | self.configs = configs 9 | self.num_hidden = [int(x) for x in configs.num_hidden.split(',')] 10 | self.num_layers = len(self.num_hidden) 11 | networks_map = { 12 | 'MotionRNN_PredRNN': MotionRNN_PredRNN.RNN, 13 | } 14 | 15 | if configs.model_name in networks_map: 16 | Network = networks_map[configs.model_name] 17 | self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device) 18 | else: 19 | raise ValueError('Name of network unknown %s' % configs.model_name) 20 | 21 | self.optimizer = Adam(self.network.parameters(), lr=configs.lr) 22 | 23 | def save(self, itr): 24 | stats = {} 25 | stats['net_param'] = self.network.state_dict() 26 | checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt'+'-'+str(itr)) 27 | torch.save(stats, checkpoint_path) 28 | print("save model to %s" % checkpoint_path) 29 | 30 | def load(self, checkpoint_path): 31 | print('load model:', checkpoint_path) 32 | stats = torch.load(checkpoint_path) 33 | self.network.load_state_dict(stats['net_param']) 34 | 35 | def train(self, frames, mask): 36 | frames_tensor = torch.FloatTensor(frames).to(self.configs.device) 37 | mask_tensor = torch.FloatTensor(mask).to(self.configs.device) 38 | self.optimizer.zero_grad() 39 | next_frames, loss = self.network(frames_tensor, mask_tensor) 40 | loss.backward() 41 | self.optimizer.step() 42 | return loss.detach().cpu().numpy() 43 | 44 | def test(self, frames, mask): 45 | frames_tensor = torch.FloatTensor(frames).to(self.configs.device) 46 | mask_tensor = torch.FloatTensor(mask).to(self.configs.device) 47 | next_frames, _ = self.network(frames_tensor, mask_tensor) 48 | return next_frames.detach().cpu().numpy() -------------------------------------------------------------------------------- /core/trainer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import datetime 3 | import cv2 4 | import numpy as np 5 | from skimage.measure import compare_ssim 6 | from core.utils import metrics 7 | from core.utils import preprocess 8 | 9 | 10 | def train(model, ims, real_input_flag, configs, itr): 11 | cost = model.train(ims, real_input_flag) 12 | if configs.reverse_input: 13 | ims_rev = np.flip(ims, axis=1).copy() 14 | cost += model.train(ims_rev, real_input_flag) 15 | cost = cost / 2 16 | 17 | if itr % configs.display_interval == 0: 18 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr)) 19 | print('training loss: ' + str(cost)) 20 | 21 | 22 | def test(model, test_input_handle, configs, itr): 23 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr)) 24 | test_input_handle.begin(do_shuffle=False) 25 | res_path = os.path.join(configs.gen_frm_dir, str(itr)) 26 | os.mkdir(res_path) 27 | avg_mse = 0 28 | batch_id = 0 29 | img_mse, ssim = [], [] 30 | csi20, csi30, csi40, csi50 = [], [], [], [] 31 | 32 | for i in range(configs.total_length - configs.input_length): 33 | img_mse.append(0) 34 | ssim.append(0) 35 | if configs.dataset_name == 'echo' or configs.dataset_name == 'guangzhou': 36 | csi20.append(0) 37 | csi30.append(0) 38 | csi40.append(0) 39 | csi50.append(0) 40 | 41 | mask_input = configs.input_length 42 | 43 | real_input_flag = np.zeros( 44 | (configs.batch_size, 45 | configs.total_length - mask_input - 1, 46 | configs.img_width // configs.patch_size, 47 | configs.img_width // configs.patch_size, 48 | configs.patch_size ** 2 * configs.img_channel)) 49 | 50 | while (test_input_handle.no_batch_left() == False): 51 | batch_id = batch_id + 1 52 | test_ims = test_input_handle.get_batch() 53 | test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) 54 | 55 | img_gen = model.test(test_dat, real_input_flag) 56 | 57 | img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) 58 | output_length = configs.total_length - configs.input_length 59 | img_gen_length = img_gen.shape[1] 60 | img_out = img_gen[:, -output_length:] 61 | 62 | # MSE per frame 63 | for i in range(output_length): 64 | x = test_ims[:, i + configs.input_length, :, :, :] 65 | gx = img_out[:, i, :, :, :] 66 | gx = np.maximum(gx, 0) 67 | gx = np.minimum(gx, 1) 68 | mse = np.square(x - gx).sum() 69 | img_mse[i] += mse 70 | avg_mse += mse 71 | real_frm = np.uint8(x * 255) 72 | pred_frm = np.uint8(gx * 255) 73 | if configs.dataset_name == 'echo' or configs.dataset_name == 'guangzhou': 74 | csi20[i] += metrics.cal_csi(pred_frm, real_frm, 20) 75 | csi30[i] += metrics.cal_csi(pred_frm, real_frm, 30) 76 | csi40[i] += metrics.cal_csi(pred_frm, real_frm, 40) 77 | csi50[i] += metrics.cal_csi(pred_frm, real_frm, 50) 78 | 79 | for b in range(configs.batch_size): 80 | score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) 81 | ssim[i] += score 82 | 83 | # save prediction examples 84 | if batch_id <= configs.num_save_samples: 85 | path = os.path.join(res_path, str(batch_id)) 86 | os.mkdir(path) 87 | for i in range(configs.total_length): 88 | name = 'gt' + str(i + 1) + '.png' 89 | file_name = os.path.join(path, name) 90 | img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) 91 | cv2.imwrite(file_name, img_gt) 92 | for i in range(img_gen_length): 93 | name = 'pd' + str(i + 1 + configs.input_length) + '.png' 94 | file_name = os.path.join(path, name) 95 | img_pd = img_gen[0, i, :, :, :] 96 | img_pd = np.maximum(img_pd, 0) 97 | img_pd = np.minimum(img_pd, 1) 98 | img_pd = np.uint8(img_pd * 255) 99 | cv2.imwrite(file_name, img_pd) 100 | test_input_handle.next() 101 | 102 | avg_mse = avg_mse / (batch_id * configs.batch_size) 103 | print('mse per seq: ' + str(avg_mse)) 104 | for i in range(configs.total_length - configs.input_length): 105 | print(img_mse[i] / (batch_id * configs.batch_size)) 106 | 107 | ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) 108 | print('ssim per frame: ' + str(np.mean(ssim))) 109 | for i in range(configs.total_length - configs.input_length): 110 | print(ssim[i]) 111 | 112 | if configs.dataset_name == 'echo' or configs.dataset_name == 'guangzhou': 113 | csi20 = np.asarray(csi20, dtype=np.float32) / batch_id 114 | csi30 = np.asarray(csi30, dtype=np.float32) / batch_id 115 | csi40 = np.asarray(csi40, dtype=np.float32) / batch_id 116 | csi50 = np.asarray(csi50, dtype=np.float32) / batch_id 117 | print('csi20 per frame: ' + str(np.mean(csi20))) 118 | for i in range(configs.total_length - configs.input_length): 119 | print(csi20[i]) 120 | print('csi30 per frame: ' + str(np.mean(csi30))) 121 | for i in range(configs.total_length - configs.input_length): 122 | print(csi30[i]) 123 | print('csi40 per frame: ' + str(np.mean(csi40))) 124 | for i in range(configs.total_length - configs.input_length): 125 | print(csi40[i]) 126 | print('csi50 per frame: ' + str(np.mean(csi50))) 127 | for i in range(configs.total_length - configs.input_length): 128 | print(csi50[i]) 129 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/metrics.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | import copy 5 | 6 | 7 | def cal_csi(pd, gt, level): 8 | # [w,h] 9 | pdf = pd.astype(np.float32) 10 | gtf = gt.astype(np.float32) 11 | pd_ = np.zeros(pd.shape) 12 | gt_ = np.zeros(gt.shape) 13 | pd_[(pdf + 30) / 2 >= level] = 1 14 | gt_[(gtf + 30) / 2 >= level] = 1 15 | csi_ = pd_ + gt_ 16 | if (csi_ >= 1).sum() == 0: 17 | return 0.0 18 | return float((csi_ == 2).sum()) / float((csi_ >= 1).sum()) 19 | 20 | 21 | def cal_far(pd, gt, level): 22 | # [w,h] 23 | pdf = pd.astype(np.float32) 24 | gtf = gt.astype(np.float32) 25 | pd_ = np.zeros(pd.shape) 26 | gt_ = np.zeros(gt.shape) 27 | pd_[(pdf + 30) / 2 >= level] = 1 28 | gt_[(gtf + 30) / 2 >= level] = 1 29 | csi_ = pd_ + gt_ 30 | tmp = copy.deepcopy(csi_) 31 | tmp[tmp == 2] = 1 32 | falsealarm = tmp - gt_ 33 | if (pd_.sum()) == 0: 34 | return 0.0 35 | return float((falsealarm == 1).sum()) / float((pd_ >= 1).sum()) 36 | -------------------------------------------------------------------------------- /core/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | 5 | def reshape_patch(img_tensor, patch_size): 6 | assert 5 == img_tensor.ndim 7 | batch_size = np.shape(img_tensor)[0] 8 | seq_length = np.shape(img_tensor)[1] 9 | img_height = np.shape(img_tensor)[2] 10 | img_width = np.shape(img_tensor)[3] 11 | num_channels = np.shape(img_tensor)[4] 12 | a = np.reshape(img_tensor, [batch_size, seq_length, 13 | img_height//patch_size, patch_size, 14 | img_width//patch_size, patch_size, 15 | num_channels]) 16 | b = np.transpose(a, [0,1,2,4,3,5,6]) 17 | patch_tensor = np.reshape(b, [batch_size, seq_length, 18 | img_height//patch_size, 19 | img_width//patch_size, 20 | patch_size*patch_size*num_channels]) 21 | return patch_tensor 22 | 23 | def reshape_patch_back(patch_tensor, patch_size): 24 | assert 5 == patch_tensor.ndim 25 | batch_size = np.shape(patch_tensor)[0] 26 | seq_length = np.shape(patch_tensor)[1] 27 | patch_height = np.shape(patch_tensor)[2] 28 | patch_width = np.shape(patch_tensor)[3] 29 | channels = np.shape(patch_tensor)[4] 30 | img_channels = channels // (patch_size*patch_size) 31 | a = np.reshape(patch_tensor, [batch_size, seq_length, 32 | patch_height, patch_width, 33 | patch_size, patch_size, 34 | img_channels]) 35 | b = np.transpose(a, [0,1,2,4,3,5,6]) 36 | img_tensor = np.reshape(b, [batch_size, seq_length, 37 | patch_height * patch_size, 38 | patch_width * patch_size, 39 | img_channels]) 40 | return img_tensor 41 | 42 | -------------------------------------------------------------------------------- /human_script/MotionRNN_PredRNN_human_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name human \ 7 | --train_data_paths ./data/human/human \ 8 | --valid_data_paths ./data/human/human \ 9 | --save_dir checkpoints/human_MotionRNN_PredRNN_check \ 10 | --gen_frm_dir results/human_MotionRNN_PredRNN_check \ 11 | --model_name MotionRNN_PredRNN \ 12 | --reverse_input 1 \ 13 | --img_height 128 \ 14 | --img_width 128 \ 15 | --img_channel 3 \ 16 | --input_length 4 \ 17 | --total_length 8 \ 18 | --num_hidden 64,64,64,64 \ 19 | --filter_size 5 \ 20 | --stride 1 \ 21 | --patch_size 4 \ 22 | --layer_norm 0 \ 23 | --scheduled_sampling 1 \ 24 | --sampling_stop_iter 50000 \ 25 | --sampling_start_value 1.0 \ 26 | --sampling_changing_rate 0.00002 \ 27 | --lr 0.0003 \ 28 | --batch_size 8 \ 29 | --max_iterations 80000 \ 30 | --display_interval 100 \ 31 | --test_interval 5000 \ 32 | --snapshot_interval 5000 33 | -------------------------------------------------------------------------------- /pic/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/pic/architecture.png -------------------------------------------------------------------------------- /pic/motion_decomp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/pic/motion_decomp.png -------------------------------------------------------------------------------- /pic/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MotionRNN/002415f4e7384f14fd20501b76febe8c5caaca7e/pic/vis.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import os 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | import math 8 | from core.data_provider import datasets_factory 9 | from core.models.model_factory import Model 10 | from core.utils import preprocess 11 | import core.trainer as trainer 12 | 13 | # ----------------------------------------------------------------------------- 14 | parser = argparse.ArgumentParser(description='PyTorch video prediction model - PredRNN') 15 | 16 | # training/test 17 | parser.add_argument('--is_training', type=int, default=1) 18 | parser.add_argument('--device', type=str, default='cpu:0') 19 | 20 | # data 21 | parser.add_argument('--dataset_name', type=str, default='mnist') 22 | parser.add_argument('--train_data_paths', type=str, default='data/moving-mnist-example/moving-mnist-train.npz') 23 | parser.add_argument('--valid_data_paths', type=str, default='data/moving-mnist-example/moving-mnist-valid.npz') 24 | parser.add_argument('--save_dir', type=str, default='checkpoints/mnist_predrnn') 25 | parser.add_argument('--gen_frm_dir', type=str, default='results/mnist_predrnn') 26 | parser.add_argument('--input_length', type=int, default=10) 27 | parser.add_argument('--total_length', type=int, default=20) 28 | parser.add_argument('--img_height', type=int, default=64) 29 | parser.add_argument('--img_width', type=int, default=64) 30 | parser.add_argument('--img_channel', type=int, default=1) 31 | 32 | # model 33 | parser.add_argument('--model_name', type=str, default='predrnn') 34 | parser.add_argument('--pretrained_model', type=str, default='') 35 | parser.add_argument('--num_hidden', type=str, default='64,64,64,64') 36 | parser.add_argument('--filter_size', type=int, default=5) 37 | parser.add_argument('--stride', type=int, default=1) 38 | parser.add_argument('--patch_size', type=int, default=4) 39 | parser.add_argument('--layer_norm', type=int, default=1) 40 | parser.add_argument('--decouple_beta', type=float, default=0.1) 41 | 42 | # scheduled sampling 43 | parser.add_argument('--scheduled_sampling', type=int, default=1) 44 | parser.add_argument('--sampling_stop_iter', type=int, default=50000) 45 | parser.add_argument('--sampling_start_value', type=float, default=1.0) 46 | parser.add_argument('--sampling_changing_rate', type=float, default=0.00002) 47 | 48 | # optimization 49 | parser.add_argument('--lr', type=float, default=0.001) 50 | parser.add_argument('--reverse_input', type=int, default=1) 51 | parser.add_argument('--batch_size', type=int, default=8) 52 | parser.add_argument('--max_iterations', type=int, default=80000) 53 | parser.add_argument('--display_interval', type=int, default=100) 54 | parser.add_argument('--test_interval', type=int, default=5000) 55 | parser.add_argument('--snapshot_interval', type=int, default=5000) 56 | parser.add_argument('--num_save_samples', type=int, default=10) 57 | parser.add_argument('--n_gpu', type=int, default=1) 58 | 59 | args = parser.parse_args() 60 | print(args) 61 | 62 | 63 | def schedule_sampling(eta, itr): 64 | zeros = np.zeros((args.batch_size, 65 | args.total_length - args.input_length - 1, 66 | args.img_width // args.patch_size, 67 | args.img_width // args.patch_size, 68 | args.patch_size ** 2 * args.img_channel)) 69 | if not args.scheduled_sampling: 70 | return 0.0, zeros 71 | 72 | if itr < args.sampling_stop_iter: 73 | eta -= args.sampling_changing_rate 74 | else: 75 | eta = 0.0 76 | random_flip = np.random.random_sample( 77 | (args.batch_size, args.total_length - args.input_length - 1)) 78 | true_token = (random_flip < eta) 79 | ones = np.ones((args.img_width // args.patch_size, 80 | args.img_width // args.patch_size, 81 | args.patch_size ** 2 * args.img_channel)) 82 | zeros = np.zeros((args.img_width // args.patch_size, 83 | args.img_width // args.patch_size, 84 | args.patch_size ** 2 * args.img_channel)) 85 | real_input_flag = [] 86 | for i in range(args.batch_size): 87 | for j in range(args.total_length - args.input_length - 1): 88 | if true_token[i, j]: 89 | real_input_flag.append(ones) 90 | else: 91 | real_input_flag.append(zeros) 92 | real_input_flag = np.array(real_input_flag) 93 | real_input_flag = np.reshape(real_input_flag, 94 | (args.batch_size, 95 | args.total_length - args.input_length - 1, 96 | args.img_width // args.patch_size, 97 | args.img_width // args.patch_size, 98 | args.patch_size ** 2 * args.img_channel)) 99 | return eta, real_input_flag 100 | 101 | 102 | def train_wrapper(model): 103 | if args.pretrained_model: 104 | model.load(args.pretrained_model) 105 | # load data 106 | train_input_handle, test_input_handle = datasets_factory.data_provider( 107 | args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, 108 | seq_length=args.total_length, is_training=True) 109 | 110 | eta = args.sampling_start_value 111 | 112 | for itr in range(1, args.max_iterations + 1): 113 | if train_input_handle.no_batch_left(): 114 | train_input_handle.begin(do_shuffle=True) 115 | ims = train_input_handle.get_batch() 116 | ims = preprocess.reshape_patch(ims, args.patch_size) 117 | 118 | eta, real_input_flag = schedule_sampling(eta, itr) 119 | 120 | trainer.train(model, ims, real_input_flag, args, itr) 121 | 122 | if itr % args.snapshot_interval == 0: 123 | model.save(itr) 124 | 125 | if itr % args.test_interval == 0: 126 | trainer.test(model, test_input_handle, args, itr) 127 | 128 | train_input_handle.next() 129 | 130 | 131 | def test_wrapper(model): 132 | model.load(args.pretrained_model) 133 | test_input_handle = datasets_factory.data_provider( 134 | args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, 135 | seq_length=args.total_length, is_training=False) 136 | trainer.test(model, test_input_handle, args, 'test_result') 137 | 138 | 139 | if os.path.exists(args.save_dir): 140 | shutil.rmtree(args.save_dir) 141 | os.makedirs(args.save_dir) 142 | 143 | if os.path.exists(args.gen_frm_dir): 144 | shutil.rmtree(args.gen_frm_dir) 145 | os.makedirs(args.gen_frm_dir) 146 | 147 | # gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(','), dtype=np.int32) 148 | # args.n_gpu = len(gpu_list) 149 | print('Initializing models') 150 | 151 | model = Model(args) 152 | 153 | if args.is_training: 154 | train_wrapper(model) 155 | else: 156 | test_wrapper(model) 157 | --------------------------------------------------------------------------------