├── README.md ├── bair_script ├── predrnn_bair_train.sh └── predrnn_v2_bair_train.sh ├── core ├── __init__.py ├── data_provider │ ├── __init__.py │ ├── bair.py │ ├── datasets_factory.py │ ├── kth_action.py │ └── mnist.py ├── layers │ ├── SpatioTemporalLSTMCell.py │ ├── SpatioTemporalLSTMCell_action.py │ ├── SpatioTemporalLSTMCell_v2.py │ ├── SpatioTemporalLSTMCell_v2_action.py │ └── __init__.py ├── models │ ├── __init__.py │ ├── action_cond_predrnn.py │ ├── action_cond_predrnn_v2.py │ ├── model_factory.py │ ├── predrnn.py │ └── predrnn_v2.py ├── trainer.py └── utils │ ├── __init__.py │ ├── metrics.py │ ├── preprocess.py │ └── tsne.py ├── kth_script ├── predrnn_kth_train.sh └── predrnn_v2_kth_train.sh ├── mnist_script ├── predrnn_mnist_train.sh └── predrnn_v2_mnist_train.sh ├── pic ├── BAIR_results.png ├── Traffic4Cast.png ├── action_based.png ├── bair.png ├── decouple.png ├── kth.png ├── mnist.png ├── network.png ├── radar.png ├── response.png └── rss.png └── run.py /README.md: -------------------------------------------------------------------------------- 1 | # PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning (TPAMI 2022) 2 | 3 | The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems. 4 | 5 | ## Initial version at NeurIPS 2017 6 | 7 | This repo first contains a PyTorch implementation of **PredRNN** (2017) [[paper](https://papers.nips.cc/paper/6689-predrnn-recurrent-neural-networks-for-predictive-learning-using-spatiotemporal-lstms)], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment. 8 | 9 | Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate. 10 | 11 | ## New in PredRNN-V2 at TPAMI 2022 12 | 13 | This repo also includes the implementation of **PredRNN-V2** [[paper](https://arxiv.org/pdf/2103.09504.pdf)], which improves PredRNN in the following three aspects. 14 | 15 | 16 | #### 1. Memory-Decoupled ST-LSTM 17 | 18 | We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics. 19 | 20 | ![decouple](./pic/decouple.png) 21 | 22 | #### 2. Reverse Scheduled Sampling 23 | 24 | Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. **Benefit:** It forces the model to learn long-term dynamics from context frames. 25 | 26 | [comment]: 27 | 28 | #### 3. Action-Conditioned Video Prediction 29 | 30 | We further extend PredRNN to action-conditioned video prediction. By fusing the actions with hidden states, PredRNN and PredRNN-V2 show highly competitive performance in long-term forecasting. They are potential to serve as the base dynamic model in model-based visual control. 31 | 32 | We show quantitative results on the BAIR robot pushing dataset for predicting 28 future frames from 2 observations. 33 | 34 | ![action](./pic/action_based.png) 35 | 36 | ## Showcases 37 | 38 | Moving MNIST 39 | 40 | ![mnist](./pic/mnist.png) 41 | 42 | KTH 43 | 44 | ![kth](./pic/kth.png) 45 | 46 | BAIR (We zoom in on the area in the red box) 47 | 48 | ![bair](./pic/bair.png) 49 | 50 | Traffic4Cast 51 | 52 | ![Traffic4Cast](./pic/Traffic4Cast.png) 53 | 54 | Radar echoes 55 | 56 | ![radar](./pic/radar.png) 57 | 58 | ## Quantitative results on Moving MNIST and KTH in LPIPS 59 | 60 | LPIPS is more sensitive to perceptual human judgments, the lower the better. 61 | 62 | | | Moving MNIST | KTH action | 63 | | ---- | ---- | ---- | 64 | | PredRNN | 0.109 | 0.204 | 65 | | PredRNN-V2 | 0.071 | 0.139 | 66 | 67 | ## Quantitative results on Traffic4Cast (Berlin) 68 | 69 | | | MSE (10^{-3}) | 70 | | ---------------- | --------------------- | 71 | | U-Net | 6.992 | 72 | | CrevNet | 6.789 | 73 | | U-Net+PredRNN-V2 | **5.135** | 74 | 75 | [comment]:<## Quantitative results on the action-conditioned BAIR dataset> 76 | 77 | [comment]: 78 | 79 | [comment]: 80 | 81 | 82 | ## Get Started 83 | 84 | 1. Install Python 3.6, PyTorch 1.9.0 for the main code. Also, install Tensorflow 2.1.0 for BAIR dataloader. 85 | 86 | 2. Download data. This repo contains code for three datasets: the [Moving Mnist dataset](https://onedrive.live.com/?authkey=%21AGzXjcOlzTQw158&id=FF7F539F0073B9E2%21124&cid=FF7F539F0073B9E2), the [KTH action dataset](https://drive.google.com/drive/folders/1_M1O4TuQOhYcNdXXuNoNjYyzGrSM9pBF?usp=sharing), and the BAIR dataset (30.1GB), which can be obtained by: 87 | 88 | ``` 89 | wget http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar 90 | ``` 91 | 92 | 3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the `--save_dir` folder. 93 | The generated future frames will be saved in the `--gen_frm_dir` folder. 94 | 95 | 4. You can get **pretrained models** from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/72241e0046a74f81bf29/) or [Google Drive](https://drive.google.com/drive/folders/1jaEHcxo_UgvgwEWKi0ygX1SbODGz6PWw). 96 | ``` 97 | cd mnist_script/ 98 | sh predrnn_mnist_train.sh 99 | sh predrnn_v2_mnist_train.sh 100 | 101 | cd kth_script/ 102 | sh predrnn_kth_train.sh 103 | sh predrnn_v2_kth_train.sh 104 | 105 | cd bair_script/ 106 | sh predrnn_bair_train.sh 107 | sh predrnn_v2_bair_train.sh 108 | ``` 109 | 110 | ## Citation 111 | 112 | If you find this repo useful, please cite the following papers. 113 | ``` 114 | @inproceedings{wang2017predrnn, 115 | title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s}, 116 | author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S}, 117 | booktitle={Advances in Neural Information Processing Systems}, 118 | pages={879--888}, 119 | year={2017} 120 | } 121 | 122 | @misc{wang2021predrnn, 123 | title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 124 | author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng}, 125 | year={2021}, 126 | eprint={2103.09504}, 127 | archivePrefix={arXiv}, 128 | } 129 | ``` 130 | 131 | -------------------------------------------------------------------------------- /bair_script/predrnn_bair_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name bair \ 7 | --train_data_paths /data/Action-BAIR/ \ 8 | --valid_data_paths /data/Action-BAIR/ \ 9 | --save_dir checkpoints/bair_action_cond_predrnn \ 10 | --gen_frm_dir results/bair_action_cond_predrnn \ 11 | --model_name action_cond_predrnn \ 12 | --reverse_input 1 \ 13 | --img_width 64 \ 14 | --img_channel 3 \ 15 | --input_length 2 \ 16 | --total_length 12 \ 17 | --num_hidden 64,64,64,64 \ 18 | --filter_size 5 \ 19 | --stride 1 \ 20 | --patch_size 1 \ 21 | --layer_norm 0 \ 22 | --decouple_beta 0.1 \ 23 | --reverse_scheduled_sampling 1 \ 24 | --r_sampling_step_1 25000 \ 25 | --r_sampling_step_2 50000 \ 26 | --r_exp_alpha 2500 \ 27 | --lr 0.0001 \ 28 | --batch_size 16 \ 29 | --max_iterations 80000 \ 30 | --display_interval 100 \ 31 | --test_interval 5000 \ 32 | --snapshot_interval 5000 \ 33 | --conv_on_input 1 \ 34 | --res_on_conv 1 -------------------------------------------------------------------------------- /bair_script/predrnn_v2_bair_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name bair \ 7 | --train_data_paths /data/Action-BAIR/ \ 8 | --valid_data_paths /data/Action-BAIR/ \ 9 | --save_dir checkpoints/bair_action_cond_predrnn_v2 \ 10 | --gen_frm_dir results/bair_action_cond_predrnn_v2 \ 11 | --model_name action_cond_predrnn_v2 \ 12 | --reverse_input 1 \ 13 | --img_width 64 \ 14 | --img_channel 3 \ 15 | --input_length 2 \ 16 | --total_length 12 \ 17 | --num_hidden 128,128,128,128 \ 18 | --filter_size 5 \ 19 | --stride 1 \ 20 | --patch_size 1 \ 21 | --layer_norm 0 \ 22 | --decouple_beta 0.1 \ 23 | --reverse_scheduled_sampling 1 \ 24 | --r_sampling_step_1 25000 \ 25 | --r_sampling_step_2 50000 \ 26 | --r_exp_alpha 2500 \ 27 | --lr 0.0001 \ 28 | --batch_size 16 \ 29 | --max_iterations 80000 \ 30 | --display_interval 100 \ 31 | --test_interval 5000 \ 32 | --snapshot_interval 5000 \ 33 | --conv_on_input 1 \ 34 | --res_on_conv 1 \ 35 | # --pretrained_model ./checkpoints/bair_predrnn_v2/bair_model.ckpt -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/core/__init__.py -------------------------------------------------------------------------------- /core/data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/core/data_provider/__init__.py -------------------------------------------------------------------------------- /core/data_provider/bair.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jianjin' 2 | 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | import tensorflow as tf 7 | import logging 8 | import random 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class InputHandle: 14 | def __init__(self, datas, indices, configs): 15 | self.name = configs['name'] + ' iterator' 16 | self.minibatch_size = configs['batch_size'] 17 | self.image_height = configs['image_height'] 18 | self.image_width = configs['image_width'] 19 | self.datas = datas 20 | self.indices = indices 21 | self.current_position = 0 22 | self.current_batch_indices = [] 23 | self.current_input_length = configs['seq_length'] 24 | self.injection_action = configs['injection_action'] 25 | 26 | def total(self): 27 | return len(self.indices) 28 | 29 | def begin(self, do_shuffle=True): 30 | logger.info("Initialization for read data ") 31 | if do_shuffle: 32 | random.shuffle(self.indices) 33 | self.current_position = 0 34 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 35 | 36 | def next(self): 37 | self.current_position += self.minibatch_size 38 | if self.no_batch_left(): 39 | return None 40 | else: 41 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 42 | 43 | def no_batch_left(self): 44 | if self.current_position + self.minibatch_size >= self.total(): 45 | return True 46 | else: 47 | return False 48 | 49 | def get_batch(self): 50 | if self.no_batch_left(): 51 | logger.error( 52 | "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") 53 | return None 54 | input_batch = np.zeros( 55 | (self.minibatch_size, self.current_input_length, self.image_height, self.image_width, 7)).astype(np.float32) 56 | for i in range(self.minibatch_size): 57 | batch_ind = self.current_batch_indices[i] 58 | begin = batch_ind[-1] 59 | end = begin + self.current_input_length 60 | k = 0 61 | for serialized_example in tf.compat.v1.python_io.tf_record_iterator(batch_ind[0]): 62 | if k == batch_ind[1]: 63 | example = tf.train.Example() 64 | example.ParseFromString(serialized_example) 65 | break 66 | k += 1 67 | for j in range(begin, end): 68 | action_name = str(j) + '/action' 69 | action_value = np.array(example.features.feature[action_name].float_list.value) 70 | if action_value.shape == (0,): # End of frames/data 71 | print("error! " + str(batch_ind)) 72 | input_batch[i, j - begin, :, :, 3:] = np.stack([np.ones([64, 64]) * i for i in action_value], axis=2) 73 | 74 | # endeffector_pos_name = str(j) + '/endeffector_pos' 75 | # endeffector_pos_value = list(example.features.feature[endeffector_pos_name].float_list.value) 76 | # endeffector_positions = np.vstack((endeffector_positions, endeffector_pos_value)) 77 | 78 | aux1_image_name = str(j) + '/image_aux1/encoded' 79 | aux1_byte_str = example.features.feature[aux1_image_name].bytes_list.value[0] 80 | aux1_img = Image.frombytes('RGB', (64, 64), aux1_byte_str) 81 | aux1_arr = np.array(aux1_img.getdata()).reshape((aux1_img.size[1], aux1_img.size[0], 3)) 82 | 83 | # main_image_name = str(j) + '/image_main/encoded' 84 | # main_byte_str = example.features.feature[main_image_name].bytes_list.value[0] 85 | # main_img = Image.frombytes('RGB', (64, 64), main_byte_str) 86 | # main_arr = np.array(main_img.getdata()).reshape((main_img.size[1], main_img.size[0], 3)) 87 | 88 | input_batch[i, j - begin, :, :, :3] = aux1_arr.reshape(64, 64, 3) / 255 89 | input_batch = input_batch.astype(np.float32) 90 | return input_batch 91 | 92 | def print_stat(self): 93 | logger.info("Iterator Name: " + self.name) 94 | logger.info(" current_position: " + str(self.current_position)) 95 | logger.info(" Minibatch Size: " + str(self.minibatch_size)) 96 | logger.info(" total Size: " + str(self.total())) 97 | logger.info(" current_input_length: " + str(self.current_input_length)) 98 | 99 | 100 | class DataProcess: 101 | def __init__(self, configs): 102 | self.configs = configs 103 | self.train_data_path = configs['train_data_paths'] 104 | self.valid_data_path = configs['valid_data_paths'] 105 | self.image_height = configs['image_height'] 106 | self.image_width = configs['image_width'] 107 | self.seq_len = configs['seq_length'] 108 | 109 | def load_data(self, path, mode='train'): 110 | path = os.path.join(path[0], 'softmotion30_44k') 111 | if mode == 'train': 112 | path = os.path.join(path, 'train') 113 | elif mode == 'test': 114 | path = os.path.join(path, 'test') 115 | else: 116 | print("ERROR!") 117 | print('begin load data' + str(path)) 118 | 119 | video_fullpaths = [] 120 | indices = [] 121 | 122 | tfrecords = os.listdir(path) 123 | tfrecords.sort() 124 | num_pictures = 0 125 | 126 | for tfrecord in tfrecords: 127 | filepath = os.path.join(path, tfrecord) 128 | video_fullpaths.append(filepath) 129 | k = 0 130 | for serialized_example in tf.compat.v1.python_io.tf_record_iterator(os.path.join(path, tfrecord)): 131 | example = tf.train.Example() 132 | example.ParseFromString(serialized_example) 133 | i = 0 134 | while True: 135 | action_name = str(i) + '/action' 136 | action_value = np.array(example.features.feature[action_name].float_list.value) 137 | if action_value.shape == (0,): # End of frames/data 138 | break 139 | i += 1 140 | num_pictures += i 141 | for j in range(i - self.seq_len + 1): 142 | indices.append((filepath, k, j)) 143 | k += 1 144 | print("there are " + str(num_pictures) + " pictures") 145 | print("there are " + str(len(indices)) + " sequences") 146 | return video_fullpaths, indices 147 | 148 | def get_train_input_handle(self): 149 | train_data, train_indices = self.load_data(self.train_data_path, mode='train') 150 | return InputHandle(train_data, train_indices, self.configs) 151 | 152 | def get_test_input_handle(self): 153 | test_data, test_indices = self.load_data(self.valid_data_path, mode='test') 154 | return InputHandle(test_data, test_indices, self.configs) 155 | -------------------------------------------------------------------------------- /core/data_provider/datasets_factory.py: -------------------------------------------------------------------------------- 1 | from core.data_provider import kth_action, mnist, bair 2 | 3 | datasets_map = { 4 | 'mnist': mnist, 5 | 'action': kth_action, 6 | 'bair': bair, 7 | } 8 | 9 | 10 | def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size, 11 | img_width, seq_length, injection_action, is_training=True): 12 | if dataset_name not in datasets_map: 13 | raise ValueError('Name of dataset unknown %s' % dataset_name) 14 | train_data_list = train_data_paths.split(',') 15 | valid_data_list = valid_data_paths.split(',') 16 | if dataset_name == 'mnist': 17 | test_input_param = {'paths': valid_data_list, 18 | 'minibatch_size': batch_size, 19 | 'input_data_type': 'float32', 20 | 'is_output_sequence': True, 21 | 'name': dataset_name + 'test iterator'} 22 | test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param) 23 | test_input_handle.begin(do_shuffle=False) 24 | if is_training: 25 | train_input_param = {'paths': train_data_list, 26 | 'minibatch_size': batch_size, 27 | 'input_data_type': 'float32', 28 | 'is_output_sequence': True, 29 | 'name': dataset_name + ' train iterator'} 30 | train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param) 31 | train_input_handle.begin(do_shuffle=True) 32 | return train_input_handle, test_input_handle 33 | else: 34 | return test_input_handle 35 | 36 | if dataset_name == 'action': 37 | input_param = {'paths': valid_data_list, 38 | 'image_width': img_width, 39 | 'minibatch_size': batch_size, 40 | 'seq_length': seq_length, 41 | 'input_data_type': 'float32', 42 | 'name': dataset_name + ' iterator'} 43 | input_handle = datasets_map[dataset_name].DataProcess(input_param) 44 | if is_training: 45 | train_input_handle = input_handle.get_train_input_handle() 46 | train_input_handle.begin(do_shuffle=True) 47 | test_input_handle = input_handle.get_test_input_handle() 48 | test_input_handle.begin(do_shuffle=False) 49 | return train_input_handle, test_input_handle 50 | else: 51 | test_input_handle = input_handle.get_test_input_handle() 52 | test_input_handle.begin(do_shuffle=False) 53 | return test_input_handle 54 | 55 | if dataset_name == 'bair': 56 | test_input_param = {'valid_data_paths': valid_data_list, 57 | 'train_data_paths': train_data_list, 58 | 'batch_size': batch_size, 59 | 'image_width': img_width, 60 | 'image_height': img_width, 61 | 'seq_length': seq_length, 62 | 'injection_action': injection_action, 63 | 'input_data_type': 'float32', 64 | 'name': dataset_name + 'test iterator'} 65 | input_handle_test = datasets_map[dataset_name].DataProcess(test_input_param) 66 | test_input_handle = input_handle_test.get_test_input_handle() 67 | test_input_handle.begin(do_shuffle=False) 68 | if is_training: 69 | train_input_param = {'valid_data_paths': valid_data_list, 70 | 'train_data_paths': train_data_list, 71 | 'image_width': img_width, 72 | 'image_height': img_width, 73 | 'batch_size': batch_size, 74 | 'seq_length': seq_length, 75 | 'injection_action': injection_action, 76 | 'input_data_type': 'float32', 77 | 'name': dataset_name + ' train iterator'} 78 | input_handle_train = datasets_map[dataset_name].DataProcess(train_input_param) 79 | train_input_handle = input_handle_train.get_train_input_handle() 80 | train_input_handle.begin(do_shuffle=True) 81 | return train_input_handle, test_input_handle 82 | else: 83 | return test_input_handle -------------------------------------------------------------------------------- /core/data_provider/kth_action.py: -------------------------------------------------------------------------------- 1 | __author__ = 'gaozhifeng' 2 | import numpy as np 3 | import os 4 | import cv2 5 | from PIL import Image 6 | import logging 7 | import random 8 | from typing import Iterable, List 9 | from dataclasses import dataclass 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class InputHandle: 14 | def __init__(self, datas, indices, input_param): 15 | self.name = input_param['name'] 16 | self.input_data_type = input_param.get('input_data_type', 'float32') 17 | self.minibatch_size = input_param['minibatch_size'] 18 | self.image_width = input_param['image_width'] 19 | self.datas = datas 20 | self.indices = indices 21 | self.current_position = 0 22 | self.current_batch_indices = [] 23 | self.current_input_length = input_param['seq_length'] 24 | 25 | def total(self): 26 | return len(self.indices) 27 | 28 | def begin(self, do_shuffle=True): 29 | logger.info("Initialization for read data ") 30 | if do_shuffle: 31 | random.shuffle(self.indices) 32 | self.current_position = 0 33 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 34 | 35 | def next(self): 36 | self.current_position += self.minibatch_size 37 | if self.no_batch_left(): 38 | return None 39 | self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] 40 | 41 | def no_batch_left(self): 42 | if self.current_position + self.minibatch_size >= self.total(): 43 | return True 44 | else: 45 | return False 46 | 47 | def get_batch(self): 48 | if self.no_batch_left(): 49 | logger.error( 50 | "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") 51 | return None 52 | # Create batch of N videos of length L, i.e. shape = (N, L, w, h, c) 53 | # where w x h is the resolution and c the number of color channels 54 | input_batch = np.zeros( 55 | (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, 1)).astype( 56 | self.input_data_type) 57 | for i in range(self.minibatch_size): 58 | batch_ind = self.current_batch_indices[i] 59 | begin = batch_ind 60 | end = begin + self.current_input_length 61 | data_slice = self.datas[begin:end, :, :, :] 62 | input_batch[i, :self.current_input_length, :, :, :] = data_slice 63 | 64 | input_batch = input_batch.astype(self.input_data_type) 65 | return input_batch 66 | 67 | def print_stat(self): 68 | logger.info("Iterator Name: " + self.name) 69 | logger.info(" current_position: " + str(self.current_position)) 70 | logger.info(" Minibatch Size: " + str(self.minibatch_size)) 71 | logger.info(" total Size: " + str(self.total())) 72 | logger.info(" current_input_length: " + str(self.current_input_length)) 73 | logger.info(" Input Data Type: " + str(self.input_data_type)) 74 | 75 | 76 | @dataclass 77 | class ActionFrameInfo: 78 | file_name: str 79 | file_path: str 80 | person_mark: int 81 | category_flag: int 82 | 83 | 84 | class DataProcess: 85 | def __init__(self, input_param): 86 | self.paths = input_param['paths'] # path to parent folder containing category dirs 87 | self.category_1 = ['boxing', 'handclapping', 'handwaving', 'walking'] 88 | self.category_2 = ['jogging', 'running'] 89 | self.categories = self.category_1 + self.category_2 90 | self.image_width = input_param['image_width'] 91 | 92 | # Hard coded training and test persons (prevent same person occurring in train - test set) 93 | self.train_person = ['01', '02', '03', '04', '05', '06', '07', '08', 94 | '09', '10', '11', '12', '13', '14', '15', '16'] 95 | self.test_person = ['17', '18', '19', '20', '21', '22', '23', '24', '25'] 96 | 97 | self.input_param = input_param 98 | self.seq_len = input_param['seq_length'] 99 | 100 | 101 | def generate_frames(self, root_path, person_ids: List[int]) -> Iterable[ActionFrameInfo]: 102 | """Generate frame info for all frames. 103 | 104 | Parameters: 105 | person_ids: persons to include 106 | """ 107 | person_mark = 0 108 | for cat_dir in self.categories: # handwaving 109 | if cat_dir in self.category_1: 110 | frame_category_flag = 1 # 20 step 111 | elif cat_dir in self.category_2: 112 | frame_category_flag = 2 # 3 step 113 | else: 114 | raise Exception("category error!!!") 115 | 116 | cat_dir_path = os.path.join(root_path, cat_dir) 117 | cat_videos = os.listdir(cat_dir_path) 118 | 119 | for person_direction_video in cat_videos: 120 | person_id = person_direction_video[6:8] # chars 6-8 contan number 121 | if person_id not in person_ids: 122 | continue 123 | person_mark += 1 # identify all stored frames as belonging to this person + direction 124 | dir_path = os.path.join(cat_dir_path, person_direction_video) 125 | filelist = os.listdir(dir_path) 126 | filelist.sort() 127 | for frame_name in filelist: 128 | if frame_name.startswith('image') == False: 129 | continue 130 | yield ActionFrameInfo( 131 | file_name=frame_name, 132 | file_path=os.path.join(dir_path, frame_name), 133 | person_mark=person_mark, 134 | category_flag=frame_category_flag 135 | ) 136 | 137 | def load_data(self, paths, mode='train'): 138 | ''' 139 | frame -- action -- person_seq(a dir) 140 | :param paths: action_path list 141 | :return: 142 | ''' 143 | 144 | path = paths[0] 145 | if mode == 'train': 146 | mode_person_ids = self.train_person 147 | elif mode == 'test': 148 | mode_person_ids = self.test_person 149 | else: 150 | raise Exception("Unexpected mode: " + mode) 151 | print('begin load data' + str(path)) 152 | 153 | frames_file_name = [] 154 | frames_person_mark = [] # for each frame in the joint array, mark the person ID 155 | frames_category = [] 156 | 157 | # First count total number of frames 158 | # Do it without creating massive array: 159 | # all_frame_info = self.generate_frames(path, mode_person_ids) 160 | tot_num_frames = sum((1 for _ in self.generate_frames(path, mode_person_ids))) 161 | print(f"Preparing to load {tot_num_frames} video frames.") 162 | 163 | # Target array containing ALL RESIZED frames 164 | data = np.empty((tot_num_frames, self.image_width, self.image_width , 1), 165 | dtype=np.float32) # np.float32 166 | 167 | # Read, resize, and store video frames 168 | for i, frame in enumerate(self.generate_frames(path, mode_person_ids)): 169 | frame_im = Image.open(frame.file_path).convert('L') # int8 2D array 170 | 171 | # input type must be float32 for default interpolation method cv2.INTER_AREA 172 | frame_np = np.array(frame_im, dtype=np.uint8) # (1000, 1000) numpy array 173 | data[i,:,:,0] = (cv2.resize( 174 | frame_np, (self.image_width,self.image_width))/255.0).astype(np.float32) 175 | 176 | frames_file_name.append(frame.file_name) 177 | frames_person_mark.append(frame.person_mark) 178 | frames_category.append(frame.category_flag) 179 | 180 | # identify sequences of within the same video 181 | indices = [] 182 | seq_end_idx = len(frames_person_mark) - 1 183 | while seq_end_idx >= self.seq_len - 1: 184 | seq_start_idx = seq_end_idx - self.seq_len + 1 185 | if frames_person_mark[seq_end_idx] == frames_person_mark[seq_start_idx]: 186 | # Get person ID at the start and end of this sequence (of seq_len) 187 | end = int(frames_file_name[seq_end_idx][6:10]) 188 | start = int(frames_file_name[seq_start_idx][6:10]) 189 | 190 | # TODO: mode == 'test' 191 | if end - start == self.seq_len - 1: 192 | # Save index into OUT data array indicating start point of sequence 193 | indices.append(seq_start_idx) 194 | 195 | # The step size depends on the category 196 | if frames_category[seq_end_idx] == 1: 197 | seq_end_idx -= self.seq_len - 1 198 | elif frames_category[seq_end_idx] == 2: 199 | seq_end_idx -= 2 200 | else: 201 | raise Exception("category error 2 !!!") 202 | 203 | seq_end_idx -= 1 204 | 205 | print("there are " + str(data.shape[0]) + " pictures") 206 | print("there are " + str(len(indices)) + " sequences") 207 | return data, indices 208 | 209 | def get_train_input_handle(self): 210 | train_data, train_indices = self.load_data(self.paths, mode='train') 211 | return InputHandle(train_data, train_indices, self.input_param) 212 | 213 | def get_test_input_handle(self): 214 | test_data, test_indices = self.load_data(self.paths, mode='test') 215 | return InputHandle(test_data, test_indices, self.input_param) 216 | 217 | -------------------------------------------------------------------------------- /core/data_provider/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class InputHandle: 5 | def __init__(self, input_param): 6 | self.paths = input_param['paths'] 7 | self.num_paths = len(input_param['paths']) 8 | self.name = input_param['name'] 9 | self.input_data_type = input_param.get('input_data_type', 'float32') 10 | self.output_data_type = input_param.get('output_data_type', 'float32') 11 | self.minibatch_size = input_param['minibatch_size'] 12 | self.is_output_sequence = input_param['is_output_sequence'] 13 | self.data = {} 14 | self.indices = {} 15 | self.current_position = 0 16 | self.current_batch_size = 0 17 | self.current_batch_indices = [] 18 | self.current_input_length = 0 19 | self.current_output_length = 0 20 | self.load() 21 | 22 | def load(self): 23 | dat_1 = np.load(self.paths[0]) 24 | for key in dat_1.keys(): 25 | self.data[key] = dat_1[key] 26 | if self.num_paths == 2: 27 | dat_2 = np.load(self.paths[1]) 28 | num_clips_1 = dat_1['clips'].shape[1] 29 | dat_2['clips'][:,:,0] += num_clips_1 30 | self.data['clips'] = np.concatenate( 31 | (dat_1['clips'], dat_2['clips']), axis=1) 32 | self.data['input_raw_data'] = np.concatenate( 33 | (dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0) 34 | self.data['output_raw_data'] = np.concatenate( 35 | (dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0) 36 | for key in self.data.keys(): 37 | print(key) 38 | print(self.data[key].shape) 39 | 40 | def total(self): 41 | return self.data['clips'].shape[1] 42 | 43 | def begin(self, do_shuffle = True): 44 | self.indices = np.arange(self.total(),dtype="int32") 45 | if do_shuffle: 46 | random.shuffle(self.indices) 47 | self.current_position = 0 48 | if self.current_position + self.minibatch_size <= self.total(): 49 | self.current_batch_size = self.minibatch_size 50 | else: 51 | self.current_batch_size = self.total() - self.current_position 52 | self.current_batch_indices = self.indices[ 53 | self.current_position:self.current_position + self.current_batch_size] 54 | self.current_input_length = max(self.data['clips'][0, ind, 1] for ind 55 | in self.current_batch_indices) 56 | self.current_output_length = max(self.data['clips'][1, ind, 1] for ind 57 | in self.current_batch_indices) 58 | 59 | def next(self): 60 | self.current_position += self.current_batch_size 61 | if self.no_batch_left(): 62 | return None 63 | if self.current_position + self.minibatch_size <= self.total(): 64 | self.current_batch_size = self.minibatch_size 65 | else: 66 | self.current_batch_size = self.total() - self.current_position 67 | self.current_batch_indices = self.indices[ 68 | self.current_position:self.current_position + self.current_batch_size] 69 | self.current_input_length = max(self.data['clips'][0, ind, 1] for ind 70 | in self.current_batch_indices) 71 | self.current_output_length = max(self.data['clips'][1, ind, 1] for ind 72 | in self.current_batch_indices) 73 | 74 | def no_batch_left(self): 75 | if self.current_position >= self.total() - self.current_batch_size: 76 | return True 77 | else: 78 | return False 79 | 80 | def input_batch(self): 81 | if self.no_batch_left(): 82 | return None 83 | input_batch = np.zeros( 84 | (self.current_batch_size, self.current_input_length) + 85 | tuple(self.data['dims'][0])).astype(self.input_data_type) 86 | input_batch = np.transpose(input_batch,(0,1,3,4,2)) 87 | for i in range(self.current_batch_size): 88 | batch_ind = self.current_batch_indices[i] 89 | begin = self.data['clips'][0, batch_ind, 0] 90 | end = self.data['clips'][0, batch_ind, 0] + \ 91 | self.data['clips'][0, batch_ind, 1] 92 | data_slice = self.data['input_raw_data'][begin:end, :, :, :] 93 | data_slice = np.transpose(data_slice,(0,2,3,1)) 94 | input_batch[i, :self.current_input_length, :, :, :] = data_slice 95 | input_batch = input_batch.astype(self.input_data_type) 96 | return input_batch 97 | 98 | def output_batch(self): 99 | if self.no_batch_left(): 100 | return None 101 | if(2 ,3) == self.data['dims'].shape: 102 | raw_dat = self.data['output_raw_data'] 103 | else: 104 | raw_dat = self.data['input_raw_data'] 105 | if self.is_output_sequence: 106 | if (1, 3) == self.data['dims'].shape: 107 | output_dim = self.data['dims'][0] 108 | else: 109 | output_dim = self.data['dims'][1] 110 | output_batch = np.zeros( 111 | (self.current_batch_size,self.current_output_length) + 112 | tuple(output_dim)) 113 | else: 114 | output_batch = np.zeros((self.current_batch_size, ) + 115 | tuple(self.data['dims'][1])) 116 | for i in range(self.current_batch_size): 117 | batch_ind = self.current_batch_indices[i] 118 | begin = self.data['clips'][1, batch_ind, 0] 119 | end = self.data['clips'][1, batch_ind, 0] + \ 120 | self.data['clips'][1, batch_ind, 1] 121 | if self.is_output_sequence: 122 | data_slice = raw_dat[begin:end, :, :, :] 123 | output_batch[i, : data_slice.shape[0], :, :, :] = data_slice 124 | else: 125 | data_slice = raw_dat[begin, :, :, :] 126 | output_batch[i,:, :, :] = data_slice 127 | output_batch = output_batch.astype(self.output_data_type) 128 | output_batch = np.transpose(output_batch, [0,1,3,4,2]) 129 | return output_batch 130 | 131 | def get_batch(self): 132 | input_seq = self.input_batch() 133 | output_seq = self.output_batch() 134 | batch = np.concatenate((input_seq, output_seq), axis=1) 135 | return batch 136 | -------------------------------------------------------------------------------- /core/layers/SpatioTemporalLSTMCell.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class SpatioTemporalLSTMCell(nn.Module): 7 | def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): 8 | super(SpatioTemporalLSTMCell, self).__init__() 9 | 10 | self.num_hidden = num_hidden 11 | self.padding = filter_size // 2 12 | self._forget_bias = 1.0 13 | if layer_norm: 14 | self.conv_x = nn.Sequential( 15 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 16 | nn.LayerNorm([num_hidden * 7, width, width]) 17 | ) 18 | self.conv_h = nn.Sequential( 19 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 20 | nn.LayerNorm([num_hidden * 4, width, width]) 21 | ) 22 | self.conv_m = nn.Sequential( 23 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 24 | nn.LayerNorm([num_hidden * 3, width, width]) 25 | ) 26 | self.conv_o = nn.Sequential( 27 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 28 | nn.LayerNorm([num_hidden, width, width]) 29 | ) 30 | else: 31 | self.conv_x = nn.Sequential( 32 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 33 | ) 34 | self.conv_h = nn.Sequential( 35 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 36 | ) 37 | self.conv_m = nn.Sequential( 38 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 39 | ) 40 | self.conv_o = nn.Sequential( 41 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 42 | ) 43 | self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False) 44 | 45 | def forward(self, x_t, h_t, c_t, m_t): 46 | x_concat = self.conv_x(x_t) 47 | h_concat = self.conv_h(h_t) 48 | m_concat = self.conv_m(m_t) 49 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1) 50 | i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1) 51 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 52 | 53 | i_t = torch.sigmoid(i_x + i_h) 54 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 55 | g_t = torch.tanh(g_x + g_h) 56 | 57 | c_new = f_t * c_t + i_t * g_t 58 | 59 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 60 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 61 | g_t_prime = torch.tanh(g_x_prime + g_m) 62 | 63 | m_new = f_t_prime * m_t + i_t_prime * g_t_prime 64 | 65 | mem = torch.cat((c_new, m_new), 1) 66 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 67 | h_new = o_t * torch.tanh(self.conv_last(mem)) 68 | 69 | return h_new, c_new, m_new 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /core/layers/SpatioTemporalLSTMCell_action.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SpatioTemporalLSTMCell(nn.Module): 8 | def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): 9 | super(SpatioTemporalLSTMCell, self).__init__() 10 | 11 | self.num_hidden = num_hidden 12 | self.padding = filter_size // 2 13 | self._forget_bias = 1.0 14 | if layer_norm: 15 | self.conv_x = nn.Sequential( 16 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding), 17 | nn.LayerNorm([num_hidden * 7, width, width]) 18 | ) 19 | self.conv_h = nn.Sequential( 20 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 21 | nn.LayerNorm([num_hidden * 4, width, width]) 22 | ) 23 | self.conv_a = nn.Sequential( 24 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 25 | nn.LayerNorm([num_hidden * 4, width, width]) 26 | ) 27 | self.conv_m = nn.Sequential( 28 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding), 29 | nn.LayerNorm([num_hidden * 3, width, width]) 30 | ) 31 | self.conv_o = nn.Sequential( 32 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), 33 | nn.LayerNorm([num_hidden, width, width]) 34 | ) 35 | else: 36 | self.conv_x = nn.Sequential( 37 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding), 38 | ) 39 | self.conv_h = nn.Sequential( 40 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 41 | ) 42 | self.conv_a = nn.Sequential( 43 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 44 | ) 45 | self.conv_m = nn.Sequential( 46 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding), 47 | ) 48 | self.conv_o = nn.Sequential( 49 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), 50 | ) 51 | self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0) 52 | 53 | def forward(self, x_t, h_t, c_t, m_t, a_t): 54 | x_concat = self.conv_x(x_t) 55 | h_concat = self.conv_h(h_t) 56 | a_concat = self.conv_a(a_t) 57 | m_concat = self.conv_m(m_t) 58 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1) 59 | i_h, f_h, g_h, o_h = torch.split(h_concat * a_concat, self.num_hidden, dim=1) 60 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 61 | 62 | i_t = torch.sigmoid(i_x + i_h) 63 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 64 | g_t = torch.tanh(g_x + g_h) 65 | 66 | c_new = f_t * c_t + i_t * g_t 67 | 68 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 69 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 70 | g_t_prime = torch.tanh(g_x_prime + g_m) 71 | 72 | m_new = f_t_prime * m_t + i_t_prime * g_t_prime 73 | 74 | mem = torch.cat((c_new, m_new), 1) 75 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 76 | h_new = o_t * torch.tanh(self.conv_last(mem)) 77 | 78 | return h_new, c_new, m_new 79 | -------------------------------------------------------------------------------- /core/layers/SpatioTemporalLSTMCell_v2.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class SpatioTemporalLSTMCell(nn.Module): 7 | def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): 8 | super(SpatioTemporalLSTMCell, self).__init__() 9 | 10 | self.num_hidden = num_hidden 11 | self.padding = filter_size // 2 12 | self._forget_bias = 1.0 13 | if layer_norm: 14 | self.conv_x = nn.Sequential( 15 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 16 | nn.LayerNorm([num_hidden * 7, width, width]) 17 | ) 18 | self.conv_h = nn.Sequential( 19 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 20 | nn.LayerNorm([num_hidden * 4, width, width]) 21 | ) 22 | self.conv_m = nn.Sequential( 23 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 24 | nn.LayerNorm([num_hidden * 3, width, width]) 25 | ) 26 | self.conv_o = nn.Sequential( 27 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 28 | nn.LayerNorm([num_hidden, width, width]) 29 | ) 30 | else: 31 | self.conv_x = nn.Sequential( 32 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 33 | ) 34 | self.conv_h = nn.Sequential( 35 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 36 | ) 37 | self.conv_m = nn.Sequential( 38 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 39 | ) 40 | self.conv_o = nn.Sequential( 41 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False), 42 | ) 43 | self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False) 44 | 45 | 46 | def forward(self, x_t, h_t, c_t, m_t): 47 | x_concat = self.conv_x(x_t) 48 | h_concat = self.conv_h(h_t) 49 | m_concat = self.conv_m(m_t) 50 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1) 51 | i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1) 52 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 53 | 54 | i_t = torch.sigmoid(i_x + i_h) 55 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 56 | g_t = torch.tanh(g_x + g_h) 57 | 58 | delta_c = i_t * g_t 59 | c_new = f_t * c_t + delta_c 60 | 61 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 62 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 63 | g_t_prime = torch.tanh(g_x_prime + g_m) 64 | 65 | delta_m = i_t_prime * g_t_prime 66 | m_new = f_t_prime * m_t + delta_m 67 | 68 | mem = torch.cat((c_new, m_new), 1) 69 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 70 | h_new = o_t * torch.tanh(self.conv_last(mem)) 71 | 72 | return h_new, c_new, m_new, delta_c, delta_m 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /core/layers/SpatioTemporalLSTMCell_v2_action.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SpatioTemporalLSTMCell(nn.Module): 8 | def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): 9 | super(SpatioTemporalLSTMCell, self).__init__() 10 | 11 | self.num_hidden = num_hidden 12 | self.padding = filter_size // 2 13 | self._forget_bias = 1.0 14 | if layer_norm: 15 | self.conv_x = nn.Sequential( 16 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding), 17 | nn.LayerNorm([num_hidden * 7, width, width]) 18 | ) 19 | self.conv_h = nn.Sequential( 20 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 21 | nn.LayerNorm([num_hidden * 4, width, width]) 22 | ) 23 | self.conv_a = nn.Sequential( 24 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 25 | nn.LayerNorm([num_hidden * 4, width, width]) 26 | ) 27 | self.conv_m = nn.Sequential( 28 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding), 29 | nn.LayerNorm([num_hidden * 3, width, width]) 30 | ) 31 | self.conv_o = nn.Sequential( 32 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), 33 | nn.LayerNorm([num_hidden, width, width]) 34 | ) 35 | else: 36 | self.conv_x = nn.Sequential( 37 | nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding), 38 | ) 39 | self.conv_h = nn.Sequential( 40 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 41 | ) 42 | self.conv_a = nn.Sequential( 43 | nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding), 44 | ) 45 | self.conv_m = nn.Sequential( 46 | nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding), 47 | ) 48 | self.conv_o = nn.Sequential( 49 | nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), 50 | ) 51 | self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0) 52 | 53 | def forward(self, x_t, h_t, c_t, m_t, a_t): 54 | x_concat = self.conv_x(x_t) 55 | h_concat = self.conv_h(h_t) 56 | a_concat = self.conv_a(a_t) 57 | m_concat = self.conv_m(m_t) 58 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1) 59 | i_h, f_h, g_h, o_h = torch.split(h_concat * a_concat, self.num_hidden, dim=1) 60 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 61 | 62 | i_t = torch.sigmoid(i_x + i_h) 63 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 64 | g_t = torch.tanh(g_x + g_h) 65 | 66 | delta_c = i_t * g_t 67 | c_new = f_t * c_t + delta_c 68 | 69 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 70 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 71 | g_t_prime = torch.tanh(g_x_prime + g_m) 72 | 73 | delta_m = i_t_prime * g_t_prime 74 | m_new = f_t_prime * m_t + delta_m 75 | 76 | mem = torch.cat((c_new, m_new), 1) 77 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 78 | h_new = o_t * torch.tanh(self.conv_last(mem)) 79 | 80 | return h_new, c_new, m_new, delta_c, delta_m 81 | -------------------------------------------------------------------------------- /core/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/core/layers/__init__.py -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/core/models/__init__.py -------------------------------------------------------------------------------- /core/models/action_cond_predrnn.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | from core.layers.SpatioTemporalLSTMCell_action import SpatioTemporalLSTMCell 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, num_layers, num_hidden, configs): 10 | super(RNN, self).__init__() 11 | 12 | self.configs = configs 13 | self.conv_on_input = self.configs.conv_on_input 14 | self.res_on_conv = self.configs.res_on_conv 15 | self.patch_height = configs.img_width // configs.patch_size 16 | self.patch_width = configs.img_width // configs.patch_size 17 | self.patch_ch = configs.img_channel * (configs.patch_size ** 2) 18 | self.action_ch = configs.num_action_ch 19 | self.rnn_height = self.patch_height 20 | self.rnn_width = self.patch_width 21 | 22 | if self.configs.conv_on_input == 1: 23 | self.rnn_height = self.patch_height // 4 24 | self.rnn_width = self.patch_width // 4 25 | self.conv_input1 = nn.Conv2d(self.patch_ch, num_hidden[0] // 2, 26 | configs.filter_size, 27 | stride=2, padding=configs.filter_size // 2, bias=False) 28 | self.conv_input2 = nn.Conv2d(num_hidden[0] // 2, num_hidden[0], configs.filter_size, stride=2, 29 | padding=configs.filter_size // 2, bias=False) 30 | self.action_conv_input1 = nn.Conv2d(self.action_ch, num_hidden[0] // 2, 31 | configs.filter_size, 32 | stride=2, padding=configs.filter_size // 2, bias=False) 33 | self.action_conv_input2 = nn.Conv2d(num_hidden[0] // 2, num_hidden[0], configs.filter_size, stride=2, 34 | padding=configs.filter_size // 2, bias=False) 35 | self.deconv_output1 = nn.ConvTranspose2d(num_hidden[num_layers - 1], num_hidden[num_layers - 1] // 2, 36 | configs.filter_size, stride=2, padding=configs.filter_size // 2, 37 | bias=False) 38 | self.deconv_output2 = nn.ConvTranspose2d(num_hidden[num_layers - 1] // 2, self.patch_ch, 39 | configs.filter_size, stride=2, padding=configs.filter_size // 2, 40 | bias=False) 41 | self.num_layers = num_layers 42 | self.num_hidden = num_hidden 43 | cell_list = [] 44 | self.beta = configs.decouple_beta 45 | self.MSE_criterion = nn.MSELoss().cuda() 46 | self.norm_criterion = nn.SmoothL1Loss().cuda() 47 | 48 | for i in range(num_layers): 49 | if i == 0: 50 | in_channel = self.patch_ch + self.action_ch if self.configs.conv_on_input == 0 else num_hidden[0] 51 | else: 52 | in_channel = num_hidden[i - 1] 53 | cell_list.append( 54 | SpatioTemporalLSTMCell(in_channel, num_hidden[i], self.rnn_width, 55 | configs.filter_size, configs.stride, configs.layer_norm) 56 | ) 57 | self.cell_list = nn.ModuleList(cell_list) 58 | if self.configs.conv_on_input == 0: 59 | self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.patch_ch + self.action_ch, 1, stride=1, 60 | padding=0, bias=False) 61 | 62 | def forward(self, all_frames, mask_true): 63 | # [batch, length, height, width, channel] -> [batch, length, channel, height, width] 64 | frames = all_frames.permute(0, 1, 4, 2, 3).contiguous() 65 | input_frames = frames[:, :, :self.patch_ch, :, :] 66 | input_actions = frames[:, :, self.patch_ch:, :, :] 67 | mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous() 68 | 69 | next_frames = [] 70 | h_t = [] 71 | c_t = [] 72 | 73 | for i in range(self.num_layers): 74 | zeros = torch.zeros( 75 | [self.configs.batch_size, self.num_hidden[i], self.rnn_height, self.rnn_width]).cuda() 76 | h_t.append(zeros) 77 | c_t.append(zeros) 78 | 79 | memory = torch.zeros([self.configs.batch_size, self.num_hidden[0], self.rnn_height, self.rnn_width]).cuda() 80 | 81 | for t in range(self.configs.total_length - 1): 82 | if t == 0: 83 | net = input_frames[:, t] 84 | else: 85 | net = mask_true[:, t - 1] * input_frames[:, t] + \ 86 | (1 - mask_true[:, t - 1]) * x_gen 87 | action = input_actions[:, t] 88 | 89 | if self.conv_on_input == 1: 90 | net_shape1 = net.size() 91 | net = self.conv_input1(net) 92 | if self.res_on_conv == 1: 93 | input_net1 = net 94 | net_shape2 = net.size() 95 | net = self.conv_input2(net) 96 | if self.res_on_conv == 1: 97 | input_net2 = net 98 | action = self.action_conv_input1(action) 99 | action = self.action_conv_input2(action) 100 | 101 | h_t[0], c_t[0], memory = self.cell_list[0](net, h_t[0], c_t[0], memory, action) 102 | 103 | for i in range(1, self.num_layers): 104 | h_t[i], c_t[i], memory = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory, action) 105 | 106 | if self.conv_on_input == 1: 107 | if self.res_on_conv == 1: 108 | x_gen = self.deconv_output1(h_t[self.num_layers - 1] + input_net2, output_size=net_shape2) 109 | x_gen = self.deconv_output2(x_gen + input_net1, output_size=net_shape1) 110 | else: 111 | x_gen = self.deconv_output1(h_t[self.num_layers - 1], output_size=net_shape2) 112 | x_gen = self.deconv_output2(x_gen, output_size=net_shape1) 113 | else: 114 | x_gen = self.conv_last(h_t[self.num_layers - 1]) 115 | next_frames.append(x_gen) 116 | 117 | # [length, batch, channel, height, width] -> [batch, length, height, width, channel] 118 | next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous() 119 | loss = self.MSE_criterion(next_frames, all_frames[:, 1:, :, :, :next_frames.shape[4]]) 120 | next_frames = next_frames[:, :, :, :, :self.patch_ch] 121 | return next_frames, loss 122 | -------------------------------------------------------------------------------- /core/models/action_cond_predrnn_v2.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | from core.layers.SpatioTemporalLSTMCell_v2_action import SpatioTemporalLSTMCell 6 | import torch.nn.functional as F 7 | 8 | 9 | class RNN(nn.Module): 10 | def __init__(self, num_layers, num_hidden, configs): 11 | super(RNN, self).__init__() 12 | 13 | self.configs = configs 14 | self.conv_on_input = self.configs.conv_on_input 15 | self.res_on_conv = self.configs.res_on_conv 16 | self.patch_height = configs.img_width // configs.patch_size 17 | self.patch_width = configs.img_width // configs.patch_size 18 | self.patch_ch = configs.img_channel * (configs.patch_size ** 2) 19 | self.action_ch = configs.num_action_ch 20 | self.rnn_height = self.patch_height 21 | self.rnn_width = self.patch_width 22 | 23 | if self.configs.conv_on_input == 1: 24 | self.rnn_height = self.patch_height // 4 25 | self.rnn_width = self.patch_width // 4 26 | self.conv_input1 = nn.Conv2d(self.patch_ch, num_hidden[0] // 2, 27 | configs.filter_size, 28 | stride=2, padding=configs.filter_size // 2, bias=False) 29 | self.conv_input2 = nn.Conv2d(num_hidden[0] // 2, num_hidden[0], configs.filter_size, stride=2, 30 | padding=configs.filter_size // 2, bias=False) 31 | self.action_conv_input1 = nn.Conv2d(self.action_ch, num_hidden[0] // 2, 32 | configs.filter_size, 33 | stride=2, padding=configs.filter_size // 2, bias=False) 34 | self.action_conv_input2 = nn.Conv2d(num_hidden[0] // 2, num_hidden[0], configs.filter_size, stride=2, 35 | padding=configs.filter_size // 2, bias=False) 36 | self.deconv_output1 = nn.ConvTranspose2d(num_hidden[num_layers - 1], num_hidden[num_layers - 1] // 2, 37 | configs.filter_size, stride=2, padding=configs.filter_size // 2, 38 | bias=False) 39 | self.deconv_output2 = nn.ConvTranspose2d(num_hidden[num_layers - 1] // 2, self.patch_ch, 40 | configs.filter_size, stride=2, padding=configs.filter_size // 2, 41 | bias=False) 42 | self.num_layers = num_layers 43 | self.num_hidden = num_hidden 44 | cell_list = [] 45 | self.beta = configs.decouple_beta 46 | self.MSE_criterion = nn.MSELoss().cuda() 47 | self.norm_criterion = nn.SmoothL1Loss().cuda() 48 | 49 | for i in range(num_layers): 50 | if i == 0: 51 | in_channel = self.patch_ch + self.action_ch if self.configs.conv_on_input == 0 else num_hidden[0] 52 | else: 53 | in_channel = num_hidden[i - 1] 54 | cell_list.append( 55 | SpatioTemporalLSTMCell(in_channel, num_hidden[i], self.rnn_width, 56 | configs.filter_size, configs.stride, configs.layer_norm) 57 | ) 58 | self.cell_list = nn.ModuleList(cell_list) 59 | if self.configs.conv_on_input == 0: 60 | self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.patch_ch + self.action_ch, 1, stride=1, 61 | padding=0, bias=False) 62 | self.adapter = nn.Conv2d(num_hidden[num_layers - 1], num_hidden[num_layers - 1], 1, stride=1, padding=0, 63 | bias=False) 64 | 65 | def forward(self, all_frames, mask_true): 66 | # [batch, length, height, width, channel] -> [batch, length, channel, height, width] 67 | frames = all_frames.permute(0, 1, 4, 2, 3).contiguous() 68 | input_frames = frames[:, :, :self.patch_ch, :, :] 69 | input_actions = frames[:, :, self.patch_ch:, :, :] 70 | mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous() 71 | 72 | next_frames = [] 73 | h_t = [] 74 | c_t = [] 75 | delta_c_list = [] 76 | delta_m_list = [] 77 | 78 | for i in range(self.num_layers): 79 | zeros = torch.zeros( 80 | [self.configs.batch_size, self.num_hidden[i], self.rnn_height, self.rnn_width]).cuda() 81 | h_t.append(zeros) 82 | c_t.append(zeros) 83 | delta_c_list.append(zeros) 84 | delta_m_list.append(zeros) 85 | 86 | decouple_loss = [] 87 | memory = torch.zeros([self.configs.batch_size, self.num_hidden[0], self.rnn_height, self.rnn_width]).cuda() 88 | 89 | for t in range(self.configs.total_length - 1): 90 | if t == 0: 91 | net = input_frames[:, t] 92 | else: 93 | net = mask_true[:, t - 1] * input_frames[:, t] + \ 94 | (1 - mask_true[:, t - 1]) * x_gen 95 | action = input_actions[:, t] 96 | 97 | if self.conv_on_input == 1: 98 | net_shape1 = net.size() 99 | net = self.conv_input1(net) 100 | if self.res_on_conv == 1: 101 | input_net1 = net 102 | net_shape2 = net.size() 103 | net = self.conv_input2(net) 104 | if self.res_on_conv == 1: 105 | input_net2 = net 106 | action = self.action_conv_input1(action) 107 | action = self.action_conv_input2(action) 108 | 109 | h_t[0], c_t[0], memory, delta_c, delta_m = self.cell_list[0](net, h_t[0], c_t[0], memory, action) 110 | delta_c_list[0] = F.normalize(self.adapter(delta_c).view(delta_c.shape[0], delta_c.shape[1], -1), dim=2) 111 | delta_m_list[0] = F.normalize(self.adapter(delta_m).view(delta_m.shape[0], delta_m.shape[1], -1), dim=2) 112 | 113 | for i in range(1, self.num_layers): 114 | h_t[i], c_t[i], memory, delta_c, delta_m = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory, action) 115 | delta_c_list[i] = F.normalize(self.adapter(delta_c).view(delta_c.shape[0], delta_c.shape[1], -1), dim=2) 116 | delta_m_list[i] = F.normalize(self.adapter(delta_m).view(delta_m.shape[0], delta_m.shape[1], -1), dim=2) 117 | 118 | for i in range(0, self.num_layers): 119 | decouple_loss.append(torch.mean(torch.abs( 120 | torch.cosine_similarity(delta_c_list[i], delta_m_list[i], dim=2)))) 121 | if self.conv_on_input == 1: 122 | if self.res_on_conv == 1: 123 | x_gen = self.deconv_output1(h_t[self.num_layers - 1] + input_net2, output_size=net_shape2) 124 | x_gen = self.deconv_output2(x_gen + input_net1, output_size=net_shape1) 125 | else: 126 | x_gen = self.deconv_output1(h_t[self.num_layers - 1], output_size=net_shape2) 127 | x_gen = self.deconv_output2(x_gen, output_size=net_shape1) 128 | else: 129 | x_gen = self.conv_last(h_t[self.num_layers - 1]) 130 | next_frames.append(x_gen) 131 | 132 | decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0)) 133 | # [length, batch, channel, height, width] -> [batch, length, height, width, channel] 134 | next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous() 135 | loss = self.MSE_criterion(next_frames, 136 | all_frames[:, 1:, :, :, :next_frames.shape[4]]) + self.beta * decouple_loss 137 | next_frames = next_frames[:, :, :, :, :self.patch_ch] 138 | return next_frames, loss 139 | -------------------------------------------------------------------------------- /core/models/model_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.optim import Adam 4 | from core.models import predrnn, predrnn_v2, action_cond_predrnn, action_cond_predrnn_v2 5 | 6 | class Model(object): 7 | def __init__(self, configs): 8 | self.configs = configs 9 | self.num_hidden = [int(x) for x in configs.num_hidden.split(',')] 10 | self.num_layers = len(self.num_hidden) 11 | networks_map = { 12 | 'predrnn': predrnn.RNN, 13 | 'predrnn_v2': predrnn_v2.RNN, 14 | 'action_cond_predrnn': action_cond_predrnn.RNN, 15 | 'action_cond_predrnn_v2': action_cond_predrnn_v2.RNN, 16 | } 17 | 18 | if configs.model_name in networks_map: 19 | Network = networks_map[configs.model_name] 20 | self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device) 21 | else: 22 | raise ValueError('Name of network unknown %s' % configs.model_name) 23 | 24 | self.optimizer = Adam(self.network.parameters(), lr=configs.lr) 25 | 26 | def save(self, itr): 27 | stats = {} 28 | stats['net_param'] = self.network.state_dict() 29 | checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt'+'-'+str(itr)) 30 | torch.save(stats, checkpoint_path) 31 | print("save model to %s" % checkpoint_path) 32 | 33 | def load(self, checkpoint_path): 34 | print('load model:', checkpoint_path) 35 | stats = torch.load(checkpoint_path) 36 | self.network.load_state_dict(stats['net_param']) 37 | 38 | def train(self, frames, mask): 39 | frames_tensor = torch.FloatTensor(frames).to(self.configs.device) 40 | mask_tensor = torch.FloatTensor(mask).to(self.configs.device) 41 | self.optimizer.zero_grad() 42 | next_frames, loss = self.network(frames_tensor, mask_tensor) 43 | loss.backward() 44 | self.optimizer.step() 45 | return loss.detach().cpu().numpy() 46 | 47 | def test(self, frames, mask): 48 | frames_tensor = torch.FloatTensor(frames).to(self.configs.device) 49 | mask_tensor = torch.FloatTensor(mask).to(self.configs.device) 50 | next_frames, _ = self.network(frames_tensor, mask_tensor) 51 | return next_frames.detach().cpu().numpy() -------------------------------------------------------------------------------- /core/models/predrnn.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | from core.layers.SpatioTemporalLSTMCell import SpatioTemporalLSTMCell 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, num_layers, num_hidden, configs): 10 | super(RNN, self).__init__() 11 | 12 | self.configs = configs 13 | self.frame_channel = configs.patch_size * configs.patch_size * configs.img_channel 14 | self.num_layers = num_layers 15 | self.num_hidden = num_hidden 16 | cell_list = [] 17 | 18 | width = configs.img_width // configs.patch_size 19 | self.MSE_criterion = nn.MSELoss() 20 | 21 | for i in range(num_layers): 22 | in_channel = self.frame_channel if i == 0 else num_hidden[i - 1] 23 | cell_list.append( 24 | SpatioTemporalLSTMCell(in_channel, num_hidden[i], width, configs.filter_size, 25 | configs.stride, configs.layer_norm) 26 | ) 27 | self.cell_list = nn.ModuleList(cell_list) 28 | self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel, 29 | kernel_size=1, stride=1, padding=0, bias=False) 30 | 31 | def forward(self, frames_tensor, mask_true): 32 | # [batch, length, height, width, channel] -> [batch, length, channel, height, width] 33 | frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous() 34 | mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous() 35 | 36 | batch = frames.shape[0] 37 | height = frames.shape[3] 38 | width = frames.shape[4] 39 | 40 | next_frames = [] 41 | h_t = [] 42 | c_t = [] 43 | 44 | for i in range(self.num_layers): 45 | zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device) 46 | h_t.append(zeros) 47 | c_t.append(zeros) 48 | 49 | memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device) 50 | 51 | for t in range(self.configs.total_length - 1): 52 | # reverse schedule sampling 53 | if self.configs.reverse_scheduled_sampling == 1: 54 | if t == 0: 55 | net = frames[:, t] 56 | else: 57 | net = mask_true[:, t - 1] * frames[:, t] + (1 - mask_true[:, t - 1]) * x_gen 58 | else: 59 | if t < self.configs.input_length: 60 | net = frames[:, t] 61 | else: 62 | net = mask_true[:, t - self.configs.input_length] * frames[:, t] + \ 63 | (1 - mask_true[:, t - self.configs.input_length]) * x_gen 64 | 65 | h_t[0], c_t[0], memory = self.cell_list[0](net, h_t[0], c_t[0], memory) 66 | 67 | for i in range(1, self.num_layers): 68 | h_t[i], c_t[i], memory = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory) 69 | 70 | x_gen = self.conv_last(h_t[self.num_layers - 1]) 71 | next_frames.append(x_gen) 72 | 73 | # [length, batch, channel, height, width] -> [batch, length, height, width, channel] 74 | next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous() 75 | loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:]) 76 | return next_frames, loss 77 | -------------------------------------------------------------------------------- /core/models/predrnn_v2.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import torch 4 | import torch.nn as nn 5 | from core.layers.SpatioTemporalLSTMCell_v2 import SpatioTemporalLSTMCell 6 | import torch.nn.functional as F 7 | from core.utils.tsne import visualization 8 | 9 | 10 | class RNN(nn.Module): 11 | def __init__(self, num_layers, num_hidden, configs): 12 | super(RNN, self).__init__() 13 | 14 | self.configs = configs 15 | self.visual = self.configs.visual 16 | self.visual_path = self.configs.visual_path 17 | 18 | self.frame_channel = configs.patch_size * configs.patch_size * configs.img_channel 19 | self.num_layers = num_layers 20 | self.num_hidden = num_hidden 21 | cell_list = [] 22 | 23 | width = configs.img_width // configs.patch_size 24 | self.MSE_criterion = nn.MSELoss() 25 | 26 | for i in range(num_layers): 27 | in_channel = self.frame_channel if i == 0 else num_hidden[i - 1] 28 | cell_list.append( 29 | SpatioTemporalLSTMCell(in_channel, num_hidden[i], width, configs.filter_size, 30 | configs.stride, configs.layer_norm) 31 | ) 32 | self.cell_list = nn.ModuleList(cell_list) 33 | self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel, kernel_size=1, stride=1, padding=0, 34 | bias=False) 35 | # shared adapter 36 | adapter_num_hidden = num_hidden[0] 37 | self.adapter = nn.Conv2d(adapter_num_hidden, adapter_num_hidden, 1, stride=1, padding=0, bias=False) 38 | 39 | def forward(self, frames_tensor, mask_true): 40 | # [batch, length, height, width, channel] -> [batch, length, channel, height, width] 41 | frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous() 42 | mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous() 43 | 44 | batch = frames.shape[0] 45 | height = frames.shape[3] 46 | width = frames.shape[4] 47 | 48 | next_frames = [] 49 | h_t = [] 50 | c_t = [] 51 | delta_c_list = [] 52 | delta_m_list = [] 53 | if self.visual: 54 | delta_c_visual = [] 55 | delta_m_visual = [] 56 | 57 | decouple_loss = [] 58 | 59 | for i in range(self.num_layers): 60 | zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device) 61 | h_t.append(zeros) 62 | c_t.append(zeros) 63 | delta_c_list.append(zeros) 64 | delta_m_list.append(zeros) 65 | 66 | memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device) 67 | 68 | for t in range(self.configs.total_length - 1): 69 | 70 | if self.configs.reverse_scheduled_sampling == 1: 71 | # reverse schedule sampling 72 | if t == 0: 73 | net = frames[:, t] 74 | else: 75 | net = mask_true[:, t - 1] * frames[:, t] + (1 - mask_true[:, t - 1]) * x_gen 76 | else: 77 | # schedule sampling 78 | if t < self.configs.input_length: 79 | net = frames[:, t] 80 | else: 81 | net = mask_true[:, t - self.configs.input_length] * frames[:, t] + \ 82 | (1 - mask_true[:, t - self.configs.input_length]) * x_gen 83 | 84 | h_t[0], c_t[0], memory, delta_c, delta_m = self.cell_list[0](net, h_t[0], c_t[0], memory) 85 | delta_c_list[0] = F.normalize(self.adapter(delta_c).view(delta_c.shape[0], delta_c.shape[1], -1), dim=2) 86 | delta_m_list[0] = F.normalize(self.adapter(delta_m).view(delta_m.shape[0], delta_m.shape[1], -1), dim=2) 87 | if self.visual: 88 | delta_c_visual.append(delta_c.view(delta_c.shape[0], delta_c.shape[1], -1)) 89 | delta_m_visual.append(delta_m.view(delta_m.shape[0], delta_m.shape[1], -1)) 90 | 91 | for i in range(1, self.num_layers): 92 | h_t[i], c_t[i], memory, delta_c, delta_m = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory) 93 | delta_c_list[i] = F.normalize(self.adapter(delta_c).view(delta_c.shape[0], delta_c.shape[1], -1), dim=2) 94 | delta_m_list[i] = F.normalize(self.adapter(delta_m).view(delta_m.shape[0], delta_m.shape[1], -1), dim=2) 95 | if self.visual: 96 | delta_c_visual.append(delta_c.view(delta_c.shape[0], delta_c.shape[1], -1)) 97 | delta_m_visual.append(delta_m.view(delta_m.shape[0], delta_m.shape[1], -1)) 98 | 99 | x_gen = self.conv_last(h_t[self.num_layers - 1]) 100 | next_frames.append(x_gen) 101 | # decoupling loss 102 | for i in range(0, self.num_layers): 103 | decouple_loss.append( 104 | torch.mean(torch.abs(torch.cosine_similarity(delta_c_list[i], delta_m_list[i], dim=2)))) 105 | 106 | if self.visual: 107 | # visualization of delta_c and delta_m 108 | delta_c_visual = torch.stack(delta_c_visual, dim=0) 109 | delta_m_visual = torch.stack(delta_m_visual, dim=0) 110 | visualization(self.configs.total_length, self.num_layers, delta_c_visual, delta_m_visual, self.visual_path) 111 | self.visual = 0 112 | 113 | decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0)) 114 | # [length, batch, channel, height, width] -> [batch, length, height, width, channel] 115 | next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous() 116 | loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:]) + self.configs.decouple_beta * decouple_loss 117 | return next_frames, loss 118 | -------------------------------------------------------------------------------- /core/trainer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import datetime 3 | import cv2 4 | import numpy as np 5 | from skimage.metrics import structural_similarity as compare_ssim 6 | from core.utils import preprocess, metrics 7 | import lpips 8 | import torch 9 | 10 | loss_fn_alex = lpips.LPIPS(net='alex') 11 | 12 | 13 | def train(model, ims, real_input_flag, configs, itr): 14 | cost = model.train(ims, real_input_flag) 15 | if configs.reverse_input: 16 | ims_rev = np.flip(ims, axis=1).copy() 17 | cost += model.train(ims_rev, real_input_flag) 18 | cost = cost / 2 19 | 20 | if itr % configs.display_interval == 0: 21 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr)) 22 | print('training loss: ' + str(cost)) 23 | 24 | 25 | def test(model, test_input_handle, configs, itr): 26 | print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') 27 | test_input_handle.begin(do_shuffle=False) 28 | res_path = os.path.join(configs.gen_frm_dir, str(itr)) 29 | os.mkdir(res_path) 30 | avg_mse = 0 31 | batch_id = 0 32 | img_mse, ssim, psnr = [], [], [] 33 | lp = [] 34 | 35 | for i in range(configs.total_length - configs.input_length): 36 | img_mse.append(0) 37 | ssim.append(0) 38 | psnr.append(0) 39 | lp.append(0) 40 | 41 | # reverse schedule sampling 42 | if configs.reverse_scheduled_sampling == 1: 43 | mask_input = 1 44 | else: 45 | mask_input = configs.input_length 46 | 47 | real_input_flag = np.zeros( 48 | (configs.batch_size, 49 | configs.total_length - mask_input - 1, 50 | configs.img_width // configs.patch_size, 51 | configs.img_width // configs.patch_size, 52 | configs.patch_size ** 2 * configs.img_channel)) 53 | 54 | if configs.reverse_scheduled_sampling == 1: 55 | real_input_flag[:, :configs.input_length - 1, :, :] = 1.0 56 | 57 | while (test_input_handle.no_batch_left() == False): 58 | batch_id = batch_id + 1 59 | test_ims = test_input_handle.get_batch() 60 | test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) 61 | test_ims = test_ims[:, :, :, :, :configs.img_channel] 62 | img_gen = model.test(test_dat, real_input_flag) 63 | 64 | img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) 65 | output_length = configs.total_length - configs.input_length 66 | img_out = img_gen[:, -output_length:] 67 | 68 | # MSE per frame 69 | for i in range(output_length): 70 | x = test_ims[:, i + configs.input_length, :, :, :] 71 | gx = img_out[:, i, :, :, :] 72 | gx = np.maximum(gx, 0) 73 | gx = np.minimum(gx, 1) 74 | mse = np.square(x - gx).sum() 75 | img_mse[i] += mse 76 | avg_mse += mse 77 | # cal lpips 78 | img_x = np.zeros([configs.batch_size, 3, configs.img_width, configs.img_width]) 79 | if configs.img_channel == 3: 80 | img_x[:, 0, :, :] = x[:, :, :, 0] 81 | img_x[:, 1, :, :] = x[:, :, :, 1] 82 | img_x[:, 2, :, :] = x[:, :, :, 2] 83 | else: 84 | img_x[:, 0, :, :] = x[:, :, :, 0] 85 | img_x[:, 1, :, :] = x[:, :, :, 0] 86 | img_x[:, 2, :, :] = x[:, :, :, 0] 87 | img_x = torch.FloatTensor(img_x) 88 | img_gx = np.zeros([configs.batch_size, 3, configs.img_width, configs.img_width]) 89 | if configs.img_channel == 3: 90 | img_gx[:, 0, :, :] = gx[:, :, :, 0] 91 | img_gx[:, 1, :, :] = gx[:, :, :, 1] 92 | img_gx[:, 2, :, :] = gx[:, :, :, 2] 93 | else: 94 | img_gx[:, 0, :, :] = gx[:, :, :, 0] 95 | img_gx[:, 1, :, :] = gx[:, :, :, 0] 96 | img_gx[:, 2, :, :] = gx[:, :, :, 0] 97 | img_gx = torch.FloatTensor(img_gx) 98 | lp_loss = loss_fn_alex(img_x, img_gx) 99 | lp[i] += torch.mean(lp_loss).item() 100 | 101 | real_frm = np.uint8(x * 255) 102 | pred_frm = np.uint8(gx * 255) 103 | 104 | psnr[i] += metrics.batch_psnr(pred_frm, real_frm) 105 | for b in range(configs.batch_size): 106 | score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) 107 | ssim[i] += score 108 | 109 | # save prediction examples 110 | if batch_id <= configs.num_save_samples: 111 | path = os.path.join(res_path, str(batch_id)) 112 | os.mkdir(path) 113 | for i in range(configs.total_length): 114 | name = 'gt' + str(i + 1) + '.png' 115 | file_name = os.path.join(path, name) 116 | img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) 117 | cv2.imwrite(file_name, img_gt) 118 | for i in range(output_length): 119 | name = 'pd' + str(i + 1 + configs.input_length) + '.png' 120 | file_name = os.path.join(path, name) 121 | img_pd = img_out[0, i, :, :, :] 122 | img_pd = np.maximum(img_pd, 0) 123 | img_pd = np.minimum(img_pd, 1) 124 | img_pd = np.uint8(img_pd * 255) 125 | cv2.imwrite(file_name, img_pd) 126 | test_input_handle.next() 127 | 128 | avg_mse = avg_mse / (batch_id * configs.batch_size) 129 | print('mse per seq: ' + str(avg_mse)) 130 | for i in range(configs.total_length - configs.input_length): 131 | print(img_mse[i] / (batch_id * configs.batch_size)) 132 | 133 | ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) 134 | print('ssim per frame: ' + str(np.mean(ssim))) 135 | for i in range(configs.total_length - configs.input_length): 136 | print(ssim[i]) 137 | 138 | psnr = np.asarray(psnr, dtype=np.float32) / batch_id 139 | print('psnr per frame: ' + str(np.mean(psnr))) 140 | for i in range(configs.total_length - configs.input_length): 141 | print(psnr[i]) 142 | 143 | lp = np.asarray(lp, dtype=np.float32) / batch_id 144 | print('lpips per frame: ' + str(np.mean(lp))) 145 | for i in range(configs.total_length - configs.input_length): 146 | print(lp[i]) 147 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/metrics.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | 5 | def batch_psnr(gen_frames, gt_frames): 6 | if gen_frames.ndim == 3: 7 | axis = (1, 2) 8 | elif gen_frames.ndim == 4: 9 | axis = (1, 2, 3) 10 | x = np.int32(gen_frames) 11 | y = np.int32(gt_frames) 12 | num_pixels = float(np.size(gen_frames[0])) 13 | mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels 14 | psnr = 20 * np.log10(255) - 10 * np.log10(mse) 15 | return np.mean(psnr) -------------------------------------------------------------------------------- /core/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import numpy as np 4 | 5 | def reshape_patch(img_tensor, patch_size): 6 | assert 5 == img_tensor.ndim 7 | batch_size = np.shape(img_tensor)[0] 8 | seq_length = np.shape(img_tensor)[1] 9 | img_height = np.shape(img_tensor)[2] 10 | img_width = np.shape(img_tensor)[3] 11 | num_channels = np.shape(img_tensor)[4] 12 | a = np.reshape(img_tensor, [batch_size, seq_length, 13 | img_height//patch_size, patch_size, 14 | img_width//patch_size, patch_size, 15 | num_channels]) 16 | b = np.transpose(a, [0,1,2,4,3,5,6]) 17 | patch_tensor = np.reshape(b, [batch_size, seq_length, 18 | img_height//patch_size, 19 | img_width//patch_size, 20 | patch_size*patch_size*num_channels]) 21 | return patch_tensor 22 | 23 | def reshape_patch_back(patch_tensor, patch_size): 24 | assert 5 == patch_tensor.ndim 25 | batch_size = np.shape(patch_tensor)[0] 26 | seq_length = np.shape(patch_tensor)[1] 27 | patch_height = np.shape(patch_tensor)[2] 28 | patch_width = np.shape(patch_tensor)[3] 29 | channels = np.shape(patch_tensor)[4] 30 | img_channels = channels // (patch_size*patch_size) 31 | a = np.reshape(patch_tensor, [batch_size, seq_length, 32 | patch_height, patch_width, 33 | patch_size, patch_size, 34 | img_channels]) 35 | b = np.transpose(a, [0,1,2,4,3,5,6]) 36 | img_tensor = np.reshape(b, [batch_size, seq_length, 37 | patch_height * patch_size, 38 | patch_width * patch_size, 39 | img_channels]) 40 | return img_tensor 41 | 42 | -------------------------------------------------------------------------------- /core/utils/tsne.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.manifold import TSNE 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch 6 | import os 7 | import shutil 8 | 9 | 10 | def scatter(x, colors, file_name, class_num): 11 | f = plt.figure(figsize=(226 / 15, 212 / 15)) 12 | ax = plt.subplot(aspect='equal') 13 | color_pen = ['black', 'r'] 14 | # relabel 15 | my_legend = ['Delta_C', 'Delta_M'] 16 | label_set = [] 17 | label_set.append(colors[0]) 18 | for i in range(1, len(colors)): 19 | flag = 1 20 | for j in range(len(label_set)): 21 | if label_set[j] == colors[i]: 22 | flag = 0 23 | break 24 | if flag: 25 | label_set.append(colors[i]) 26 | # draw 27 | for i in range(class_num): 28 | ax.scatter(x[colors == label_set[i], 0], x[colors == label_set[i], 1], lw=0, s=70, c=color_pen[i], 29 | label=str(my_legend[i])) 30 | ax.set_xticks([]) 31 | ax.set_yticks([]) 32 | ax.axis('tight') 33 | ax.legend(loc='upper right') 34 | f.savefig(file_name + ".png", bbox_inches='tight') 35 | print(file_name + ' save finished') 36 | 37 | 38 | def plot_TSNE(data, label, path, title, class_num): 39 | colors = label 40 | all_features_np = data 41 | tsne_features = TSNE(random_state=20190129).fit_transform(all_features_np) 42 | scatter(tsne_features, colors, os.path.join(path, title), class_num) 43 | 44 | 45 | def visualization(length, layers, c, m, path, elements=10): 46 | ''' 47 | visualization of memory cells decoupling 48 | :param length: sequence length 49 | :param layers: stacked predictive layers 50 | :param c: variables 51 | :param m: variables 52 | :param path: save path 53 | :param elements: select top k element to visualization 54 | :return: 55 | ''' 56 | if os.path.exists(path): 57 | shutil.rmtree(path) 58 | os.makedirs(path) 59 | for t in range(length - 1): 60 | for i in range(layers): 61 | data = [] 62 | label = [] 63 | for j in range(c[layers * t + i].shape[0]): 64 | for k in range(c[layers * t + i].shape[1]): 65 | # choose the most dominated variables to the similarity 66 | value1, index1 = torch.topk(c[layers * t + i, j, k], elements) 67 | value2, index2 = torch.topk(m[layers * t + i, j, k], elements) 68 | # c [c_topk, elements in m_topk pos] 69 | c_key = F.normalize(torch.cat([value1, c[layers * t + i, j, k, index2]], dim=0), 70 | dim=0).detach().cpu().numpy().tolist() 71 | data.append(c_key) 72 | label.append(0) 73 | # m [elements in c_topk pos, m_topk] 74 | m_key = F.normalize(torch.cat([m[layers * t + i, j, k, index1], value2], dim=0), 75 | dim=0).detach().cpu().numpy().tolist() 76 | data.append(m_key) 77 | label.append(1) 78 | plot_TSNE(np.array(data), np.array(label), path, 'case_' + str(j) + '_tsne_' + str(i) + '_' + str(t), 2) 79 | -------------------------------------------------------------------------------- /kth_script/predrnn_kth_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name action \ 7 | --train_data_paths /workspace/wuhaixu/predrnn/data/kth_action \ 8 | --valid_data_paths /workspace/wuhaixu/predrnn/data/kth_action \ 9 | --save_dir checkpoints/kth_predrnn \ 10 | --gen_frm_dir results/kth_predrnn \ 11 | --model_name predrnn \ 12 | --reverse_input 1 \ 13 | --img_width 128 \ 14 | --img_channel 1 \ 15 | --input_length 10 \ 16 | --total_length 20 \ 17 | --num_hidden 128,128,128,128 \ 18 | --filter_size 5 \ 19 | --stride 1 \ 20 | --patch_size 4 \ 21 | --layer_norm 0 \ 22 | --scheduled_sampling 1 \ 23 | --sampling_stop_iter 50000 \ 24 | --sampling_start_value 1.0 \ 25 | --sampling_changing_rate 0.00002 \ 26 | --lr 0.0003 \ 27 | --batch_size 4 \ 28 | --max_iterations 80000 \ 29 | --display_interval 100 \ 30 | --test_interval 5000 \ 31 | --snapshot_interval 5000 -------------------------------------------------------------------------------- /kth_script/predrnn_v2_kth_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name action \ 7 | --train_data_paths /workspace/wuhaixu/predrnn/data/kth_action \ 8 | --valid_data_paths /workspace/wuhaixu/predrnn/data/kth_action \ 9 | --save_dir checkpoints/kth_predrnn_v2 \ 10 | --gen_frm_dir results/kth_predrnn_v2 \ 11 | --model_name predrnn_v2 \ 12 | --visual 0 \ 13 | --reverse_input 1 \ 14 | --img_width 128 \ 15 | --img_channel 1 \ 16 | --input_length 10 \ 17 | --total_length 20 \ 18 | --num_hidden 128,128,128,128 \ 19 | --filter_size 5 \ 20 | --stride 1 \ 21 | --patch_size 4 \ 22 | --layer_norm 0 \ 23 | --decouple_beta 0.01 \ 24 | --reverse_scheduled_sampling 1 \ 25 | --r_sampling_step_1 5000 \ 26 | --r_sampling_step_2 50000 \ 27 | --r_exp_alpha 2000 \ 28 | --lr 0.0001 \ 29 | --batch_size 4 \ 30 | --max_iterations 80000 \ 31 | --display_interval 100 \ 32 | --test_interval 5000 \ 33 | --snapshot_interval 5000 \ 34 | # --pretrained_model ./checkpoints/kth_predrnn_v2/kth_model.ckpt -------------------------------------------------------------------------------- /mnist_script/predrnn_mnist_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name mnist \ 7 | --train_data_paths /workspace/wuhaixu/predrnn/data/moving-mnist-example/moving-mnist-train.npz \ 8 | --valid_data_paths /workspace/wuhaixu/predrnn/data/moving-mnist-example/moving-mnist-valid.npz \ 9 | --save_dir checkpoints/mnist_predrnn \ 10 | --gen_frm_dir results/mnist_predrnn \ 11 | --model_name predrnn \ 12 | --reverse_input 1 \ 13 | --img_width 64 \ 14 | --img_channel 1 \ 15 | --input_length 10 \ 16 | --total_length 20 \ 17 | --num_hidden 128,128,128,128 \ 18 | --filter_size 5 \ 19 | --stride 1 \ 20 | --patch_size 4 \ 21 | --layer_norm 0 \ 22 | --scheduled_sampling 1 \ 23 | --sampling_stop_iter 50000 \ 24 | --sampling_start_value 1.0 \ 25 | --sampling_changing_rate 0.00002 \ 26 | --lr 0.0003 \ 27 | --batch_size 8 \ 28 | --max_iterations 80000 \ 29 | --display_interval 100 \ 30 | --test_interval 5000 \ 31 | --snapshot_interval 5000 -------------------------------------------------------------------------------- /mnist_script/predrnn_v2_mnist_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | cd .. 3 | python -u run.py \ 4 | --is_training 1 \ 5 | --device cuda \ 6 | --dataset_name mnist \ 7 | --train_data_paths /workspace/wuhaixu/predrnn/data/moving-mnist-example/moving-mnist-train.npz \ 8 | --valid_data_paths /workspace/wuhaixu/predrnn/data/moving-mnist-example/moving-mnist-valid.npz \ 9 | --save_dir checkpoints/mnist_predrnn_v2 \ 10 | --gen_frm_dir results/mnist_predrnn_v2 \ 11 | --model_name predrnn_v2 \ 12 | --reverse_input 1 \ 13 | --img_width 64 \ 14 | --img_channel 1 \ 15 | --input_length 10 \ 16 | --total_length 20 \ 17 | --num_hidden 128,128,128,128 \ 18 | --filter_size 5 \ 19 | --stride 1 \ 20 | --patch_size 4 \ 21 | --layer_norm 0 \ 22 | --decouple_beta 0.1 \ 23 | --reverse_scheduled_sampling 1 \ 24 | --r_sampling_step_1 25000 \ 25 | --r_sampling_step_2 50000 \ 26 | --r_exp_alpha 2500 \ 27 | --lr 0.0001 \ 28 | --batch_size 8 \ 29 | --max_iterations 80000 \ 30 | --display_interval 100 \ 31 | --test_interval 5000 \ 32 | --snapshot_interval 5000 \ 33 | # --pretrained_model ./checkpoints/mnist_predrnn_v2/mnist_model.ckpt -------------------------------------------------------------------------------- /pic/BAIR_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/BAIR_results.png -------------------------------------------------------------------------------- /pic/Traffic4Cast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/Traffic4Cast.png -------------------------------------------------------------------------------- /pic/action_based.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/action_based.png -------------------------------------------------------------------------------- /pic/bair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/bair.png -------------------------------------------------------------------------------- /pic/decouple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/decouple.png -------------------------------------------------------------------------------- /pic/kth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/kth.png -------------------------------------------------------------------------------- /pic/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/mnist.png -------------------------------------------------------------------------------- /pic/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/network.png -------------------------------------------------------------------------------- /pic/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/radar.png -------------------------------------------------------------------------------- /pic/response.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/response.png -------------------------------------------------------------------------------- /pic/rss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/predrnn-pytorch/49d3b0f3a2757ad9c299716ec455a96868455d67/pic/rss.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yunbo' 2 | 3 | import os 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | import math 8 | from core.data_provider import datasets_factory 9 | from core.models.model_factory import Model 10 | from core.utils import preprocess 11 | import core.trainer as trainer 12 | 13 | # ----------------------------------------------------------------------------- 14 | parser = argparse.ArgumentParser(description='PyTorch video prediction model - PredRNN') 15 | 16 | # training/test 17 | parser.add_argument('--is_training', type=int, default=1) 18 | parser.add_argument('--device', type=str, default='cpu:0') 19 | 20 | # data 21 | parser.add_argument('--dataset_name', type=str, default='mnist') 22 | parser.add_argument('--train_data_paths', type=str, default='data/moving-mnist-example/moving-mnist-train.npz') 23 | parser.add_argument('--valid_data_paths', type=str, default='data/moving-mnist-example/moving-mnist-valid.npz') 24 | parser.add_argument('--save_dir', type=str, default='checkpoints/mnist_predrnn') 25 | parser.add_argument('--gen_frm_dir', type=str, default='results/mnist_predrnn') 26 | parser.add_argument('--input_length', type=int, default=10) 27 | parser.add_argument('--total_length', type=int, default=20) 28 | parser.add_argument('--img_width', type=int, default=64) 29 | parser.add_argument('--img_channel', type=int, default=1) 30 | 31 | # model 32 | parser.add_argument('--model_name', type=str, default='predrnn') 33 | parser.add_argument('--pretrained_model', type=str, default='') 34 | parser.add_argument('--num_hidden', type=str, default='64,64,64,64') 35 | parser.add_argument('--filter_size', type=int, default=5) 36 | parser.add_argument('--stride', type=int, default=1) 37 | parser.add_argument('--patch_size', type=int, default=4) 38 | parser.add_argument('--layer_norm', type=int, default=1) 39 | parser.add_argument('--decouple_beta', type=float, default=0.1) 40 | 41 | # reverse scheduled sampling 42 | parser.add_argument('--reverse_scheduled_sampling', type=int, default=0) 43 | parser.add_argument('--r_sampling_step_1', type=float, default=25000) 44 | parser.add_argument('--r_sampling_step_2', type=int, default=50000) 45 | parser.add_argument('--r_exp_alpha', type=int, default=5000) 46 | # scheduled sampling 47 | parser.add_argument('--scheduled_sampling', type=int, default=1) 48 | parser.add_argument('--sampling_stop_iter', type=int, default=50000) 49 | parser.add_argument('--sampling_start_value', type=float, default=1.0) 50 | parser.add_argument('--sampling_changing_rate', type=float, default=0.00002) 51 | 52 | # optimization 53 | parser.add_argument('--lr', type=float, default=0.001) 54 | parser.add_argument('--reverse_input', type=int, default=1) 55 | parser.add_argument('--batch_size', type=int, default=8) 56 | parser.add_argument('--max_iterations', type=int, default=80000) 57 | parser.add_argument('--display_interval', type=int, default=100) 58 | parser.add_argument('--test_interval', type=int, default=5000) 59 | parser.add_argument('--snapshot_interval', type=int, default=5000) 60 | parser.add_argument('--num_save_samples', type=int, default=10) 61 | parser.add_argument('--n_gpu', type=int, default=1) 62 | 63 | # visualization of memory decoupling 64 | parser.add_argument('--visual', type=int, default=0) 65 | parser.add_argument('--visual_path', type=str, default='./decoupling_visual') 66 | 67 | # action-based predrnn 68 | parser.add_argument('--injection_action', type=str, default='concat') 69 | parser.add_argument('--conv_on_input', type=int, default=0, help='conv on input') 70 | parser.add_argument('--res_on_conv', type=int, default=0, help='res on conv') 71 | parser.add_argument('--num_action_ch', type=int, default=4, help='num action ch') 72 | 73 | args = parser.parse_args() 74 | print(args) 75 | 76 | 77 | def reserve_schedule_sampling_exp(itr): 78 | if itr < args.r_sampling_step_1: 79 | r_eta = 0.5 80 | elif itr < args.r_sampling_step_2: 81 | r_eta = 1.0 - 0.5 * math.exp(-float(itr - args.r_sampling_step_1) / args.r_exp_alpha) 82 | else: 83 | r_eta = 1.0 84 | 85 | if itr < args.r_sampling_step_1: 86 | eta = 0.5 87 | elif itr < args.r_sampling_step_2: 88 | eta = 0.5 - (0.5 / (args.r_sampling_step_2 - args.r_sampling_step_1)) * (itr - args.r_sampling_step_1) 89 | else: 90 | eta = 0.0 91 | 92 | r_random_flip = np.random.random_sample( 93 | (args.batch_size, args.input_length - 1)) 94 | r_true_token = (r_random_flip < r_eta) 95 | 96 | random_flip = np.random.random_sample( 97 | (args.batch_size, args.total_length - args.input_length - 1)) 98 | true_token = (random_flip < eta) 99 | 100 | ones = np.ones((args.img_width // args.patch_size, 101 | args.img_width // args.patch_size, 102 | args.patch_size ** 2 * args.img_channel)) 103 | zeros = np.zeros((args.img_width // args.patch_size, 104 | args.img_width // args.patch_size, 105 | args.patch_size ** 2 * args.img_channel)) 106 | 107 | real_input_flag = [] 108 | for i in range(args.batch_size): 109 | for j in range(args.total_length - 2): 110 | if j < args.input_length - 1: 111 | if r_true_token[i, j]: 112 | real_input_flag.append(ones) 113 | else: 114 | real_input_flag.append(zeros) 115 | else: 116 | if true_token[i, j - (args.input_length - 1)]: 117 | real_input_flag.append(ones) 118 | else: 119 | real_input_flag.append(zeros) 120 | 121 | real_input_flag = np.array(real_input_flag) 122 | real_input_flag = np.reshape(real_input_flag, 123 | (args.batch_size, 124 | args.total_length - 2, 125 | args.img_width // args.patch_size, 126 | args.img_width // args.patch_size, 127 | args.patch_size ** 2 * args.img_channel)) 128 | return real_input_flag 129 | 130 | 131 | def schedule_sampling(eta, itr): 132 | zeros = np.zeros((args.batch_size, 133 | args.total_length - args.input_length - 1, 134 | args.img_width // args.patch_size, 135 | args.img_width // args.patch_size, 136 | args.patch_size ** 2 * args.img_channel)) 137 | if not args.scheduled_sampling: 138 | return 0.0, zeros 139 | 140 | if itr < args.sampling_stop_iter: 141 | eta -= args.sampling_changing_rate 142 | else: 143 | eta = 0.0 144 | random_flip = np.random.random_sample( 145 | (args.batch_size, args.total_length - args.input_length - 1)) 146 | true_token = (random_flip < eta) 147 | ones = np.ones((args.img_width // args.patch_size, 148 | args.img_width // args.patch_size, 149 | args.patch_size ** 2 * args.img_channel)) 150 | zeros = np.zeros((args.img_width // args.patch_size, 151 | args.img_width // args.patch_size, 152 | args.patch_size ** 2 * args.img_channel)) 153 | real_input_flag = [] 154 | for i in range(args.batch_size): 155 | for j in range(args.total_length - args.input_length - 1): 156 | if true_token[i, j]: 157 | real_input_flag.append(ones) 158 | else: 159 | real_input_flag.append(zeros) 160 | real_input_flag = np.array(real_input_flag) 161 | real_input_flag = np.reshape(real_input_flag, 162 | (args.batch_size, 163 | args.total_length - args.input_length - 1, 164 | args.img_width // args.patch_size, 165 | args.img_width // args.patch_size, 166 | args.patch_size ** 2 * args.img_channel)) 167 | return eta, real_input_flag 168 | 169 | 170 | def train_wrapper(model): 171 | if args.pretrained_model: 172 | model.load(args.pretrained_model) 173 | # load data 174 | train_input_handle, test_input_handle = datasets_factory.data_provider( 175 | args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, 176 | seq_length=args.total_length, injection_action=args.injection_action, is_training=True) 177 | 178 | eta = args.sampling_start_value 179 | 180 | for itr in range(1, args.max_iterations + 1): 181 | if train_input_handle.no_batch_left(): 182 | train_input_handle.begin(do_shuffle=True) 183 | ims = train_input_handle.get_batch() 184 | ims = preprocess.reshape_patch(ims, args.patch_size) 185 | 186 | if args.reverse_scheduled_sampling == 1: 187 | real_input_flag = reserve_schedule_sampling_exp(itr) 188 | else: 189 | eta, real_input_flag = schedule_sampling(eta, itr) 190 | 191 | trainer.train(model, ims, real_input_flag, args, itr) 192 | 193 | if itr % args.snapshot_interval == 0: 194 | model.save(itr) 195 | 196 | if itr % args.test_interval == 0: 197 | trainer.test(model, test_input_handle, args, itr) 198 | 199 | train_input_handle.next() 200 | 201 | 202 | def test_wrapper(model): 203 | model.load(args.pretrained_model) 204 | test_input_handle = datasets_factory.data_provider( 205 | args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, 206 | seq_length=args.total_length, injection_action=args.injection_action, is_training=False) 207 | trainer.test(model, test_input_handle, args, 'test_result') 208 | 209 | 210 | if os.path.exists(args.save_dir): 211 | shutil.rmtree(args.save_dir) 212 | os.makedirs(args.save_dir) 213 | 214 | if os.path.exists(args.gen_frm_dir): 215 | shutil.rmtree(args.gen_frm_dir) 216 | os.makedirs(args.gen_frm_dir) 217 | 218 | print('Initializing models') 219 | 220 | model = Model(args) 221 | 222 | if args.is_training: 223 | train_wrapper(model) 224 | else: 225 | test_wrapper(model) 226 | --------------------------------------------------------------------------------