├── docs ├── figure_1.png └── figure_2.png ├── requirements.txt ├── models ├── __init__.py ├── ResNetAutoEncoder.py ├── criterion.py └── submodules.py ├── utils ├── misc.py ├── __init__.py ├── read_BAIR_tfrecords.py ├── convert_tf_pretrained.py ├── metrics.py ├── position_encoding.py ├── pre_processing.py ├── fvd.py ├── pytorch_i3d.py └── train_summary.py ├── LICENSE ├── train_AutoEncoder_lightning.py ├── train_Predictor_lightning.py ├── configs ├── config_KTH_VFI_NPVP-S.yaml ├── config_KTH_Autoencoder.yaml ├── config_KTH_VFI_NPVP-D.yaml ├── config_KTH_VFP_NPVP-D.yaml ├── config_KTH_VFP_NPVP-S.yaml ├── config_KTH_Unified_NPVP-D.yaml ├── config_KTH_Unified_NPVP-S.yaml ├── config_BAIR_VFP_NPVP-D.yaml ├── config_BAIR_VFP_NPVP-S.yaml ├── config_SMMNIST_Autoencoder.yaml ├── config_SMMNIST_VFI_NPVP-D.yaml ├── config_SMMNIST_VFI_NPVP-S.yaml ├── config_SMMNIST_VFP_NPVP-D.yaml ├── config_SMMNIST_VFP_NPVP-S.yaml ├── config_BAIR_Autoencoder.yaml ├── config_KITTI_Autoencoder.yaml ├── config_KITTI_VFP_NPVP-D.yaml ├── config_KITTI_VFP_NPVP-S.yaml ├── config_Cityscapes_VFP_NPVP-D.yaml ├── config_Cityscapes_VFP_NPVP-S.yaml └── config_Cityscapes_Autoencoder.yaml └── README.md /docs/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiYe20/NPVP/HEAD/docs/figure_1.png -------------------------------------------------------------------------------- /docs/figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiYe20/NPVP/HEAD/docs/figure_2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.2 2 | hydra-core==1.2.0 3 | numpy 4 | opencv-python-headless 5 | Pillow 6 | 7 | pytorch-lightning==1.6.3 8 | tensorflow==2.7.0 9 | timm==0.5.4 10 | torch==1.9.0 11 | torchaudio==0.9.0 12 | torchmetrics==0.8.2 13 | torchvision==0.10.0 14 | tqdm==4.64.0 15 | setuptools==58.2.0 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterion import GDL, MSELoss, BiPatchNCE, L1Loss, GANLoss, TemporalDiff, Div_KL, GradientPanelty 2 | from .ResNetAutoEncoder import ResnetEncoder, ResnetDecoder, LitAE 3 | from .VidHRFormer import VidHRformerDecoderNAR, VidHRFormerEncoder 4 | from .submodules import CoorGenerator, NRMLP, PosFeatFuser, FutureFrameQueryGenerator, EventEncoder 5 | from .Predictor import Predictor, Discriminator, LitPredictor -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from torch import Tensor 6 | from typing import Optional 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | class NestedTensor(object): 15 | def __init__(self, tensors, mask: Optional[Tensor]): 16 | self.tensors = tensors 17 | self.mask = mask 18 | 19 | def to(self, device): 20 | # type: (Device) -> NestedTensor # noqa 21 | cast_tensor = self.tensors.to(device) 22 | mask = self.mask 23 | if mask is not None: 24 | assert mask is not None 25 | cast_mask = mask.to(device) 26 | else: 27 | cast_mask = None 28 | return NestedTensor(cast_tensor, cast_mask) 29 | 30 | def decompose(self): 31 | return self.tensors, self.mask 32 | 33 | def __repr__(self): 34 | return str(self.tensors) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import KTHDataset, VidCenterCrop, VidPad, VidResize, BAIRDataset, VidCrop, MovingMNISTDataset, ClipDataset, KITTIDataset 2 | from .dataset import VidRandomHorizontalFlip, VidRandomVerticalFlip, StochasticMovingMNIST, mean_std_compute 3 | from .dataset import VidToTensor, VidNormalize, VidReNormalize, get_dataloader, LitDataModule 4 | from .misc import NestedTensor, set_seed 5 | from .train_summary import save_ckpt, load_ckpt, init_loss_dict, write_summary, resume_training, write_code_files, show_AE_samples, show_predictor_samples 6 | from .train_summary import visualize_batch_clips, parameters_count, AverageMeters, init_loss_dict, write_summary, BatchAverageMeter, gather_AverageMeters 7 | from .train_summary import read_code_files, VisCallbackAE, VisCallbackPredictor, save_code_cfg 8 | from .metrics import PSNR, SSIM, pred_ave_metrics, MSEScore 9 | from .position_encoding import PositionEmbeddding2D, PositionEmbeddding1D, PositionEmbeddding3D 10 | from .fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shawn Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train_AutoEncoder_lightning.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | from pytorch_lightning import loggers as pl_loggers 4 | 5 | from models import LitAE 6 | from utils import LitDataModule, VisCallbackAE, save_code_cfg 7 | 8 | import hydra 9 | from hydra import compose, initialize 10 | from omegaconf import DictConfig, OmegaConf 11 | import argparse 12 | from pathlib import Path 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 16 | parser.add_argument('--config_path', type=str, required=True, help='Path to configuration file') 17 | args = parser.parse_args() 18 | return args.config_path 19 | 20 | #@hydra.main(version_base=None, config_path=".", config_name="config") 21 | def main(cfg : DictConfig) -> None: 22 | #save the code and config 23 | #save_code_cfg(cfg, cfg.AE.ckpt_save_dir) 24 | 25 | pl.seed_everything(cfg.Env.rand_seed, workers=True) 26 | #init model and dataloader 27 | data_module = LitDataModule(cfg) 28 | AE = LitAE(cfg) 29 | 30 | #init logger and all callbacks 31 | checkpoint_callback = ModelCheckpoint(dirpath=cfg.AE.ckpt_save_dir, every_n_epochs = cfg.AE.log_per_epochs, 32 | save_top_k= 50, monitor = 'L1_loss_valid', filename= "AE-{epoch:02d}") 33 | #callbacks = [VisCallbackAE(), checkpoint_callback] 34 | if cfg.Env.visual_callback: 35 | callbacks = [VisCallbackAE(), checkpoint_callback] 36 | else: 37 | callbacks = [checkpoint_callback] 38 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.AE.tensorboard_save_dir) 39 | 40 | trainer = pl.Trainer(accelerator="gpu", devices=cfg.Env.world_size, 41 | max_epochs=cfg.AE.epochs, enable_progress_bar=True, sync_batchnorm=True, 42 | callbacks = callbacks, logger=tb_logger, strategy = cfg.Env.strategy) 43 | trainer.fit(AE, data_module, ckpt_path=cfg.AE.resume_ckpt) 44 | 45 | if __name__ == '__main__': 46 | config_path = Path(parse_args()) 47 | initialize(version_base=None, config_path=str(config_path.parent)) 48 | cfg = compose(config_name=str(config_path.name)) 49 | 50 | main(cfg) -------------------------------------------------------------------------------- /train_Predictor_lightning.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | from pytorch_lightning.callbacks import Callback 4 | from pytorch_lightning import loggers as pl_loggers 5 | from pytorch_lightning.plugins import PrecisionPlugin 6 | 7 | from models import LitPredictor 8 | from utils import LitDataModule, VisCallbackPredictor 9 | 10 | import hydra 11 | from hydra import compose, initialize 12 | from omegaconf import DictConfig, OmegaConf 13 | import argparse 14 | from pathlib import Path 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 18 | parser.add_argument('--config_path', type=str, required=True, help='Path to configuration file') 19 | args = parser.parse_args() 20 | return args.config_path 21 | 22 | #@hydra.main(version_base=None, config_path=".", config_name="config") 23 | def main(cfg : DictConfig) -> None: 24 | #save the code and config 25 | #save_code_cfg(cfg, cfg.Predictor.ckpt_save_dir) 26 | 27 | pl.seed_everything(cfg.Env.rand_seed, workers=True) 28 | #init model and dataloader 29 | data_module = LitDataModule(cfg) 30 | predictor = LitPredictor(cfg) 31 | 32 | #init logger and all callbacks 33 | checkpoint_callback = ModelCheckpoint(dirpath=cfg.Predictor.ckpt_save_dir, every_n_epochs = cfg.Predictor.log_per_epochs, 34 | save_top_k= cfg.Predictor.epochs, monitor = 'loss_val', filename= "Predictor-{epoch:02d}") 35 | if cfg.Env.visual_callback: 36 | callbacks = [VisCallbackPredictor(), checkpoint_callback] 37 | else: 38 | callbacks = [checkpoint_callback] 39 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.Predictor.tensorboard_save_dir) 40 | trainer = pl.Trainer(accelerator="gpu", devices=cfg.Env.world_size, 41 | max_epochs=cfg.Predictor.epochs, enable_progress_bar=True, sync_batchnorm=True, 42 | callbacks = callbacks, logger=tb_logger, strategy = cfg.Env.strategy) 43 | if cfg.Predictor.init_det_ckpt_for_vae is not None and cfg.Predictor.resume_ckpt is None: 44 | predictor = predictor.load_from_checkpoint(cfg = cfg, checkpoint_path=cfg.Predictor.init_det_ckpt_for_vae, 45 | strict = False) 46 | 47 | trainer.fit(predictor, data_module) 48 | else: 49 | trainer.fit(predictor, data_module, ckpt_path=cfg.Predictor.resume_ckpt) 50 | 51 | if __name__ == '__main__': 52 | config_path = Path(parse_args()) 53 | initialize(version_base=None, config_path=str(config_path.parent)) 54 | cfg = compose(config_name=str(config_path.name)) 55 | 56 | main(cfg) -------------------------------------------------------------------------------- /utils/read_BAIR_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | import shutil 6 | import os 7 | from tqdm import tqdm 8 | 9 | #Requirements: tensorflow 2.6.0 10 | def read_BAIR_tf2_record(records_dir, save_dir): 11 | """ 12 | Args: 13 | record_file: string for the BAIR tf record file path 14 | Returns: 15 | imgs: The images saved in the input record_file 16 | """ 17 | ORIGINAL_HEIGHT = 64 18 | ORIGINAL_WIDTH = 64 19 | COLOR_CHAN = 3 20 | 21 | records_path = Path(records_dir) 22 | tf_record_files = sorted(list(records_path.glob('*.tfrecords'))) 23 | dataset = tf.data.TFRecordDataset(tf_record_files) 24 | 25 | pgbar = tqdm(total = 256*len(tf_record_files), desc = 'Processing...') 26 | for example_id, example in enumerate(dataset): 27 | example_dir = Path(save_dir).joinpath(f'example_{example_id}') 28 | if example_dir.exists(): 29 | shutil.rmtree(example_dir.absolute()) 30 | example_dir.mkdir(parents=True, exist_ok=True) 31 | 32 | for i in range(0, 30): 33 | image_main_name = str(i) + '/image_main/encoded' 34 | image_aux1_name = str(i) + '/image_aux1/encoded' 35 | 36 | features = {image_aux1_name: tf.io.FixedLenFeature([1], tf.string), 37 | image_main_name: tf.io.FixedLenFeature([1], tf.string)} 38 | 39 | features = tf.io.parse_single_example(example, features=features) 40 | 41 | image_aux1 = tf.io.decode_raw(features[image_aux1_name], tf.uint8) 42 | image_aux1 = tf.reshape(image_aux1, shape=[1, ORIGINAL_HEIGHT * ORIGINAL_WIDTH * COLOR_CHAN]) 43 | image_aux1 = tf.reshape(image_aux1, shape=[ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN]) 44 | 45 | frame_name = example_dir.joinpath(f'{i:04n}.png') 46 | 47 | frame = Image.fromarray(image_aux1.numpy(), 'RGB') 48 | frame.save(frame_name.absolute().as_posix()) 49 | #frame = tf.image.encode_png(image_aux1) 50 | #with open(frame_name.absolute().as_posix(), 'wb') as f: 51 | # f.write(frame) 52 | pgbar.update(1) 53 | 54 | def resize_im(features, image_name, conf, height = None): 55 | COLOR_CHAN = 3 56 | if '128x128' in conf: 57 | ORIGINAL_WIDTH = 128 58 | ORIGINAL_HEIGHT = 128 59 | IMG_WIDTH = 128 60 | IMG_HEIGHT = 128 61 | elif height != None: 62 | ORIGINAL_WIDTH = height 63 | ORIGINAL_HEIGHT = height 64 | IMG_WIDTH = height 65 | IMG_HEIGHT = height 66 | else: 67 | ORIGINAL_WIDTH = 64 68 | ORIGINAL_HEIGHT = 64 69 | IMG_WIDTH = 64 70 | IMG_HEIGHT = 64 71 | 72 | image = tf.decode_raw(features[image_name], tf.uint8) 73 | image = tf.reshape(image, shape=[1, ORIGINAL_HEIGHT * ORIGINAL_WIDTH * COLOR_CHAN]) 74 | image = tf.reshape(image, shape=[ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN]) 75 | if IMG_HEIGHT != IMG_WIDTH: 76 | raise ValueError('Unequal height and width unsupported') 77 | crop_size = min(ORIGINAL_HEIGHT, ORIGINAL_WIDTH) 78 | image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size) 79 | image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN]) 80 | image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH]) 81 | image = tf.cast(image, tf.float32) / 255.0 82 | 83 | return image 84 | 85 | if __name__ == '__main__': 86 | """ 87 | read_BAIR_tf2_record('/store/travail/xiyex/BAIR/softmotion30_44k/test/traj_0_to_255.tfrecords', 88 | '/store/travail/xiyex/BAIR/softmotion30_44k/test') 89 | """ 90 | 91 | read_BAIR_tf2_record('/store/travail/xiyex/BAIR/softmotion30_44k/train', '/store/travail/xiyex/BAIR/softmotion30_44k/train') 92 | -------------------------------------------------------------------------------- /configs/config_KTH_VFI_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: True #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_Autoencoder.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_VFI_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: True #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_VFP_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_VFP_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_Unified_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_Unified_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_Unified_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: True #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_KTH_Unified_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KTH" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KTH" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 10 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 10 17 | test_num_future_frames: 20 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KTH_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KTH_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KTH_Predictor_Unified_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/KTH_Predictor_Unified_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KTH_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: True #use random context for the learning (i.e.,for unified model) 51 | min_lo: 4 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 16 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 10 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 20 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_BAIR_VFP_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "BAIR" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../BAIR" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/BAIR_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/BAIR_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_BAIR_VFP_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "BAIR" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../BAIR" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/BAIR_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_NPVP_S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/BAIR_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_SMMNIST_Autoencoder.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "SMMNIST" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../SMMNIST" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 5 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 5 17 | test_num_future_frames: 10 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Sigmoid' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/SMMNIST_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 5 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 15 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_SMMNIST_VFI_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "SMMNIST" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../SMMNIST" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 5 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 5 17 | test_num_future_frames: 10 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Sigmoid' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFI_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFI_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/SMMNIST_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 5 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: True #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 15 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_SMMNIST_VFI_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "SMMNIST" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../SMMNIST" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 5 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 5 17 | test_num_future_frames: 10 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Sigmoid' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFI_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFI_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/SMMNIST_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 5 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: True #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 15 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_SMMNIST_VFP_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "SMMNIST" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../SMMNIST" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 5 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 5 17 | test_num_future_frames: 10 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Sigmoid' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/SMMNIST_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 5 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 15 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_SMMNIST_VFP_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "SMMNIST" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../SMMNIST" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 1 14 | num_past_frames: 5 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 5 17 | test_num_future_frames: 10 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Sigmoid' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/SMMNIST_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/SMMNIST_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 600 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 5 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 5 #number of past frames 56 | context_num_f: 5 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 15 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf -------------------------------------------------------------------------------- /configs/config_BAIR_Autoencoder.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "BAIR" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../BAIR" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/BAIR_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 64 31 | n_downsampling: 3 32 | num_res_blocks: 2 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_stochastic_4to5" 41 | tensorboard_save_dir: "./NPVP_ckpts/BAIR_Predictor_VFP_stochastic_4to5_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/BAIR_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_KITTI_Autoencoder.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KITTI" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KITTI_Processed" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 4 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 5 16 | test_num_past_frames: 4 17 | test_num_future_frames: 5 18 | batch_size: 16 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KITTI_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5" 41 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KITTI_ResnetAE/AE-epoch=499.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 3 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 6 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 9 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_KITTI_VFP_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KITTI" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KITTI_Processed" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 4 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 5 16 | test_num_past_frames: 4 17 | test_num_future_frames: 5 18 | batch_size: 16 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KITTI_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5" 41 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KITTI_ResnetAE/AE-epoch=499.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 3 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 6 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 9 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_KITTI_VFP_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "KITTI" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "./KITTI_Processed" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 4 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 5 16 | test_num_past_frames: 4 17 | test_num_future_frames: 5 18 | batch_size: 16 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/KITTI_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5" 41 | tensorboard_save_dir: "./NPVP_ckpts/KITTI_Predictor_VFP_stochastic_4to5_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/KITTI_ResnetAE/AE-epoch=499.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 3 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 6 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 5 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 9 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_Cityscapes_VFP_NPVP-D.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "CityScapes" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'CityScapes', 'KITTI' 10 | dir: "../CityScapes" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_Predictor_VFP_NPVP-D" 41 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_VFP_NPVP-D_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/Cityscapes_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: False #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: False 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_Cityscapes_VFP_NPVP-S.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "CityScapes" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'CityScapes', 'KITTI' 10 | dir: "../CityScapes" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_Predictor_VFP_NPVP-S" 41 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_Predictor_VFP_NPVP-S_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/Cityscapes_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-8 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/config_Cityscapes_Autoencoder.yaml: -------------------------------------------------------------------------------- 1 | Env: 2 | world_size: 1 3 | rand_seed: 3047 4 | port: '12355' 5 | strategy: 'ddp_find_unused_parameters_false' 6 | visual_callback: True ###!!!!!Set this to be False for Multi-GPU training, otherwise the training would stuck 7 | 8 | Dataset: 9 | name: "CityScapes" #Name of the dataset, 'KTH', 'SMMNIST', 'BAIR', 'Cityscapes', 'KITTI' 10 | dir: "../CityScapes" #Dataset Folder 11 | dev_set_size: null #number of examples for dev set 12 | num_workers: 16 13 | img_channels: 3 14 | num_past_frames: 2 #For all experiments, we take a video sample with the length of (num_past_frames + num_future_frames) from dataset 15 | num_future_frames: 10 16 | test_num_past_frames: 2 17 | test_num_future_frames: 28 18 | batch_size: 8 19 | phase: 'deploy' #'debug' phase, split train/val; 'deploy' phase, no val set 20 | 21 | #Configuration for the autoencoder 22 | AE: 23 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE" #autoencoder checkpoinrt save dir 24 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_ResnetAE_tensorboard" 25 | resume_ckpt: null #null or path string for the resume checkpoint 26 | start_epoch: 0 27 | 28 | epochs: 500 29 | AE_lr: 1e-4 30 | ngf: 32 31 | n_downsampling: 4 32 | num_res_blocks: 3 33 | out_layer: 'Tanh' #'Tanh' for all datasets, except SMMNIST; for SMMNIST, set to be 'Sigmoid' 34 | learn_3d: False #if True, violates permutation invariant 35 | 36 | log_per_epochs: 2 #training log frequency 37 | 38 | #Configuration for the NP-based predictor 39 | Predictor: 40 | ckpt_save_dir: "./NPVP_ckpts/Cityscapes_Predictor_VFP_stochastic_4to5" 41 | tensorboard_save_dir: "./NPVP_ckpts/Cityscapes_Predictor_VFP_stochastic_4to5_tensorboard" 42 | resume_ckpt: null #null or path string for the resume checkpoint 43 | init_det_ckpt_for_vae: null #null or path string for a trained deterministic model, which serves as the initialization of stochastic model 44 | resume_AE_ckpt: "./NPVP_ckpts/Cityscapes_ResnetAE/AE-epoch=299.ckpt" #path string for the trained autoencoder in stage one. 45 | start_epoch: 0 46 | 47 | epochs: 500 48 | log_per_epochs: 5 #training log frequency 49 | 50 | rand_context: False #use random context for the learning (i.e.,for unified model) 51 | min_lo: 2 #Minimum length of the observed clip, not used if rand_context is False 52 | max_lo: 10 #maximum length of the observed clip, not used if rand_context is False 53 | 54 | VFI: False #video frame interpolation training mode 55 | context_num_p: 2 #number of past frames 56 | context_num_f: 2 #number of future frames 57 | num_interpolate: 8 #number of frames to interpolate, context_num_p + context_num_f + num_interpolate == cfg.Dataset.num_past_frames + cfg.Dataset.num_future_frames 58 | 59 | max_H: 8 #Height for the frame visual feature extracted by the CNN encoder 60 | max_W: 8 #Width for the frame visual feature extracted by the CNN encoder 61 | max_T: 12 #!! equals to (num_past_frames + num_future_frames) in the Dataset configuration 62 | 63 | embed_dim: 512 #Channels for the frame visual feature extracted by the CNN encoder 64 | fuse_method: 'Add' 65 | param_free_norm_type: 'layer' 66 | evt_former: True #if use VidHRFormerEncoder to learn event coding (other than mean) 67 | evt_former_num_layers: 4 #number of Transformer block for event encoding VidHRFormerEncoder, not used if evt_former is False 68 | evt_hidden_channels: 256 #number of channels for event coding 69 | stochastic: True #True for NPVP-S (stochastic), False for NPVP-D (deterministic) 70 | transformer_layers: 8 #number of Transformer block for Transformer decoder 71 | 72 | predictor_lr: 1e-4 73 | max_grad_norm: 1.0 74 | use_cosine_scheduler: True 75 | scheduler_eta_min: 1e-7 76 | scheduler_T0: 150 #Epochs for each cycle of cosine learning rate schedule 77 | 78 | lam_PF_L1: 0.01 #weight for the predicted feature l1 loss 79 | KL_beta: 1e-6 #1e-6 for SMMNIST and BAIR, 1e-8 for all other dataset 80 | 81 | use_gan: False #GAN loss is Deprecated, not used in the experiments. 82 | lam_gan: 0.001 83 | ndf: 64 #Discriminator ndf 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Alt text](./docs/figure_1.png?raw=true "Overall Framework") 2 | ![Alt text](./docs/figure_2.png?raw=true "NPVP") 3 | 4 | # NPVP: A unified model for continuous conditional video prediction 5 | https://openaccess.thecvf.com/content/CVPR2023W/Precognition/html/Ye_A_Unified_Model_for_Continuous_Conditional_Video_Prediction_CVPRW_2023_paper.html 6 | 7 | ### Preparing Datasets 8 | Processed KTH dataset: https://drive.google.com/file/d/1RbJyGrYdIp4ROy8r0M-lLAbAMxTRQ-sd/view?usp=sharing \ 9 | SM-MNIST: https://drive.google.com/file/d/1eSpXRojBjvE4WoIgeplUznFyRyI3X64w/view?usp=drive_link 10 | 11 | For other datasets, please download them from the official website. Here we show the dataset folder structure. 12 | 13 | #### BAIR 14 | Please download the original BAIR dataset and utilize the "/utils/read_BAIR_tfrecords.py" script to convert it into frames as follows: 15 | 16 | /BAIR \ 17 |      test/ \ 18 |          example_0/ \ 19 |             0000.png \ 20 |             0001.png \ 21 |             ... \ 22 |          example_1/ \ 23 |             0000.png \ 24 |             0001.png \ 25 |             ... \ 26 |          example_... \ 27 |      train/ \ 28 |          example_0/ \ 29 |             0000.png \ 30 |             0001.png \ 31 |             ... \ 32 |          example_... 33 | 34 | #### Cityscapes 35 | Please download "leftImg8bit_sequence_trainvaltest.zip" from the official website. Center crop and resize all the frames to the size of 128X128. Save all the frames as follows: 36 | 37 | /Cityscapes \ 38 |      test/ \ 39 |          berlin/ \ 40 |             berlin_000000_000000_leftImg8bit.png \ 41 |             berlin_000000_000001_leftImg8bit.png \ 42 |             ... \ 43 |          bielefeld/ \ 44 |             bielefeld_000000_000302_leftImg8bit.png \ 45 |             bielefeld_000000_000302_leftImg8bit.png \ 46 |             ... \ 47 |          ... \ 48 |      train/\ 49 |          aachen/ \ 50 |             .... \ 51 |          bochum/ \ 52 |             .... \ 53 |          ... \ 54 |      val/\ 55 |             .... 56 | 57 | #### KITTI 58 | Please download the raw data (synced+rectified) from KITTI official website. Center crop and resize all the frames to the resolution of 128X128. 59 | Save all the frames as follows: 60 | 61 | /KITTI \ 62 |      2011_09_26_drive_0001_sync/ \ 63 |             0000000000.png \ 64 |             0000000001.png \ 65 |             ... \ 66 |      2011_09_26_drive_0002_sync/ \ 67 |             ... \ 68 |       ... 69 | 70 | ## Training 71 | ### Stage 1: CNN autoencoder training 72 | Train the autoencoder and save the checkpoint. Configuration files for Stage 1 training are located in the "./configs" directory, with filenames ending in "*_Autoencoder.yaml". Before training, review the configuration file and adjust the dataset directory, checkpoint saving, TensorBoard log saving, etc., as needed. 73 | 74 | Usage example: 75 | ``` 76 | python train_AutoEncoder_lightning.py --config_path ./configs/config_KTH_Autoencoder.yaml 77 | ``` 78 | 79 | ### Stage 2: NP-based Predictor training 80 | 81 | With a trained Autoencoder from stage 1, we can load it for the training of the NP-based Predictor in stage 2. Configuration files for Stage 2 training are located in the "./configs" directory, with filenames ending in "*_NPVP-D.yaml" or "_NPVP-S.yaml". Prior to training, review the configuration file and adjust the dataset directory, checkpoint saving, TensorBoard log saving, etc., according to your specific requirements. 82 | 83 | Usage example: 84 | ``` 85 | python train_Predictor_lightning.py --config_path ./configs/config_KTH_Unified_NPVP-S.yaml 86 | ``` 87 | 88 | ## Inference 89 | Please read the inference.ipynb for the inference example of a KTH unified model. 90 | 91 | Step 1: Download the process KTH dataset 92 | 93 | Step 2: Download the Autoencoder checkpoint: https://drive.google.com/drive/folders/1eji1SxfT8do8TnWNPZqmhuOqxQZuaEpo?usp=sharing 94 | 95 | Step 3: Download the Unified_NPVP-S checkpoint: https://drive.google.com/drive/folders/1knqw-KuWDSx6E-tG8jiOEG1G3BYMJJIf?usp=sharing 96 | 97 | Step 4: Read and run "inference.ipynb". 98 | 99 | ### Citing 100 | 101 | Please cite the paper if you find our work is helpful. 102 | ``` 103 | @inproceedings{ye2023unified, 104 | title={A Unified Model for Continuous Conditional Video Prediction}, 105 | author={Ye, Xi and Bilodeau, Guillaume-Alexandre}, 106 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 107 | pages={3603--3612}, 108 | year={2023} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from typing import Union 6 | from .train_summary import load_ckpt 7 | 8 | import numpy as np 9 | from math import exp 10 | 11 | 12 | def PSNR(x: Tensor, y: Tensor, data_range: Union[float, int] = 1.0, mean_flag: bool = True) -> Tensor: 13 | """ 14 | Comput the average PSNR between two batch of images. 15 | x: input image, Tensor with shape (N, C, H, W) 16 | y: input image, Tensor with shape (N, C, H, W) 17 | data_range: the maximum pixel value range of input images, used to normalize 18 | pixel values to [0,1], default is 1.0 19 | """ 20 | 21 | EPS = 1e-8 22 | x = x/float(data_range) 23 | y = y/float(data_range) 24 | 25 | mse = torch.mean((x-y)**2, dim = (1, 2, 3)) 26 | score = -10*torch.log10(mse + EPS) 27 | if mean_flag: 28 | return torch.mean(score).item() 29 | else: 30 | return score 31 | 32 | def MSEScore(x: Tensor, y: Tensor, mean_flag: bool = True) -> Tensor: 33 | """ 34 | Comput the average PSNR between two batch of images. 35 | x: input image, Tensor with shape (N, C, H, W) 36 | y: input image, Tensor with shape (N, C, H, W) 37 | data_range: the maximum pixel value range of input images, used to normalize 38 | pixel values to [0,1], default is 1.0 39 | """ 40 | mse = torch.sum((x-y)**2, dim = (1, 2, 3)) 41 | if mean_flag: 42 | return torch.mean(mse).item() 43 | else: 44 | return mse 45 | 46 | 47 | class SSIM(torch.nn.Module): 48 | def __init__(self, window_size = 11): 49 | super(SSIM, self).__init__() 50 | self.window_size = window_size 51 | self.channel = 1 52 | self.window = self.create_window(window_size, self.channel) 53 | self.__name__ = 'SSIM' 54 | 55 | def forward(self, img1: Tensor, img2: Tensor, mean_flag: bool = True) -> float: 56 | """ 57 | img1: (N, C, H, W) 58 | img2: (N, C, H, W) 59 | Return: 60 | batch average ssim_index: float 61 | """ 62 | (_, channel, _, _) = img1.size() 63 | 64 | if channel == self.channel and self.window.data.type() == img1.data.type(): 65 | window = self.window 66 | else: 67 | window = self.create_window(self.window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | self.window = window 74 | self.channel = channel 75 | 76 | 77 | return self._ssim(img1, img2, window, self.window_size, channel, mean_flag) 78 | 79 | def gaussian(self, window_size, sigma): 80 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 81 | return gauss/gauss.sum() 82 | 83 | def create_window(self, window_size, channel): 84 | _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) 85 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 86 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 87 | 88 | return window 89 | 90 | def _ssim(self, img1, img2, window, window_size, channel, mean_flag): 91 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 92 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 93 | 94 | mu1_sq = mu1.pow(2) 95 | mu2_sq = mu2.pow(2) 96 | mu1_mu2 = mu1*mu2 97 | 98 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 99 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 100 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 101 | 102 | C1 = 0.01**2 103 | C2 = 0.03**2 104 | 105 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 106 | if mean_flag: 107 | return ssim_map.mean() 108 | else: 109 | return torch.mean(ssim_map, dim=(1,2,3)) 110 | 111 | def pred_ave_metrics(model, data_loader, metric_func, renorm_transform, num_future_frames, ckpt = None, device = 'cuda:0'): 112 | if ckpt is not None: 113 | _, _, _, VPTR_state_dict, _, _ = load_ckpt(ckpt) 114 | model.load_state_dict(VPTR_state_dict) 115 | model = model.eval() 116 | ave_metric = np.zeros(num_future_frames) 117 | sample_num = 0 118 | 119 | with torch.no_grad(): 120 | for idx, sample in enumerate(data_loader, 0): 121 | past_frames, future_frames = sample 122 | past_frames = past_frames.to(device) 123 | future_frames = future_frames.to(device) 124 | mask = None 125 | pred = model(past_frames,future_frames, mask)[0] 126 | 127 | for i in range(num_future_frames): 128 | pred_t = pred[:, i, ...] 129 | future_frames_t = future_frames[:, i, ...] 130 | 131 | renorm_pred = renorm_transform(pred_t) 132 | renorm_future_frames = renorm_transform(future_frames_t) 133 | 134 | m = metric_func(renorm_pred, renorm_future_frames)*pred_t.shape[0] 135 | ave_metric[i] += m 136 | 137 | sample_num += pred.shape[0] 138 | 139 | ave_metric = ave_metric / sample_num 140 | return ave_metric 141 | 142 | if __name__ == '__main__': 143 | ssim = SSIM() 144 | 145 | random_img1 = torch.randn(4, 3, 256, 256) 146 | random_img2 = torch.randn(4, 3, 256, 256) 147 | ssim_index = ssim(random_img1, random_img2, mean_flag = False) 148 | print(ssim_index) 149 | 150 | psnr = PSNR(random_img1, random_img2, mean_flag = False) 151 | print(psnr) 152 | """ 153 | import torchvision.transforms as transforms 154 | from PIL import Image 155 | 156 | img1 = transforms.ToTensor()(Image.open('./einstein.png').convert('L')) 157 | img1 = img1.unsqueeze(0) 158 | 159 | img2 = img1.clone() 160 | ssim_index = ssim(img1, img2) 161 | print(ssim_index) 162 | 163 | ssim_index = ssim(img1, torch.randn(1, 1, 256, 256)) 164 | print(ssim_index) 165 | """ -------------------------------------------------------------------------------- /utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import numpy as np 5 | from utils.misc import NestedTensor 6 | 7 | """ 8 | 1D position encoding and 2D postion encoding 9 | The code is modified based on DETR of Facebook: 10 | https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 11 | """ 12 | 13 | class PositionEmbeddding1D(nn.Module): 14 | """ 15 | 1D position encoding 16 | Based on Attetion is all you need paper and DETR PositionEmbeddingSine class 17 | """ 18 | def __init__(self, temperature = 10000, normalize = False, scale = None): 19 | super().__init__() 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, L: int, N: int, E: int): 30 | """ 31 | Args: 32 | L for length, N for batch size, E for embedding size (dimension of transformer). 33 | 34 | Returns: 35 | pos: position encoding, with shape [L, N, E] 36 | """ 37 | pos_embed = torch.ones(N, L, dtype = torch.float32).cumsum(axis = 1) 38 | dim_t = torch.arange(E, dtype = torch.float32) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / E) 40 | if self.normalize: 41 | eps = 1e-6 42 | pos_embed = pos_embed / (L + eps) * self.scale 43 | 44 | pos_embed = pos_embed[:, :, None] / dim_t 45 | pos_embed = torch.stack((pos_embed[:, :, 0::2].sin(), pos_embed[:, :, 1::2].cos()), dim = 3).flatten(2) 46 | pos_embed = pos_embed.permute(1, 0, 2) 47 | pos_embed.requires_grad_(False) 48 | 49 | return pos_embed 50 | 51 | class PositionEmbeddding2D(nn.Module): 52 | """ 53 | 2D position encoding, borrowed from DETR PositionEmbeddingSine class 54 | https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 55 | """ 56 | def __init__(self, temperature=10000, normalize=False, scale=None, device = torch.device('cuda:0')): 57 | super().__init__() 58 | self.temperature = temperature 59 | self.normalize = normalize 60 | self.device = device 61 | if scale is not None and normalize is False: 62 | raise ValueError("normalize should be True if scale is passed") 63 | if scale is None: 64 | scale = 2 * math.pi 65 | self.scale = scale 66 | 67 | def forward(self, N: int, E: int, H: int, W: int): 68 | """ 69 | Args: 70 | N for batch size, E for embedding size (channel of feature), H for height, W for width 71 | 72 | Returns: 73 | pos_embed: positional encoding with shape (N, E, H, W) 74 | """ 75 | assert E % 2 == 0, "Embedding size should be even number" 76 | 77 | y_embed = torch.ones(N, H, W, dtype=torch.float32, device = self.device).cumsum(dim = 1) 78 | x_embed = torch.ones(N, H, W, dtype=torch.float32, device = self.device).cumsum(dim = 2) 79 | if self.normalize: 80 | eps = 1e-6 81 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 82 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 83 | 84 | dim_t = torch.arange(E//2, dtype=torch.float32, device=self.device) 85 | dim_t = self.temperature ** (2 * (dim_t // 2) / (E//2)) 86 | 87 | pos_x = x_embed[:, :, :, None] / dim_t 88 | pos_y = y_embed[:, :, :, None] / dim_t 89 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 90 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 91 | pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 92 | pos_embed.requires_grad_(False) 93 | return pos_embed 94 | 95 | 96 | class PositionEmbeddding3D(nn.Module): 97 | """ 98 | 3D position encoding 99 | """ 100 | def __init__(self, E, T, temperature=10000, normalize=False, scale=None, device = torch.device('cuda:0')): 101 | """ 102 | E: embedding size, i.e. pos feature length 103 | T: video clip length 104 | """ 105 | super().__init__() 106 | self.E = E 107 | self.T = T 108 | self.temperature = temperature 109 | self.normalize = normalize 110 | self.device = device 111 | if scale is not None and normalize is False: 112 | raise ValueError("normalize should be True if scale is passed") 113 | if scale is None: 114 | scale = 2 * math.pi 115 | self.scale = scale 116 | 117 | def forward(self, tensorlist: NestedTensor): 118 | """ 119 | Args: 120 | tensorlist: NestedTensor which includes feature maps X and corresponding mask 121 | X: tensor with shape (N*T, C, H, W), N for batch size, C for channel of features, T for time, H for height, W for width 122 | mask: None or tensor with shape (N*T, H, W) 123 | Returns: 124 | pos_embed: positional encoding with shape (N, E, T, H, W) 125 | """ 126 | NT, C, H, W= tensorlist.tensors.shape 127 | N = NT//self.T 128 | mask = tensorlist.mask 129 | assert self.E % 3 == 0, "Embedding size should be divisible by 3" 130 | 131 | if mask is None: 132 | t_embed = torch.ones(N, self.T, H, W, dtype = torch.float32, device = self.device).cumsum(dim = 1) 133 | y_embed = torch.ones(N, self.T, H, W, dtype=torch.float32, device = self.device).cumsum(dim = 2) 134 | x_embed = torch.ones(N, self.T, H, W, dtype=torch.float32, device = self.device).cumsum(dim = 3) 135 | else: 136 | mask = mask.reshape(N, self.T, H, W) 137 | #binary mask, 1 for the image area, 0 for the padding area 138 | t_embed = mask.cumsum(dim = 1, dtype = torch.float32).to(self.device) 139 | y_embed = mask.cumsum(dim = 2, dtype = torch.float32).to(self.device) 140 | x_embed = mask.cumsum(dim = 3, dtype = torch.float32).to(self.device) 141 | if self.normalize: 142 | eps = 1e-6 143 | t_embed = t_embed / (t_embed[:, :-1, :, :] + eps) * self.scale 144 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 145 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 146 | 147 | dim_t = torch.arange(self.E//3, dtype=torch.float32, device=self.device) 148 | dim_t = self.temperature ** (2 * (dim_t // 2) / (self.E//3)) 149 | 150 | pos_t = t_embed[:, :, :, :, None] / dim_t 151 | pos_x = x_embed[:, :, :, :, None] / dim_t 152 | pos_y = y_embed[:, :, :, :, None] / dim_t 153 | 154 | pos_t = torch.stack((pos_t[:, :, :, :, 0::2].sin(), pos_t[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 155 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 156 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 157 | 158 | pos_embed = torch.cat((pos_t, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3) 159 | pos_embed.requires_grad_(False) 160 | 161 | return pos_embed -------------------------------------------------------------------------------- /utils/pre_processing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | import shutil 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import os 8 | import cv2 9 | """ 10 | import detectron2.data.transforms as T 11 | from detectron2.data import ( 12 | MetadataCatalog, 13 | build_detection_test_loader, 14 | build_detection_train_loader, 15 | ) 16 | from detectron2.checkpoint import DetectionCheckpointer 17 | from detectron2.modeling import build_model 18 | from detectron2.config import get_cfg 19 | import torch 20 | from detectron2 import model_zoo 21 | """ 22 | 23 | from itertools import groupby 24 | from operator import itemgetter 25 | from tqdm import tqdm 26 | import zipfile 27 | """ 28 | ffmpeg command: 29 | ffmpeg -i person01_boxing_d1_uncomp.avi -s 64x64 test%03d.png 30 | ffmpeg -i test.avi out.mp4 31 | """ 32 | 33 | def unzip(file_dir, output_dir): 34 | with zipfile.ZipFile(file_dir, 'r') as zip_ref: 35 | zip_ref.extractall(output_dir) 36 | 37 | def vid2frames(dataset_dir, root_dir, frame_size = None): 38 | videos_path = Path(dataset_dir) 39 | vid_file_list = list(videos_path.glob("*.mp4")) 40 | print(len(vid_file_list)) 41 | for vid in vid_file_list: 42 | dir_str = '' 43 | for s in vid.name.strip().split('.')[:-1]: 44 | dir_str += s 45 | dir_str = dir_str.replace(" ", "") 46 | img_dir = Path(root_dir).joinpath(Path(dir_str)) 47 | 48 | if img_dir.exists(): 49 | shutil.rmtree(img_dir.absolute()) 50 | 51 | img_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | if frame_size is not None: 54 | command = ['ffmpeg', '-i', vid.absolute().as_posix(), '-f', 'image2', '-s', f'{frame_size}x{frame_size}', f'{img_dir.absolute().as_posix()}/image_%04d_{frame_size}x{frame_size}.png'] 55 | else: 56 | #command = ['ffmpeg', '-i', vid.absolute().as_posix(), '-f', 'image2', f'{img_dir.absolute().as_posix()}/image_%04d.png'] 57 | command = ['ffmpeg', '-i', vid.absolute().as_posix(), '-r', '10', '-s', f'64x64', f'{img_dir.absolute().as_posix()}/image_%04d.png'] 58 | 59 | process = subprocess.Popen(command, stdout = subprocess.PIPE) 60 | 61 | def frames2vid(frames_dir, out_file, frame_size, fps = 25): 62 | frames_path = Path(frames_dir) 63 | out_file_path = Path(out_file) 64 | 65 | command = ['ffmpeg', '-f', 'image2', '-r', f'{fps}', '-i', f'{frames_path.absolute().as_posix()}/image_%04d_{frame_size}x{frame_size}.png', 66 | '-vcodec', 'libx264', '-pix_fmt', 'yuv420p', f'{out_file_path.absolute().as_posix()}'] 67 | """ 68 | command = ['ffmpeg', '-f', 'image2', '-r', f'{fps}', '-i', f'{frames_path.absolute().as_posix()}/image_%04d.png', 69 | '-vcodec', 'libx264', '-pix_fmt', 'yuv420p', f'{out_file_path.absolute().as_posix()}'] 70 | """ 71 | process = subprocess.Popen(command, stdout = subprocess.PIPE) 72 | 73 | def subsample(frames_dir, factor = 5): 74 | frames_path = Path(frames_dir) 75 | frame_list = sorted(list(frames_path.glob(f'*.png'))) 76 | keep_list = frame_list[::factor] 77 | delete_list = [f for f in frame_list if f not in keep_list] 78 | for f in delete_list: 79 | f.unlink() 80 | 81 | class detectron_detector(object): 82 | """ 83 | Custom batch predicting function based on original detectron2.DefaultPredictor 84 | Ref: https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/defaults.py 85 | """ 86 | def __init__(self, cfg): 87 | self.cfg = cfg.clone() 88 | self.model = build_model(self.cfg) 89 | self.model.eval() 90 | self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) 91 | 92 | checkpointer = DetectionCheckpointer(self.model) 93 | checkpointer.load(cfg.MODEL.WEIGHTS) 94 | 95 | self.transform_gen = T.ResizeShortestEdge( 96 | [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST 97 | ) 98 | 99 | self.input_format = cfg.INPUT.FORMAT 100 | assert self.input_format in ["RGB", "BGR"], self.input_format 101 | 102 | def __call__(self, input_img): 103 | """ 104 | input_img --- numpy array with shape (H, W, C) 105 | """ 106 | with torch.no_grad(): 107 | # Apply pre-processing to image. 108 | if self.input_format == "RGB": 109 | # whether the model expects BGR inputs or RGB 110 | input_img = input_img[:, :, ::-1] 111 | height, width, _ = input_img.shape 112 | 113 | image= self.transform_gen.get_transform(input_img).apply_image(input_img) 114 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 115 | inputs = {"image": image, "height": height, "width": width} 116 | 117 | predictions = self.model([inputs]) 118 | 119 | return predictions[0] 120 | 121 | def human_detector(original_frames_path, save_dir, clip_length = 20): 122 | """ 123 | Remove the empty KTH frames. 124 | original_frames_path: path for the original frames 125 | save_dir: save directory for the frames with human detected 126 | """ 127 | all_vid_files = sorted(list(Path(original_frames_path).absolute().glob('*.avi'))) 128 | #each person has 4 videos 129 | cfg = get_cfg() 130 | cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")) 131 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml") 132 | cfg.INPUT.FORMAT = 'RGB' 133 | detector = detectron_detector(cfg) 134 | frames_folder = [] 135 | 136 | for f in all_vid_files: 137 | person_id = int(str(f.name).strip().split('_')[0][-2:]) 138 | frame_folder = f.parent.joinpath(f.name.strip().split('.')[0]) 139 | img_files = sorted(list(frame_folder.glob('*'))) 140 | 141 | scores = [] 142 | pgbar = tqdm(total = len(img_files), desc = f'Detecting {f}') 143 | for img_path in img_files: 144 | img = Image.open(img_path.absolute().as_posix()).convert('RGB') 145 | img = np.asarray(img) 146 | preds = detector(img) 147 | scores.append(preds["instances"].scores.cpu().numpy()) 148 | #print(preds["instances"].pred_classes, preds["instances"].scores) 149 | pgbar.update(1) 150 | 151 | #label the frame without human as -1, frame with human as frame id 152 | human_frame_ids = [] 153 | for i in range(len(scores)): 154 | s = scores[i] 155 | if len(s) > 0: 156 | if s[0] > 0.5: #person class id of COCO-dectection dataset is 0 157 | human_frame_ids.append(i) 158 | else: 159 | human_frame_ids.append(-1) 160 | else: 161 | human_frame_ids.append(-1) 162 | 163 | #get the subset of all conscutive frames with human detected 164 | consecutive_frames = [] 165 | for k, g in groupby(enumerate(human_frame_ids), lambda x: x[0] - x[1]): 166 | cf = list(map(itemgetter(1), g)) 167 | if len(cf) >= clip_length: 168 | consecutive_frames.append(cf) 169 | 170 | #copy all the frames with human detected to the new folder 171 | if not Path(save_dir).exists(): 172 | Path(save_dir).mkdir(parents=True, exist_ok=True) 173 | for idx, cf in enumerate(consecutive_frames): 174 | new_folder = Path(save_dir).joinpath(frame_folder.name + f'_no_empty_{idx}') 175 | for f_id in cf: 176 | src = img_files[f_id] 177 | if not Path(new_folder).exists(): 178 | Path(new_folder).mkdir(parents=True, exist_ok=True) 179 | shutil.copy(src.absolute().as_posix(), new_folder) 180 | 181 | def process_cityscapes(seq_dir, out_dir, img_size = (128, 128)): 182 | """ 183 | Center crop the original frames and then resize them to be img_size, save into the out_dir 184 | Args: 185 | seq_dir: the extracted sequence dir of the cityscapes, i.e., leftimg8bit_sequence 186 | out_dir: the directory to save the processed frames 187 | img_size: the desired size 188 | """ 189 | def center_crop(image): 190 | h, w, c = image.shape 191 | new_h, new_w = h if h < w else w, w if w < h else h 192 | r_min, r_max = h//2 - new_h//2, h//2 + new_h//2 193 | c_min, c_max = w//2 - new_w//2, w//2 + new_w//2 194 | return image[r_min:r_max, c_min:c_max, :] 195 | 196 | 197 | seq_path = Path(seq_dir).absolute() 198 | out_path = Path(out_dir).absolute() 199 | frames_folders = os.listdir(seq_path) 200 | frames_folders = [seq_path.joinpath(s) for s in frames_folders] 201 | out_frames_folders = [out_path.joinpath(s.name) for s in frames_folders] 202 | 203 | 204 | for idx, folder in enumerate(frames_folders): 205 | img_folders = [folder.joinpath(s) for s in os.listdir(folder)] 206 | out_img_folders = [out_frames_folders[idx].joinpath(Path(s).name) for s in os.listdir(folder)] 207 | for idx, img_folder in enumerate(img_folders): 208 | img_files = sorted(list(img_folder.glob('*'))) 209 | out_img_files = [out_img_folders[idx].joinpath(Path(s).name) for s in img_files] 210 | for idx, f in enumerate(img_files): 211 | frame = cv2.imread(f.absolute().as_posix()) 212 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 213 | img_cc = center_crop(img) 214 | pil_im = Image.fromarray(img_cc) 215 | pil_im_rsz = pil_im.resize(img_size, Image.LANCZOS) 216 | try: 217 | pil_im_rsz.save(out_img_files[idx]) 218 | except Exception as err: 219 | parent_dir = out_img_files[idx].parent 220 | parent_dir.mkdir(parents=True, exist_ok=True) 221 | pil_im_rsz.save(out_img_files[idx]) 222 | 223 | 224 | 225 | 226 | 227 | 228 | if __name__ == '__main__': 229 | #unzip('boxing.zip', './boxing') 230 | #vid2frames('/store/travail/xiyex/KTH/boxing', None) 231 | #frames2vid('/home/travail/xiyex/VideoFramePrediction/dataset/KTH/walking/person01_walking_d1_uncomp', './out.mp4', 64) 232 | """ 233 | consecutive_frames = human_detector('/store/travail/xiyex/KTH/jogging', 234 | '/store/travail/xiyex/KTH/jogging_no_empty') 235 | """ 236 | #vid2frames('/store/travail/xiyex/Human36Moriginal/train/S11/Videos', '/store/travail/xiyex/Human36M/train/S11') 237 | seq_dir = '/home/travail/xiyex/cityscapes/seqs/leftImg8bit_sequence' 238 | out_dir = '/home/travail/xiyex/cityscapes/processed_seq' 239 | process_cityscapes(seq_dir, out_dir, img_size = (128, 128)) -------------------------------------------------------------------------------- /utils/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import os.path as osp 5 | import math 6 | import torch.nn.functional as F 7 | 8 | # try: 9 | # from torchvision.models.utils import load_state_dict_from_url 10 | # except ImportError: 11 | # from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | # i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" 14 | 15 | # def load_i3d_pretrained(device=torch.device('cpu')): 16 | # from .pytorch_i3d import InceptionI3d 17 | # i3d = InceptionI3d(400, in_channels=3).to(device) 18 | # try: # can't access internet from compute canada, so need a local version 19 | # filepath = 'models/i3d_pretrained_400.pt' 20 | # i3d.load_state_dict(torch.load(filepath, map_location=device)) 21 | # except: 22 | # state_dict = load_state_dict_from_url(i3D_WEIGHTS_URL, progress=True, map_location=device) 23 | # i3d.load_state_dict(state_dict) 24 | # i3d = torch.nn.DataParallel(i3d) 25 | # i3d.eval() 26 | # return i3d 27 | 28 | 29 | # https://github.com/universome/fvd-comparison 30 | i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" 31 | 32 | def load_i3d_pretrained(device=torch.device('cpu')): 33 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') 34 | if not os.path.exists(filepath): 35 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") 36 | i3d = torch.jit.load(filepath).eval().to(device) 37 | i3d = torch.nn.DataParallel(i3d) 38 | return i3d 39 | 40 | 41 | def get_feats(videos, detector, device, bs=10): 42 | # videos : torch.tensor BCTHW [0, 1] 43 | detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. 44 | feats = np.empty((0, 400)) 45 | device = torch.device("cuda:0") if device is not torch.device("cpu") else device 46 | with torch.no_grad(): 47 | for i in range((len(videos)-1)//bs + 1): 48 | temp = torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device) 49 | feats = np.vstack([feats, detector(temp, **detector_kwargs).detach().cpu().numpy()]) 50 | return feats 51 | 52 | 53 | def get_fvd_feats(videos, i3d, device, bs=10): 54 | # videos in [0, 1] as torch tensor BCTHW 55 | # videos = [preprocess_single(video) for video in videos] 56 | embeddings = get_feats(videos, i3d, device, bs) 57 | return embeddings 58 | 59 | # """ 60 | # Copy-pasted from Copy-pasted from https://github.com/NVlabs/stylegan2-ada-pytorch 61 | # """ 62 | 63 | # import ctypes 64 | # import fnmatch 65 | # import importlib 66 | # import inspect 67 | # import numpy as np 68 | # import os 69 | # import shutil 70 | # import sys 71 | # import types 72 | # import io 73 | # import pickle 74 | # import re 75 | # import requests 76 | # import html 77 | # import hashlib 78 | # import glob 79 | # import tempfile 80 | # import urllib 81 | # import urllib.request 82 | # import uuid 83 | 84 | # from distutils.util import strtobool 85 | # from typing import Any, List, Tuple, Union, Dict 86 | 87 | # def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: 88 | # """Download the given URL and return a binary-mode file object to access the data.""" 89 | # assert num_attempts >= 1 90 | 91 | # # Doesn't look like an URL scheme so interpret it as a local filename. 92 | # if not re.match('^[a-z]+://', url): 93 | # return url if return_filename else open(url, "rb") 94 | 95 | # # Handle file URLs. This code handles unusual file:// patterns that 96 | # # arise on Windows: 97 | # # 98 | # # file:///c:/foo.txt 99 | # # 100 | # # which would translate to a local '/c:/foo.txt' filename that's 101 | # # invalid. Drop the forward slash for such pathnames. 102 | # # 103 | # # If you touch this code path, you should test it on both Linux and 104 | # # Windows. 105 | # # 106 | # # Some internet resources suggest using urllib.request.url2pathname() but 107 | # # but that converts forward slashes to backslashes and this causes 108 | # # its own set of problems. 109 | # if url.startswith('file://'): 110 | # filename = urllib.parse.urlparse(url).path 111 | # if re.match(r'^/[a-zA-Z]:', filename): 112 | # filename = filename[1:] 113 | # return filename if return_filename else open(filename, "rb") 114 | 115 | # url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 116 | 117 | # # Download. 118 | # url_name = None 119 | # url_data = None 120 | # with requests.Session() as session: 121 | # if verbose: 122 | # print("Downloading %s ..." % url, end="", flush=True) 123 | # for attempts_left in reversed(range(num_attempts)): 124 | # try: 125 | # with session.get(url) as res: 126 | # res.raise_for_status() 127 | # if len(res.content) == 0: 128 | # raise IOError("No data received") 129 | 130 | # if len(res.content) < 8192: 131 | # content_str = res.content.decode("utf-8") 132 | # if "download_warning" in res.headers.get("Set-Cookie", ""): 133 | # links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 134 | # if len(links) == 1: 135 | # url = requests.compat.urljoin(url, links[0]) 136 | # raise IOError("Google Drive virus checker nag") 137 | # if "Google Drive - Quota exceeded" in content_str: 138 | # raise IOError("Google Drive download quota exceeded -- please try again later") 139 | 140 | # match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 141 | # url_name = match[1] if match else url 142 | # url_data = res.content 143 | # if verbose: 144 | # print(" done") 145 | # break 146 | # except KeyboardInterrupt: 147 | # raise 148 | # except: 149 | # if not attempts_left: 150 | # if verbose: 151 | # print(" failed") 152 | # raise 153 | # if verbose: 154 | # print(".", end="", flush=True) 155 | 156 | # # Return data as file object. 157 | # assert not return_filename 158 | # return io.BytesIO(url_data) 159 | 160 | 161 | def preprocess_single(video, resolution=224, sequence_length=None): 162 | # video: CTHW, [0, 1] 163 | c, t, h, w = video.shape 164 | 165 | # temporal crop 166 | if sequence_length is not None: 167 | assert sequence_length <= t 168 | video = video[:, :sequence_length] 169 | 170 | # scale shorter side to resolution 171 | scale = resolution / min(h, w) 172 | if h < w: 173 | target_size = (resolution, math.ceil(w * scale)) 174 | else: 175 | target_size = (math.ceil(h * scale), resolution) 176 | video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) 177 | 178 | # center crop 179 | c, t, h, w = video.shape 180 | w_start = (w - resolution) // 2 181 | h_start = (h - resolution) // 2 182 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 183 | 184 | # [0, 1] -> [-1, 1] 185 | video = (video - 0.5) * 2 186 | return video.contiguous() 187 | 188 | 189 | def get_logits(i3d, videos, device): 190 | #assert videos.shape[0] % 2 == 0 191 | logits = torch.empty(0, 400) 192 | with torch.no_grad(): 193 | for i in range(len(videos)): 194 | # logits.append(i3d(preprocess_single(videos[i]).unsqueeze(0).to(device)).detach().cpu()) 195 | logits = torch.vstack([logits, i3d(preprocess_single(videos[i]).unsqueeze(0).to(device)).detach().cpu()]) 196 | # logits = torch.cat(logits, dim=0) 197 | return logits 198 | 199 | 200 | def get_fvd_logits(videos, i3d, device): 201 | # videos in [0, 1] as torch tensor BCTHW 202 | # videos = [preprocess_single(video) for video in videos] 203 | embeddings = get_logits(i3d, videos, device) 204 | return embeddings 205 | 206 | 207 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 208 | def _symmetric_matrix_square_root(mat, eps=1e-10): 209 | u, s, v = torch.linalg.svd(mat) 210 | si = torch.where(s < eps, s, torch.sqrt(s)) 211 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 212 | 213 | 214 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 215 | def trace_sqrt_product(sigma, sigma_v): 216 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 217 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 218 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 219 | 220 | 221 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 222 | def cov(m, rowvar=False): 223 | '''Estimate a covariance matrix given data. 224 | 225 | Covariance indicates the level to which two variables vary together. 226 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 227 | then the covariance matrix element `C_{ij}` is the covariance of 228 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 229 | 230 | Args: 231 | m: A 1-D or 2-D array containing multiple variables and observations. 232 | Each row of `m` represents a variable, and each column a single 233 | observation of all those variables. 234 | rowvar: If `rowvar` is True, then each row represents a 235 | variable, with observations in the columns. Otherwise, the 236 | relationship is transposed: each column represents a variable, 237 | while the rows contain observations. 238 | 239 | Returns: 240 | The covariance matrix of the variables. 241 | ''' 242 | if m.dim() > 2: 243 | raise ValueError('m has more than 2 dimensions') 244 | if m.dim() < 2: 245 | m = m.view(1, -1) 246 | if not rowvar and m.size(0) != 1: 247 | m = m.t() 248 | 249 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 250 | m -= torch.mean(m, dim=1, keepdim=True) 251 | mt = m.t() # if complex: mt = m.t().conj() 252 | return fact * m.matmul(mt).squeeze() 253 | 254 | 255 | # def frechet_distance(x1, x2): 256 | # x1 = x1.flatten(start_dim=1) 257 | # x2 = x2.flatten(start_dim=1) 258 | # m, m_w = x1.mean(dim=0), x2.mean(dim=0) 259 | # sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 260 | # sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 261 | # trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 262 | # mean = torch.sum((m - m_w) ** 2) 263 | # fd = trace + mean 264 | # return fd 265 | 266 | 267 | """ 268 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py 269 | """ 270 | from typing import Tuple 271 | from scipy.linalg import sqrtm 272 | import numpy as np 273 | 274 | 275 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 276 | mu = feats.mean(axis=0) # [d] 277 | sigma = np.cov(feats, rowvar=False) # [d, d] 278 | return mu, sigma 279 | 280 | 281 | def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 282 | mu_gen, sigma_gen = compute_stats(feats_fake) 283 | mu_real, sigma_real = compute_stats(feats_real) 284 | m = np.square(mu_gen - mu_real).sum() 285 | s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 286 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 287 | return float(fid) 288 | -------------------------------------------------------------------------------- /utils/pytorch_i3d.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/piergiaj/pytorch-i3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class MaxPool3dSamePadding(nn.MaxPool3d): 8 | 9 | def compute_pad(self, dim, s): 10 | if s % self.stride[dim] == 0: 11 | return max(self.kernel_size[dim] - self.stride[dim], 0) 12 | else: 13 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 14 | 15 | def forward(self, x): 16 | # compute 'same' padding 17 | (batch, channel, t, h, w) = x.size() 18 | out_t = np.ceil(float(t) / float(self.stride[0])) 19 | out_h = np.ceil(float(h) / float(self.stride[1])) 20 | out_w = np.ceil(float(w) / float(self.stride[2])) 21 | pad_t = self.compute_pad(0, t) 22 | pad_h = self.compute_pad(1, h) 23 | pad_w = self.compute_pad(2, w) 24 | 25 | pad_t_f = pad_t // 2 26 | pad_t_b = pad_t - pad_t_f 27 | pad_h_f = pad_h // 2 28 | pad_h_b = pad_h - pad_h_f 29 | pad_w_f = pad_w // 2 30 | pad_w_b = pad_w - pad_w_f 31 | 32 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 33 | x = F.pad(x, pad) 34 | return super(MaxPool3dSamePadding, self).forward(x) 35 | 36 | 37 | class Unit3D(nn.Module): 38 | 39 | def __init__(self, in_channels, 40 | output_channels, 41 | kernel_shape=(1, 1, 1), 42 | stride=(1, 1, 1), 43 | padding=0, 44 | activation_fn=F.relu, 45 | use_batch_norm=True, 46 | use_bias=False, 47 | name='unit_3d'): 48 | 49 | """Initializes Unit3D module.""" 50 | super(Unit3D, self).__init__() 51 | 52 | self._output_channels = output_channels 53 | self._kernel_shape = kernel_shape 54 | self._stride = stride 55 | self._use_batch_norm = use_batch_norm 56 | self._activation_fn = activation_fn 57 | self._use_bias = use_bias 58 | self.name = name 59 | self.padding = padding 60 | 61 | self.conv3d = nn.Conv3d(in_channels=in_channels, 62 | out_channels=self._output_channels, 63 | kernel_size=self._kernel_shape, 64 | stride=self._stride, 65 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 66 | bias=self._use_bias) 67 | 68 | if self._use_batch_norm: 69 | self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) 70 | 71 | def compute_pad(self, dim, s): 72 | if s % self._stride[dim] == 0: 73 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 74 | else: 75 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 76 | 77 | 78 | def forward(self, x): 79 | # compute 'same' padding 80 | (batch, channel, t, h, w) = x.size() 81 | out_t = np.ceil(float(t) / float(self._stride[0])) 82 | out_h = np.ceil(float(h) / float(self._stride[1])) 83 | out_w = np.ceil(float(w) / float(self._stride[2])) 84 | pad_t = self.compute_pad(0, t) 85 | pad_h = self.compute_pad(1, h) 86 | pad_w = self.compute_pad(2, w) 87 | 88 | pad_t_f = pad_t // 2 89 | pad_t_b = pad_t - pad_t_f 90 | pad_h_f = pad_h // 2 91 | pad_h_b = pad_h - pad_h_f 92 | pad_w_f = pad_w // 2 93 | pad_w_b = pad_w - pad_w_f 94 | 95 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 96 | x = F.pad(x, pad) 97 | 98 | x = self.conv3d(x) 99 | if self._use_batch_norm: 100 | x = self.bn(x) 101 | if self._activation_fn is not None: 102 | x = self._activation_fn(x) 103 | return x 104 | 105 | 106 | 107 | class InceptionModule(nn.Module): 108 | def __init__(self, in_channels, out_channels, name): 109 | super(InceptionModule, self).__init__() 110 | 111 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 112 | name=name+'/Branch_0/Conv3d_0a_1x1') 113 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 114 | name=name+'/Branch_1/Conv3d_0a_1x1') 115 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 116 | name=name+'/Branch_1/Conv3d_0b_3x3') 117 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 118 | name=name+'/Branch_2/Conv3d_0a_1x1') 119 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 120 | name=name+'/Branch_2/Conv3d_0b_3x3') 121 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 122 | stride=(1, 1, 1), padding=0) 123 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 124 | name=name+'/Branch_3/Conv3d_0b_1x1') 125 | self.name = name 126 | 127 | def forward(self, x): 128 | b0 = self.b0(x) 129 | b1 = self.b1b(self.b1a(x)) 130 | b2 = self.b2b(self.b2a(x)) 131 | b3 = self.b3b(self.b3a(x)) 132 | return torch.cat([b0,b1,b2,b3], dim=1) 133 | 134 | 135 | class InceptionI3d(nn.Module): 136 | """Inception-v1 I3D architecture. 137 | The model is introduced in: 138 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 139 | Joao Carreira, Andrew Zisserman 140 | https://arxiv.org/pdf/1705.07750v1.pdf. 141 | See also the Inception architecture, introduced in: 142 | Going deeper with convolutions 143 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 144 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 145 | http://arxiv.org/pdf/1409.4842v1.pdf. 146 | """ 147 | 148 | # Endpoints of the model in order. During construction, all the endpoints up 149 | # to a designated `final_endpoint` are returned in a dictionary as the 150 | # second return value. 151 | VALID_ENDPOINTS = ( 152 | 'Conv3d_1a_7x7', 153 | 'MaxPool3d_2a_3x3', 154 | 'Conv3d_2b_1x1', 155 | 'Conv3d_2c_3x3', 156 | 'MaxPool3d_3a_3x3', 157 | 'Mixed_3b', 158 | 'Mixed_3c', 159 | 'MaxPool3d_4a_3x3', 160 | 'Mixed_4b', 161 | 'Mixed_4c', 162 | 'Mixed_4d', 163 | 'Mixed_4e', 164 | 'Mixed_4f', 165 | 'MaxPool3d_5a_2x2', 166 | 'Mixed_5b', 167 | 'Mixed_5c', 168 | 'Logits', 169 | 'Predictions', 170 | ) 171 | 172 | def __init__(self, num_classes=400, spatial_squeeze=True, 173 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): 174 | """Initializes I3D model instance. 175 | Args: 176 | num_classes: The number of outputs in the logit layer (default 400, which 177 | matches the Kinetics dataset). 178 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 179 | before returning (default True). 180 | final_endpoint: The model contains many possible endpoints. 181 | `final_endpoint` specifies the last endpoint for the model to be built 182 | up to. In addition to the output at `final_endpoint`, all the outputs 183 | at endpoints up to `final_endpoint` will also be returned, in a 184 | dictionary. `final_endpoint` must be one of 185 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 186 | name: A string (optional). The name of this module. 187 | Raises: 188 | ValueError: if `final_endpoint` is not recognized. 189 | """ 190 | 191 | if final_endpoint not in self.VALID_ENDPOINTS: 192 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 193 | 194 | super(InceptionI3d, self).__init__() 195 | self._num_classes = num_classes 196 | self._spatial_squeeze = spatial_squeeze 197 | self._final_endpoint = final_endpoint 198 | self.logits = None 199 | 200 | if self._final_endpoint not in self.VALID_ENDPOINTS: 201 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 202 | 203 | self.end_points = {} 204 | end_point = 'Conv3d_1a_7x7' 205 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 206 | stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) 207 | if self._final_endpoint == end_point: return 208 | 209 | end_point = 'MaxPool3d_2a_3x3' 210 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 211 | padding=0) 212 | if self._final_endpoint == end_point: return 213 | 214 | end_point = 'Conv3d_2b_1x1' 215 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 216 | name=name+end_point) 217 | if self._final_endpoint == end_point: return 218 | 219 | end_point = 'Conv3d_2c_3x3' 220 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 221 | name=name+end_point) 222 | if self._final_endpoint == end_point: return 223 | 224 | end_point = 'MaxPool3d_3a_3x3' 225 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 226 | padding=0) 227 | if self._final_endpoint == end_point: return 228 | 229 | end_point = 'Mixed_3b' 230 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) 231 | if self._final_endpoint == end_point: return 232 | 233 | end_point = 'Mixed_3c' 234 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) 235 | if self._final_endpoint == end_point: return 236 | 237 | end_point = 'MaxPool3d_4a_3x3' 238 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), 239 | padding=0) 240 | if self._final_endpoint == end_point: return 241 | 242 | end_point = 'Mixed_4b' 243 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) 244 | if self._final_endpoint == end_point: return 245 | 246 | end_point = 'Mixed_4c' 247 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) 248 | if self._final_endpoint == end_point: return 249 | 250 | end_point = 'Mixed_4d' 251 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) 252 | if self._final_endpoint == end_point: return 253 | 254 | end_point = 'Mixed_4e' 255 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) 256 | if self._final_endpoint == end_point: return 257 | 258 | end_point = 'Mixed_4f' 259 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) 260 | if self._final_endpoint == end_point: return 261 | 262 | end_point = 'MaxPool3d_5a_2x2' 263 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), 264 | padding=0) 265 | if self._final_endpoint == end_point: return 266 | 267 | end_point = 'Mixed_5b' 268 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) 269 | if self._final_endpoint == end_point: return 270 | 271 | end_point = 'Mixed_5c' 272 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'Logits' 276 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], 277 | stride=(1, 1, 1)) 278 | self.dropout = nn.Dropout(dropout_keep_prob) 279 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 280 | kernel_shape=[1, 1, 1], 281 | padding=0, 282 | activation_fn=None, 283 | use_batch_norm=False, 284 | use_bias=True, 285 | name='logits') 286 | 287 | self.build() 288 | 289 | 290 | def replace_logits(self, num_classes): 291 | self._num_classes = num_classes 292 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 293 | kernel_shape=[1, 1, 1], 294 | padding=0, 295 | activation_fn=None, 296 | use_batch_norm=False, 297 | use_bias=True, 298 | name='logits') 299 | 300 | 301 | def build(self): 302 | for k in self.end_points.keys(): 303 | self.add_module(k, self.end_points[k]) 304 | 305 | def forward(self, x): 306 | for end_point in self.VALID_ENDPOINTS: 307 | if end_point in self.end_points: 308 | x = self._modules[end_point](x) # use _modules to work with dataparallel 309 | 310 | x = self.logits(self.dropout(self.avg_pool(x))) 311 | if self._spatial_squeeze: 312 | logits = x.squeeze(3).squeeze(3) 313 | logits = logits.mean(dim=2) 314 | # logits is batch X time X classes, which is what we want to work with 315 | return logits 316 | 317 | 318 | def extract_features(self, x): 319 | for end_point in self.VALID_ENDPOINTS: 320 | if end_point in self.end_points: 321 | x = self._modules[end_point](x) 322 | return self.avg_pool(x) 323 | -------------------------------------------------------------------------------- /models/ResNetAutoEncoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import functools 4 | from torch.nn import init 5 | import pytorch_lightning as pl 6 | from .criterion import L1Loss 7 | """ 8 | Modified based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 9 | """ 10 | 11 | from .submodules import NonLocalAttenion2D, NonLocalAttenion1D, Factorized3DConvAttn 12 | 13 | class LitAE(pl.LightningModule): 14 | def __init__(self, cfg): 15 | super().__init__() 16 | self.VPTR_Enc = ResnetEncoder(cfg.Dataset.img_channels, ngf=cfg.AE.ngf, n_downsampling = cfg.AE.n_downsampling, 17 | num_res_blocks = cfg.AE.num_res_blocks, norm_layer=nn.BatchNorm2d, 18 | norm_layer1d=nn.BatchNorm1d, learn_3d = cfg.AE.learn_3d) 19 | self.VPTR_Dec = ResnetDecoder(cfg.Dataset.img_channels, ngf=cfg.AE.ngf, n_downsampling = cfg.AE.n_downsampling, 20 | out_layer = cfg.AE.out_layer, norm_layer=nn.BatchNorm2d) 21 | self.cfg = cfg 22 | self.l1_loss = L1Loss() 23 | 24 | def forward(self, x): 25 | encoding = self.VPTR_Enc(x) 26 | rec_frames = self.VPTR_Dec(encoding) 27 | return rec_frames, encoding 28 | 29 | def training_step(self, batch, batch_idx): 30 | loss = self.shared_step(batch, batch_idx) 31 | self.log('L1_loss_train', loss) 32 | return loss 33 | 34 | def validation_step(self, batch, batch_idx): 35 | loss = self.shared_step(batch, batch_idx) 36 | self.log('L1_loss_valid', loss) 37 | return loss 38 | 39 | def shared_step(self, batch, batch_idx): 40 | past_frames, future_frames = batch 41 | x = torch.cat([past_frames, future_frames], dim = 1) 42 | rec_frames = self.VPTR_Dec(self.VPTR_Enc(x)) 43 | loss = self.l1_loss(rec_frames, x) 44 | return loss 45 | 46 | def configure_optimizers(self): 47 | optimizer = torch.optim.Adam(params = list(self.VPTR_Enc.parameters()) + list(self.VPTR_Dec.parameters()), 48 | lr=self.cfg.AE.AE_lr, betas = (0.5, 0.999)) 49 | return optimizer 50 | 51 | class ResnetEncoder(nn.Module): 52 | def __init__(self, input_nc, ngf=64, n_downsampling = 3, num_res_blocks = 2, norm_layer=nn.BatchNorm2d, norm_layer1d=nn.BatchNorm1d, use_dropout=False, padding_type='reflect', learn_3d = True): 53 | """Construct a Resnet-based Encoder 54 | Parameters: 55 | input_nc (int) -- the number of channels in input images 56 | ngf (int) -- the number of filters in the last conv layer 57 | norm_layer -- normalization layer 58 | use_dropout (bool) -- if use dropout layers 59 | n_blocks (int) -- the number of ResNet blocks 60 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 61 | """ 62 | 63 | super().__init__() 64 | if type(norm_layer) == functools.partial: 65 | use_bias = norm_layer.func == nn.InstanceNorm2d 66 | else: 67 | use_bias = norm_layer == nn.InstanceNorm2d 68 | self.n_downsampling = n_downsampling 69 | self.num_res_blocks = num_res_blocks 70 | self.block0 = nn.Sequential(nn.ReflectionPad2d(3), 71 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 72 | norm_layer(ngf), 73 | nn.ReLU(True)) 74 | 75 | self.block1 = nn.Sequential(nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=use_bias), 76 | norm_layer(ngf*2), 77 | nn.ReLU(True)) 78 | ngf = ngf*2 79 | for i in range(1, n_downsampling): 80 | setattr(self, f'block{i+1}_3dConvAttn', Factorized3DConvAttn(in_channels=ngf, 81 | norm_layer_2d = norm_layer, 82 | norm_layer_1d = norm_layer1d, 83 | activ_func = nn.ReLU(True), 84 | learn_3d = learn_3d)) 85 | setattr(self, f'block{i+1}_conv', nn.Sequential(nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=use_bias), 86 | norm_layer(ngf*2), 87 | nn.ReLU(True))) 88 | ngf = ngf*2 89 | """ 90 | self.block2_3dConvAttn = Factorized3DConvAttn(in_channels=ngf*2, 91 | norm_layer_2d = norm_layer, 92 | norm_layer_1d = norm_layer1d, 93 | activ_func = nn.ReLU(True), 94 | learn_3d = learn_3d) 95 | self.block2_conv = nn.Sequential(nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, bias=use_bias), 96 | norm_layer(ngf * 4), 97 | nn.ReLU(True)) 98 | 99 | self.block3_3dConvAttn = Factorized3DConvAttn(in_channels=ngf*4, 100 | norm_layer_2d = norm_layer, 101 | norm_layer_1d = norm_layer1d, 102 | activ_func = nn.ReLU(True), 103 | learn_3d = learn_3d) 104 | self.block3_conv = nn.Sequential(nn.Conv2d(ngf*4, ngf*8, kernel_size=3, stride=2, padding=1, bias=use_bias), 105 | norm_layer(ngf * 8), 106 | nn.ReLU(True)) 107 | """ 108 | 109 | #9 resnet-blocks 110 | for i in range(num_res_blocks): # add ResNet blocks 111 | setattr(self, f'res_3dConvAttn_{i}', Factorized3DConvAttn(in_channels=ngf, 112 | norm_layer_2d = norm_layer, 113 | norm_layer_1d = norm_layer1d, 114 | activ_func = nn.ReLU(True), 115 | learn_3d = learn_3d)) 116 | setattr(self, f'res_conv_{i}', ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)) 117 | 118 | self.out_act = nn.ReLU() 119 | 120 | def forward(self, x): 121 | """ 122 | x: (N, T, C, H, W) 123 | """ 124 | N, T, _, _, _ = x.shape 125 | x = x.flatten(0, 1) 126 | 127 | x = self.block0(x) 128 | x = self.block1(x) 129 | for i in range(1, self.n_downsampling): 130 | x = getattr(self, f'block{i+1}_3dConvAttn')(x, T) 131 | x = getattr(self, f'block{i+1}_conv')(x) 132 | """ 133 | x = self.block2_3dConvAttn(x, T) 134 | x = self.block2_conv(x) 135 | x = self.block3_3dConvAttn(x, T) 136 | x = self.block3_conv(x) 137 | """ 138 | for i in range(self.num_res_blocks): 139 | x = getattr(self, f'res_3dConvAttn_{i}')(x, T) 140 | x = getattr(self, f'res_conv_{i}')(x) 141 | 142 | x = self.out_act(x) 143 | _, C, H, W = x.shape 144 | x = x.reshape(N, T, C, H, W) 145 | 146 | return x 147 | 148 | class ResnetDecoder(nn.Module): 149 | def __init__(self, output_nc, ngf=64, n_downsampling = 2, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect', out_layer = 'Tanh'): 150 | """Construct a Resnet-based Encoder 151 | Parameters: 152 | output_nc (int) -- the number of channels in output images 153 | ngf (int) -- the number of filters in the last conv layer 154 | norm_layer -- normalization layer 155 | use_dropout (bool) -- if use dropout layers 156 | n_blocks (int) -- the number of ResNet blocks 157 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 158 | """ 159 | super().__init__() 160 | if type(norm_layer) == functools.partial: 161 | use_bias = norm_layer.func == nn.InstanceNorm2d 162 | else: 163 | use_bias = norm_layer == nn.InstanceNorm2d 164 | 165 | model = [] 166 | 167 | #The first up-sampling layer 168 | mult = 2**n_downsampling 169 | model += [nn.ConvTranspose2d(ngf*mult, int(ngf * mult / 2), 170 | kernel_size=3, stride=2, 171 | padding=1, output_padding=1, 172 | bias=use_bias), 173 | norm_layer(int(ngf * mult / 2)), 174 | nn.ReLU(True)] 175 | 176 | for i in range(1, n_downsampling): # add upsampling layers 177 | mult = 2 ** (n_downsampling - i) 178 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 179 | kernel_size=3, stride=2, 180 | padding=1, output_padding=1, 181 | bias=use_bias), 182 | norm_layer(int(ngf * mult / 2)), 183 | nn.ReLU(True)] 184 | model += [nn.ReflectionPad2d(3)] 185 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 186 | if out_layer == 'Tanh': 187 | model += [nn.Tanh()] 188 | elif out_layer == 'Sigmoid': 189 | model += [nn.Sigmoid()] 190 | else: 191 | raise ValueError("Unsupported output layer") 192 | 193 | self.model = nn.Sequential(*model) 194 | 195 | def forward(self, x): 196 | """ 197 | x: (N, T, C, H, W) 198 | """ 199 | N, T, _, _, _ = x.shape 200 | x = x.flatten(0, 1) 201 | x = self.model(x) 202 | NT, C, H, W = x.shape 203 | x = x.reshape(NT//T, T, C, H, W) 204 | return x 205 | 206 | 207 | class ResnetBlock(nn.Module): 208 | """Define a Resnet block""" 209 | 210 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 211 | """Initialize the Resnet block 212 | A resnet block is a conv block with skip connections 213 | We construct a conv block with build_conv_block function, 214 | and implement skip connections in function. 215 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 216 | """ 217 | super(ResnetBlock, self).__init__() 218 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 219 | 220 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 221 | """Construct a convolutional block. 222 | Parameters: 223 | dim (int) -- the number of channels in the conv layer. 224 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 225 | norm_layer -- normalization layer 226 | use_dropout (bool) -- if use dropout layers. 227 | use_bias (bool) -- if the conv layer uses bias or not 228 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 229 | """ 230 | conv_block = [] 231 | p = 0 232 | if padding_type == 'reflect': 233 | conv_block += [nn.ReflectionPad2d(1)] 234 | elif padding_type == 'replicate': 235 | conv_block += [nn.ReplicationPad2d(1)] 236 | elif padding_type == 'zero': 237 | p = 1 238 | else: 239 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 240 | 241 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 242 | if use_dropout: 243 | conv_block += [nn.Dropout(0.5)] 244 | 245 | p = 0 246 | if padding_type == 'reflect': 247 | conv_block += [nn.ReflectionPad2d(1)] 248 | elif padding_type == 'replicate': 249 | conv_block += [nn.ReplicationPad2d(1)] 250 | elif padding_type == 'zero': 251 | p = 1 252 | else: 253 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 254 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 255 | 256 | return nn.Sequential(*conv_block) 257 | 258 | def forward(self, x): 259 | """Forward function (with skip connections)""" 260 | out = x + self.conv_block(x) # add skip connections 261 | return out 262 | 263 | def init_weights(net, init_type='normal', init_gain=0.02): 264 | """Initialize network weights. 265 | Parameters: 266 | net (network) -- network to be initialized 267 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 268 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 269 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 270 | work better for some applications. Feel free to try yourself. 271 | """ 272 | def init_func(m): # define the initialization function 273 | classname = m.__class__.__name__ 274 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 275 | if init_type == 'normal': 276 | init.normal_(m.weight.data, 0.0, init_gain) 277 | elif init_type == 'xavier': 278 | init.xavier_normal_(m.weight.data, gain=init_gain) 279 | elif init_type == 'kaiming': 280 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 281 | elif init_type == 'orthogonal': 282 | init.orthogonal_(m.weight.data, gain=init_gain) 283 | else: 284 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 285 | if hasattr(m, 'bias') and m.bias is not None: 286 | init.constant_(m.bias.data, 0.0) 287 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 288 | init.normal_(m.weight.data, 1.0, init_gain) 289 | init.constant_(m.bias.data, 0.0) 290 | 291 | print('initialize network with %s' % init_type) 292 | net.apply(init_func) # apply the initialization function 293 | 294 | 295 | if __name__ == '__main__': 296 | x = torch.randn(16, 3, 64, 64) 297 | enc = ResnetEncoder(input_nc = 3) 298 | dec = ResnetDecoder(output_nc = 3) 299 | out = dec(enc(x)) 300 | 301 | print(out.shape) -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | import random 7 | 8 | class GANLoss(nn.Module): 9 | """Define different GAN objectives. 10 | The GANLoss class abstracts away the need to create the target label tensor 11 | that has the same size as the input. 12 | """ 13 | 14 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, lam_gan = 1.0): 15 | """ Initialize the GANLoss class. 16 | Parameters: 17 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 18 | target_real_label (bool) - - label for a real image 19 | target_fake_label (bool) - - label of a fake image 20 | Note: Do not use sigmoid as the last layer of Discriminator. 21 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 22 | """ 23 | super(GANLoss, self).__init__() 24 | self.register_buffer('real_label', torch.tensor(target_real_label)) 25 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 26 | self.gan_mode = gan_mode 27 | if gan_mode == 'lsgan': 28 | self.loss = nn.MSELoss() 29 | elif gan_mode == 'vanilla': 30 | self.loss = nn.BCEWithLogitsLoss() 31 | elif gan_mode in ['wgangp']: 32 | self.loss = None 33 | else: 34 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 35 | 36 | self.lam_gan = lam_gan 37 | 38 | def get_target_tensor(self, prediction, target_is_real): 39 | """Create label tensors with the same size as the input. 40 | Parameters: 41 | prediction (tensor) - - tpyically the prediction from a discriminator 42 | target_is_real (bool) - - if the ground truth label is for real images or fake images 43 | Returns: 44 | A label tensor filled with ground truth label, and with the size of the input 45 | """ 46 | 47 | if target_is_real: 48 | target_tensor = self.real_label 49 | else: 50 | target_tensor = self.fake_label 51 | return target_tensor.expand_as(prediction) 52 | 53 | def __call__(self, prediction, target_is_real): 54 | """Calculate loss given Discriminator's output and grount truth labels. 55 | Parameters: 56 | prediction (tensor) - - tpyically the prediction output from a discriminator 57 | target_is_real (bool) - - if the ground truth label is for real images or fake images 58 | Returns: 59 | the calculated loss. 60 | """ 61 | if self.gan_mode in ['lsgan', 'vanilla']: 62 | target_tensor = self.get_target_tensor(prediction, target_is_real) 63 | loss = self.loss(prediction, target_tensor) 64 | elif self.gan_mode == 'wgangp': 65 | if target_is_real: 66 | loss = -prediction.mean() 67 | else: 68 | loss = prediction.mean() 69 | return loss*self.lam_gan 70 | 71 | 72 | class GradientPanelty(nn.Module): 73 | def __init__(self, lam_wgan, lam_gp = 10.0): 74 | super().__init__() 75 | self.lam_gan = lam_wgan 76 | self.lam_gp = lam_gp 77 | 78 | def __call__(self, real_x, fake_x, critic): 79 | """ 80 | real_x: (N, C, H, W) 81 | fake_x: (N, C, H, W) 82 | """ 83 | batchsize = fake_x.shape[0] 84 | t = torch.rand(batchsize, 1, 1, 1, device = fake_x.device) 85 | interpolate_x = t*real_x + (1-t)*fake_x 86 | interpolate_x.requires_grad_(True) 87 | 88 | critic_out = critic(interpolate_x) 89 | 90 | gradients = torch.autograd.grad(outputs=critic_out, inputs=interpolate_x, 91 | grad_outputs=torch.ones(critic_out.size()).to(fake_x.device), 92 | create_graph=True, retain_graph=True, only_inputs=True) 93 | gradients = gradients[0].view(batchsize, -1) 94 | gradients_norm = (gradients+1e-16).norm(2, dim = 1) 95 | gradient_panelty = torch.clamp(gradients_norm - 1, min = 0.) 96 | return torch.square(gradient_panelty).mean() * self.lam_gp * self.lam_gan 97 | 98 | 99 | class L1Loss(nn.Module): 100 | def __init__(self, norm_dim = None, lam = 1.0): 101 | """ 102 | Args: 103 | norm_dim: dimensionality for normalizing the input features, default no normalization 104 | lam: the weight for the L1 loss 105 | """ 106 | super().__init__() 107 | self.norm_dim = norm_dim 108 | self.lam = lam 109 | 110 | def __call__(self, gt, pred): 111 | """ 112 | pred --- tensor with shape (B, T, ...) 113 | gt --- tensor with shape (B, T, ...) 114 | """ 115 | if self.norm_dim is not None: 116 | gt = F.normalize(gt, p = 2, dim = self.norm_dim) 117 | pred = F.normalize(pred, p = 2, dim = self.norm_dim) 118 | 119 | se = torch.abs(pred - gt) 120 | mse = se.mean() 121 | return mse*self.lam 122 | 123 | class TemporalDiff(nn.Module): 124 | def __init__(self, lam = 1.0): 125 | super().__init__() 126 | self.lam = lam #weight for the temporal difference loss 127 | 128 | def forward(self, gt, pred): 129 | """ 130 | pred --- tensor with shape (B, T, ...) 131 | gt --- tensor with shape (B, T, ...) 132 | """ 133 | shuffle_pred, shuffle_gt = self.random_shuffle(pred, gt) 134 | diff_pred = pred - shuffle_pred 135 | diff_gt = gt - shuffle_gt 136 | 137 | loss = torch.abs(diff_pred - diff_gt).mean() 138 | 139 | return self.lam*loss 140 | 141 | def random_shuffle(self, pred, gt): 142 | #Shuffle along the temporal axis to get the product of two marginal distributions 143 | T = pred.shape[1] 144 | 145 | rand_shift = random.randint(1, T-1) 146 | return torch.roll(pred, rand_shift, 1), torch.roll(gt, rand_shift, 1) 147 | 148 | class MSELoss(nn.Module): 149 | def __init__(self, temporal_weight = None, norm_dim = None): 150 | """ 151 | Args: 152 | temporal_weight: penalty for loss at different time step, Tensor with length T 153 | """ 154 | super().__init__() 155 | self.temporal_weight = temporal_weight 156 | self.norm_dim = norm_dim 157 | 158 | def __call__(self, gt, pred): 159 | """ 160 | pred --- tensor with shape (B, T, ...) 161 | gt --- tensor with shape (B, T, ...) 162 | """ 163 | if self.norm_dim is not None: 164 | gt = F.normalize(gt, p = 2, dim = self.norm_dim) 165 | pred = F.normalize(pred, p = 2, dim = self.norm_dim) 166 | 167 | se = torch.square(pred - gt) 168 | if self.temporal_weight is not None: 169 | w = self.temporal_weight.to(se.device) 170 | if len(se.shape) == 5: 171 | se = se * w[None, :, None, None, None] 172 | elif len(se.shape) == 6: 173 | se = se * w[None, :, None, None, None, None] #for warped frames, (N, num_future_frames, num_past_frames, C, H, W) 174 | mse = se.mean() 175 | return mse 176 | 177 | class GDL(nn.Module): 178 | def __init__(self, alpha = 1, temporal_weight = None): 179 | """ 180 | Args: 181 | alpha: hyper parameter of GDL loss, float 182 | temporal_weight: penalty for loss at different time step, Tensor with length T 183 | """ 184 | super().__init__() 185 | self.alpha = alpha 186 | self.temporal_weight = temporal_weight 187 | 188 | def __call__(self, gt, pred): 189 | """ 190 | pred --- tensor with shape (B, T, ...) 191 | gt --- tensor with shape (B, T, ...) 192 | """ 193 | gt_shape = gt.shape 194 | if len(gt_shape) == 5: 195 | B, T, _, _, _ = gt.shape 196 | elif len(gt_shape) == 6: #for warped frames, (N, num_future_frames, num_past_frames, C, H, W) 197 | B, T, TP, _, _, _ = gt.shape 198 | gt = gt.flatten(0, -4) 199 | pred = pred.flatten(0, -4) 200 | 201 | gt_i1 = gt[:, :, 1:, :] 202 | gt_i2 = gt[:, :, :-1, :] 203 | gt_j1 = gt[:, :, :, :-1] 204 | gt_j2 = gt[:, :, :, 1:] 205 | 206 | pred_i1 = pred[:, :, 1:, :] 207 | pred_i2 = pred[:, :, :-1, :] 208 | pred_j1 = pred[:, :, :, :-1] 209 | pred_j2 = pred[:, :, :, 1:] 210 | 211 | term1 = torch.abs(gt_i1 - gt_i2) 212 | term2 = torch.abs(pred_i1 - pred_i2) 213 | term3 = torch.abs(gt_j1 - gt_j2) 214 | term4 = torch.abs(pred_j1 - pred_j2) 215 | 216 | if self.alpha != 1: 217 | gdl1 = torch.pow(torch.abs(term1 - term2), self.alpha) 218 | gdl2 = torch.pow(torch.abs(term3 - term4), self.alpha) 219 | else: 220 | gdl1 = torch.abs(term1 - term2) 221 | gdl2 = torch.abs(term3 - term4) 222 | 223 | if self.temporal_weight is not None: 224 | assert self.temporal_weight.shape[0] == T, "Mismatch between temporal_weight and predicted sequence length" 225 | w = self.temporal_weight.to(gdl1.device) 226 | _, C, H, W = gdl1.shape 227 | _, C2, H2, W2= gdl2.shape 228 | if len(gt_shape) == 5: 229 | gdl1 = gdl1.reshape(B, T, C, H, W) 230 | gdl2 = gdl2.reshape(B, T, C2, H2, W2) 231 | gdl1 = gdl1 * w[None, :, None, None, None] 232 | gdl2 = gdl2 * w[None, :, None, None, None] 233 | elif len(gt_shape) == 6: 234 | gdl1 = gdl1.reshape(B, T, TP, C, H, W) 235 | gdl2 = gdl2.reshape(B, T, TP, C2, H2, W2) 236 | gdl1 = gdl1 * w[None, :, None, None, None, None] 237 | gdl2 = gdl2 * w[None, :, None, None, None, None] 238 | 239 | #gdl1 = gdl1.sum(-1).sum(-1) 240 | #gdl2 = gdl2.sum(-1).sum(-1) 241 | 242 | #gdl_loss = torch.mean(gdl1 + gdl2) 243 | gdl1 = gdl1.mean() 244 | gdl2 = gdl2.mean() 245 | gdl_loss = gdl1 + gdl2 246 | 247 | return gdl_loss 248 | 249 | class BiPatchNCE(nn.Module): 250 | """ 251 | Bidirectional patchwise contrastive loss 252 | Implemented Based on https://github.com/alexandonian/contrastive-feature-loss/blob/main/models/networks/nce.py 253 | """ 254 | def __init__(self, N, T, h, w, temperature = 0.07, lam=1.0): 255 | """ 256 | T: number of frames 257 | N: batch size 258 | h: feature height 259 | w: feature width 260 | temporal_weight: penalty for loss at different time step, Tensor with length T 261 | """ 262 | super().__init__() 263 | 264 | #mask meaning; 1 for postive pairs, 0 for negative pairs 265 | mask = torch.eye(h*w).long() 266 | mask = mask.unsqueeze(0).repeat(N*T, 1, 1).requires_grad_(False) #(N*T, h*w, h*w) 267 | self.register_buffer('mask', mask) 268 | self.temperature = temperature 269 | self.lam = lam 270 | 271 | def forward(self, gt_f, pred_f): 272 | """ 273 | gt_f: ground truth feature/images, with shape (N, T, C, h, w) 274 | pred_f: predicted feature/images, with shape (N, T, C, h, w) 275 | """ 276 | mask = self.mask 277 | 278 | gt_f = rearrange(gt_f, "N T C h w -> (N T) (h w) C") 279 | pred_f = rearrange(pred_f, "N T C h w -> (N T) (h w) C") 280 | 281 | #direction 1, decompose the matmul to two steps, Stop gradient for the negative pairs 282 | score1_diag = torch.matmul(gt_f, pred_f.transpose(1, 2)) * mask 283 | score1_non_diag = torch.matmul(gt_f, pred_f.detach().transpose(1, 2)) * (1.0 - mask) 284 | score1 = score1_diag + score1_non_diag #(N*T, h*w, h*w) 285 | score1 = torch.div(score1, self.temperature) 286 | 287 | #direction 2 288 | score2_diag = torch.matmul(pred_f, gt_f.transpose(1, 2)) * mask 289 | score2_non_diag = torch.matmul(pred_f, gt_f.detach().transpose(1, 2)) * (1.0 - mask) 290 | score2 = score2_diag + score2_non_diag 291 | score2 = torch.div(score2, self.temperature) 292 | 293 | target = (mask == 1).int() 294 | target = target.to(score1.device) 295 | target.requires_grad = False 296 | target = target.flatten(0, 1) #(N*T*h*w, h*w) 297 | target = torch.argmax(target, dim = 1) 298 | 299 | loss1 = nn.CrossEntropyLoss()(score1.flatten(0, 1), target) 300 | loss2 = nn.CrossEntropyLoss()(score2.flatten(0, 1), target) 301 | loss = (loss1 + loss2)*0.5 302 | 303 | return loss*self.lam 304 | 305 | 306 | class NoamOpt: 307 | """ 308 | defatult setup from attention is all you need: 309 | factor = 2 310 | optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 311 | Optim wrapper that implements rate. 312 | """ 313 | def __init__(self, model_size, factor, train_loader, warmup_epochs, optimizer): 314 | self.optimizer = optimizer 315 | self._step = 0 316 | self.warmup = len(train_loader)*warmup_epochs 317 | self.factor = factor 318 | self.model_size = model_size 319 | self._rate = 0 320 | 321 | def step(self): 322 | "Update parameters and rate" 323 | self._step += 1 324 | rate = self.rate() 325 | for p in self.optimizer.param_groups: 326 | p['lr'] = rate 327 | self._rate = rate 328 | 329 | def rate(self, step = None): 330 | "Implement `lrate` above" 331 | if step is None: 332 | step = self._step 333 | return self.factor * \ 334 | (self.model_size ** (-0.5) * 335 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 336 | 337 | def reset_step(self, init_epoch, train_loader): 338 | print("!!!!Learning rate warmup warning: If you are resume training, keep the same Batchsize as before!!!!") 339 | self._step = len(train_loader) * init_epoch 340 | 341 | class Div_KL(nn.Module): 342 | def __init__(self, beta): 343 | super().__init__() 344 | self.beta = beta 345 | 346 | def forward(self, mu1, logvar1, mu2, logvar2): 347 | # KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2)) = 348 | # log( sqrt( 349 | # 350 | N = mu1.shape[0] 351 | sigma1 = logvar1.mul(0.5).exp() 352 | sigma2 = logvar2.mul(0.5).exp() 353 | kld = torch.log(sigma2/sigma1) + (torch.exp(logvar1) + (mu1 - mu2)**2)/(2*torch.exp(logvar2)) - 1/2 354 | return self.beta * kld.sum() / N 355 | 356 | if __name__ == '__main__': 357 | """ 358 | w = temporal_weight_func(10) 359 | gdl_loss = GDL(temporal_weight = w) 360 | a = torch.randn(4, 10, 3, 64, 64) 361 | b = torch.randn(4, 10, 3, 64, 64) 362 | gdl = gdl_loss(a, b) 363 | 364 | print(gdl) 365 | 366 | mse_loss = MSELoss() 367 | mse1 = mse_loss(a, b) 368 | 369 | mse2 = nn.MSELoss()(a, b) 370 | print((mse1 - mse2).sum()) 371 | 372 | a = torch.randn(4, 10, 384) 373 | b = torch.randn(10, 4, 384) 374 | pc_ssl = PCSSL(10, 4, w) 375 | pc_ssl1 = PCSSL(10, 4) 376 | 377 | print(pc_ssl1(a, b)) 378 | print(pc_ssl(a, b)) 379 | 380 | tc_ssl1 = TCSSL(10, 4, 3, w) 381 | tc_ssl2 = TCSSL(10, 4, 3) 382 | print(tc_ssl1(a), tc_ssl2(a)) 383 | 384 | 385 | a = torch.randn(4, 10, 384, 2, 2) 386 | b = torch.randn(4, 10, 384, 2, 2) 387 | dpc_ssl1 = DPCSSL(10, 4, 2, 2, temporal_weight = w) 388 | dpc_ssl2 = DPCSSL(10, 4, 2, 2) 389 | 390 | print(dpc_ssl1(a, b), dpc_ssl2(a, b)) 391 | 392 | m = BiPatchNCE(4, 10, 8, 8) 393 | gt = torch.randn(4, 10, 256, 8, 8).requires_grad_(True) 394 | pred = torch.randn(4, 10, 256, 8, 8).requires_grad_(True) 395 | loss, score1,score2 = m(gt, pred) 396 | loss.backward() 397 | print(loss, score1.shape, score2.shape) 398 | #print(score1.grad) 399 | #print(score2.grad) 400 | 401 | 402 | feat = torch.randn(1, 5, 384, 5, 5) 403 | m = TermporalPairwiseMSE(norm_dim = 2) 404 | loss= m(feat) 405 | print(loss) 406 | 407 | 408 | 409 | nce = BiPatchNCE(4, 10, 8, 8) 410 | feat_q = torch.randn(4, 10, 512, 8, 8) 411 | feat_q = feat_q.to('cuda') 412 | feat_k = feat_q 413 | 414 | loss1, score1 = nce(F.normalize(feat_q, p=2.0, dim=2), F.normalize(feat_q, p=2.0, dim=2)) 415 | print(loss1.mean()) 416 | """ 417 | gt = torch.randn(4, 64, 64, 3) 418 | pred = torch.randn(4, 64, 64, 3) 419 | l1_loss = L1Loss() 420 | loss1 = l1_loss(gt, pred) 421 | print(loss1) 422 | loss2 = l1_loss.patch_l1(gt, pred) 423 | print(loss2) 424 | -------------------------------------------------------------------------------- /utils/train_summary.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pathlib import Path 3 | import torch 4 | import torchvision.transforms as transforms 5 | from pathlib import Path 6 | import shutil 7 | from collections import OrderedDict 8 | 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.callbacks import Callback 11 | 12 | def save_code_cfg(cfg, ckpt_dir): 13 | #Save the code and config.ymal 14 | if not Path(ckpt_dir).exists(): 15 | Path(ckpt_dir).mkdir(parents=True, exist_ok=True) 16 | code_files = read_code_files() 17 | torch.save({ 18 | 'cfg': cfg, 19 | 'code': code_files 20 | }, Path(ckpt_dir).joinpath('code_cfg.tar').absolute().as_posix()) 21 | 22 | class VisCallbackAE(Callback): 23 | @pl.utilities.rank_zero_only 24 | def on_train_epoch_end(self, trainer, pl_module): 25 | if trainer.current_epoch % pl_module.cfg.AE.log_per_epochs == 0: 26 | print(pl_module.device) 27 | pl_module = pl_module.eval() 28 | renorm_transform = trainer.datamodule.renorm_transform 29 | 30 | dataloader = trainer.datamodule.train_dataloader() 31 | save_dir = Path(pl_module.cfg.AE.ckpt_save_dir).joinpath(f'train_gifs_epoch{trainer.current_epoch}') 32 | self.vis_samples(dataloader, renorm_transform, pl_module, save_dir) 33 | 34 | dataloader = trainer.datamodule.val_dataloader() 35 | save_dir = Path(pl_module.cfg.AE.ckpt_save_dir).joinpath(f'val_gifs_epoch{trainer.current_epoch}') 36 | self.vis_samples(dataloader, renorm_transform, pl_module, save_dir) 37 | 38 | def vis_samples(self, dataloader, renorm_transform, pl_module, save_dir): 39 | past_frames, future_frames = next(iter(dataloader)) 40 | past_frames, future_frames = past_frames.to(pl_module.device), future_frames.to(pl_module.device) 41 | with torch.no_grad(): 42 | rec_past_frames, _ = pl_module(past_frames) 43 | rec_future_frames, _ = pl_module(future_frames) 44 | 45 | gt_frames = torch.cat([past_frames, future_frames], dim = 1) 46 | rec_frames = torch.cat([rec_past_frames, rec_future_frames], dim = 1) 47 | 48 | N = future_frames.shape[0] 49 | idx = min(N, 4) 50 | 51 | visualize_batch_clips(gt_frames[0:idx, :, ...], rec_frames[0:idx, :, ...], rec_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'ae') 52 | 53 | class VisCallbackPredictor(Callback): 54 | @pl.utilities.rank_zero_only 55 | def on_train_epoch_end(self, trainer, pl_module): 56 | if trainer.current_epoch % pl_module.cfg.Predictor.log_per_epochs == 0: 57 | pl_module = pl_module.eval() 58 | renorm_transform = trainer.datamodule.renorm_transform 59 | 60 | dataloader = trainer.datamodule.train_dataloader() 61 | save_dir = Path(pl_module.cfg.Predictor.ckpt_save_dir).joinpath(f'train_gifs_epoch{trainer.current_epoch}') 62 | self.vis_samples(dataloader, renorm_transform, pl_module, save_dir) 63 | 64 | dataloader = trainer.datamodule.val_dataloader() 65 | save_dir = Path(pl_module.cfg.Predictor.ckpt_save_dir).joinpath(f'val_gifs_epoch{trainer.current_epoch}') 66 | self.vis_samples(dataloader, renorm_transform, pl_module, save_dir) 67 | 68 | def vis_samples(self, dataloader, renorm_transform, pl_module, save_dir): 69 | batch = next(iter(dataloader)) 70 | batch = pl_module.batch_process_fn(batch) 71 | past_frames, future_frames = batch 72 | past_frames, future_frames = past_frames.to(pl_module.device), future_frames.to(pl_module.device) 73 | with torch.no_grad(): 74 | rec_past_frames, rec_future_frames, pred_future_frames = pl_module(past_frames, future_frames) 75 | 76 | N = pred_future_frames.shape[0] 77 | idx = min(N, 4) 78 | 79 | visualize_batch_clips(past_frames[0:idx, :, ...], rec_past_frames[0:idx, :, ...], rec_future_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'ae') 80 | visualize_batch_clips(past_frames[0:idx, :, ...], future_frames[0:idx, :, ...], pred_future_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'pred') 81 | 82 | def resume_training(module_dict, optimizer_dict, resume_ckpt, loss_name_list = None, map_location = None): 83 | modules_state_dict, optimizers_state_dict, start_epoch, history_loss_dict, _ = load_ckpt(resume_ckpt, map_location) 84 | for k, m in module_dict.items(): 85 | state_dict = modules_state_dict[k] 86 | try: 87 | m.load_state_dict(state_dict) 88 | except RuntimeError: #load the model trained by data distributed parallel 89 | new_state_dict = OrderedDict() 90 | for sk, sv in state_dict.items(): 91 | nk = sk[7:] # remove `module.` 92 | new_state_dict[nk] = sv 93 | m.load_state_dict(new_state_dict) 94 | for k, m in optimizer_dict.items(): 95 | state_dict = optimizers_state_dict[k] 96 | try: 97 | m.load_state_dict(state_dict) 98 | except RuntimeError: 99 | print('Optimizer statedict with module.') 100 | new_state_dict = OrderedDict() 101 | for sk, sv in state_dict.items(): 102 | nk = sk[7:] # remove `module.` 103 | new_state_dict[nk] = sv 104 | m.load_state_dict(new_state_dict) 105 | 106 | if map_location is None: 107 | loss_dict = init_loss_dict(loss_name_list, history_loss_dict) 108 | return loss_dict, start_epoch 109 | else: 110 | return history_loss_dict, start_epoch 111 | 112 | 113 | class AverageMeters(object): 114 | def __init__(self, loss_name_list): 115 | self.loss_name_list = loss_name_list 116 | self.meters = {} 117 | for name in loss_name_list: 118 | self.meters[name] = BatchAverageMeter(name, ':.10e') 119 | 120 | def iter_update(self, iter_loss_dict): 121 | for k, v in iter_loss_dict.items(): 122 | self.meters[k].update(v) 123 | 124 | def epoch_update(self, loss_dict, epoch, train_flag = True): 125 | if train_flag: 126 | for k, v in loss_dict.items(): 127 | try: 128 | v.train.append(self.meters[k].avg) 129 | except AttributeError: 130 | pass 131 | except KeyError: 132 | v.train.append(0) 133 | else: 134 | for k, v in loss_dict.items(): 135 | try: 136 | v.val.append(self.meters[k].avg) 137 | except AttributeError: 138 | pass 139 | except KeyError: 140 | v.val.append(0) 141 | loss_dict['epochs'] = epoch 142 | 143 | return loss_dict 144 | 145 | def gather_AverageMeters(aveMeter_list): 146 | """ 147 | average the avg value from different rank 148 | Args: 149 | aveMeter_list: list of AverageMeters objects 150 | """ 151 | AM0 = aveMeter_list[0] 152 | name_list = AM0.loss_name_list 153 | 154 | return_AM = AverageMeters(name_list) 155 | for name in name_list: 156 | avg_val = 0 157 | for am in aveMeter_list: 158 | rank_avg = am.meters[name].avg 159 | avg_val += rank_avg 160 | avg_val = avg_val/len(aveMeter_list) 161 | return_AM.meters[name].avg = avg_val 162 | 163 | return return_AM 164 | 165 | 166 | class Loss_tuple(object): 167 | def __init__(self): 168 | self.train = [] 169 | self.val = [] 170 | 171 | def init_loss_dict(loss_name_list, history_loss_dict = None): 172 | loss_dict = {} 173 | for name in loss_name_list: 174 | loss_dict[name] = Loss_tuple() 175 | loss_dict['epochs'] = 0 176 | 177 | if history_loss_dict is not None: 178 | for k, v in history_loss_dict.items(): 179 | loss_dict[k] = v 180 | 181 | for k, v in loss_dict.items(): 182 | if k not in history_loss_dict: 183 | lt = Loss_tuple() 184 | lt.train = [0] * history_loss_dict['epochs'] 185 | lt.val = [0] * history_loss_dict['epochs'] 186 | loss_dict[k] = lt 187 | 188 | return loss_dict 189 | 190 | def write_summary(summary_writer, in_loss_dict, train_flag = True): 191 | loss_dict = in_loss_dict.copy() 192 | del loss_dict['epochs'] 193 | if train_flag: 194 | for k, v in loss_dict.items(): 195 | for i in range(len(v.train)): 196 | summary_writer.add_scalars(k, {'train': v.train[i]}, i+1) 197 | else: 198 | for k, v in loss_dict.items(): 199 | for i in range(len(v.val)): 200 | summary_writer.add_scalars(k, {'val': v.val[i]}, i+1) 201 | 202 | def save_ckpt(Modules_dict, Optimizers_dict, epoch, loss_dict, ckpt_codes, save_dir): 203 | #Save checkpoints every epoch 204 | if not Path(save_dir).exists(): 205 | Path(save_dir).mkdir(parents=True, exist_ok=True) 206 | ckpt_file = Path(save_dir).joinpath(f"epoch_{epoch}.tar") 207 | 208 | module_state_dict = {} 209 | for k, m in Modules_dict.items(): 210 | module_state_dict[k] = m.state_dict() 211 | optim_state_dict = {} 212 | for k, m in Optimizers_dict.items(): 213 | optim_state_dict[k] = m.state_dict() 214 | torch.save({ 215 | 'epoch': epoch, 216 | 'loss_dict': loss_dict, #{loss_name: [train_loss_list, val_loss_list]} 217 | 'Module_state_dict': module_state_dict, 218 | 'optimizer_state_dict': optim_state_dict, 219 | 'code': ckpt_codes 220 | }, ckpt_file.absolute().as_posix()) 221 | 222 | def load_ckpt(ckpt_file, map_location = None): 223 | ckpt = torch.load(ckpt_file, map_location = map_location) 224 | 225 | epoch = ckpt["epoch"] 226 | loss_dict = ckpt["loss_dict"] 227 | Modules_state_dict = ckpt['Module_state_dict'] 228 | Optimizers_state_dict = ckpt['optimizer_state_dict'] 229 | code = ckpt['code'] 230 | 231 | return Modules_state_dict, Optimizers_state_dict, epoch, loss_dict, code 232 | 233 | def visualize_batch_clips(gt_past_frames_batch, gt_future_frames_batch, pred_frames_batch, file_dir, renorm_transform = None, desc = None): 234 | """ 235 | pred_frames_batch: tensor with shape (N, future_clip_length, C, H, W) 236 | gt_future_frames_batch: tensor with shape (N, future_clip_length, C, H, W) 237 | gt_past_frames_batch: tensor with shape (N, past_clip_length, C, H, W) 238 | """ 239 | if not Path(file_dir).exists(): 240 | Path(file_dir).mkdir(parents=True, exist_ok=True) 241 | def save_clip(clip, file_name): 242 | imgs = [] 243 | if renorm_transform is not None: 244 | clip = renorm_transform(clip) 245 | clip = torch.clamp(clip, min = 0., max = 1.0) 246 | for i in range(clip.shape[0]): 247 | img = transforms.ToPILImage()(clip[i, ...]) 248 | imgs.append(img) 249 | 250 | imgs[0].save(str(Path(file_name).absolute()), save_all = True, append_images = imgs[1:], loop = 0) 251 | 252 | def append_frames(batch, max_clip_length): 253 | d = max_clip_length - batch.shape[1] 254 | batch = torch.cat([batch, batch[:, -2:-1, :, :, :].repeat(1, d, 1, 1, 1)], dim = 1) 255 | return batch 256 | max_length = max(gt_future_frames_batch.shape[1], gt_past_frames_batch.shape[1]) 257 | max_length = max(max_length, pred_frames_batch.shape[1]) 258 | if gt_past_frames_batch.shape[1] < max_length: 259 | gt_past_frames_batch = append_frames(gt_past_frames_batch, max_length) 260 | if gt_future_frames_batch.shape[1] < max_length: 261 | gt_future_frames_batch = append_frames(gt_future_frames_batch, max_length) 262 | if pred_frames_batch.shape[1] < max_length: 263 | pred_frames_batch = append_frames(pred_frames_batch, max_length) 264 | 265 | batch = torch.cat([gt_past_frames_batch, gt_future_frames_batch, pred_frames_batch], dim = -1) #shape (N, clip_length, C, H, 3W) 266 | batch = batch.cpu() 267 | N = batch.shape[0] 268 | for n in range(N): 269 | clip = batch[n, ...] 270 | file_name = file_dir.joinpath(f'{desc}_clip_{n}.gif') 271 | save_clip(clip, file_name) 272 | 273 | def read_code_files(): 274 | """ 275 | Read all the files under VideoFramePrediction into bytes, and return a dictionary 276 | key of the dict is file name (do not include root dir) 277 | value of the dict is bytes of each file 278 | """ 279 | proj_folder = Path(__file__).resolve().parents[1].absolute() 280 | code_files = [] 281 | for file in proj_folder.rglob('*'): 282 | file_str = str(file) 283 | if '.git' not in file_str and '__pycache__' not in file_str and '.ipynb_checkpoints' not in file_str: 284 | code_files.append(file) 285 | 286 | code_file_dict = {} 287 | for file_name in code_files: 288 | try: 289 | with open(file_name, 'rb') as f: 290 | str_name = str(file_name).strip().split('VideoFramePrediction') 291 | str_name = 'VideoFramePrediction' + str_name[-1] 292 | code_file_dict[str_name] = f.read() 293 | except IsADirectoryError: 294 | pass 295 | 296 | return code_file_dict 297 | 298 | def write_code_files(code_file_dict, parent_dir): 299 | """ 300 | Write the saved code file dictionary to disk 301 | parent_dir: directory to place all the saved code files 302 | """ 303 | for k, v in code_file_dict.items(): 304 | file_path = Path(parent_dir).joinpath(k) 305 | if not file_path.exists(): 306 | file_path.parent.mkdir(parents = True, exist_ok=True) 307 | with open(file_path, 'ab') as f: 308 | f.write(v) 309 | 310 | class BatchAverageMeter(object): 311 | """Computes and stores the average and current value 312 | https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363 313 | """ 314 | def __init__(self, name, fmt=':f'): 315 | self.name = name 316 | self.fmt = fmt 317 | self.reset() 318 | 319 | def reset(self): 320 | self.val = 0 321 | self.avg = 0 322 | self.sum = 0 323 | self.count = 0 324 | 325 | def update(self, val, n=1): 326 | self.val = val 327 | self.sum += val * n 328 | self.count += n 329 | self.avg = self.sum / self.count 330 | 331 | def __str__(self): 332 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 333 | return fmtstr.format(**self.__dict__) 334 | 335 | def parameters_count(model): 336 | """ 337 | for name, param in model.named_parameters(): 338 | print(name, param.size()) 339 | """ 340 | count = sum(p.numel() for p in model.parameters() if p.requires_grad) 341 | print(f"Number of trainable parameters are {float(count)/1e6} Million") 342 | return count 343 | 344 | def show_AE_samples(VPTR_Enc, VPTR_Dec, sample, renorm_transform, device, save_dir): 345 | VPTR_Enc = VPTR_Enc.eval() 346 | VPTR_Dec = VPTR_Dec.eval() 347 | with torch.no_grad(): 348 | past_frames, future_frames = sample 349 | past_frames = past_frames.to(device) 350 | future_frames = future_frames.to(device) 351 | 352 | past_gt_feats = VPTR_Enc(past_frames) 353 | future_gt_feats = VPTR_Enc(future_frames) 354 | 355 | rec_past_frames = VPTR_Dec(past_gt_feats) 356 | rec_future_frames = VPTR_Dec(future_gt_feats) 357 | 358 | gt_frames = torch.cat([past_frames, future_frames], dim = 1) 359 | rec_frames = torch.cat([rec_past_frames, rec_future_frames], dim = 1) 360 | 361 | N = future_frames.shape[0] 362 | idx = min(N, 4) 363 | visualize_batch_clips(gt_frames[0:idx, :, ...], rec_frames[0:idx, :, ...], rec_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'ae') 364 | 365 | def show_predictor_samples(VPTR_Enc, VPTR_Dec, predictor, sample, save_dir, device, renorm_transform): 366 | predictor = predictor.eval() 367 | with torch.no_grad(): 368 | past_frames, future_frames = sample 369 | past_frames = past_frames.to(device) 370 | future_frames = future_frames.to(device) 371 | 372 | past_gt_feats = VPTR_Enc(past_frames) 373 | future_gt_feats = VPTR_Enc(future_frames) 374 | 375 | rec_past_frames = VPTR_Dec(past_gt_feats) 376 | rec_future_frames = VPTR_Dec(future_gt_feats) 377 | try: 378 | stochastic_flag = predictor.stochastic 379 | except AttributeError as err: 380 | stochastic_flag = predictor.module.stochastic 381 | if stochastic_flag: 382 | pred_future_feats, _, _, _, _ = predictor(past_gt_feats, future_gt_feats) 383 | else: 384 | pred_future_feats = predictor(past_gt_feats, future_gt_feats) 385 | 386 | pred_future_frames = VPTR_Dec(pred_future_feats) 387 | 388 | N = pred_future_frames.shape[0] 389 | idx = min(N, 4) 390 | visualize_batch_clips(past_frames[0:idx, :, ...], future_frames[0:idx, :, ...], pred_future_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'pred') 391 | visualize_batch_clips(past_frames[0:idx, :, ...], rec_future_frames[0:idx, :, ...], rec_past_frames[0:idx, :, ...], save_dir, renorm_transform, desc = 'ae') 392 | 393 | 394 | if __name__ == '__main__': 395 | #code_file_dict = read_code_files() 396 | modules_state_dict, optimizers_state_dict, start_epoch, history_loss_dict, code_file_dict = load_ckpt('/home/travail/xiyex/VPTR_ckpts/MNIST_GPT_MSEGDL_ckpt/epoch_193.tar') 397 | write_code_files(code_file_dict, '/home/travail/xiyex/Copy1VideoFramePrediction') 398 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | import random 7 | 8 | 9 | class Factorized3DConvAttn(nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | atten_channels_downsample_ratio = 8, #ratio for projecting input channels to key/query feature channels 13 | value_channels_downsample_ratio = 2, #ratio for projecting input channels to value feature channels 14 | use_bias = True, #use bias for the key/query projection 15 | learn_gamma = True, #Learning the gamma of skip connection 16 | norm_layer_2d = nn.BatchNorm2d, 17 | norm_layer_1d = nn.BatchNorm1d, 18 | activ_func = nn.ReLU(), 19 | conv_first = True, #True for conv2d-attn2d-conv1d-attn1d, False for attn2d-conv2d-attn1d-conv1d 20 | learn_3d = True 21 | ): 22 | super().__init__() 23 | self.in_channels = in_channels 24 | 25 | self.spatial_conv = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=use_bias), 26 | norm_layer_2d(in_channels), 27 | activ_func) 28 | self.attn2d = NonLocalAttenion2D(in_channels, atten_channels_downsample_ratio, value_channels_downsample_ratio, True, learn_gamma, 29 | norm_layer_2d(in_channels), activ_func=activ_func) 30 | self.learn_3d = learn_3d 31 | self.temporal_conv = None 32 | self.attn1d = None 33 | if self.learn_3d: 34 | self.temporal_conv = nn.Sequential(nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, padding='same', bias=use_bias), 35 | norm_layer_1d(in_channels), 36 | activ_func) 37 | self.attn1d = NonLocalAttenion1D(in_channels, atten_channels_downsample_ratio, value_channels_downsample_ratio, True, learn_gamma, 38 | norm_layer_1d(in_channels), activ_func=activ_func) 39 | 40 | self.conv_first = conv_first 41 | self.forward_func = self.conv_forward 42 | if not self.conv_first: 43 | self.forward_func = self.attn_forward 44 | 45 | def forward(self, x, T): 46 | return self.forward_func(x, T) 47 | 48 | def conv_forward(self, x, T): 49 | """ 50 | x: (N*T, C, H, W) 51 | N: batch_size 52 | T: video clip length 53 | out: (N*T, C, H, W) 54 | """ 55 | NT, C, H, W = x.size() 56 | N = NT//T 57 | 58 | skip = x 59 | x = self.spatial_conv(x) + x #(N*T, C, H, W) 60 | x = self.attn2d(x) 61 | 62 | if self.learn_3d: 63 | x = x.reshape(N, T, C, H, W).permute(0, 3, 4, 2, 1).flatten(0, 2) #(N*H*W, C, T) 64 | x = self.temporal_conv(x) + x 65 | x = self.attn1d(x) 66 | x = x.reshape(N, H, W, C, T).permute(0, 4, 3, 1, 2).flatten(0, 1) 67 | 68 | x = x + skip 69 | 70 | return x 71 | 72 | def attn_forward(self, x, T): 73 | """ 74 | x: (N*T, C, H, W) 75 | N: batch_size 76 | T: video clip length 77 | out: (N*T, C, H, W) 78 | """ 79 | NT, C, H, W = x.size() 80 | N = NT//T 81 | 82 | skip = x 83 | x = self.attn2d(x) 84 | x = self.spatial_conv(x) + x #(N*T, C, H, W) 85 | 86 | if self.learn_3d: 87 | x = x.reshape(N, T, C, H, W).permute(0, 3, 4, 2, 1).flatten(0, 2) #(N*H*W, C, T) 88 | x = self.attn1d(x) 89 | x = self.temporal_conv(x) + x 90 | 91 | x = x.reshape(N, H, W, C, T).permute(0, 4, 3, 1, 2).flatten(0, 1) 92 | 93 | x = x + skip 94 | 95 | return x 96 | 97 | 98 | class NonLocalAttenion2D(nn.Module): 99 | """ 100 | Based on https://github.com/brain-research/self-attention-gan/blob/master/non_local.py 101 | """ 102 | def __init__(self, 103 | in_channels, #Number of input feature channels 104 | atten_channels_downsample_ratio = 8, #ratio for projecting input channels to key/query feature channels 105 | value_channels_downsample_ratio = 2, #ratio for projecting input channels to value feature channels 106 | bias = True, #use bias for the key/query projection 107 | learn_gamma = True, #Learning the gamma of skip connection 108 | norm_func = None, 109 | activ_func = None 110 | ): 111 | super().__init__() 112 | self.bias = bias 113 | self.in_channels = in_channels 114 | self.attn_dim = in_channels//atten_channels_downsample_ratio 115 | self.value_dim = in_channels//value_channels_downsample_ratio 116 | 117 | self.Wq = nn.Linear(in_channels, self.attn_dim, bias = bias) 118 | self.Wk = nn.Linear(in_channels, self.attn_dim, bias = bias) 119 | self.Wv = nn.Linear(in_channels, self.value_dim, bias = bias) 120 | self.out_proj = nn.Linear(self.value_dim, in_channels, bias=bias) 121 | 122 | self.max_pool = nn.MaxPool2d((2, 2), stride = 2) 123 | self.learn_gamma = learn_gamma 124 | if learn_gamma: 125 | self.gamma = nn.Parameter(torch.tensor(0., dtype=torch.float32)) 126 | else: 127 | self.gamma = 1.0 128 | 129 | self.norm_func = nn.Identity() 130 | if norm_func is not None: 131 | self.norm_func = norm_func 132 | self.activ_func = nn.Identity() 133 | if activ_func is not None: 134 | self.activ_func = activ_func 135 | 136 | self._reset_parameters() 137 | 138 | def forward(self, x): 139 | """ 140 | x: (N, C, H, W) 141 | out: (N, C, H, W) 142 | """ 143 | N, C, H, W = x.size() 144 | skip = x 145 | 146 | x = x.flatten(2, 3).permute(0, 2, 1) #(N, H*W, C) 147 | 148 | query = self.Wq(x) #(N, H*W, self.attn_dim) 149 | key = self.Wk(x).reshape(N, H, W, self.attn_dim).permute(0, 3, 1, 2) #(N, self.attn_dim, H, W) 150 | #Downsample key length to be H/2*W/2 151 | key = self.max_pool(key).flatten(2, 3) #(N, self.attn_dim, H*W/4) 152 | 153 | attn_score = torch.matmul(query, key) #(N, H*W, H*W/4) 154 | attn_score = F.softmax(attn_score, dim = -1) 155 | 156 | value = self.Wv(x).reshape(N, H, W, self.value_dim).permute(0, 3, 1, 2) #(N, self.value_dim, H, W) 157 | #Downsample value length to be H/2*W/2 158 | value = self.max_pool(value).flatten(2, 3).permute(0, 2, 1) #(N, H*W/4, self.value_dim) 159 | 160 | attn_out = torch.matmul(attn_score, value) #attention operation, (N, H*W, self.value_dim) 161 | 162 | out = self.out_proj(attn_out) #output projection, (N, H*W, self.in_channels) 163 | out = out.reshape(N, H, W, self.in_channels).permute(0, 3, 1, 2) 164 | 165 | out = self.activ_func(self.norm_func(out)) 166 | out = skip + self.gamma * out #Skip connection 167 | 168 | return out 169 | 170 | def _reset_parameters(self): 171 | if self.bias: 172 | nn.init.constant_(self.Wq.bias, 0.) 173 | nn.init.constant_(self.Wk.bias, 0.) 174 | nn.init.constant_(self.Wv.bias, 0.) 175 | nn.init.constant_(self.out_proj.bias, 0.) 176 | 177 | nn.init.xavier_uniform_(self.Wq.weight) 178 | nn.init.xavier_uniform_(self.Wk.weight) 179 | nn.init.xavier_uniform_(self.Wv.weight) 180 | nn.init.xavier_uniform_(self.out_proj.weight) 181 | 182 | class NonLocalAttenion1D(nn.Module): 183 | """ 184 | Based on https://github.com/brain-research/self-attention-gan/blob/master/non_local.py 185 | """ 186 | def __init__(self, 187 | in_channels, #Number of input feature channels 188 | atten_channels_downsample_ratio = 8, #ratio for projecting input channels to key/query feature channels 189 | value_channels_downsample_ratio = 2, #ratio for projecting input channels to value feature channels 190 | bias = True, #use bias for the key/query projection 191 | learn_gamma = True, 192 | norm_func = None, 193 | activ_func = None 194 | ): 195 | super().__init__() 196 | self.bias = bias 197 | self.in_channels = in_channels 198 | self.attn_dim = in_channels//atten_channels_downsample_ratio 199 | self.value_dim = in_channels//value_channels_downsample_ratio 200 | 201 | self.Wq = nn.Linear(in_channels, self.attn_dim, bias = bias) 202 | self.Wk = nn.Linear(in_channels, self.attn_dim, bias = bias) 203 | self.Wv = nn.Linear(in_channels, self.value_dim, bias = bias) 204 | self.out_proj = nn.Linear(self.value_dim, in_channels, bias=bias) 205 | 206 | self.learn_gamma = learn_gamma 207 | if learn_gamma: 208 | self.gamma = nn.Parameter(torch.tensor(0., dtype=torch.float32)) 209 | else: 210 | self.gamma = 1.0 211 | 212 | self.norm_func = nn.Identity() 213 | if norm_func is not None: 214 | self.norm_func = norm_func 215 | self.activ_func = nn.Identity() 216 | if activ_func is not None: 217 | self.activ_func = activ_func 218 | 219 | self._reset_parameters() 220 | 221 | def forward(self, x): 222 | """ 223 | x: (N, C, T) 224 | out: (N, C, T) 225 | """ 226 | x = x.permute(0, 2, 1) 227 | N, T, C = x.size() 228 | 229 | query = self.Wq(x) #(N, T, self.attn_dim) 230 | key = self.Wk(x) #(N, T, self.attn_dim) 231 | 232 | attn_score = torch.matmul(query, key.permute(0, 2, 1)) #(N, T, T) 233 | attn_score = F.softmax(attn_score, dim = -1) 234 | 235 | value = self.Wv(x) #(N, T, self.value_dim) 236 | 237 | attn_out = torch.matmul(attn_score, value) #attention operation, (N, T, self.value_dim) 238 | out = self.out_proj(attn_out) #output projection, (N, T, self.in_channels) 239 | out = out.permute(0, 2, 1) 240 | out = self.activ_func(self.norm_func(out)) 241 | out = x.permute(0, 2, 1) + self.gamma * out #Skip connection 242 | 243 | return out 244 | 245 | def _reset_parameters(self): 246 | if self.bias: 247 | nn.init.constant_(self.Wq.bias, 0.) 248 | nn.init.constant_(self.Wk.bias, 0.) 249 | nn.init.constant_(self.Wv.bias, 0.) 250 | nn.init.constant_(self.out_proj.bias, 0.) 251 | 252 | nn.init.xavier_uniform_(self.Wq.weight) 253 | nn.init.xavier_uniform_(self.Wk.weight) 254 | nn.init.xavier_uniform_(self.Wv.weight) 255 | nn.init.xavier_uniform_(self.out_proj.weight) 256 | 257 | 258 | class NRMLP(nn.Module): 259 | def __init__(self, out_channels, dim_x = 3, d_model = 256, MLP_layers = 4, scale = 10, fix_B = False, fuse_method = 'SPADE'): 260 | """ 261 | Modified based on https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb 262 | The output layer is moved to the "PosFeatFuser" 263 | """ 264 | super().__init__() 265 | self.scale = scale 266 | self.dim_x = dim_x 267 | self.out_channels = out_channels 268 | self.MLP_layers = MLP_layers 269 | self.d_model = d_model 270 | self.fix_B = fix_B 271 | 272 | self.MLP = [] 273 | 274 | self.mapping_fn = self.gaussian_mapping 275 | self.MLP.append(nn.Linear(2*self.d_model, self.d_model)) 276 | if self.fix_B: 277 | self.register_buffer('B', torch.normal(mean = 0, std = 1.0, size = (self.d_model, self.dim_x)) * self.scale) 278 | else: 279 | #Default init for Linear is uniform distribution, would not produce a result as good as gaussian initialization 280 | #self.B = nn.Linear(self.dim_x, self.d_model, bias = False) 281 | self.B = nn.Parameter(torch.normal(mean = 0, std = 1.0, size = (self.d_model, self.dim_x)) * self.scale, requires_grad = True) 282 | 283 | #Init B with normal distribution or constant would produce much different result. 284 | #self.B = nn.Parameter(torch.ones(self.d_model, self.dim_x), requires_grad = True) 285 | 286 | self.MLP.append(nn.ReLU()) 287 | for i in range(self.MLP_layers - 2): 288 | self.MLP.append(nn.Linear(self.d_model, self.d_model)) 289 | self.MLP.append(nn.ReLU()) 290 | 291 | self.MLP = nn.Sequential(*self.MLP) 292 | 293 | self.fuse_method = fuse_method 294 | 295 | self.mlp_beta = nn.Linear(self.d_model, out_channels) 296 | if self.fuse_method == 'SPADE': 297 | self.mlp_gamma = nn.Linear(self.d_model, out_channels) 298 | 299 | def forward(self, x): 300 | """ 301 | Args: 302 | x: (N, d), N denotes the number of elements (coordinates) 303 | Return: 304 | out: (N, out_channels) 305 | """ 306 | x = self.mapping_fn(x) 307 | x = self.MLP(x) 308 | beta = self.mlp_beta(x) 309 | if self.fuse_method == 'SPADE': 310 | gamma = self.mlp_gamma(x) 311 | else: 312 | gamma = torch.zeros_like(beta) 313 | 314 | return beta, gamma 315 | 316 | 317 | def gaussian_mapping(self, x): 318 | """ 319 | Args: 320 | x: (N, d), N denotes the number of elements (coordinates) 321 | B: (m, d) 322 | """ 323 | proj = (2. * float(math.pi) * x) @ self.B.T 324 | #proj = self.B(2. * float(math.pi) * x) 325 | out = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1) 326 | 327 | return out 328 | 329 | class CoorGenerator(nn.Module): 330 | def __init__(self, max_H, max_W, max_T): 331 | """ 332 | Normalize the coordinated to be [0,1] 333 | """ 334 | super().__init__() 335 | self.max_H = max_H 336 | self.max_W = max_W 337 | self.max_T = max_T 338 | 339 | def forward(self, t_list, h_list, w_list): 340 | """ 341 | The h/w/t index starts with 0 342 | Args: 343 | h_list: list of h coordinates, Tensor with shape (H,) 344 | w_list: list of w coordinates, Tensor with shape (W,) 345 | t_list: list of t coordinates, Tensor with shape (T,) 346 | Returns: 347 | coor: Tensor with shape (T*H*W, 3), for the last dim, the coordinate order is (t, h, w) 348 | """ 349 | assert torch.max(h_list) <= self.max_H and torch.min(h_list) >= 0., "Invalid H coordinates" 350 | assert torch.max(w_list) <= self.max_W and torch.min(w_list) >= 0., "Invalid W coordinates" 351 | assert torch.max(t_list) <= self.max_T and torch.min(t_list) >= 0., "Invalid T coordinates" 352 | 353 | norm_h_list = h_list/self.max_H 354 | norm_w_list = w_list/self.max_W 355 | norm_t_list = t_list/self.max_T 356 | 357 | hvv, wvv = torch.meshgrid(norm_h_list, norm_w_list) 358 | s_coor = torch.stack([hvv, wvv], dim=-1) 359 | t_coor = torch.ones_like(hvv)[None, :, :] * norm_t_list[:, None, None] 360 | 361 | s_coor = s_coor.unsqueeze(0).repeat(norm_t_list.shape[0], 1, 1, 1) 362 | coor = torch.cat([t_coor.unsqueeze(-1), s_coor], dim = -1) 363 | 364 | coor = coor.flatten(0, 2) 365 | 366 | return coor 367 | 368 | class EventEncoder(nn.Module): 369 | def __init__(self, in_channels, hidden_channels, n_layers, stochastic): 370 | super().__init__() 371 | self.stochastic = stochastic 372 | self.n_layers = n_layers 373 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels), 374 | nn.BatchNorm2d(in_channels), 375 | nn.ReLU(True)) 376 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels, hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 377 | nn.BatchNorm2d(hidden_channels), 378 | nn.ReLU(True)) 379 | 380 | for i in range(n_layers): 381 | setattr(self, f'MLP_{i}', nn.Sequential(nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1, stride=1, bias=False), 382 | nn.BatchNorm2d(hidden_channels), 383 | nn.ReLU(True))) 384 | self.mu_net = nn.Conv2d(hidden_channels, in_channels, kernel_size=1, stride=1, bias=True) 385 | if self.stochastic: 386 | self.logvar_net = nn.Conv2d(hidden_channels, in_channels, kernel_size=1, stride=1, bias=True) 387 | 388 | def forward(self, x): 389 | """ 390 | x: (N, C, H, W) #the event coding 391 | Return: 392 | mu: (N, C, H, W) 393 | logvar: (N, C, H, W) 394 | z: (N, C, H, W) 395 | """ 396 | x = self.conv2(self.conv1(x)) 397 | for i in range(self.n_layers): 398 | x = getattr(self, f'MLP_{i}')(x) 399 | 400 | mu = self.mu_net(x) 401 | 402 | if self.stochastic: 403 | logvar = self.logvar_net(x) 404 | return self.reparameterize(mu, logvar), mu, logvar 405 | else: 406 | return mu 407 | 408 | def reparameterize(self, mu, logvar): 409 | eps = torch.randn(mu.shape, device = mu.device) 410 | return mu + torch.exp(0.5*logvar) * eps 411 | 412 | class PosFeatFuser(nn.Module): 413 | def __init__(self, x_channels, param_free_norm_type = 'layer'): 414 | """ 415 | Modified from https://github.com/NVlabs/SPADE/blob/master/models/networks/normalization.py 416 | There is no learned parameters in this module 417 | """ 418 | super().__init__() 419 | 420 | if param_free_norm_type == 'instance': 421 | self.param_free_norm = nn.InstanceNorm2d(x_channels, affine=False) 422 | elif param_free_norm_type == 'syncbatch': 423 | self.param_free_norm = nn.SyncBatchNorm(x_channels, affine=False) 424 | elif param_free_norm_type == 'batch': 425 | self.param_free_norm = nn.BatchNorm2d(x_channels, affine=False) 426 | elif param_free_norm_type == 'layer': 427 | self.param_free_norm = nn.GroupNorm(1, x_channels, affine=False) #equivalent to layernorm 428 | else: 429 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 430 | % param_free_norm_type) 431 | 432 | def forward(self, x, pos_beta, pos_gamma): 433 | """ 434 | Args: 435 | x: (N, T, H, W, C) 436 | pos_gamma: (T*H*W, C) 437 | pos_beta: (T*H*W, C) 438 | 439 | Return: 440 | out: (N, T, H, W, C) 441 | """ 442 | 443 | # Part 1. generate parameter-free normalized activations 444 | x = x.permute(0, 1, 4, 2, 3) #(N, T, C, H, W) 445 | N, T, C, H, W = x.shape 446 | normalized = self.param_free_norm(x.flatten(0, 1)).reshape(N, T, C, H, W) 447 | 448 | # apply scale and bias 449 | pos_gamma = pos_gamma.reshape(T, H, W, C).permute(0, 3, 1, 2) 450 | pos_beta = pos_beta.reshape(T, H, W, C).permute(0, 3, 1, 2) 451 | 452 | out = normalized * (1 + pos_gamma) + pos_beta 453 | 454 | return out.permute(0, 1, 3, 4, 2) 455 | 456 | class FutureFrameQueryGenerator(nn.Module): 457 | def __init__(self, T): 458 | """ 459 | T: number of queries to generate 460 | """ 461 | super().__init__() 462 | self.T = T 463 | 464 | def forward(self, evt, pos_beta, pos_gamma, pos_fuser): 465 | """ 466 | Args: 467 | evt: (N, C, H, W) 468 | pos_gamma: (T*H*W, C) 469 | pos_beta: (T*H*W, C) 470 | pos_fuser: instance of PosFeatFuser 471 | Return: 472 | out: (N, T, C, H, W) 473 | """ 474 | out = evt.unsqueeze(1).repeat(1, self.T, 1, 1, 1) 475 | out = pos_fuser(out, pos_beta, pos_gamma) 476 | 477 | return out 478 | 479 | if __name__ == '__main__': 480 | """ 481 | coor_generator = CoorGenerator(8, 8, 10) 482 | h_list, w_list = torch.linspace(0, 7, 8), torch.linspace(0, 7, 8) 483 | t_list = torch.linspace(0, 9, 10) 484 | coor = coor_generator(t_list, h_list, w_list) 485 | 486 | nrmlp = NRMLP(out_channels = 512, fuse_method='SPADE').to('cuda:0') 487 | nrmlp_beta, nrmlp_gamma = nrmlp(coor.to('cuda:0')) 488 | 489 | x = torch.randn(64, 10, 512, 8, 8).to('cuda:0') 490 | fuser = PosFeatFuser(x_channels=512, param_free_norm_type = 'instance').to('cuda:0') 491 | out = fuser(x, nrmlp_beta, nrmlp_gamma) 492 | print(out.shape) 493 | 494 | ent = EventEncoder(512, 256, 1).to('cuda:0') 495 | z, mu, logvar = ent(x) 496 | 497 | query_generator = FutureFrameQueryGenerator(T = 10) 498 | future_queries = query_generator(z, nrmlp_beta, nrmlp_gamma, fuser) 499 | print(future_queries.shape) 500 | """ --------------------------------------------------------------------------------