├── .gitignore ├── README.md ├── assets ├── demo.gif └── logo.png ├── config ├── AE │ ├── bair.yaml │ ├── cityscapes.yaml │ ├── kth.yaml │ ├── smmnist.yaml │ └── ucf.yaml └── DM │ ├── bair.yaml │ ├── cityscapes.yaml │ ├── kth.yaml │ ├── smmnist.yaml │ └── ucf.yaml ├── data ├── BAIR │ ├── 01_bair_download.sh │ └── bair_convert.py ├── KTH │ ├── 01_kth_download.sh │ ├── 02_kth_train_val_test_split.py │ ├── 03_kth_convert.py │ ├── kth_actions_frames.py │ ├── sequences.txt │ ├── train.txt │ └── valid.txt ├── SMMNIST │ ├── 01_mnist_download_and_convert.py │ ├── stochastic_moving_mnist.py │ └── stochastic_moving_mnist_edited.py ├── augmentation.py ├── base.py ├── h5.py ├── two_frames_dataset.py └── video_dataset.py ├── extdm.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── metrics ├── calculate_fvd.py ├── calculate_lpips.py ├── calculate_psnr.py ├── calculate_ssim.py ├── demo.py ├── fvd.py ├── i3d_torchscript.pt └── pytorch_i3d.py ├── model ├── BaseDM_adaptor │ ├── DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada.py │ ├── DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22.py │ ├── DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u12.py │ ├── DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u22.py │ ├── DenoiseNet_STWAtt_w_wo_ref_adaptor_cross_multi.py │ ├── Diffusion.py │ ├── VideoFlowDiffusion_multi.py │ ├── VideoFlowDiffusion_multi1248.py │ ├── VideoFlowDiffusion_multi_w_ref.py │ ├── VideoFlowDiffusion_multi_w_ref_u22.py │ └── text.py └── LFAE │ ├── __init__.py │ ├── bg_motion_predictor.py │ ├── flow_autoenc.py │ ├── generator.py │ ├── model.py │ ├── pixelwise_flow_predictor.py │ ├── region_predictor.py │ ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py │ └── util.py ├── scripts ├── AE │ ├── run.py │ ├── train.py │ ├── train_AE_bair.sh │ ├── train_AE_cityscapes.sh │ ├── train_AE_kth.sh │ ├── train_AE_smmnist.sh │ ├── train_AE_ucf.sh │ ├── valid.py │ ├── valid_AE_bair.sh │ ├── valid_AE_cityscapes.sh │ ├── valid_AE_kth.sh │ ├── valid_AE_smmnist.sh │ └── valid_AE_ucf.sh └── DM │ ├── run.py │ ├── train.py │ ├── train_DM_bair.sh │ ├── train_DM_cityscapes.sh │ ├── train_DM_kth.sh │ ├── train_DM_smmnist.sh │ ├── train_DM_ucf.sh │ ├── valid.py │ ├── valid_DM_bair.sh │ ├── valid_DM_cityscapes.sh │ ├── valid_DM_kth.sh │ ├── valid_DM_smmnist.sh │ └── valid_with_generate_flow_and_conf.py ├── setup.py ├── utils ├── logger.py ├── lr_scheduler.py ├── meter.py ├── misc.py ├── seed.py └── visualize.py └── vis ├── save_new.py ├── save_visualization_and_metrics_one_by_one.py ├── save_visualization_and_metrics_one_by_one_LFDM.py ├── test_flowae_run_groundtruth.py ├── test_flowae_run_groundtruth_flow_conf.py ├── test_flowae_run_our_result.py ├── test_flowae_run_video2video.py └── vis copy.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs_training/AE/* 2 | logs_training/DM/* 3 | logs_validation/AE/* 4 | logs_validation/DM/* 5 | logs_validation/pretrained_DM/* 6 | **/__pycache__/** 7 | wandb 8 | slurm-*.out 9 | flow_output 10 | .__dpc* 11 | video2video/* 12 | metrics_carla/ 13 | output_tensor/* 14 | output_model/* -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nku-zhichengzhang/ExtDM/daaae01e926b9b021c81676cbd4a17555b722aa9/assets/demo.gif -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nku-zhichengzhang/ExtDM/daaae01e926b9b021c81676cbd4a17555b722aa9/assets/logo.png -------------------------------------------------------------------------------- /config/AE/bair.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: bair64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/bair_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: False 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 10 23 | max_frame_distance: 30 24 | train_params: 25 | type: train 26 | cond_frames: 2 27 | pred_frames: 2 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 2 32 | pred_frames: 28 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 10 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 100 70 | num_repeats: 100 71 | scheduler_param: 72 | milestones: [50000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 100 76 | valid_batch_size: 256 77 | dataloader_workers: 6 78 | print_freq: 100 79 | save_img_freq: 1000 80 | update_ckpt_freq: 1000 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | visualizer_params: 93 | kp_size: 2 94 | draw_border: True 95 | colormap: 'gist_rainbow' 96 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/AE/cityscapes.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: cityscapes128 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/cityscapes_h5 12 | frame_shape: 128 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 30 24 | train_params: 25 | type: train 26 | cond_frames: 2 27 | pred_frames: 5 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 2 32 | pred_frames: 28 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 20 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'perspective' # 'zero', 'shift', 'affine', 'perspective' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.25 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.25 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 432 70 | num_repeats: 20 71 | scheduler_param: 72 | milestones: [20000, 40000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 128 76 | valid_batch_size: 256 77 | dataloader_workers: 8 78 | print_freq: 500 79 | save_img_freq: 1000 80 | update_ckpt_freq: 1000 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25, 0.125] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | visualizer_params: 93 | kp_size: 2 94 | draw_border: True 95 | colormap: 'gist_rainbow' 96 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/AE/kth.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: kth64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/kth_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 20 24 | train_params: 25 | type: train 26 | cond_frames: 10 27 | pred_frames: 5 28 | valid_params: 29 | total_videos: 256 30 | type: valid 31 | cond_frames: 10 32 | pred_frames: 40 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 20 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 1072 70 | num_repeats: 25 71 | scheduler_param: 72 | milestones: [50000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 256 76 | valid_batch_size: 256 77 | dataloader_workers: 8 78 | print_freq: 500 79 | save_img_freq: 1000 80 | update_ckpt_freq: 1000 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | visualizer_params: 93 | kp_size: 2 94 | draw_border: True 95 | colormap: 'gist_rainbow' 96 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/AE/smmnist.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: smmnist64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/smmnist_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: False 16 | time_flip: True 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 10 24 | train_params: 25 | type: train 26 | cond_frames: 10 27 | pred_frames: 5 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 10 32 | pred_frames: 10 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 10 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 100 70 | num_repeats: 10 71 | scheduler_param: 72 | milestones: [50000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 100 76 | valid_batch_size: 256 77 | dataloader_workers: 6 78 | print_freq: 100 79 | save_img_freq: 1000 80 | update_ckpt_freq: 1000 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | visualizer_params: 93 | kp_size: 2 94 | draw_border: True 95 | colormap: 'gist_rainbow' 96 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/AE/ucf.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: ucf101_64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/UCF101_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 20 24 | train_params: 25 | type: train 26 | cond_frames: 4 27 | pred_frames: 8 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 4 32 | pred_frames: 16 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 64 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 21 # 21 for 100, 54 for bs256 70 | num_repeats: 100 71 | scheduler_param: 72 | milestones: [100000, 150000] 73 | gamma: 0.8 74 | lr: 2.0e-4 75 | batch_size: 100 # 100, 256 76 | valid_batch_size: 256 77 | dataloader_workers: 8 78 | print_freq: 100 79 | save_img_freq: 1000 80 | update_ckpt_freq: 1000 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | visualizer_params: 93 | kp_size: 2 94 | draw_border: True 95 | colormap: 'gist_rainbow' 96 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/DM/bair.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: bair64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/bair_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: False 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 30 24 | train_params: 25 | type: train 26 | cond_frames: 2 27 | pred_frames: 10 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 2 32 | pred_frames: 28 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 10 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 100 70 | num_repeats: 100 71 | scheduler_param: 72 | milestones: [80000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 100 76 | valid_batch_size: 256 77 | dataloader_workers: 4 78 | print_freq: 100 79 | save_img_freq: 2500 80 | update_ckpt_freq: 2500 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | diffusion_params: 93 | model_params: 94 | null_cond_prob: 0.0 95 | use_residual_flow: False 96 | only_use_flow: False 97 | sampling_timesteps: 10 98 | loss_type: 'l2' 99 | ada_layers: 'auto' 100 | train_params: 101 | max_epochs: 148 102 | num_repeats: 2 103 | scheduler_param: 104 | milestones: [70000, 150000] 105 | gamma: 0.8 106 | lr: 1.5e-4 107 | batch_size: 48 108 | valid_batch_size: 48 109 | dataloader_workers: 32 110 | print_freq: 100 111 | save_img_freq: 1000 112 | save_vid_freq: 1000 113 | update_ckpt_freq: 1000 114 | save_ckpt_freq: 1000 115 | 116 | visualizer_params: 117 | kp_size: 2 118 | draw_border: True 119 | colormap: 'gist_rainbow' 120 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/DM/cityscapes.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: cityscapes128 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/cityscapes_h5 12 | frame_shape: 128 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 30 24 | train_params: 25 | type: train 26 | cond_frames: 2 27 | pred_frames: 5 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 2 32 | pred_frames: 28 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 20 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'perspective' # 'zero', 'shift', 'affine', 'perspective' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.25 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.25 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 432 70 | num_repeats: 20 71 | scheduler_param: 72 | milestones: [20000, 40000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 128 76 | valid_batch_size: 256 77 | dataloader_workers: 8 78 | print_freq: 500 79 | save_img_freq: 2500 80 | update_ckpt_freq: 2500 81 | save_ckpt_freq: 1000 82 | scales: [1, 0.5, 0.25, 0.125] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | diffusion_params: 93 | model_params: 94 | null_cond_prob: 0.0 95 | use_residual_flow: False 96 | only_use_flow: False 97 | sampling_timesteps: 10 98 | loss_type: 'l2' 99 | ada_layers: 'auto' 100 | train_params: 101 | max_epochs: 100000 102 | num_repeats: 1 103 | scheduler_param: 104 | milestones: [100000, 150000] 105 | gamma: 0.75 106 | lr: 1.2e-4 107 | batch_size: 32 108 | valid_batch_size: 8 109 | dataloader_workers: 32 110 | print_freq: 100 111 | save_img_freq: 1000 112 | save_vid_freq: 1000 113 | update_ckpt_freq: 1000 114 | save_ckpt_freq: 1000 115 | 116 | visualizer_params: 117 | kp_size: 2 118 | draw_border: True 119 | colormap: 'gist_rainbow' 120 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/DM/kth.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: kth64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/kth_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 50 24 | train_params: 25 | type: train 26 | cond_frames: 10 27 | pred_frames: 20 28 | valid_params: 29 | total_videos: 256 30 | type: valid 31 | cond_frames: 10 32 | pred_frames: 40 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 20 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.50 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.50 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 1070 70 | num_repeats: 100 71 | scheduler_param: 72 | milestones: [150000] 73 | gamma: 0.5 74 | lr: 2.0e-4 75 | batch_size: 256 76 | valid_batch_size: 256 77 | dataloader_workers: 32 78 | print_freq: 500 79 | save_img_freq: 2500 80 | update_ckpt_freq: 2500 81 | save_ckpt_freq: 5000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | 92 | diffusion_params: 93 | model_params: 94 | null_cond_prob: 0.0 95 | use_residual_flow: False 96 | only_use_flow: False 97 | sampling_timesteps: 10 98 | loss_type: 'l2' 99 | ada_layers: 'auto' 100 | train_params: 101 | max_epochs: 26000 # 13368 102 | num_repeats: 1 103 | scheduler_param: 104 | milestones: [80000, 150000] 105 | gamma: 0.75 106 | lr: 2.0e-4 107 | batch_size: 36 108 | valid_batch_size: 16 109 | dataloader_workers: 32 110 | print_freq: 100 111 | save_img_freq: 1000 112 | save_vid_freq: 1000 113 | update_ckpt_freq: 1000 114 | save_ckpt_freq: 1000 115 | 116 | visualizer_params: 117 | kp_size: 2 118 | draw_border: True 119 | colormap: 'gist_rainbow' 120 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/DM/smmnist.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: smmnist64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/smmnist_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: False 16 | time_flip: True 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 10 24 | train_params: 25 | type: train 26 | cond_frames: 10 27 | pred_frames: 5 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 10 32 | pred_frames: 10 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 10 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 100 70 | num_repeats: 10 71 | epoch_milestones: [60, 90] 72 | lr: 2.0e-4 73 | batch_size: 100 74 | valid_batch_size: 256 75 | dataloader_workers: 6 76 | print_freq: 100 77 | save_img_freq: 1000 78 | update_ckpt_freq: 1000 79 | save_ckpt_freq: 5000 80 | scales: [1, 0.5, 0.25] 81 | transform_params: 82 | sigma_affine: 0.05 83 | sigma_tps: 0.005 84 | points_tps: 5 85 | loss_weights: 86 | perceptual: [10, 10, 10, 10, 10] 87 | equivariance_shift: 10 88 | equivariance_affine: 10 89 | diffusion_params: 90 | model_params: 91 | null_cond_prob: 0.0 92 | use_residual_flow: False 93 | only_use_flow: False 94 | sampling_timesteps: 10 95 | loss_type: 'l2' 96 | ada_layers: 'auto' 97 | train_params: 98 | max_epochs: 134 99 | num_repeats: 1 100 | scheduler_param: 101 | milestones: [100000, 150000] 102 | gamma: 0.8 103 | lr: 2.0e-4 104 | batch_size: 40 105 | valid_batch_size: 16 106 | dataloader_workers: 32 107 | print_freq: 100 108 | save_img_freq: 1000 109 | save_vid_freq: 1000 110 | update_ckpt_freq: 1000 111 | save_ckpt_freq: 1000 112 | 113 | visualizer_params: 114 | kp_size: 2 115 | draw_border: True 116 | colormap: 'gist_rainbow' 117 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /config/DM/ucf.yaml: -------------------------------------------------------------------------------- 1 | #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 2 | #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 3 | #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 4 | #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 5 | #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 6 | #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 7 | 8 | experiment_name: ucf101_64 9 | 10 | dataset_params: 11 | root_dir: /home/ubuntu/zzc/data/video_prediction/dataset_h5/UCF101_h5 12 | frame_shape: 64 13 | augmentation_params: 14 | flip_param: 15 | horizontal_flip: True 16 | time_flip: False 17 | jitter_param: 18 | brightness: 0.1 19 | contrast: 0.1 20 | saturation: 0.1 21 | hue: 0.1 22 | min_frame_distance: 0 23 | max_frame_distance: 20 24 | train_params: 25 | type: train 26 | cond_frames: 4 27 | pred_frames: 8 28 | valid_params: 29 | total_videos: 256 30 | type: test 31 | cond_frames: 4 32 | pred_frames: 16 33 | 34 | flow_params: 35 | model_params: 36 | num_regions: 64 37 | num_channels: 3 38 | estimate_affine: True 39 | revert_axis_swap: True 40 | bg_predictor_params: 41 | block_expansion: 32 42 | max_features: 1024 43 | num_blocks: 5 44 | bg_type: 'affine' 45 | region_predictor_params: 46 | temperature: 0.1 47 | block_expansion: 32 48 | max_features: 1024 49 | scale_factor: 0.5 50 | num_blocks: 5 51 | pca_based: True 52 | pad: 0 53 | fast_svd: False 54 | generator_params: 55 | block_expansion: 64 56 | max_features: 512 57 | num_down_blocks: 2 58 | num_bottleneck_blocks: 6 59 | skips: True 60 | pixelwise_flow_predictor_params: 61 | block_expansion: 64 62 | max_features: 1024 63 | num_blocks: 5 64 | scale_factor: 0.5 65 | use_deformed_source: True 66 | use_covar_heatmap: True 67 | estimate_occlusion_map: True 68 | train_params: 69 | max_epochs: 21 # 21 for 100, 54 for bs256 70 | num_repeats: 100 71 | scheduler_param: 72 | milestones: [100000, 150000] 73 | gamma: 0.8 74 | lr: 2.0e-4 75 | batch_size: 100 # 100, 256 76 | valid_batch_size: 256 77 | dataloader_workers: 8 78 | print_freq: 100 79 | save_img_freq: 5000 80 | update_ckpt_freq: 5000 81 | save_ckpt_freq: 5000 82 | scales: [1, 0.5, 0.25] 83 | transform_params: 84 | sigma_affine: 0.05 85 | sigma_tps: 0.005 86 | points_tps: 5 87 | loss_weights: 88 | perceptual: [10, 10, 10, 10, 10] 89 | equivariance_shift: 10 90 | equivariance_affine: 10 91 | diffusion_params: 92 | model_params: 93 | null_cond_prob: 0.0 94 | use_residual_flow: False 95 | only_use_flow: False 96 | sampling_timesteps: 10 97 | loss_type: 'l2' 98 | ada_layers: 'auto' 99 | train_params: 100 | max_epochs: 21 101 | num_repeats: 32 102 | scheduler_param: 103 | milestones: [100000, 150000] 104 | gamma: 0.8 105 | lr: 2.0e-4 106 | batch_size: 32 107 | valid_batch_size: 64 108 | dataloader_workers: 8 109 | print_freq: 100 110 | save_img_freq: 1000 111 | save_vid_freq: 1000 112 | update_ckpt_freq: 1000 113 | save_ckpt_freq: 1000 114 | 115 | visualizer_params: 116 | kp_size: 2 117 | draw_border: True 118 | colormap: 'gist_rainbow' 119 | region_bg_color: [1, 1, 1] -------------------------------------------------------------------------------- /data/BAIR/01_bair_download.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR=$1 2 | if [ -z $TARGET_DIR ] 3 | then 4 | echo "Must specify target directory" 5 | else 6 | # mkdir $TARGET_DIR/ 7 | # URL=http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar 8 | # wget $URL -P $TARGET_DIR 9 | tar -xvf $TARGET_DIR/bair_robot_pushing_dataset_v0.tar -C $TARGET_DIR 10 | fi 11 | 12 | # Example: 13 | # cd /home/ubuntu/zzc/code/videoprediction/edm-neurips23/data/BAIR 14 | # bash 01_bair_download.sh /mnt/hdd/zzc/data/video_prediction/BAIR -------------------------------------------------------------------------------- /data/BAIR/bair_convert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/convert_bair.py 2 | import argparse 3 | import glob 4 | import imageio 5 | import io 6 | import numpy as np 7 | import os 8 | import sys 9 | import tensorflow as tf 10 | 11 | from PIL import Image 12 | from tensorflow.python.platform import gfile 13 | from tqdm import tqdm 14 | 15 | from h5 import HDF5Maker 16 | 17 | 18 | def get_seq(data_dir, dname): 19 | data_dir = '%s/softmotion30_44k/%s' % (data_dir, dname) 20 | 21 | filenames = gfile.Glob(os.path.join(data_dir, '*')) 22 | if not filenames: 23 | raise RuntimeError('No data files found.') 24 | 25 | for f in filenames: 26 | k = 0 27 | # tf.enable_eager_execution() 28 | for serialized_example in tf.python_io.tf_record_iterator(f): 29 | example = tf.train.Example() 30 | example.ParseFromString(serialized_example) 31 | image_seq = [] 32 | for i in range(30): 33 | image_name = str(i) + '/image_aux1/encoded' 34 | byte_str = example.features.feature[image_name].bytes_list.value[0] 35 | # image_seq.append(byte_str) 36 | img = Image.frombytes('RGB', (64, 64), byte_str) 37 | arr = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3) 38 | image_seq.append(arr) 39 | # image_seq = np.concatenate(image_seq, axis=0) 40 | k = k + 1 41 | yield f, k, image_seq 42 | 43 | 44 | def make_h5_from_bair(bair_dir, split='train', out_dir='./h5_ds', vids_per_shard=100000, force_h5=False): 45 | 46 | # H5 maker 47 | h5_maker = HDF5Maker(out_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 48 | 49 | seq_generator = get_seq(bair_dir, split) 50 | 51 | filenames = gfile.Glob(os.path.join('%s/softmotion30_44k/%s' % (bair_dir, split), '*')) 52 | for file in tqdm(filenames): 53 | 54 | # num = sum(1 for _ in tf.python_io.tf_record_iterator(file)) 55 | num = 256 56 | for i in tqdm(range(num)): 57 | 58 | try: 59 | f, k, seq = next(seq_generator) 60 | h5_maker.add_data(seq, dtype=None) 61 | # h5_maker.add_data(seq, dtype='uint8') 62 | 63 | except StopIteration: 64 | break 65 | 66 | except (KeyboardInterrupt, SystemExit): 67 | print("Ctrl+C!!") 68 | break 69 | 70 | except: 71 | e = sys.exc_info()[0] 72 | print("ERROR:", e) 73 | 74 | h5_maker.close() 75 | 76 | 77 | if __name__ == "__main__": 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 81 | parser.add_argument('--bair_dir', type=str, help="Directory with videos") 82 | parser.add_argument('--vids_per_shard', type=int, default=100000) 83 | parser.add_argument('--force_h5', type=eval, default=False) 84 | 85 | args = parser.parse_args() 86 | 87 | make_h5_from_bair(out_dir=os.path.join(args.out_dir, 'train'), bair_dir=args.bair_dir, split='train', vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 88 | make_h5_from_bair(out_dir=os.path.join(args.out_dir, 'test'), bair_dir=args.bair_dir, split='test', vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 89 | -------------------------------------------------------------------------------- /data/KTH/01_kth_download.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR=$1 2 | if [ -z $TARGET_DIR ] 3 | then 4 | echo "Must specify target directory" 5 | else 6 | mkdir $TARGET_DIR/raw 7 | for c in walking jogging running handwaving handclapping boxing 8 | do 9 | URL=http://www.csc.kth.se/cvap/actions/"$c".zip 10 | wget $URL -P $TARGET_DIR/raw 11 | mkdir $TARGET_DIR/raw/$c 12 | unzip $TARGET_DIR/raw/"$c".zip -d $TARGET_DIR/raw/$c 13 | rm $TARGET_DIR/raw/"$c".zip 14 | done 15 | 16 | fi -------------------------------------------------------------------------------- /data/KTH/02_kth_train_val_test_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data.KTH.kth_actions_frames import kth_actions_dict, settings, actions, person_ids, MCVD_person_ids 3 | 4 | def convert_with_official_split(): 5 | """ 6 | len train 760 7 | len valid 768 8 | len test 863 9 | min_length train is 30 10 | min_length valid is 26 11 | min_length test is 24 12 | """ 13 | for data_split in ['train', 'valid', 'test']: 14 | print('Converting ' + data_split) 15 | 16 | with open(f"{data_split}-official.txt", 'w') as f: 17 | min_length = 1e6 18 | split_person_ids = person_ids[data_split] 19 | for person_id in split_person_ids: 20 | # print(' Converting person' + person_id) 21 | for action in kth_actions_dict['person'+person_id]: 22 | for setting in kth_actions_dict['person'+person_id][action]: 23 | for frame_idxs in kth_actions_dict['person'+person_id][action][setting]: 24 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 25 | file_path = os.path.join(action, file_name) 26 | # index of kth_actions_frames.py starts from 1 but we need 0 27 | # and wo should fix to [a,b) not [a,b] 28 | # eg: 1-123, 124-345, length is 345 29 | # -> 0-122, 123-344, length is 345 30 | # -> 0-123, 123-345, length is 345 same 31 | start_frame_idxs = frame_idxs[0] - 1 32 | end_frames_idxs = frame_idxs[1] 33 | 34 | min_length = min(min_length, end_frames_idxs - start_frame_idxs) 35 | 36 | f.write(f"{file_path} {start_frame_idxs} {end_frames_idxs}\n") 37 | print('min_length', data_split, 'is', min_length) 38 | print('Converting', data_split, 'done.') 39 | 40 | def convert_with_all_frames(): 41 | """ 42 | len train 191 43 | len valid 192 44 | len test 216 45 | min_length train is 250 46 | min_length valid is 204 47 | min_length test is 256 48 | """ 49 | for data_split in ['train', 'valid', 'test']: 50 | cnt = 0 51 | print('Converting ' + data_split) 52 | 53 | with open(f"{data_split}.txt", 'w') as f: 54 | min_length = 1e6 55 | split_person_ids = person_ids[data_split] 56 | for person_id in split_person_ids: 57 | # print(' Converting person' + person_id) 58 | for action in kth_actions_dict['person'+person_id]: 59 | for setting in kth_actions_dict['person'+person_id][action]: 60 | a_list = sorted(kth_actions_dict['person'+person_id][action][setting]) 61 | # index of kth_actions_frames.py starts from 1 but we need 0 62 | # and wo should fix to [a,b) not [a,b] 63 | # eg: 1-12, ... 124-345, length is 345 64 | # -> 0-11, ... 123-344, length is 345 65 | # -> 0-345, length is 345 same 66 | start_frame_idxs = a_list[0][0] - 1 67 | end_frames_idxs = a_list[-1][1] 68 | 69 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 70 | file_path = os.path.join(action, file_name) 71 | min_length = min(min_length, end_frames_idxs - start_frame_idxs) 72 | 73 | f.write(f"{file_path}\n") 74 | cnt += 1 75 | print('num ', data_split, 'is', cnt) 76 | print('min_length', data_split, 'is', min_length) 77 | print('Converting', data_split, 'done.') 78 | print("") 79 | 80 | def convert_MCVD_setting(): 81 | """ 82 | num train is 479 83 | num valid is 120 84 | min_length train is 230 85 | min_length valid is 204 86 | """ 87 | 88 | for data_split in ['train', 'valid']: 89 | cnt = 0 90 | print('Converting ' + data_split) 91 | 92 | with open(f"{data_split}.txt", 'w') as f: 93 | min_length = 1e6 94 | split_person_ids = MCVD_person_ids[data_split] 95 | for person_id in split_person_ids: 96 | # print(' Converting person' + person_id) 97 | for action in kth_actions_dict['person'+person_id]: 98 | for setting in kth_actions_dict['person'+person_id][action]: 99 | a_list = sorted(kth_actions_dict['person'+person_id][action][setting]) 100 | start_frame_idxs = a_list[0][0] - 1 101 | end_frames_idxs = a_list[-1][1] 102 | 103 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 104 | file_path = os.path.join(action, file_name) 105 | min_length = min(min_length, end_frames_idxs - start_frame_idxs) 106 | 107 | f.write(f"{file_path}\n") 108 | cnt += 1 109 | print('num ', data_split, 'is', cnt) 110 | print('min_length', data_split, 'is', min_length) 111 | print('Converting', data_split, 'done.') 112 | print("") 113 | 114 | def convert_with_official_train_and_mcvd_valid_split(): 115 | """ 116 | Converting train 117 | min_length train is 24 118 | Converting train done. 119 | Converting valid 120 | num valid is 120 121 | min_length valid is 204 122 | Converting valid done. 123 | """ 124 | data_split = 'train' 125 | cnt = 0 126 | print('Converting ' + data_split) 127 | with open(f"{data_split}.txt", 'w') as f: 128 | min_length = 1e6 129 | split_person_ids = MCVD_person_ids[data_split] 130 | for person_id in split_person_ids: 131 | # print(' Converting person' + person_id) 132 | for action in kth_actions_dict['person'+person_id]: 133 | for setting in kth_actions_dict['person'+person_id][action]: 134 | for frame_idxs in kth_actions_dict['person'+person_id][action][setting]: 135 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 136 | file_path = os.path.join(action, file_name) 137 | # index of kth_actions_frames.py starts from 1 but we need 0 138 | # and wo should fix to [a,b) not [a,b] 139 | # eg: 1-123, 124-345, length is 345 140 | # -> 0-122, 123-344, length is 345 141 | # -> 0-123, 123-345, length is 345 same 142 | start_frame_idxs = frame_idxs[0] - 1 143 | end_frames_idxs = frame_idxs[1] 144 | 145 | min_length = min(min_length, end_frames_idxs - start_frame_idxs) 146 | 147 | f.write(f"{file_path} {start_frame_idxs} {end_frames_idxs}\n") 148 | cnt += 1 149 | print('num ', data_split, 'is', cnt) 150 | print('min_length', data_split, 'is', min_length) 151 | print('Converting', data_split, 'done.') 152 | 153 | data_split = 'valid' 154 | cnt = 0 155 | print('Converting ' + data_split) 156 | 157 | with open(f"{data_split}.txt", 'w') as f: 158 | min_length = 1e6 159 | split_person_ids = MCVD_person_ids[data_split] 160 | for person_id in split_person_ids: 161 | # print(' Converting person' + person_id) 162 | for action in kth_actions_dict['person'+person_id]: 163 | for setting in kth_actions_dict['person'+person_id][action]: 164 | a_list = sorted(kth_actions_dict['person'+person_id][action][setting]) 165 | start_frame_idxs = a_list[0][0] - 1 166 | end_frames_idxs = a_list[-1][1] 167 | 168 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 169 | file_path = os.path.join(action, file_name) 170 | min_length = min(min_length, end_frames_idxs - start_frame_idxs) 171 | 172 | f.write(f"{file_path}\n") 173 | cnt += 1 174 | print('num ', data_split, 'is', cnt) 175 | print('min_length', data_split, 'is', min_length) 176 | print('Converting', data_split, 'done.') 177 | print("") 178 | 179 | convert_MCVD_setting() 180 | # convert_with_official_split() 181 | # convert_with_official_train_and_mcvd_valid_split() -------------------------------------------------------------------------------- /data/KTH/03_kth_convert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/convert_bair.py 2 | import argparse 3 | import cv2 4 | import glob 5 | import numpy as np 6 | import os 7 | import pickle 8 | import sys 9 | from tqdm import tqdm 10 | 11 | sys.path.append("..") 12 | from h5 import HDF5Maker 13 | 14 | class KTH_HDF5Maker(HDF5Maker): 15 | 16 | def add_video_info(self): 17 | pass 18 | 19 | def create_video_groups(self): 20 | self.writer.create_group('len') 21 | self.writer.create_group('videos') 22 | 23 | def add_video_data(self, data, dtype=None): 24 | frames = data 25 | self.writer['len'].create_dataset(str(self.count), data=len(frames)) 26 | self.writer.create_group(str(self.count)) 27 | for i, frame in enumerate(frames): 28 | self.writer[str(self.count)].create_dataset(str(i), data=frame, dtype=dtype, compression="lzf") 29 | 30 | def read_video(video_path, image_size): 31 | # opencv is faster than mediapy 32 | cap = cv2.VideoCapture(video_path) 33 | frames = [] 34 | while True: 35 | ret, frame = cap.read() 36 | if not ret: 37 | break 38 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 39 | image = cv2.resize(gray, (image_size, image_size)) 40 | frames.append(image) 41 | cap.release() 42 | if frames == []: 43 | print(f"the file {video_path} may has been damaged, you should") 44 | print("1. clean old video file and old hdf5 file") 45 | print("2. download new video file") 46 | print("3. generate hdf5 file again") 47 | ValueError() 48 | return frames 49 | 50 | # def show_video(frames): 51 | # import matplotlib.pyplot as plt 52 | # from matplotlib.animation import FuncAnimation 53 | # im1 = plt.imshow(frames[0]) 54 | # def update(frame): 55 | # im1.set_data(frame) 56 | # ani = FuncAnimation(plt.gcf(), update, frames=frames, interval=10, repeat=False) 57 | # plt.show() 58 | 59 | def make_h5_from_kth(kth_dir, split_dir, image_size=64, out_dir='./h5_ds', vids_per_shard=1000000, force_h5=False): 60 | 61 | # classes = ['train', 'valid', 'test'] 62 | classes = ['train', 'valid'] 63 | 64 | for type in classes: 65 | print(f"process {type}") 66 | dataste_dir = out_dir + '/' + type 67 | h5_maker = KTH_HDF5Maker(dataste_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 68 | count = 0 69 | # try: 70 | with open(f"{split_dir}/{type}.txt", "r") as f: 71 | lines = f.read().splitlines() 72 | 73 | isSplit = len(lines[0].split(' ')) > 1 74 | 75 | for line in tqdm(lines): 76 | if isSplit: 77 | path, start, end = line.split(' ') 78 | path = os.path.join(kth_dir, path) 79 | frames = read_video(path, image_size) 80 | frames = frames[int(start):int(end)] 81 | else: 82 | path = os.path.join(kth_dir, line) 83 | frames = read_video(path, image_size) 84 | h5_maker.add_data(frames, dtype='uint8') 85 | count += 1 86 | # except StopIteration: 87 | # break 88 | # except (KeyboardInterrupt, SystemExit): 89 | # print("Ctrl+C!!") 90 | # break 91 | # except: 92 | # e = sys.exc_info()[0] 93 | # print("ERROR:", e) 94 | 95 | h5_maker.close() 96 | print(f"process {type} done") 97 | 98 | 99 | if __name__ == "__main__": 100 | 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 103 | parser.add_argument('--split_dir', type=str, help="Directory to split dataset files") 104 | parser.add_argument('--kth_dir', type=str, help="Directory with KTH") 105 | parser.add_argument('--image_size', type=int, default=64) 106 | parser.add_argument('--vids_per_shard', type=int, default=1000000) 107 | parser.add_argument('--force_h5', type=eval, default=False) 108 | 109 | args = parser.parse_args() 110 | 111 | make_h5_from_kth(out_dir=args.out_dir, kth_dir=args.kth_dir, split_dir=args.split_dir, image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) 112 | 113 | # Example: 114 | 115 | -------------------------------------------------------------------------------- /data/KTH/valid.txt: -------------------------------------------------------------------------------- 1 | boxing/person21_boxing_d1_uncomp.avi 2 | boxing/person21_boxing_d2_uncomp.avi 3 | boxing/person21_boxing_d3_uncomp.avi 4 | boxing/person21_boxing_d4_uncomp.avi 5 | handclapping/person21_handclapping_d1_uncomp.avi 6 | handclapping/person21_handclapping_d2_uncomp.avi 7 | handclapping/person21_handclapping_d3_uncomp.avi 8 | handclapping/person21_handclapping_d4_uncomp.avi 9 | handwaving/person21_handwaving_d1_uncomp.avi 10 | handwaving/person21_handwaving_d2_uncomp.avi 11 | handwaving/person21_handwaving_d3_uncomp.avi 12 | handwaving/person21_handwaving_d4_uncomp.avi 13 | jogging/person21_jogging_d1_uncomp.avi 14 | jogging/person21_jogging_d2_uncomp.avi 15 | jogging/person21_jogging_d3_uncomp.avi 16 | jogging/person21_jogging_d4_uncomp.avi 17 | running/person21_running_d1_uncomp.avi 18 | running/person21_running_d2_uncomp.avi 19 | running/person21_running_d3_uncomp.avi 20 | running/person21_running_d4_uncomp.avi 21 | walking/person21_walking_d1_uncomp.avi 22 | walking/person21_walking_d2_uncomp.avi 23 | walking/person21_walking_d3_uncomp.avi 24 | walking/person21_walking_d4_uncomp.avi 25 | boxing/person22_boxing_d1_uncomp.avi 26 | boxing/person22_boxing_d2_uncomp.avi 27 | boxing/person22_boxing_d3_uncomp.avi 28 | boxing/person22_boxing_d4_uncomp.avi 29 | handclapping/person22_handclapping_d1_uncomp.avi 30 | handclapping/person22_handclapping_d2_uncomp.avi 31 | handclapping/person22_handclapping_d3_uncomp.avi 32 | handclapping/person22_handclapping_d4_uncomp.avi 33 | handwaving/person22_handwaving_d1_uncomp.avi 34 | handwaving/person22_handwaving_d2_uncomp.avi 35 | handwaving/person22_handwaving_d3_uncomp.avi 36 | handwaving/person22_handwaving_d4_uncomp.avi 37 | jogging/person22_jogging_d1_uncomp.avi 38 | jogging/person22_jogging_d2_uncomp.avi 39 | jogging/person22_jogging_d3_uncomp.avi 40 | jogging/person22_jogging_d4_uncomp.avi 41 | running/person22_running_d1_uncomp.avi 42 | running/person22_running_d2_uncomp.avi 43 | running/person22_running_d3_uncomp.avi 44 | running/person22_running_d4_uncomp.avi 45 | walking/person22_walking_d1_uncomp.avi 46 | walking/person22_walking_d2_uncomp.avi 47 | walking/person22_walking_d3_uncomp.avi 48 | walking/person22_walking_d4_uncomp.avi 49 | boxing/person23_boxing_d1_uncomp.avi 50 | boxing/person23_boxing_d2_uncomp.avi 51 | boxing/person23_boxing_d3_uncomp.avi 52 | boxing/person23_boxing_d4_uncomp.avi 53 | handclapping/person23_handclapping_d1_uncomp.avi 54 | handclapping/person23_handclapping_d2_uncomp.avi 55 | handclapping/person23_handclapping_d3_uncomp.avi 56 | handclapping/person23_handclapping_d4_uncomp.avi 57 | handwaving/person23_handwaving_d1_uncomp.avi 58 | handwaving/person23_handwaving_d2_uncomp.avi 59 | handwaving/person23_handwaving_d3_uncomp.avi 60 | handwaving/person23_handwaving_d4_uncomp.avi 61 | jogging/person23_jogging_d1_uncomp.avi 62 | jogging/person23_jogging_d2_uncomp.avi 63 | jogging/person23_jogging_d3_uncomp.avi 64 | jogging/person23_jogging_d4_uncomp.avi 65 | running/person23_running_d1_uncomp.avi 66 | running/person23_running_d2_uncomp.avi 67 | running/person23_running_d3_uncomp.avi 68 | running/person23_running_d4_uncomp.avi 69 | walking/person23_walking_d1_uncomp.avi 70 | walking/person23_walking_d2_uncomp.avi 71 | walking/person23_walking_d3_uncomp.avi 72 | walking/person23_walking_d4_uncomp.avi 73 | boxing/person24_boxing_d1_uncomp.avi 74 | boxing/person24_boxing_d2_uncomp.avi 75 | boxing/person24_boxing_d3_uncomp.avi 76 | boxing/person24_boxing_d4_uncomp.avi 77 | handclapping/person24_handclapping_d1_uncomp.avi 78 | handclapping/person24_handclapping_d2_uncomp.avi 79 | handclapping/person24_handclapping_d3_uncomp.avi 80 | handclapping/person24_handclapping_d4_uncomp.avi 81 | handwaving/person24_handwaving_d1_uncomp.avi 82 | handwaving/person24_handwaving_d2_uncomp.avi 83 | handwaving/person24_handwaving_d3_uncomp.avi 84 | handwaving/person24_handwaving_d4_uncomp.avi 85 | jogging/person24_jogging_d1_uncomp.avi 86 | jogging/person24_jogging_d2_uncomp.avi 87 | jogging/person24_jogging_d3_uncomp.avi 88 | jogging/person24_jogging_d4_uncomp.avi 89 | running/person24_running_d1_uncomp.avi 90 | running/person24_running_d2_uncomp.avi 91 | running/person24_running_d3_uncomp.avi 92 | running/person24_running_d4_uncomp.avi 93 | walking/person24_walking_d1_uncomp.avi 94 | walking/person24_walking_d2_uncomp.avi 95 | walking/person24_walking_d3_uncomp.avi 96 | walking/person24_walking_d4_uncomp.avi 97 | boxing/person25_boxing_d1_uncomp.avi 98 | boxing/person25_boxing_d2_uncomp.avi 99 | boxing/person25_boxing_d3_uncomp.avi 100 | boxing/person25_boxing_d4_uncomp.avi 101 | handclapping/person25_handclapping_d1_uncomp.avi 102 | handclapping/person25_handclapping_d2_uncomp.avi 103 | handclapping/person25_handclapping_d3_uncomp.avi 104 | handclapping/person25_handclapping_d4_uncomp.avi 105 | handwaving/person25_handwaving_d1_uncomp.avi 106 | handwaving/person25_handwaving_d2_uncomp.avi 107 | handwaving/person25_handwaving_d3_uncomp.avi 108 | handwaving/person25_handwaving_d4_uncomp.avi 109 | jogging/person25_jogging_d1_uncomp.avi 110 | jogging/person25_jogging_d2_uncomp.avi 111 | jogging/person25_jogging_d3_uncomp.avi 112 | jogging/person25_jogging_d4_uncomp.avi 113 | running/person25_running_d1_uncomp.avi 114 | running/person25_running_d2_uncomp.avi 115 | running/person25_running_d3_uncomp.avi 116 | running/person25_running_d4_uncomp.avi 117 | walking/person25_walking_d1_uncomp.avi 118 | walking/person25_walking_d2_uncomp.avi 119 | walking/person25_walking_d3_uncomp.avi 120 | walking/person25_walking_d4_uncomp.avi 121 | -------------------------------------------------------------------------------- /data/SMMNIST/01_mnist_download_and_convert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/edenton/svg/blob/master/data/convert_bair.py 2 | import argparse 3 | import cv2 4 | import glob 5 | import numpy as np 6 | import os 7 | import pickle 8 | import sys 9 | from tqdm import tqdm 10 | from stochastic_moving_mnist import StochasticMovingMNIST 11 | import torch 12 | 13 | sys.path.append("..") 14 | from h5 import HDF5Maker 15 | 16 | class KTH_HDF5Maker(HDF5Maker): 17 | 18 | def add_video_info(self): 19 | pass 20 | 21 | def create_video_groups(self): 22 | self.writer.create_group('len') 23 | self.writer.create_group('videos') 24 | 25 | def add_video_data(self, data, dtype=None): 26 | frames = data 27 | self.writer['len'].create_dataset(str(self.count), data=len(frames)) 28 | self.writer.create_group(str(self.count)) 29 | for i, frame in enumerate(frames): 30 | self.writer[str(self.count)].create_dataset(str(i), data=frame, dtype=dtype, compression="lzf") 31 | 32 | def read_video(video_path, image_size): 33 | # opencv is faster than mediapy 34 | cap = cv2.VideoCapture(video_path) 35 | frames = [] 36 | while True: 37 | ret, frame = cap.read() 38 | if not ret: 39 | break 40 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 41 | image = cv2.resize(gray, (image_size, image_size)) 42 | frames.append(image) 43 | cap.release() 44 | if frames == []: 45 | print(f"the file {video_path} may has been damaged, you should") 46 | print("1. clean old video file and old hdf5 file") 47 | print("2. download new video file") 48 | print("3. generate hdf5 file again") 49 | ValueError() 50 | return frames 51 | 52 | def read_data(video): 53 | # data torch [0,1] torch.Size([40, 1, 64, 64]) 54 | # return frames list [0, 255] [ array, array, ...] 55 | 56 | frames = [] 57 | 58 | video = (video.squeeze()*255).type(torch.uint8).numpy() 59 | 60 | for frame in video: 61 | frames.append(frame) 62 | 63 | if frames == []: 64 | print("Error: ") 65 | print("1. clean old video file and old hdf5 file") 66 | print("2. download new video file") 67 | print("3. generate hdf5 file again") 68 | ValueError() 69 | 70 | return frames 71 | 72 | def make_h5_from_kth(mnist_dir, image_size=64, seq_len=40, out_dir='./h5_ds', vids_per_shard=1000000, force_h5=False): 73 | 74 | train_dataset = StochasticMovingMNIST( 75 | mnist_dir, train=True, seq_len=seq_len, num_digits=2, 76 | step_length=0.1, with_target=False 77 | ) 78 | 79 | test_dataset = StochasticMovingMNIST( 80 | mnist_dir, train=False, seq_len=seq_len, num_digits=2, 81 | step_length=0.1, with_target=False, 82 | total_videos=256 83 | ) 84 | 85 | print(len(train_dataset)) 86 | 87 | print(train_dataset[0].shape) 88 | print(train_dataset[0].shape) 89 | 90 | # import torch 91 | # print(torch.min(train_dataset[0][1]), torch.max(train_dataset[0][1])) 92 | # # value [0, 1] 93 | 94 | # train_dataset 95 | 96 | print(f"process train_dataset") 97 | 98 | dataste_dir = out_dir + '/train' 99 | h5_maker = KTH_HDF5Maker(dataste_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 100 | 101 | 102 | for data in tqdm(train_dataset): 103 | # try: 104 | frames = read_data(data) 105 | h5_maker.add_data(frames, dtype='uint8') 106 | # except StopIteration: 107 | # break 108 | # except (KeyboardInterrupt, SystemExit): 109 | # print("Ctrl+C!!") 110 | # break 111 | # except: 112 | # e = sys.exc_info()[0] 113 | # print("ERROR:", e) 114 | 115 | h5_maker.close() 116 | 117 | # test_dataset 118 | 119 | print(f"process test_dataset") 120 | 121 | dataste_dir = out_dir + '/test' 122 | h5_maker = KTH_HDF5Maker(dataste_dir, num_per_shard=vids_per_shard, force=force_h5, video=True) 123 | 124 | for data in tqdm(test_dataset): 125 | # try: 126 | frames = read_data(data) 127 | h5_maker.add_data(frames, dtype='uint8') 128 | # except StopIteration: 129 | # break 130 | # except (KeyboardInterrupt, SystemExit): 131 | # print("Ctrl+C!!") 132 | # break 133 | # except: 134 | # e = sys.exc_info()[0] 135 | # print("ERROR:", e) 136 | 137 | h5_maker.close() 138 | 139 | print(f"process done!") 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--seq_len', type=int, default=40) 146 | parser.add_argument('--out_dir', type=str, help="Directory to save .hdf5 files") 147 | parser.add_argument('--mnist_dir', type=str, help="Directory with KTH") 148 | parser.add_argument('--image_size', type=int, default=64) 149 | parser.add_argument('--vids_per_shard', type=int, default=1000000) 150 | parser.add_argument('--force_h5', type=eval, default=False) 151 | 152 | args = parser.parse_args() 153 | 154 | make_h5_from_kth(out_dir=args.out_dir, mnist_dir=args.mnist_dir, image_size=args.image_size, vids_per_shard=args.vids_per_shard, force_h5=args.force_h5) -------------------------------------------------------------------------------- /data/SMMNIST/stochastic_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torchvision import datasets, transforms 5 | 6 | 7 | class ToTensor(object): 8 | """Converts a numpy.ndarray (... x H x W x C) to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 1.0]. 9 | """ 10 | def __init__(self, scale=True): 11 | self.scale = scale 12 | def __call__(self, arr): 13 | if isinstance(arr, np.ndarray): 14 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3)) 15 | if self.scale: 16 | return video.float() 17 | else: 18 | return video.float() 19 | else: 20 | raise NotImplementedError 21 | 22 | 23 | # https://github.com/edenton/svg/blob/master/data/moving_mnist.py 24 | class StochasticMovingMNIST(object): 25 | 26 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 27 | 28 | def __init__(self, data_root, train=True, seq_len=20, num_digits=2, image_size=64, deterministic=False, 29 | step_length=0.1, total_videos=-1, with_target=False, transform=transforms.Compose([ToTensor()])): 30 | path = data_root 31 | self.seq_len = seq_len 32 | self.num_digits = num_digits 33 | self.image_size = image_size 34 | self.step_length = step_length 35 | self.with_target = with_target 36 | self.transform = transform 37 | self.deterministic = deterministic 38 | 39 | self.seed_is_set = False # multi threaded loading 40 | self.digit_size = 32 41 | self.channels = 1 42 | 43 | self.data = datasets.MNIST( 44 | path, 45 | train=train, 46 | download=True, 47 | transform=transforms.Compose( 48 | [transforms.Resize(self.digit_size), 49 | transforms.ToTensor()])) 50 | 51 | self.N = len(self.data) if total_videos == -1 else total_videos 52 | 53 | print(f"Dataset length: {self.__len__()}") 54 | 55 | def set_seed(self, seed): 56 | if not self.seed_is_set: 57 | self.seed_is_set = True 58 | np.random.seed(seed) 59 | 60 | def __len__(self): 61 | return self.N 62 | 63 | def __getitem__(self, index): 64 | self.set_seed(index) 65 | image_size = self.image_size 66 | digit_size = self.digit_size 67 | x = np.zeros((self.seq_len, 68 | image_size, 69 | image_size, 70 | self.channels), 71 | dtype=np.float32) 72 | for n in range(self.num_digits): 73 | idx = np.random.randint(self.N) 74 | digit, _ = self.data[idx] 75 | 76 | sx = np.random.randint(image_size-digit_size) 77 | sy = np.random.randint(image_size-digit_size) 78 | dx = np.random.randint(-4, 5) 79 | dy = np.random.randint(-4, 5) 80 | for t in range(self.seq_len): 81 | if sy < 0: 82 | sy = 0 83 | if self.deterministic: 84 | dy = -dy 85 | else: 86 | dy = np.random.randint(1, 5) 87 | dx = np.random.randint(-4, 5) 88 | elif sy >= image_size-32: 89 | sy = image_size-32-1 90 | if self.deterministic: 91 | dy = -dy 92 | else: 93 | dy = np.random.randint(-4, 0) 94 | dx = np.random.randint(-4, 5) 95 | 96 | if sx < 0: 97 | sx = 0 98 | if self.deterministic: 99 | dx = -dx 100 | else: 101 | dx = np.random.randint(1, 5) 102 | dy = np.random.randint(-4, 5) 103 | elif sx >= image_size-32: 104 | sx = image_size-32-1 105 | if self.deterministic: 106 | dx = -dx 107 | else: 108 | dx = np.random.randint(-4, 0) 109 | dy = np.random.randint(-4, 5) 110 | 111 | x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() 112 | sy += dy 113 | sx += dx 114 | 115 | x[x>1] = 1. 116 | 117 | if self.with_target: 118 | targets = np.array(x >= 0.5, dtype=float) 119 | 120 | if self.transform is not None: 121 | x = self.transform(x) 122 | if self.with_target: 123 | targets = self.transform(targets) 124 | 125 | if self.with_target: 126 | return x, targets 127 | else: 128 | return x 129 | -------------------------------------------------------------------------------- /data/SMMNIST/stochastic_moving_mnist_edited.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torchvision import datasets, transforms 5 | 6 | 7 | class ToTensor(object): 8 | """Converts a numpy.ndarray (... x H x W x C) to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 1.0]. 9 | """ 10 | def __init__(self, scale=True): 11 | self.scale = scale 12 | def __call__(self, arr): 13 | if isinstance(arr, np.ndarray): 14 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3)) 15 | if self.scale: 16 | return video.float() 17 | else: 18 | return video.float() 19 | else: 20 | raise NotImplementedError 21 | 22 | 23 | # https://github.com/edenton/svg/blob/master/data/moving_mnist.py 24 | class StochasticMovingMNIST(object): 25 | 26 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 27 | 28 | def __init__(self, data_root, train=True, num_digits=2, image_size=64, deterministic=False, 29 | step_length=0.1, total_videos=-1, with_target=False, transform=transforms.Compose([ToTensor()]), 30 | same_samples=6, diff_samples=1, cond_seq_len=10, pred_seq_len=0): 31 | path = data_root 32 | self.num_digits = num_digits 33 | self.image_size = image_size 34 | self.step_length = step_length 35 | self.with_target = with_target 36 | self.transform = transform 37 | self.deterministic = deterministic 38 | 39 | self.seed_is_set = False # multi threaded loading 40 | self.digit_size = 32 41 | self.channels = 1 42 | 43 | self.same_samples = same_samples 44 | self.diff_samples = diff_samples 45 | self.cond_seq_len = cond_seq_len 46 | self.pred_seq_len = pred_seq_len 47 | 48 | self.data = datasets.MNIST( 49 | path, 50 | train=train, 51 | download=True, 52 | transform=transforms.Compose( 53 | [transforms.Resize(self.digit_size), 54 | transforms.ToTensor()])) 55 | 56 | self.N = len(self.data) if total_videos == -1 else total_videos 57 | 58 | print(f"Dataset length: {self.__len__()}") 59 | 60 | def set_seed(self, seed): 61 | if not self.seed_is_set: 62 | self.seed_is_set = True 63 | np.random.seed(seed) 64 | 65 | def __len__(self): 66 | return self.N 67 | 68 | def __getitem__(self, index): 69 | self.set_seed(index) 70 | image_size = self.image_size 71 | digit_size = self.digit_size 72 | x = np.zeros((self.same_samples+self.diff_samples, self.cond_seq_len+self.pred_seq_len, 73 | image_size, 74 | image_size, 75 | self.channels), 76 | dtype=np.float32) 77 | for n in range(self.num_digits): 78 | same_idx = np.random.randint(self.N) 79 | same_digits = [self.data[same_idx][0] for _ in range(self.same_samples)] 80 | diff_idxs = np.random.randint(self.N, size=self.diff_samples) 81 | diff_digits = [self.data[diff_idx][0] for diff_idx in diff_idxs] 82 | digits = same_digits + diff_digits 83 | 84 | sx = np.random.randint(image_size-digit_size) 85 | sy = np.random.randint(image_size-digit_size) 86 | dx = np.random.randint(-4, 5) 87 | dy = np.random.randint(-4, 5) 88 | for t in range(self.cond_seq_len): 89 | if sy < 0: 90 | sy = 0 91 | if self.deterministic: 92 | dy = -dy 93 | else: 94 | dy = np.random.randint(1, 5) 95 | dx = np.random.randint(-4, 5) 96 | elif sy >= image_size-32: 97 | sy = image_size-32-1 98 | if self.deterministic: 99 | dy = -dy 100 | else: 101 | dy = np.random.randint(-4, 0) 102 | dx = np.random.randint(-4, 5) 103 | 104 | if sx < 0: 105 | sx = 0 106 | if self.deterministic: 107 | dx = -dx 108 | else: 109 | dx = np.random.randint(1, 5) 110 | dy = np.random.randint(-4, 5) 111 | elif sx >= image_size-32: 112 | sx = image_size-32-1 113 | if self.deterministic: 114 | dx = -dx 115 | else: 116 | dx = np.random.randint(-4, 0) 117 | dy = np.random.randint(-4, 5) 118 | 119 | for x_idx, digit in enumerate(digits): 120 | x[x_idx, t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() 121 | 122 | sy += dy 123 | sx += dx 124 | 125 | pred_init_sy = sy 126 | pred_init_sx = sx 127 | pred_init_dy = dy 128 | pred_init_dx = dx 129 | 130 | for n in range(self.same_samples+self.diff_samples): 131 | sy = pred_init_sy 132 | sx = pred_init_sx 133 | dy = pred_init_dy 134 | dx = pred_init_dx 135 | for t in range(self.cond_seq_len, self.cond_seq_len+self.pred_seq_len): 136 | if sy < 0: 137 | sy = 0 138 | if self.deterministic: 139 | dy = -dy 140 | else: 141 | dy = np.random.randint(1, 5) 142 | dx = np.random.randint(-4, 5) 143 | elif sy >= image_size-32: 144 | sy = image_size-32-1 145 | if self.deterministic: 146 | dy = -dy 147 | else: 148 | dy = np.random.randint(-4, 0) 149 | dx = np.random.randint(-4, 5) 150 | 151 | if sx < 0: 152 | sx = 0 153 | if self.deterministic: 154 | dx = -dx 155 | else: 156 | dx = np.random.randint(1, 5) 157 | dy = np.random.randint(-4, 5) 158 | elif sx >= image_size-32: 159 | sx = image_size-32-1 160 | if self.deterministic: 161 | dx = -dx 162 | else: 163 | dx = np.random.randint(-4, 0) 164 | dy = np.random.randint(-4, 5) 165 | 166 | x[n, t, sy:sy+32, sx:sx+32, 0] += digits[n].numpy().squeeze() 167 | sy += dy 168 | sx += dx 169 | 170 | x[x>1] = 1. 171 | 172 | if self.with_target: 173 | targets = np.array(x >= 0.5, dtype=float) 174 | 175 | if self.transform is not None: 176 | x = [self.transform(xx) for xx in x] 177 | if self.with_target: 178 | targets = [self.transform(target) for target in targets] 179 | 180 | if self.with_target: 181 | return x, targets 182 | else: 183 | return x 184 | 185 | # import mediapy as media 186 | # mnist_dir="/home/u1120230288/zzc/data/video_prediction/dataset/SMMNIST_h5" 187 | # train_dataset = StochasticMovingMNIST( 188 | # mnist_dir, train=True, num_digits=2, 189 | # step_length=0.1, with_target=False, 190 | # cond_seq_len=10, pred_seq_len=10, same_samples=6, diff_samples=1 191 | # ) 192 | # a_video_samples = train_dataset[0] 193 | # media.show_videos([video.squeeze().numpy() for video in a_video_samples], fps=10) -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | from torch.utils.data import Dataset, ConcatDataset 4 | import mediapy as media 5 | 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensorV2 8 | import random 9 | 10 | # for image dataset 11 | # import albumentations 12 | # from PIL import Image 13 | import torch 14 | 15 | from data.h5 import HDF5Dataset 16 | 17 | class ConcatDatasetWithIndex(ConcatDataset): 18 | """Modified from original pytorch code to return dataset idx""" 19 | def __getitem__(self, idx): 20 | if idx < 0: 21 | if -idx > len(self): 22 | raise ValueError("absolute value of index should not exceed dataset length") 23 | idx = len(self) + idx 24 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 25 | if dataset_idx == 0: 26 | sample_idx = idx 27 | else: 28 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 29 | return self.datasets[dataset_idx][sample_idx], dataset_idx 30 | 31 | 32 | class VideoPaths(Dataset): 33 | def __init__(self, paths, start_idxs, end_idxs, trans=None, labels=None): 34 | self._length = len(paths) 35 | self._trans = trans 36 | 37 | if labels is None: 38 | self.labels = dict() 39 | else: 40 | self.labels = labels 41 | 42 | self.labels["file_path"] = paths 43 | self.labels["start_idx"] = start_idxs 44 | self.labels["end_idx"] = end_idxs 45 | 46 | def __len__(self): 47 | return self._length 48 | 49 | def preprocess_video(self, video_path, start_idx, end_idx): 50 | video = media.read_video(video_path)[start_idx:end_idx] 51 | video = np.array(video).astype(np.uint8) 52 | tmp_video = [] 53 | for i in range(len(video)): 54 | tmp_video.append(self._trans(image=video[i])["image"]) 55 | video = np.array(tmp_video) 56 | video = (video/127.5 - 1.0).astype(np.float32) 57 | # [0,255] -> [-1,1] 58 | return video 59 | 60 | def __getitem__(self, i): 61 | video = dict() 62 | video["video"] = self.preprocess_video(self.labels["file_path"][i], int(self.labels["start_idx"][i]), int(self.labels["end_idx"][i])) 63 | for k in self.labels: 64 | video[k] = self.labels[k][i] 65 | return video 66 | 67 | 68 | class HDF5InterfaceDataset(Dataset): 69 | def __init__(self, data_dir, frames_per_sample, random_time=True, total_videos=-1, start_at=0, labels=None): 70 | super().__init__() 71 | if labels is None: 72 | self.labels = dict() 73 | else: 74 | self.labels = labels 75 | self.data_dir = data_dir 76 | self.videos_ds = HDF5Dataset(data_dir) 77 | self.total_videos = total_videos 78 | self.start_at = start_at 79 | self.random_time = random_time 80 | self.frames_per_sample = frames_per_sample 81 | 82 | # The numpy HWC image is converted to pytorch CHW tensor. 83 | # If the image is in HW format (grayscale image), 、 84 | # it will be converted to pytorch HW tensor. 85 | flag = random.choice([0,1]) 86 | 87 | self.trans = A.Compose([ 88 | A.HorizontalFlip(p=flag), 89 | ToTensorV2() 90 | ]) 91 | 92 | def __len__(self): 93 | if self.total_videos > 0: 94 | return self.total_videos 95 | else: 96 | return len(self.videos_ds) 97 | 98 | def max_index(self): 99 | return len(self.videos_ds) 100 | 101 | def len_of_vid(self, index): 102 | video_index = index % self.__len__() 103 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 104 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 105 | video_len = f['len'][str(idx_in_shard)][()] 106 | return video_len 107 | 108 | def __getitem__(self, index, time_idx=0): 109 | # Use `index` to select the video, and then 110 | # randomly choose a `frames_per_sample` window of frames in the video 111 | video = dict() 112 | 113 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 114 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 115 | final_clip = [] 116 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 117 | video_len = f['len'][str(idx_in_shard)][()] - self.start_at 118 | if self.random_time and video_len > self.frames_per_sample: 119 | time_idx = np.random.choice(video_len - self.frames_per_sample) 120 | time_idx += self.start_at 121 | for i in range(time_idx, min(time_idx + self.frames_per_sample, video_len)): 122 | final_clip.append(self.trans(image=f[str(idx_in_shard)][str(i)][()])["image"]) 123 | final_clip = torch.stack(final_clip) 124 | final_clip = (final_clip/127.5 - 1.0).type(torch.float32) 125 | video["video"] = final_clip 126 | 127 | for k in self.labels: 128 | video[k] = self.labels[k][i] 129 | 130 | return video 131 | 132 | 133 | # class ImagePaths(Dataset): 134 | # def __init__(self, paths, size=None, random_crop=False, labels=None): 135 | # self.size = size 136 | # self.random_crop = random_crop 137 | 138 | # self.labels = dict() if labels is None else labels 139 | # self.labels["file_path"] = paths 140 | # self._length = len(paths) 141 | 142 | # if self.size is not None and self.size > 0: 143 | # self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 144 | # if not self.random_crop: 145 | # self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 146 | # else: 147 | # self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 148 | # self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 149 | # else: 150 | # self.preprocessor = lambda **kwargs: kwargs 151 | 152 | # def __len__(self): 153 | # return self._length 154 | 155 | # def preprocess_image(self, image_path): 156 | # image = Image.open(image_path) 157 | # if not image.mode == "RGB": 158 | # image = image.convert("RGB") 159 | # image = np.array(image).astype(np.uint8) 160 | # image = self.preprocessor(image=image)["image"] 161 | # image = (image/127.5 - 1.0).astype(np.float32) 162 | # return image 163 | 164 | # def __getitem__(self, i): 165 | # example = dict() 166 | # example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 167 | # for k in self.labels: 168 | # example[k] = self.labels[k][i] 169 | # return example 170 | 171 | # class NumpyPaths(ImagePaths): 172 | # def preprocess_image(self, image_path): 173 | # image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 174 | # image = np.transpose(image, (1,2,0)) 175 | # image = Image.fromarray(image, mode="RGB") 176 | # image = np.array(image).astype(np.uint8) 177 | # image = self.preprocessor(image=image)["image"] 178 | # image = (image/127.5 - 1.0).astype(np.float32) 179 | # return image 180 | -------------------------------------------------------------------------------- /data/h5.py: -------------------------------------------------------------------------------- 1 | # https://github.com/fab-jul/hdf5_dataloader 2 | import argparse 3 | import glob 4 | import h5py 5 | import numpy as np 6 | import os 7 | import pickle 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | 12 | default_opener = lambda p_: h5py.File(p_, 'r') 13 | 14 | 15 | class HDF5Dataset(Dataset): 16 | 17 | @staticmethod 18 | def _get_num_in_shard(shard_p, opener=default_opener): 19 | print(f'\rh5: Opening {shard_p}... ', end='') 20 | try: 21 | with opener(shard_p) as f: 22 | num_per_shard = len(f['len'].keys()) 23 | except: 24 | print(f"h5: Could not open {shard_p}!") 25 | num_per_shard = -1 26 | return num_per_shard 27 | 28 | @staticmethod 29 | def check_shard_lengths(file_paths, opener=default_opener): 30 | """ 31 | Filter away the last shard, which is assumed to be smaller. this double checks that all other shards have the 32 | same number of entries. 33 | :param file_paths: list of .hdf5 files 34 | :param opener: 35 | :return: tuple (ps, num_per_shard) where 36 | ps = filtered file paths, 37 | num_per_shard = number of entries in all of the shards in `ps` 38 | """ 39 | shard_lengths = [] 40 | print("Checking shard_lengths in", file_paths) 41 | for i, p in enumerate(file_paths): 42 | shard_lengths.append(HDF5Dataset._get_num_in_shard(p, opener)) 43 | return shard_lengths 44 | 45 | def __init__(self, data_path, # hdf5 file, or directory of hdf5s 46 | shuffle_shards=False, 47 | opener=default_opener, 48 | seed=29): 49 | self.data_path = data_path 50 | self.shuffle_shards = shuffle_shards 51 | self.opener = opener 52 | self.seed = seed 53 | 54 | # If `data_path` is an hdf5 file 55 | if os.path.splitext(self.data_path)[-1] == '.hdf5' or os.path.splitext(self.data_path)[-1] == '.h5': 56 | self.data_dir = os.path.dirname(self.data_path) 57 | self.shard_paths = [self.data_path] 58 | # Else, if `data_path` is a directory of hdf5s 59 | else: 60 | self.data_dir = self.data_path 61 | self.shard_paths = sorted(glob.glob(os.path.join(self.data_dir, '*.hdf5')) + glob.glob(os.path.join(self.data_dir, '*.h5'))) 62 | 63 | assert len(self.shard_paths) > 0, "h5: Directory does not have any .hdf5 files! Dir: " + self.data_dir 64 | 65 | self.shard_lengths = HDF5Dataset.check_shard_lengths(self.shard_paths, self.opener) 66 | self.num_per_shard = self.shard_lengths[0] 67 | self.total_num = sum(self.shard_lengths) 68 | 69 | assert len(self.shard_paths) > 0, "h5: Could not find .hdf5 files! Dir: " + self.data_dir + " ; len(self.shard_paths) = " + str(len(self.shard_paths)) 70 | 71 | self.num_of_shards = len(self.shard_paths) 72 | 73 | print("h5: paths", len(self.shard_paths), "; shard_lengths", self.shard_lengths, "; total", self.total_num) 74 | 75 | # Shuffle shards 76 | if self.shuffle_shards: 77 | np.random.seed(seed) 78 | np.random.shuffle(self.shard_paths) 79 | 80 | def __len__(self): 81 | return self.total_num 82 | 83 | def get_indices(self, idx): 84 | shard_idx = np.digitize(idx, np.cumsum(self.shard_lengths)) 85 | idx_in_shard = str(idx - sum(self.shard_lengths[:shard_idx])) 86 | return shard_idx, idx_in_shard 87 | 88 | def __getitem__(self, index): 89 | idx = index % self.total_num 90 | shard_idx, idx_in_shard = self.get_indices(idx) 91 | # Read from shard 92 | with self.opener(self.shard_paths[shard_idx]) as f: 93 | data = f[idx_in_shard][()] 94 | return data 95 | 96 | 97 | class HDF5Maker(): 98 | 99 | def __init__(self, out_path, num_per_shard=100000, max_shards=None, name=None, name_fmt='shard_{:04d}.hdf5', force=False, video=False): 100 | 101 | # `out_path` could be an hdf5 file, or a directory of hdf5s 102 | # If `out_path` is an hdf5 file, then `name` will be its basename 103 | # If `out_path` is a directory, then `name` will be used if provided else name_fmt will be used 104 | 105 | self.out_path = out_path 106 | self.num_per_shard = num_per_shard 107 | self.max_shards= max_shards 108 | self.name = name 109 | self.name_fmt = name_fmt 110 | self.force = force 111 | self.video = video 112 | 113 | # If `out_path` is an hdf5 file 114 | if os.path.splitext(self.out_path)[-1] == '.hdf5' or os.path.splitext(self.out_path)[-1] == '.h5': 115 | # If it exists, check if it should be deleted 116 | if os.path.isfile(self.out_path): 117 | if not self.force: 118 | raise ValueError('{} already exists.'.format(self.out_path)) 119 | print('Removing {}...'.format(self.out_path)) 120 | os.remove(self.out_path) 121 | # Make the directory if it does not exist 122 | self.out_dir = os.path.dirname(self.out_path) 123 | os.makedirs(self.out_dir, exist_ok=True) 124 | # Extract its name 125 | self.name = os.path.basename(self.out_path) 126 | # Else, if `out_path` is a directory 127 | else: 128 | self.out_dir = self.out_path 129 | # If `out_dir` exists 130 | if os.path.isdir(self.out_dir): 131 | # Check if it should be deleted 132 | if not self.force: 133 | raise ValueError('{} already exists.'.format(self.out_dir)) 134 | print('Removing *.hdf5 files from {}...'.format(self.out_dir)) 135 | files = glob.glob(os.path.join(self.out_dir, "*.hdf5")) 136 | files += glob.glob(os.path.join(self.out_dir, "*.h5")) 137 | for file in files: 138 | os.remove(file) 139 | # Else, make the directory 140 | else: 141 | os.makedirs(self.out_dir) 142 | 143 | self.writer = None 144 | self.shard_paths = [] 145 | self.shard_number = 0 146 | 147 | # To save num_of_objs in each item 148 | shard_idx = 0 149 | idx_in_shard = 0 150 | 151 | self.create_new_shard() 152 | self.add_video_info() 153 | 154 | def create_new_shard(self): 155 | 156 | if self.writer: 157 | self.writer.close() 158 | 159 | self.shard_number += 1 160 | 161 | if self.max_shards is not None and self.shard_number == self.max_shards + 1: 162 | print('Created {} shards, ENDING.'.format(self.max_shards)) 163 | return 164 | 165 | self.shard_p = os.path.join(self.out_dir, self.name_fmt.format(self.shard_number) if self.name is None else self.name) 166 | assert not os.path.exists(self.shard_p), 'Record already exists! {}'.format(self.shard_p) 167 | self.shard_paths.append(self.shard_p) 168 | 169 | print('Creating shard # {}: {}...'.format(self.shard_number, self.shard_p)) 170 | self.writer = h5py.File(self.shard_p, 'w') 171 | 172 | if self.video: 173 | self.create_video_groups() 174 | 175 | self.count = 0 176 | 177 | def add_video_info(self): 178 | pass 179 | 180 | def create_video_groups(self): 181 | self.writer.create_group('len') 182 | self.writer.create_group('videos') 183 | 184 | def add_video_data(self, data, dtype=None): 185 | self.writer['len'].create_dataset(str(self.count), data=len(data)) 186 | self.writer.create_group(str(self.count)) 187 | for i, frame in enumerate(data): 188 | self.writer[str(self.count)].create_dataset(str(i), data=frame, dtype=dtype, compression="lzf") 189 | 190 | def add_data(self, data, dtype=None, return_curr_count=False): 191 | 192 | if self.video: 193 | self.add_video_data(data, dtype) 194 | else: 195 | NotImplementedError() 196 | 197 | curr_count = self.count 198 | self.count += 1 199 | 200 | if self.count == self.num_per_shard: 201 | self.create_new_shard() 202 | 203 | if return_curr_count: 204 | return curr_count 205 | 206 | def close(self): 207 | self.writer.close() 208 | assert len(self.shard_paths) 209 | 210 | 211 | if __name__ == "__main__": 212 | 213 | # Make 214 | h5_maker = HDF5Maker('EXPERIMENTS/h5', num_per_shard=10, force=True, video=True) 215 | 216 | a = [torch.zeros(60, 3, 64, 64)] * 20 217 | for data in a: 218 | h5_maker.add_data(data) 219 | 220 | h5_maker.close() 221 | 222 | # Read 223 | h5_ds = HDF5Dataset('EXPERIMENTS/h5') 224 | 225 | print(len(h5_ds)) # 20 226 | 227 | data = h5_ds[11] 228 | 229 | assert torch.all(data == a[11]) 230 | -------------------------------------------------------------------------------- /data/video_dataset.py: -------------------------------------------------------------------------------- 1 | # loading video dataset for training and testing 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch.utils.data as data 7 | # from skimage.color import gray2rgb 8 | 9 | import cv2 10 | # import torchvision.transforms.functional as F 11 | from torchvision import transforms 12 | 13 | from data.h5 import HDF5Dataset 14 | 15 | from einops import rearrange, repeat 16 | 17 | def dataset2video(video): 18 | if len(video.shape) == 3: 19 | video = repeat(video, 't h w -> t c h w', c=3) 20 | elif video.shape[1] == 1: 21 | video = repeat(video, 't c h w -> t (n c) h w', n=3) 22 | else: 23 | video = rearrange(video, 't h w c -> t c h w') 24 | return video 25 | 26 | def dataset2videos(videos): 27 | if len(videos.shape) == 4: 28 | videos = repeat(videos, 'b t h w -> b t c h w', c=3) 29 | elif videos.shape[2] == 1: 30 | videos = repeat(videos, 'b t c h w -> b t (n c) h w', n=3) 31 | else: 32 | videos = rearrange(videos, 'b t h w c -> b t c h w') 33 | return videos 34 | 35 | def resize(im, desired_size, interpolation): 36 | old_size = im.shape[:2] 37 | ratio = float(desired_size)/max(old_size) 38 | new_size = tuple(int(x*ratio) for x in old_size) 39 | 40 | im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) 41 | delta_w = desired_size - new_size[1] 42 | delta_h = desired_size - new_size[0] 43 | top, bottom = delta_h//2, delta_h-(delta_h//2) 44 | left, right = delta_w//2, delta_w-(delta_w//2) 45 | 46 | color = [0, 0, 0] 47 | new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) 48 | 49 | return new_im 50 | 51 | class VideoDataset(data.Dataset): 52 | def __init__(self, 53 | data_dir, 54 | type='train', 55 | total_videos=-1, 56 | num_frames=40, 57 | image_size=64, 58 | random_time=True, 59 | # color_jitter=None, 60 | random_horizontal_flip=False 61 | ): 62 | super(VideoDataset, self).__init__() 63 | self.data_dir = data_dir 64 | self.type = type 65 | self.num_frames = num_frames 66 | self.image_size = image_size 67 | self.total_videos = total_videos 68 | self.random_time = random_time 69 | self.random_horizontal_flip = random_horizontal_flip 70 | # self.jitter = transforms.ColorJitter(hue=color_jitter) if color_jitter else None 71 | 72 | if "UCF" in self.data_dir: 73 | self.videos_ds = HDF5Dataset(self.data_dir) 74 | # Train 75 | # self.num_train_vids = 9624 76 | # self.num_test_vids = 3696 # -> 369 : https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 77 | with self.videos_ds.opener(self.videos_ds.shard_paths[0]) as f: 78 | self.num_train_vids = f['num_train'][()] 79 | self.num_test_vids = f['num_test'][()]//10 # https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 80 | else: 81 | self.videos_ds = HDF5Dataset(os.path.join(self.data_dir, type)) 82 | 83 | def __len__(self): 84 | if self.total_videos > 0: 85 | return self.total_videos 86 | else: 87 | if "UCF" in self.data_dir: 88 | return self.num_train_vids if self.type=='train' else self.num_test_vids 89 | else: 90 | return len(self.videos_ds) 91 | 92 | def len_of_vid(self, index): 93 | video_index = index % self.__len__() 94 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 95 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 96 | video_len = f['len'][str(idx_in_shard)][()] 97 | return video_len 98 | 99 | def max_index(self): 100 | if "UCF" in self.data_dir: 101 | return self.num_train_vids if self.type=='train' else self.num_test_vids 102 | else: 103 | return len(self.videos_ds) 104 | 105 | def __getitem__(self, index, time_idx=0): 106 | if "UCF" in self.data_dir: 107 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 108 | if not self.type=='train': 109 | video_index = video_index * 10 + self.num_train_vids # https://arxiv.org/pdf/1511.05440.pdf takes every 10th test video 110 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 111 | 112 | # random crop 113 | crop_c = np.random.randint(int(self.image_size/240*320) - self.image_size) if self.type=='train' else int((self.image_size/240*320 - self.image_size)/2) 114 | 115 | # random horizontal flip 116 | flip_p = np.random.randint(2) == 0 if self.random_horizontal_flip else 0 117 | 118 | # read data 119 | prefinals = [] 120 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 121 | total_num_frames = f['len'][str(idx_in_shard)][()] 122 | 123 | # sample frames 124 | if self.random_time and total_num_frames > self.num_frames: 125 | # sampling start frames 126 | time_idx = np.random.choice(total_num_frames - self.num_frames) 127 | # read frames 128 | for i in range(time_idx, min(time_idx + self.num_frames, total_num_frames)): 129 | img = f[str(idx_in_shard)][str(i)][()] 130 | arr = transforms.RandomHorizontalFlip(flip_p)(transforms.ToTensor()(img[:, crop_c:crop_c + self.image_size])) 131 | prefinals.append(arr) 132 | 133 | data = torch.stack(prefinals) 134 | data = rearrange(data, "t c h w -> t h w c") 135 | return data, video_index 136 | 137 | else: 138 | video_index = round(index / (self.__len__() - 1) * (self.max_index() - 1)) 139 | shard_idx, idx_in_shard = self.videos_ds.get_indices(video_index) 140 | 141 | prefinals = [] 142 | with self.videos_ds.opener(self.videos_ds.shard_paths[shard_idx]) as f: 143 | total_num_frames = f['len'][str(idx_in_shard)][()] 144 | 145 | # sample frames 146 | if self.random_time and total_num_frames > self.num_frames: 147 | # sampling start frames 148 | time_idx = np.random.choice(total_num_frames - self.num_frames) 149 | # read frames 150 | for i in range(time_idx, min(time_idx + self.num_frames, total_num_frames)): 151 | img = f[str(idx_in_shard)][str(i)][()] 152 | arr = torch.tensor(img)/255.0 153 | prefinals.append(arr) 154 | 155 | data = torch.stack(prefinals) 156 | return data, video_index 157 | 158 | def check_video_data_structure(): 159 | import mediapy as media 160 | 161 | # dataset_root = "/mnt/sda/hjy/fdm/CARLA_Town_01_h5" # u11 - xs 162 | dataset_root = "/home/ubuntu/zzc/data/video_prediction/UCF101/UCF101_h5" # u11 - xs 163 | 164 | dataset_type = 'train' 165 | train_dataset = VideoDataset(dataset_root, dataset_type) 166 | print(len(train_dataset)) 167 | print(train_dataset[10][0].shape) 168 | print(torch.min(train_dataset[10][0]), torch.max(train_dataset[10][0])) 169 | print(train_dataset[10][1]) 170 | 171 | # dataset_type = 'valid' 172 | dataset_type = 'test' 173 | test_dataset = VideoDataset(dataset_root, dataset_type, total_videos=256) 174 | print(len(test_dataset)) 175 | print(test_dataset[10][0].shape) 176 | print(torch.min(test_dataset[10][0]), torch.max(test_dataset[10][0])) 177 | print(test_dataset[10][1]) 178 | 179 | train_video = train_dataset[20][0] 180 | test_video = test_dataset[20][0] 181 | 182 | train_video = dataset2video(train_video) 183 | test_video = dataset2video(test_video) 184 | 185 | print(train_video.shape) 186 | print(test_video.shape) 187 | 188 | media.show_video(rearrange(train_video, 't c h w -> t h w c').numpy(),fps = 20) 189 | media.show_video(rearrange(test_video, 't c h w -> t h w c').numpy(),fps = 20) 190 | 191 | """ 192 | 479 193 | torch.Size([40, 64, 64]) 194 | tensor(0.0627) tensor(0.8078) 195 | 10 196 | 197 | or like 198 | 199 | 256 200 | torch.Size([30, 128, 128, 3]) 201 | tensor(0.) tensor(0.8863) 202 | 60 203 | 204 | """ 205 | 206 | def check_num_workers(): 207 | from time import time 208 | import multiprocessing as mp 209 | from torch.utils.data import DataLoader 210 | 211 | print(f"num of CPU: {mp.cpu_count()}") 212 | 213 | # dataset_root = "/mnt/rhdd/zzc/data/video_prediction/KTH/processed/" # u8 - xs 214 | dataset_root = "/mnt/sda/hjy/kth/processed/" # u11 - xs 215 | # dataset_root = "/mnt/sda/hjy/kth/kth_h5/" # u16 - 0.72s 216 | dataset_type = 'train' 217 | train_dataset = VideoDataset(dataset_root, dataset_type) 218 | 219 | for num_workers in range(8, 10, 2): 220 | train_dataloader = DataLoader( 221 | train_dataset, 222 | batch_size=32, 223 | shuffle=True, 224 | num_workers=num_workers, 225 | pin_memory=True, 226 | drop_last=False 227 | ) 228 | 229 | for _ in range(5): 230 | start = time() 231 | for _, _ in enumerate(train_dataloader, 0): 232 | pass 233 | end = time() 234 | print("Finish with:{} second, num_workers={}".format(end - start, num_workers)) 235 | 236 | if __name__ == "__main__": 237 | check_video_data_structure() 238 | # check_num_workers() 239 | 240 | 241 | -------------------------------------------------------------------------------- /extdm.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: extdm 3 | Version: 1.0.0 4 | Requires-Dist: torch 5 | Requires-Dist: numpy 6 | -------------------------------------------------------------------------------- /extdm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | extdm.egg-info/PKG-INFO 4 | extdm.egg-info/SOURCES.txt 5 | extdm.egg-info/dependency_links.txt 6 | extdm.egg-info/requires.txt 7 | extdm.egg-info/top_level.txt -------------------------------------------------------------------------------- /extdm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extdm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | -------------------------------------------------------------------------------- /extdm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /metrics/calculate_fvd.py: -------------------------------------------------------------------------------- 1 | from metrics.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | 6 | def trans(x): 7 | # if greyscale images add channel 8 | if x.shape[-3] == 1: 9 | x = x.repeat(1, 1, 3, 1, 1) 10 | 11 | # permute BTCHW -> BCTHW 12 | x = x.permute(0, 2, 1, 3, 4) 13 | 14 | return x 15 | 16 | def calculate_fvd(videos1, videos2, device): 17 | print("calculate_fvd...") 18 | 19 | # videos [batch_size, timestamps, channel, h, w] 20 | 21 | assert videos1.shape == videos2.shape 22 | 23 | i3d = load_i3d_pretrained(device=device) 24 | fvd_results = [] 25 | 26 | # support grayscale input, if grayscale -> channel*3 27 | # BTCHW -> BCTHW 28 | # videos -> [batch_size, channel, timestamps, h, w] 29 | 30 | videos1 = trans(videos1) 31 | videos2 = trans(videos2) 32 | 33 | fvd_results = {} 34 | 35 | for clip_timestamp in tqdm(range(videos1.shape[-3])): 36 | 37 | # for calculate FVD, each clip_timestamp must >= 10 38 | if clip_timestamp < 10: 39 | continue 40 | 41 | # get a video clip 42 | # videos_clip [batch_size, channel, timestamps[:clip], h, w] 43 | videos_clip1 = videos1[:, :, : clip_timestamp] 44 | videos_clip2 = videos2[:, :, : clip_timestamp] 45 | 46 | # get FVD features 47 | feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) 48 | feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) 49 | 50 | # calculate FVD when timestamps[:clip] 51 | fvd_results[f'[{clip_timestamp}]'] = frechet_distance(feats1, feats2) 52 | 53 | result = { 54 | "fvd": fvd_results, 55 | "fvd_video_setting": videos1.shape, 56 | "fvd_video_setting_name": "batch_size, channel, time, heigth, width", 57 | } 58 | 59 | return result 60 | 61 | def get_feats(videos, device, mini_bs=10): 62 | i3d = load_i3d_pretrained(device=device) 63 | videos = trans(videos) 64 | feats = get_fvd_feats(videos, i3d=i3d, device=device, bs=mini_bs) 65 | return feats 66 | 67 | def calculate_fvd1(videos1, videos2, device, mini_bs=10): 68 | # assert videos1.shape == videos2.shape 69 | i3d = load_i3d_pretrained(device=device) 70 | videos1 = trans(videos1) 71 | videos2 = trans(videos2) 72 | feats1 = get_fvd_feats(videos1, i3d=i3d, device=device, bs=mini_bs) 73 | feats2 = get_fvd_feats(videos2, i3d=i3d, device=device, bs=mini_bs) 74 | return frechet_distance(feats1, feats2) 75 | 76 | def calculate_fvd2(feats1, feats2): 77 | return frechet_distance(feats1, feats2) 78 | 79 | # test code / using example 80 | 81 | def main(): 82 | NUMBER_OF_VIDEOS = 100 83 | VIDEO_LENGTH = 30 84 | CHANNEL = 3 85 | SIZE = 64 86 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 87 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 88 | device = torch.device("cuda") 89 | # device = torch.device("cpu") 90 | mini_bs=2 91 | 92 | print(calculate_fvd1(videos1,videos2,device, mini_bs=16)) 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /metrics/calculate_lpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | import torch 7 | import lpips 8 | 9 | spatial = True # Return a spatial map of perceptual distance. 10 | 11 | # Linearly calibrated models (LPIPS) 12 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' 13 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' 14 | 15 | def trans(x): 16 | # if greyscale images add channel 17 | if x.shape[-3] == 1: 18 | x = x.repeat(1, 1, 3, 1, 1) 19 | 20 | # value range [0, 1] -> [-1, 1] 21 | x = x * 2 - 1 22 | 23 | return x 24 | 25 | def calculate_lpips(videos1, videos2, device): 26 | # image should be RGB, IMPORTANT: normalized to [-1,1] 27 | # print("calculate_lpips...") 28 | 29 | assert videos1.shape == videos2.shape 30 | 31 | # videos [batch_size, timestamps, channel, h, w] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # value range [0, 1] -> [-1, 1] 35 | videos1 = trans(videos1) 36 | videos2 = trans(videos2) 37 | 38 | lpips_results = [] 39 | 40 | for video_num in range(videos1.shape[0]): 41 | # for video_num in tqdm(range(videos1.shape[0])): 42 | # get a video 43 | # video [timestamps, channel, h, w] 44 | video1 = videos1[video_num] 45 | video2 = videos2[video_num] 46 | 47 | lpips_results_of_a_video = [] 48 | for clip_timestamp in range(len(video1)): 49 | # get a img 50 | # img [timestamps[x], channel, h, w] 51 | # img [channel, h, w] tensor 52 | 53 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 54 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 55 | 56 | loss_fn.to(device) 57 | 58 | # calculate lpips of a video 59 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 60 | lpips_results.append(lpips_results_of_a_video) 61 | 62 | lpips = {} 63 | lpips_std = {} 64 | 65 | for clip_timestamp in range(len(video1)): 66 | lpips[f'avg[{clip_timestamp}]'] = np.mean(lpips_results[:,clip_timestamp]) 67 | lpips_std[f'std[{clip_timestamp}]'] = np.std(lpips_results[:,clip_timestamp]) 68 | 69 | result = { 70 | "lpips": lpips, 71 | "lpips_std": lpips_std, 72 | "lpips_video_setting": video1.shape, 73 | "lpips_video_setting_name": "time, channel, heigth, width", 74 | } 75 | 76 | return result 77 | 78 | def calculate_lpips1(videos1, videos2, device): 79 | assert videos1.shape == videos2.shape 80 | videos1 = trans(videos1) 81 | videos2 = trans(videos2) 82 | lpips_results = [] 83 | for video_num in range(videos1.shape[0]): 84 | # for video_num in tqdm(range(videos1.shape[0])): 85 | video1 = videos1[video_num] 86 | video2 = videos2[video_num] 87 | lpips_results_of_a_video = [] 88 | for clip_timestamp in range(len(video1)): 89 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 90 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 91 | loss_fn.to(device) 92 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 93 | lpips_results.append(lpips_results_of_a_video) 94 | lpips_results = np.array(lpips_results) 95 | return np.mean(lpips_results), np.std(lpips_results) 96 | 97 | def calculate_lpips2(videos1, videos2, device): 98 | assert videos1.shape == videos2.shape 99 | videos1 = trans(videos1) 100 | videos2 = trans(videos2) 101 | lpips_results = [] 102 | for video_num in range(videos1.shape[0]): 103 | # for video_num in tqdm(range(videos1.shape[0])): 104 | video1 = videos1[video_num] 105 | video2 = videos2[video_num] 106 | lpips_results_of_a_video = [] 107 | for clip_timestamp in range(len(video1)): 108 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 109 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 110 | loss_fn.to(device) 111 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 112 | lpips_results.append(lpips_results_of_a_video) 113 | lpips_results = np.array(lpips_results) 114 | # print(np.mean(lpips_results,axis=-1)) 115 | return np.min(np.mean(lpips_results,axis=-1)) 116 | 117 | def calculate_lpips3(videos1, videos2, device): 118 | assert videos1.shape == videos2.shape 119 | videos1 = trans(videos1) 120 | videos2 = trans(videos2) 121 | lpips_results = [] 122 | for video_num in range(videos1.shape[0]): 123 | # for video_num in tqdm(range(videos1.shape[0])): 124 | video1 = videos1[video_num] 125 | video2 = videos2[video_num] 126 | lpips_results_of_a_video = [] 127 | for clip_timestamp in range(len(video1)): 128 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 129 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 130 | loss_fn.to(device) 131 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 132 | lpips_results.append(lpips_results_of_a_video) 133 | lpips_results = np.array(lpips_results) 134 | # print(np.mean(lpips_results,axis=-1)) 135 | return np.mean(lpips_results,axis=-1) 136 | 137 | # test code / using example 138 | 139 | def main(): 140 | NUMBER_OF_VIDEOS = 8 141 | VIDEO_LENGTH = 20 142 | CHANNEL = 3 143 | SIZE = 64 144 | CALCULATE_PER_FRAME = 5 145 | CALCULATE_FINAL = True 146 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 147 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 148 | device = torch.device("cuda") 149 | # device = torch.device("cpu") 150 | 151 | import json 152 | result = calculate_lpips2(videos1, videos2, device) 153 | # print(json.dumps(result, indent=4)) 154 | print(result) 155 | 156 | if __name__ == "__main__": 157 | main() -------------------------------------------------------------------------------- /metrics/calculate_psnr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | def img_psnr(img1, img2): 7 | # [0,1] 8 | # compute mse 9 | # mse = np.mean((img1-img2)**2) 10 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 11 | # compute psnr 12 | if mse < 1e-10: 13 | return 100 14 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 15 | return psnr 16 | 17 | def trans(x): 18 | return x 19 | 20 | def calculate_psnr(videos1, videos2): 21 | # print("calculate_psnr...") 22 | 23 | # videos [batch_size, timestamps, channel, h, w] 24 | 25 | assert videos1.shape == videos2.shape 26 | 27 | videos1 = trans(videos1) 28 | videos2 = trans(videos2) 29 | 30 | psnr_results = [] 31 | 32 | for video_num in range(videos1.shape[0]): 33 | # for video_num in tqdm(range(videos1.shape[0])): 34 | # get a video 35 | # video [timestamps, channel, h, w] 36 | video1 = videos1[video_num] 37 | video2 = videos2[video_num] 38 | 39 | psnr_results_of_a_video = [] 40 | for clip_timestamp in range(len(video1)): 41 | # get a img 42 | # img [timestamps[x], channel, h, w] 43 | # img [channel, h, w] numpy 44 | 45 | img1 = video1[clip_timestamp].cpu().numpy() 46 | img2 = video2[clip_timestamp].cpu().numpy() 47 | 48 | # calculate psnr of a video 49 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 50 | 51 | psnr_results.append(psnr_results_of_a_video) 52 | 53 | psnr_results = np.array(psnr_results) 54 | 55 | psnr = {} 56 | psnr_std = {} 57 | 58 | for clip_timestamp in range(len(video1)): 59 | psnr[f'avg[{clip_timestamp}]'] = np.mean(psnr_results[:,clip_timestamp]) 60 | psnr_std[f'std[{clip_timestamp}]'] = np.std(psnr_results[:,clip_timestamp]) 61 | 62 | result = { 63 | "psnr": psnr, 64 | "psnr_std": psnr_std, 65 | "psnr_video_setting": video1.shape, 66 | "psnr_video_setting_name": "time, channel, heigth, width", 67 | } 68 | 69 | return result 70 | 71 | def calculate_psnr1(videos1, videos2): 72 | assert videos1.shape == videos2.shape 73 | videos1 = trans(videos1) 74 | videos2 = trans(videos2) 75 | psnr_results = [] 76 | for video_num in range(videos1.shape[0]): 77 | # for video_num in tqdm(range(videos1.shape[0])): 78 | video1 = videos1[video_num] 79 | video2 = videos2[video_num] 80 | psnr_results_of_a_video = [] 81 | for clip_timestamp in range(len(video1)): 82 | img1 = video1[clip_timestamp].cpu().numpy() 83 | img2 = video2[clip_timestamp].cpu().numpy() 84 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 85 | psnr_results.append(psnr_results_of_a_video) 86 | psnr_results = np.array(psnr_results) 87 | return np.mean(psnr_results), np.std(psnr_results) 88 | 89 | def calculate_psnr2(videos1, videos2): 90 | assert videos1.shape == videos2.shape 91 | videos1 = trans(videos1) 92 | videos2 = trans(videos2) 93 | psnr_results = [] 94 | for video_num in range(videos1.shape[0]): 95 | # for video_num in tqdm(range(videos1.shape[0])): 96 | video1 = videos1[video_num] 97 | video2 = videos2[video_num] 98 | psnr_results_of_a_video = [] 99 | for clip_timestamp in range(len(video1)): 100 | img1 = video1[clip_timestamp].cpu().numpy() 101 | img2 = video2[clip_timestamp].cpu().numpy() 102 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 103 | psnr_results.append(psnr_results_of_a_video) 104 | psnr_results = np.array(psnr_results) 105 | # print(np.mean(psnr_results,axis=-1)) 106 | return np.max(np.mean(psnr_results,axis=-1)) 107 | 108 | def calculate_psnr3(videos1, videos2): 109 | assert videos1.shape == videos2.shape 110 | videos1 = trans(videos1) 111 | videos2 = trans(videos2) 112 | psnr_results = [] 113 | for video_num in range(videos1.shape[0]): 114 | # for video_num in tqdm(range(videos1.shape[0])): 115 | video1 = videos1[video_num] 116 | video2 = videos2[video_num] 117 | psnr_results_of_a_video = [] 118 | for clip_timestamp in range(len(video1)): 119 | img1 = video1[clip_timestamp].cpu().numpy() 120 | img2 = video2[clip_timestamp].cpu().numpy() 121 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 122 | psnr_results.append(psnr_results_of_a_video) 123 | psnr_results = np.array(psnr_results) 124 | # print(np.mean(psnr_results,axis=-1)) 125 | return psnr_results 126 | # test code / using example 127 | 128 | def main(): 129 | NUMBER_OF_VIDEOS = 8 130 | VIDEO_LENGTH = 50 131 | CHANNEL = 3 132 | SIZE = 64 133 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 134 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 135 | device = torch.device("cuda") 136 | 137 | import json 138 | result = calculate_psnr2(videos1, videos2) 139 | # print(json.dumps(result, indent=4)) 140 | print(result) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() -------------------------------------------------------------------------------- /metrics/calculate_ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import cv2 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01 ** 2 8 | C2 = 0.03 ** 2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1 ** 2 16 | mu2_sq = mu2 ** 2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 22 | (sigma1_sq + sigma2_sq + C2)) 23 | return ssim_map.mean() 24 | 25 | 26 | def calculate_ssim_function(img1, img2): 27 | # [0,1] 28 | # ssim is the only metric extremely sensitive to gray being compared to b/w 29 | if not img1.shape == img2.shape: 30 | raise ValueError('Input images must have the same dimensions.') 31 | if img1.ndim == 2: 32 | return ssim(img1, img2) 33 | elif img1.ndim == 3: 34 | if img1.shape[0] == 3: 35 | ssims = [] 36 | for i in range(3): 37 | ssims.append(ssim(img1[i], img2[i])) 38 | return np.array(ssims).mean() 39 | elif img1.shape[0] == 1: 40 | return ssim(np.squeeze(img1), np.squeeze(img2)) 41 | else: 42 | raise ValueError('Wrong input image dimensions.') 43 | 44 | def trans(x): 45 | return x 46 | 47 | def calculate_ssim(videos1, videos2): 48 | # print("calculate_ssim...") 49 | 50 | # videos [batch_size, timestamps, channel, h, w] 51 | 52 | assert videos1.shape == videos2.shape 53 | 54 | videos1 = trans(videos1) 55 | videos2 = trans(videos2) 56 | 57 | ssim_results = [] 58 | 59 | for video_num in range(videos1.shape[0]): 60 | # for video_num in tqdm(range(videos1.shape[0])): 61 | # get a video 62 | # video [timestamps, channel, h, w] 63 | video1 = videos1[video_num] 64 | video2 = videos2[video_num] 65 | 66 | ssim_results_of_a_video = [] 67 | for clip_timestamp in range(len(video1)): 68 | # get a img 69 | # img [timestamps[x], channel, h, w] 70 | # img [channel, h, w] numpy 71 | 72 | img1 = video1[clip_timestamp].cpu().numpy() 73 | img2 = video2[clip_timestamp].cpu().numpy() 74 | 75 | # calculate ssim of a video 76 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 77 | 78 | ssim_results.append(ssim_results_of_a_video) 79 | 80 | ssim_results = np.array(ssim_results) 81 | 82 | ssim = {} 83 | ssim_std = {} 84 | 85 | for clip_timestamp in range(len(video1)): 86 | ssim[f'avg[{clip_timestamp}]'] = np.mean(ssim_results[:,clip_timestamp]) 87 | ssim_std[f'std[{clip_timestamp}]'] = np.std(ssim_results[:,clip_timestamp]) 88 | 89 | result = { 90 | "ssim": ssim, 91 | "ssim_std": ssim_std, 92 | "ssim_video_setting": video1.shape, 93 | "ssim_video_setting_name": "time, channel, heigth, width", 94 | } 95 | 96 | return result 97 | 98 | def calculate_ssim1(videos1, videos2): 99 | assert videos1.shape == videos2.shape 100 | videos1 = trans(videos1) 101 | videos2 = trans(videos2) 102 | ssim_results = [] 103 | for video_num in range(videos1.shape[0]): 104 | # for video_num in tqdm(range(videos1.shape[0])): 105 | video1 = videos1[video_num] 106 | video2 = videos2[video_num] 107 | ssim_results_of_a_video = [] 108 | for clip_timestamp in range(len(video1)): 109 | img1 = video1[clip_timestamp].cpu().numpy() 110 | img2 = video2[clip_timestamp].cpu().numpy() 111 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 112 | ssim_results.append(ssim_results_of_a_video) 113 | ssim_results = np.array(ssim_results) 114 | return np.mean(ssim_results), np.std(ssim_results) 115 | 116 | def calculate_ssim2(videos1, videos2): 117 | assert videos1.shape == videos2.shape 118 | videos1 = trans(videos1) 119 | videos2 = trans(videos2) 120 | ssim_results = [] 121 | for video_num in range(videos1.shape[0]): 122 | # for video_num in tqdm(range(videos1.shape[0])): 123 | video1 = videos1[video_num] 124 | video2 = videos2[video_num] 125 | ssim_results_of_a_video = [] 126 | for clip_timestamp in range(len(video1)): 127 | img1 = video1[clip_timestamp].cpu().numpy() 128 | img2 = video2[clip_timestamp].cpu().numpy() 129 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 130 | ssim_results.append(ssim_results_of_a_video) 131 | ssim_results = np.array(ssim_results) 132 | # print(np.mean(ssim_results,axis=-1)) 133 | return np.max(np.mean(ssim_results,axis=-1)) 134 | 135 | # test code / using example 136 | 137 | def main(): 138 | NUMBER_OF_VIDEOS = 8 139 | VIDEO_LENGTH = 20 140 | CHANNEL = 3 141 | SIZE = 64 142 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 143 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 144 | device = torch.device("cuda") 145 | 146 | import json 147 | result = calculate_ssim2(videos1, videos2) 148 | # print(json.dumps(result, indent=4)) 149 | print(result) 150 | 151 | if __name__ == "__main__": 152 | main() -------------------------------------------------------------------------------- /metrics/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics.calculate_fvd import calculate_fvd 3 | from metrics.calculate_psnr import calculate_psnr 4 | from metrics.calculate_ssim import calculate_ssim 5 | from metrics.calculate_lpips import calculate_lpips 6 | 7 | # ps: pixel value should be in [0, 1]! 8 | 9 | NUMBER_OF_VIDEOS = 8 10 | VIDEO_LENGTH = 30 11 | CHANNEL = 3 12 | SIZE = 64 13 | CALCULATE_PER_FRAME = 8 14 | CALCULATE_FINAL = True 15 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 16 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 17 | device = torch.device("cuda") 18 | # device = torch.device("cpu") 19 | 20 | import json 21 | result = {} 22 | result['fvd'] = calculate_fvd(videos1, videos2, CALCULATE_PER_FRAME, CALCULATE_FINAL, device) 23 | result['ssim'] = calculate_ssim(videos1, videos2, CALCULATE_PER_FRAME, CALCULATE_FINAL) 24 | result['psnr'] = calculate_psnr(videos1, videos2, CALCULATE_PER_FRAME, CALCULATE_FINAL) 25 | result['lpips'] = calculate_lpips(videos1, videos2, CALCULATE_PER_FRAME, CALCULATE_FINAL, device) 26 | print(json.dumps(result, indent=4)) 27 | -------------------------------------------------------------------------------- /metrics/i3d_torchscript.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nku-zhichengzhang/ExtDM/daaae01e926b9b021c81676cbd4a17555b722aa9/metrics/i3d_torchscript.pt -------------------------------------------------------------------------------- /model/BaseDM_adaptor/text.py: -------------------------------------------------------------------------------- 1 | # the code from https://github.com/lucidrains/video-diffusion-pytorch 2 | import torch 3 | from einops import rearrange 4 | 5 | 6 | def exists(val): 7 | return val is not None 8 | 9 | 10 | # singleton globals 11 | MODEL = None 12 | TOKENIZER = None 13 | BERT_MODEL_DIM = 768 14 | 15 | 16 | def get_tokenizer(): 17 | global TOKENIZER 18 | if not exists(TOKENIZER): 19 | TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') 20 | return TOKENIZER 21 | 22 | 23 | def get_bert(): 24 | global MODEL 25 | if not exists(MODEL): 26 | MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased') 27 | if torch.cuda.is_available(): 28 | MODEL = MODEL.cuda() 29 | 30 | return MODEL 31 | 32 | 33 | # tokenize 34 | 35 | def tokenize(texts, add_special_tokens=True): 36 | if not isinstance(texts, (list, tuple)): 37 | texts = [texts] 38 | 39 | tokenizer = get_tokenizer() 40 | 41 | encoding = tokenizer.batch_encode_plus( 42 | texts, 43 | add_special_tokens=add_special_tokens, 44 | padding=True, 45 | return_tensors='pt' 46 | ) 47 | 48 | token_ids = encoding.input_ids 49 | return token_ids 50 | 51 | 52 | # embedding function 53 | 54 | @torch.no_grad() 55 | def bert_embed( 56 | token_ids, 57 | return_cls_repr=False, 58 | eps=1e-8, 59 | pad_id=0. 60 | ): 61 | model = get_bert() 62 | mask = token_ids != pad_id 63 | 64 | if torch.cuda.is_available(): 65 | token_ids = token_ids.cuda() 66 | mask = mask.cuda() 67 | 68 | outputs = model( 69 | input_ids=token_ids, 70 | attention_mask=mask, 71 | output_hidden_states=True 72 | ) 73 | 74 | hidden_state = outputs.hidden_states[-1] 75 | 76 | if return_cls_repr: 77 | return hidden_state[:, 0] # return [cls] as representation 78 | 79 | if not exists(mask): 80 | return hidden_state.mean(dim=1) 81 | 82 | mask = mask[:, 1:] # mean all tokens excluding [cls], accounting for length 83 | mask = rearrange(mask, 'b n -> b n 1') 84 | 85 | numer = (hidden_state[:, 1:] * mask).sum(dim=1) 86 | denom = mask.sum(dim=1) 87 | masked_mean = numer / (denom + eps) 88 | return masked_mean 89 | -------------------------------------------------------------------------------- /model/LFAE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nku-zhichengzhang/ExtDM/daaae01e926b9b021c81676cbd4a17555b722aa9/model/LFAE/__init__.py -------------------------------------------------------------------------------- /model/LFAE/bg_motion_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | 10 | from torch import nn 11 | import torch 12 | from model.LFAE.util import Encoder 13 | 14 | 15 | class BGMotionPredictor(nn.Module): 16 | """ 17 | Module for background estimation, return single transformation, parametrized as 3x3 matrix. 18 | """ 19 | 20 | def __init__(self, block_expansion, num_channels, max_features, num_blocks, bg_type='zero'): 21 | super(BGMotionPredictor, self).__init__() 22 | assert bg_type in ['zero', 'shift', 'affine', 'perspective'] 23 | 24 | self.bg_type = bg_type 25 | if self.bg_type != 'zero': 26 | self.encoder = Encoder(block_expansion, in_features=num_channels * 2, max_features=max_features, 27 | num_blocks=num_blocks) 28 | in_features = min(max_features, block_expansion * (2 ** num_blocks)) 29 | if self.bg_type == 'perspective': 30 | self.fc = nn.Linear(in_features, 8) 31 | self.fc.weight.data.zero_() 32 | self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0], dtype=torch.float)) 33 | elif self.bg_type == 'affine': 34 | self.fc = nn.Linear(in_features, 6) 35 | self.fc.weight.data.zero_() 36 | self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) 37 | elif self.bg_type == 'shift': 38 | self.fc = nn.Linear(in_features, 2) 39 | self.fc.weight.data.zero_() 40 | self.fc.bias.data.copy_(torch.tensor([0, 0], dtype=torch.float)) 41 | 42 | def forward(self, source_image, driving_image): 43 | bs = source_image.shape[0] 44 | out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) 45 | if self.bg_type != 'zero': 46 | prediction = self.encoder(torch.cat([source_image, driving_image], dim=1)) 47 | prediction = prediction[-1].mean(dim=(2, 3)) 48 | prediction = self.fc(prediction) 49 | if self.bg_type == 'shift': 50 | out[:, :2, 2] = prediction 51 | elif self.bg_type == 'affine': 52 | out[:, :2, :] = prediction.view(bs, 2, 3) 53 | elif self.bg_type == 'perspective': 54 | out[:, :2, :] = prediction[:, :6].view(bs, 2, 3) 55 | out[:, 2, :2] = prediction[:, 6:].view(bs, 2) 56 | 57 | return out 58 | -------------------------------------------------------------------------------- /model/LFAE/flow_autoenc.py: -------------------------------------------------------------------------------- 1 | # utilize RegionMM to design a flow auto-encoder 2 | 3 | import torch 4 | import torch.nn as nn 5 | import yaml 6 | 7 | from model.LFAE.generator import Generator 8 | from model.LFAE.bg_motion_predictor import BGMotionPredictor 9 | from model.LFAE.region_predictor import RegionPredictor 10 | 11 | 12 | # based on RegionMM 13 | class FlowAE(nn.Module): 14 | def __init__(self, is_train=False, 15 | config=None): 16 | super(FlowAE, self).__init__() 17 | 18 | if '.yaml' in config: 19 | with open(config) as f: 20 | config = yaml.safe_load(f) 21 | 22 | model_params = config['flow_params']['model_params'] 23 | 24 | self.generator = Generator(num_regions=model_params['num_regions'], 25 | num_channels=model_params['num_channels'], 26 | revert_axis_swap=model_params['revert_axis_swap'], 27 | **model_params['generator_params']).cuda() 28 | self.region_predictor = RegionPredictor(num_regions=model_params['num_regions'], 29 | num_channels=model_params['num_channels'], 30 | estimate_affine=model_params['estimate_affine'], 31 | **model_params['region_predictor_params']).cuda() 32 | self.bg_predictor = BGMotionPredictor(num_channels=model_params['num_channels'], 33 | **model_params['bg_predictor_params']) 34 | 35 | self.is_train = is_train 36 | 37 | self.ref_img = None 38 | self.dri_img = None 39 | self.generated = None 40 | 41 | def forward(self): 42 | source_region_params = self.region_predictor(self.ref_img) 43 | self.driving_region_params = self.region_predictor(self.dri_img) 44 | 45 | bg_params = self.bg_predictor(self.ref_img, self.dri_img) 46 | self.generated = self.generator(self.ref_img, source_region_params=source_region_params, 47 | driving_region_params=self.driving_region_params, bg_params=bg_params) 48 | self.generated.update({'source_region_params': source_region_params, 49 | 'driving_region_params': self.driving_region_params}) 50 | def set_train_input(self, ref_img, dri_img): 51 | self.ref_img = ref_img.cuda() 52 | self.dri_img = dri_img.cuda() 53 | 54 | 55 | if __name__ == "__main__": 56 | # default image size is 128 57 | import os 58 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 59 | ref_img = torch.rand((5, 3, 128, 128), dtype=torch.float32) 60 | dri_img = torch.rand((5, 3, 128, 128), dtype=torch.float32) 61 | model = FlowAE(is_train=True).cuda() 62 | model.train() 63 | model.set_train_input(ref_img=ref_img, dri_img=dri_img) 64 | model.forward() 65 | 66 | -------------------------------------------------------------------------------- /model/LFAE/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | from model.LFAE.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 14 | from model.LFAE.pixelwise_flow_predictor import PixelwiseFlowPredictor 15 | 16 | class Generator(nn.Module): 17 | """ 18 | Generator that given source image and region parameters try to transform image according to movement trajectories 19 | induced by region parameters. Generator follows Johnson architecture. 20 | """ 21 | 22 | def __init__(self, num_channels, num_regions, block_expansion, max_features, num_down_blocks, 23 | num_bottleneck_blocks, pixelwise_flow_predictor_params=None, skips=False, revert_axis_swap=True): 24 | super(Generator, self).__init__() 25 | 26 | if pixelwise_flow_predictor_params is not None: 27 | self.pixelwise_flow_predictor = PixelwiseFlowPredictor(num_regions=num_regions, num_channels=num_channels, 28 | revert_axis_swap=revert_axis_swap, 29 | **pixelwise_flow_predictor_params) 30 | else: 31 | self.pixelwise_flow_predictor = None 32 | 33 | self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 34 | 35 | down_blocks = [] 36 | for i in range(num_down_blocks): 37 | in_features = min(max_features, block_expansion * (2 ** i)) 38 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 39 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 40 | self.down_blocks = nn.ModuleList(down_blocks) 41 | 42 | up_blocks = [] 43 | for i in range(num_down_blocks): 44 | in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) 45 | out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) 46 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 47 | self.up_blocks = nn.ModuleList(up_blocks) 48 | 49 | self.bottleneck = torch.nn.Sequential() 50 | in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) 51 | for i in range(num_bottleneck_blocks): 52 | self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) 53 | 54 | self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) 55 | self.num_channels = num_channels 56 | self.skips = skips 57 | 58 | @staticmethod 59 | def deform_input(inp, optical_flow): 60 | _, h_old, w_old, _ = optical_flow.shape 61 | _, _, h, w = inp.shape 62 | if h_old != h or w_old != w: 63 | optical_flow = optical_flow.permute(0, 3, 1, 2) 64 | optical_flow = F.interpolate(optical_flow, size=(h, w), mode='bilinear') 65 | optical_flow = optical_flow.permute(0, 2, 3, 1) 66 | return F.grid_sample(inp, optical_flow) 67 | 68 | def apply_optical(self, input_previous=None, input_skip=None, motion_params=None): 69 | if motion_params is not None: 70 | if 'occlusion_map' in motion_params: 71 | occlusion_map = motion_params['occlusion_map'] 72 | else: 73 | occlusion_map = None 74 | deformation = motion_params['optical_flow'] 75 | input_skip = self.deform_input(input_skip, deformation) 76 | 77 | if occlusion_map is not None: 78 | if input_skip.shape[2] != occlusion_map.shape[2] or input_skip.shape[3] != occlusion_map.shape[3]: 79 | occlusion_map = F.interpolate(occlusion_map, size=input_skip.shape[2:], mode='bilinear') 80 | if input_previous is not None: 81 | input_skip = input_skip * occlusion_map + input_previous * (1 - occlusion_map) 82 | else: 83 | input_skip = input_skip * occlusion_map 84 | out = input_skip 85 | else: 86 | out = input_previous if input_previous is not None else input_skip 87 | return out 88 | 89 | def forward_bottle(self, source_image): 90 | out = self.first(source_image) 91 | skips = [out] 92 | for i in range(len(self.down_blocks)): 93 | out = self.down_blocks[i](out) 94 | skips.append(out) 95 | 96 | return out 97 | 98 | def forward(self, source_image, driving_region_params, source_region_params, bg_params=None): 99 | out = self.first(source_image) 100 | skips = [out] 101 | for i in range(len(self.down_blocks)): 102 | out = self.down_blocks[i](out) 103 | skips.append(out) 104 | 105 | output_dict = {} 106 | output_dict["bottle_neck_feat"] = out 107 | if self.pixelwise_flow_predictor is not None: 108 | motion_params = self.pixelwise_flow_predictor(source_image=source_image, 109 | driving_region_params=driving_region_params, 110 | source_region_params=source_region_params, 111 | bg_params=bg_params) 112 | output_dict["deformed"] = self.deform_input(source_image, motion_params['optical_flow']) 113 | output_dict["optical_flow"] = motion_params['optical_flow'] 114 | if 'occlusion_map' in motion_params: 115 | output_dict['occlusion_map'] = motion_params['occlusion_map'] 116 | else: 117 | motion_params = None 118 | 119 | out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params) 120 | 121 | out = self.bottleneck(out) 122 | for i in range(len(self.up_blocks)): 123 | if self.skips: 124 | out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params) 125 | out = self.up_blocks[i](out) 126 | if self.skips: 127 | out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params) 128 | out = self.final(out) 129 | out = torch.sigmoid(out) 130 | 131 | if self.skips: 132 | out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params) 133 | 134 | output_dict["prediction"] = out 135 | 136 | return output_dict 137 | 138 | def compute_fea(self, source_image): 139 | out = self.first(source_image) 140 | for i in range(len(self.down_blocks)): 141 | out = self.down_blocks[i](out) 142 | return out 143 | 144 | def forward_with_flow(self, source_image, optical_flow, occlusion_map): 145 | out = self.first(source_image) 146 | skips = [out] 147 | for i in range(len(self.down_blocks)): 148 | out = self.down_blocks[i](out) 149 | skips.append(out) 150 | 151 | output_dict = {} 152 | motion_params = {} 153 | motion_params["optical_flow"] = optical_flow 154 | motion_params["occlusion_map"] = occlusion_map 155 | output_dict["deformed"] = self.deform_input(source_image, motion_params['optical_flow']) 156 | 157 | out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params) 158 | 159 | out = self.bottleneck(out) 160 | for i in range(len(self.up_blocks)): 161 | if self.skips: 162 | out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params) 163 | out = self.up_blocks[i](out) 164 | if self.skips: 165 | out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params) 166 | 167 | 168 | ################################################################## 169 | # print("=====out=====", out.shape) 170 | 171 | # import os 172 | # folder_path = './output_tensor' # 请替换为实际的文件夹路径 173 | 174 | # # 获取文件夹中已有的文件索引列表 175 | # existing_indices = [int(file.split('_')[-1].split('.')[0]) for file in os.listdir(folder_path) if file.startswith('output_')] 176 | 177 | # # 找到最大的索引值,如果文件夹为空,则设置为0 178 | # max_index = max(existing_indices) if existing_indices else 0 179 | 180 | # # 构建新文件的索引和文件名 181 | # new_index = max_index + 1 182 | # output_filename = f'output_{new_index}.pt' 183 | # output_path = os.path.join(folder_path, output_filename) 184 | 185 | # # 保存Tensor文件 186 | # torch.save(out, output_path) 187 | 188 | # print(f"保存文件 {output_filename} 完成,新索引为 {new_index}") 189 | #################################################################### 190 | out = self.final(out) 191 | out = torch.sigmoid(out) 192 | 193 | if self.skips: 194 | out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params) 195 | 196 | output_dict["prediction"] = out 197 | 198 | return output_dict 199 | -------------------------------------------------------------------------------- /model/LFAE/pixelwise_flow_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import torch 13 | from model.LFAE.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, region2gaussian 14 | from model.LFAE.util import to_homogeneous, from_homogeneous 15 | 16 | 17 | class PixelwiseFlowPredictor(nn.Module): 18 | """ 19 | Module that predicts a pixelwise flow from sparse motion representation given by 20 | source_region_params and driving_region_params 21 | """ 22 | 23 | def __init__(self, block_expansion, num_blocks, max_features, num_regions, num_channels, 24 | estimate_occlusion_map=False, scale_factor=1, region_var=0.01, 25 | use_covar_heatmap=False, use_deformed_source=True, revert_axis_swap=False): 26 | super(PixelwiseFlowPredictor, self).__init__() 27 | self.hourglass = Hourglass(block_expansion=block_expansion, 28 | in_features=(num_regions + 1) * (num_channels * use_deformed_source + 1), 29 | max_features=max_features, num_blocks=num_blocks) 30 | 31 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_regions + 1, kernel_size=(7, 7), padding=(3, 3)) 32 | 33 | if estimate_occlusion_map: 34 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) 35 | else: 36 | self.occlusion = None 37 | 38 | self.num_regions = num_regions 39 | self.scale_factor = scale_factor 40 | self.region_var = region_var 41 | self.use_covar_heatmap = use_covar_heatmap 42 | self.use_deformed_source = use_deformed_source 43 | self.revert_axis_swap = revert_axis_swap 44 | 45 | if self.scale_factor != 1: 46 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 47 | 48 | def create_heatmap_representations(self, source_image, driving_region_params, source_region_params): 49 | """ 50 | Eq 6. in the paper H_k(z) 51 | """ 52 | spatial_size = source_image.shape[2:] 53 | covar = self.region_var if not self.use_covar_heatmap else driving_region_params['covar'] 54 | gaussian_driving = region2gaussian(driving_region_params['shift'], covar=covar, spatial_size=spatial_size) 55 | covar = self.region_var if not self.use_covar_heatmap else source_region_params['covar'] 56 | gaussian_source = region2gaussian(source_region_params['shift'], covar=covar, spatial_size=spatial_size) 57 | 58 | heatmap = gaussian_driving - gaussian_source 59 | 60 | # adding background feature 61 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]) 62 | heatmap = torch.cat([zeros.type(heatmap.type()), heatmap], dim=1) 63 | heatmap = heatmap.unsqueeze(2) 64 | return heatmap 65 | 66 | def create_sparse_motions(self, source_image, driving_region_params, source_region_params, bg_params=None): 67 | bs, _, h, w = source_image.shape 68 | identity_grid = make_coordinate_grid((h, w), type=source_region_params['shift'].type()) 69 | identity_grid = identity_grid.view(1, 1, h, w, 2) 70 | coordinate_grid = identity_grid - driving_region_params['shift'].view(bs, self.num_regions, 1, 1, 2) 71 | if 'affine' in driving_region_params: 72 | affine = torch.matmul(source_region_params['affine'], torch.inverse(driving_region_params['affine'])) 73 | if self.revert_axis_swap: 74 | affine = affine * torch.sign(affine[:, :, 0:1, 0:1]) 75 | affine = affine.unsqueeze(-3).unsqueeze(-3) 76 | affine = affine.repeat(1, 1, h, w, 1, 1) 77 | coordinate_grid = torch.matmul(affine, coordinate_grid.unsqueeze(-1)) 78 | coordinate_grid = coordinate_grid.squeeze(-1) 79 | 80 | driving_to_source = coordinate_grid + source_region_params['shift'].view(bs, self.num_regions, 1, 1, 2) 81 | 82 | # adding background feature 83 | if bg_params is None: 84 | bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 85 | else: 86 | bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 87 | bg_grid = to_homogeneous(bg_grid) 88 | bg_grid = torch.matmul(bg_params.view(bs, 1, 1, 1, 3, 3), bg_grid.unsqueeze(-1)).squeeze(-1) 89 | bg_grid = from_homogeneous(bg_grid) 90 | 91 | sparse_motions = torch.cat([bg_grid, driving_to_source], dim=1) 92 | 93 | return sparse_motions 94 | 95 | def create_deformed_source_image(self, source_image, sparse_motions): 96 | bs, _, h, w = source_image.shape 97 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_regions + 1, 1, 1, 1, 1) 98 | source_repeat = source_repeat.view(bs * (self.num_regions + 1), -1, h, w) 99 | sparse_motions = sparse_motions.view((bs * (self.num_regions + 1), h, w, -1)) 100 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions) 101 | sparse_deformed = sparse_deformed.view((bs, self.num_regions + 1, -1, h, w)) 102 | return sparse_deformed 103 | 104 | def forward(self, source_image, driving_region_params, source_region_params, bg_params=None): 105 | if self.scale_factor != 1: 106 | source_image = self.down(source_image) 107 | 108 | bs, _, h, w = source_image.shape 109 | 110 | out_dict = dict() 111 | heatmap_representation = self.create_heatmap_representations(source_image, driving_region_params, 112 | source_region_params) 113 | sparse_motion = self.create_sparse_motions(source_image, driving_region_params, 114 | source_region_params, bg_params=bg_params) 115 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion) 116 | if self.use_deformed_source: 117 | predictor_input = torch.cat([heatmap_representation, deformed_source], dim=2) 118 | else: 119 | predictor_input = heatmap_representation 120 | predictor_input = predictor_input.view(bs, -1, h, w) 121 | 122 | prediction = self.hourglass(predictor_input) 123 | 124 | mask = self.mask(prediction) 125 | mask = F.softmax(mask, dim=1) 126 | mask = mask.unsqueeze(2) 127 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) 128 | deformation = (sparse_motion * mask).sum(dim=1) 129 | deformation = deformation.permute(0, 2, 3, 1) 130 | 131 | out_dict['optical_flow'] = deformation 132 | 133 | if self.occlusion: 134 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 135 | out_dict['occlusion_map'] = occlusion_map 136 | 137 | return out_dict 138 | -------------------------------------------------------------------------------- /model/LFAE/region_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | 10 | from torch import nn 11 | import torch 12 | import torch.nn.functional as F 13 | from model.LFAE.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Encoder 14 | 15 | 16 | def svd(covar, fast=False): 17 | if fast: 18 | from torch_batch_svd import svd as fast_svd 19 | return fast_svd(covar) 20 | else: 21 | u, s, v = torch.svd(covar.cpu()) 22 | s = s.to(covar.device) 23 | u = u.to(covar.device) 24 | v = v.to(covar.device) 25 | return u, s, v 26 | 27 | 28 | class RegionPredictor(nn.Module): 29 | """ 30 | Region estimating. Estimate affine parameters of the region. 31 | """ 32 | 33 | def __init__(self, block_expansion, num_regions, num_channels, max_features, 34 | num_blocks, temperature, estimate_affine=False, scale_factor=1, 35 | pca_based=False, fast_svd=False, pad=3): 36 | super(RegionPredictor, self).__init__() 37 | self.predictor = Hourglass(block_expansion, in_features=num_channels, 38 | max_features=max_features, num_blocks=num_blocks) 39 | 40 | self.regions = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_regions, kernel_size=(7, 7), 41 | padding=pad) 42 | 43 | # FOMM-like regression based representation 44 | if estimate_affine and not pca_based: 45 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 46 | out_channels=4, kernel_size=(7, 7), padding=pad) 47 | self.jacobian.weight.data.zero_() 48 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1], dtype=torch.float)) 49 | else: 50 | self.jacobian = None 51 | 52 | self.temperature = temperature 53 | self.scale_factor = scale_factor 54 | self.pca_based = pca_based 55 | self.fast_svd = fast_svd 56 | 57 | if self.scale_factor != 1: 58 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 59 | 60 | def region2affine(self, region): 61 | shape = region.shape 62 | region = region.unsqueeze(-1) 63 | grid = make_coordinate_grid(shape[2:], region.type()).unsqueeze_(0).unsqueeze_(0) 64 | mean = (region * grid).sum(dim=(2, 3)) 65 | region_params = {'shift': mean} 66 | 67 | if self.pca_based: 68 | mean_sub = grid - mean.unsqueeze(-2).unsqueeze(-2) 69 | covar = torch.matmul(mean_sub.unsqueeze(-1), mean_sub.unsqueeze(-2)) 70 | covar = covar * region.unsqueeze(-1) 71 | covar = covar.sum(dim=(2, 3)) 72 | region_params['covar'] = covar 73 | 74 | return region_params 75 | 76 | def forward(self, x): 77 | if self.scale_factor != 1: 78 | x = self.down(x) 79 | 80 | feature_map = self.predictor(x) 81 | # 生成特征图 82 | prediction = self.regions(feature_map) 83 | # 由特征图生生成 num_regions 个通道的特征图 84 | 85 | final_shape = prediction.shape 86 | region = prediction.view(final_shape[0], final_shape[1], -1) 87 | region = F.softmax(region / self.temperature, dim=2) 88 | region = region.view(*final_shape) 89 | # 生成 num_regions 个热图 90 | 91 | # print("生成 num_regions 个热图") 92 | # print(region.shape, torch.min(region), torch.max(region)) 93 | 94 | region_params = self.region2affine(region) 95 | region_params['heatmap'] = region 96 | 97 | # Regression-based estimation 98 | if self.jacobian is not None: 99 | jacobian_map = self.jacobian(feature_map) 100 | jacobian_map = jacobian_map.reshape(final_shape[0], 1, 4, final_shape[2], 101 | final_shape[3]) 102 | region = region.unsqueeze(2) 103 | 104 | jacobian = region * jacobian_map 105 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 106 | jacobian = jacobian.sum(dim=-1) 107 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 108 | region_params['affine'] = jacobian 109 | region_params['covar'] = torch.matmul(jacobian, jacobian.permute(0, 1, 3, 2)) 110 | elif self.pca_based: 111 | covar = region_params['covar'] 112 | shape = covar.shape 113 | covar = covar.view(-1, 2, 2) 114 | u, s, v = svd(covar, self.fast_svd) 115 | d = torch.diag_embed(s ** 0.5) 116 | sqrt = torch.matmul(u, d) 117 | sqrt = sqrt.view(*shape) 118 | region_params['affine'] = sqrt 119 | region_params['u'] = u 120 | region_params['d'] = d 121 | 122 | return region_params 123 | -------------------------------------------------------------------------------- /model/LFAE/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /model/LFAE/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /model/LFAE/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /model/LFAE/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /scripts/AE/run.py: -------------------------------------------------------------------------------- 1 | # Estimate flow and occlusion mask via RegionMM for MHAD dataset 2 | # this code is based on RegionMM from Snap Inc. 3 | # https://github.com/snap-research/articulated-animation 4 | 5 | import os 6 | import sys 7 | import math 8 | import yaml 9 | from argparse import ArgumentParser 10 | from shutil import copy 11 | 12 | import wandb 13 | import datetime 14 | 15 | from model.LFAE.generator import Generator 16 | from model.LFAE.bg_motion_predictor import BGMotionPredictor 17 | from model.LFAE.region_predictor import RegionPredictor 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | import numpy as np 22 | import random 23 | 24 | from train import train 25 | 26 | from utils.logger import Logger 27 | from utils.seed import setup_seed 28 | 29 | 30 | if __name__ == "__main__": 31 | cudnn.enabled = True 32 | cudnn.benchmark = True 33 | 34 | if sys.version_info[0] < 3: 35 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 36 | 37 | parser = ArgumentParser() 38 | parser.add_argument("--postfix", default="") 39 | parser.add_argument("--config",default="./config/smmnist64.yaml",help="path to config") 40 | parser.add_argument("--log_dir",default='./logs_training/diffusion', help="path to log into") 41 | parser.add_argument("--checkpoint", # use the pretrained Taichi model provided by Snap 42 | default="", 43 | help="path to checkpoint to restore") 44 | parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))), 45 | help="Names of the devices comma separated.") 46 | 47 | parser.add_argument("--random-seed", default=1234) 48 | parser.add_argument("--set-start", default=False) 49 | parser.add_argument("--mode", default="train", choices=["train"]) 50 | parser.add_argument("--verbose", default=False, help="Print model architecture") 51 | 52 | args = parser.parse_args() 53 | 54 | setup_seed(int(args.random_seed)) 55 | 56 | with open(args.config) as f: 57 | config = yaml.safe_load(f) 58 | 59 | if args.postfix == '': 60 | postfix = '' 61 | else: 62 | postfix = '_' + args.postfix 63 | 64 | log_dir = os.path.join(args.log_dir, os.path.basename(args.config).split('.')[0]+postfix) 65 | if not os.path.exists(log_dir): 66 | os.makedirs(log_dir) 67 | if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))): 68 | copy(args.config, log_dir) 69 | 70 | # the directory to save checkpoints 71 | config["snapshots"] = os.path.join(log_dir, 'snapshots') 72 | os.makedirs(config["snapshots"], exist_ok=True) 73 | # the directory to save images of training results 74 | config["imgshots"] = os.path.join(log_dir, 'imgshots') 75 | os.makedirs(config["imgshots"], exist_ok=True) 76 | 77 | config["set_start"] = args.set_start 78 | 79 | train_params = config['flow_params']['train_params'] 80 | model_params = config['flow_params']['model_params'] 81 | dataset_params = config['dataset_params'] 82 | 83 | log_txt = os.path.join(log_dir, 84 | "B"+format(train_params['batch_size'], "04d")+ 85 | "E"+format(train_params['max_epochs'], "04d")+".log") 86 | sys.stdout = Logger(log_txt, sys.stdout) 87 | 88 | wandb.login() 89 | wandb.init( 90 | entity="nku428", 91 | project="EDM_v1", 92 | config=config, 93 | name=f"{config['experiment_name']}{postfix}", 94 | dir=log_dir, 95 | tags=["flow"] 96 | ) 97 | 98 | print("postfix:", postfix) 99 | print("checkpoint:", args.checkpoint) 100 | print("batch size:", train_params['batch_size']) 101 | 102 | generator = Generator(num_regions=model_params['num_regions'], 103 | num_channels=model_params['num_channels'], 104 | revert_axis_swap=model_params['revert_axis_swap'], 105 | **model_params['generator_params']) 106 | 107 | if torch.cuda.is_available(): 108 | generator.to(args.device_ids[0]) 109 | if args.verbose: 110 | print(generator) 111 | 112 | region_predictor = RegionPredictor(num_regions=model_params['num_regions'], 113 | num_channels=model_params['num_channels'], 114 | estimate_affine=model_params['estimate_affine'], 115 | **model_params['region_predictor_params']) 116 | 117 | if torch.cuda.is_available(): 118 | region_predictor.to(args.device_ids[0]) 119 | 120 | if args.verbose: 121 | print(region_predictor) 122 | 123 | bg_predictor = BGMotionPredictor(num_channels=model_params['num_channels'], 124 | **model_params['bg_predictor_params']) 125 | if torch.cuda.is_available(): 126 | bg_predictor.to(args.device_ids[0]) 127 | if args.verbose: 128 | print(bg_predictor) 129 | 130 | print("Training...") 131 | train( 132 | config, 133 | dataset_params, 134 | train_params, 135 | generator, 136 | region_predictor, 137 | bg_predictor, 138 | log_dir, 139 | args.checkpoint, 140 | args.device_ids 141 | ) -------------------------------------------------------------------------------- /scripts/AE/train_AE_bair.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/train_AE_bair.sh 2 | 3 | # Training from scratch 4 | python ./scripts/AE/run.py \ 5 | --config ./config/AE/bair.yaml \ 6 | --log_dir ./logs_training/AE/BAIR \ 7 | --device_ids 0,1 \ 8 | --postfix test 9 | 10 | # Resuming training from checkpoint 11 | # --checkpoint ./logs_training/AE//snapshots/RegionMM.pth \ 12 | # --set-start True -------------------------------------------------------------------------------- /scripts/AE/train_AE_cityscapes.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/train_AE_cityscapes.sh 2 | 3 | # Training from scratch 4 | python ./scripts/AE/run.py \ 5 | --config ./config/AE/cityscapes.yaml \ 6 | --log_dir ./logs_training/AE/cityscapes \ 7 | --device_ids 0,1 \ 8 | --postfix test 9 | 10 | # Resuming training from checkpoint 11 | # --checkpoint ./logs_training/AE//snapshots/RegionMM.pth \ 12 | # --set-start True -------------------------------------------------------------------------------- /scripts/AE/train_AE_kth.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/train_AE_kth.sh 2 | 3 | # Training from scratch 4 | python ./scripts/AE/run.py \ 5 | --config ./config/AE/kth.yaml \ 6 | --log_dir ./logs_training/AE/KTH \ 7 | --device_ids 0,1 \ 8 | --postfix test 9 | 10 | # Resuming training from checkpoint 11 | # --checkpoint ./logs_training/AE//snapshots/RegionMM.pth \ 12 | # --set-start True -------------------------------------------------------------------------------- /scripts/AE/train_AE_smmnist.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/train_AE_smmnist.sh 2 | 3 | # Training from scratch 4 | python ./scripts/AE/run.py \ 5 | --config ./config/AE/smmnist.yaml \ 6 | --log_dir ./logs_training/AE/SMMNIST \ 7 | --device_ids 0,1 \ 8 | --postfix test 9 | 10 | # Resuming training from checkpoint 11 | # --checkpoint ./logs_training/AE//snapshots/RegionMM.pth \ 12 | # --set-start True -------------------------------------------------------------------------------- /scripts/AE/train_AE_ucf.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/train_AE_ucf.sh 2 | 3 | # Training from scratch 4 | python ./scripts/AE/run.py \ 5 | --config ./config/AE/ucf.yaml \ 6 | --log_dir ./logs_training/AE/UCF \ 7 | --device_ids 0,1 \ 8 | --postfix test 9 | 10 | # Resuming training from checkpoint 11 | # --checkpoint ./logs_training/AE//snapshots/RegionMM.pth \ 12 | # --set-start True -------------------------------------------------------------------------------- /scripts/AE/valid_AE_bair.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/valid_AE_bair.sh 2 | 3 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 4 | pretrained_path=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 5 | 6 | python ./scripts/AE/valid.py \ 7 | --cond_frames 2 \ 8 | --pred_frames 28 \ 9 | --num_videos 256 \ 10 | --batch_size 256 \ 11 | --input_size 64 \ 12 | --log_dir "./logs_validation/AE/BAIR/BAIR_test" \ 13 | --data_dir $data_path/bair_h5 \ 14 | --config_path "$pretrained_path/BAIR/bair64_scale0.50/bair64.yaml" \ 15 | --restore_from $pretrained_path/BAIR/bair64_scale0.50/snapshots/RegionMM.pth \ 16 | --data_type "test" \ 17 | --save-video True \ 18 | --random-seed 1000 \ 19 | --gpu "0" -------------------------------------------------------------------------------- /scripts/AE/valid_AE_cityscapes.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/valid_AE_cityscapes.sh 2 | 3 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 4 | pretrained_path=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 5 | 6 | python ./scripts/AE/valid.py \ 7 | --cond_frames 2 \ 8 | --pred_frames 28 \ 9 | --num_videos 256 \ 10 | --batch_size 256 \ 11 | --input_size 128 \ 12 | --log_dir "./logs_validation/AE/Cityscapes/cityscapes_test" \ 13 | --data_dir $data_path/cityscapes_h5 \ 14 | --config_path "$pretrained_path/Cityscapes/cityscapes128_perspective/cityscapes128.yaml" \ 15 | --restore_from $pretrained_path/Cityscapes/cityscapes128_perspective/snapshots/RegionMM.pth \ 16 | --data_type "val" \ 17 | --save-video True \ 18 | --random-seed 1000 \ 19 | --gpu "0" -------------------------------------------------------------------------------- /scripts/AE/valid_AE_kth.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/valid_AE_kth.sh 2 | 3 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 4 | pretrained_path=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 5 | 6 | python ./scripts/AE/valid.py \ 7 | --cond_frames 1 \ 8 | --pred_frames 10 \ 9 | --num_videos 256 \ 10 | --batch_size 2 \ 11 | --input_size 64 \ 12 | --log_dir "./logs_validation/AE/KTH/KTH_test" \ 13 | --data_dir $data_path/kth_h5 \ 14 | --config_path "$pretrained_path/KTH/kth64_region10_res0.5/kth64_origin.yaml" \ 15 | --restore_from $pretrained_path/KTH/kth64_region10_res0.5/snapshots/RegionMM.pth \ 16 | --data_type "valid" \ 17 | --save-video True \ 18 | --random-seed 1000 \ 19 | --gpu "0" -------------------------------------------------------------------------------- /scripts/AE/valid_AE_smmnist.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/valid_AE_smmnist.sh 2 | 3 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 4 | pretrained_path=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 5 | 6 | python ./scripts/AE/valid.py \ 7 | --cond_frames 1 \ 8 | --pred_frames 10 \ 9 | --num_videos 256 \ 10 | --batch_size 2 \ 11 | --input_size 64 \ 12 | --log_dir "./logs_validation/AE/SMMNIST/SMMNIST_test" \ 13 | --data_dir $data_path/smmnist_h5 \ 14 | --config_path "$pretrained_path/SMMNIST/smmnist64_FlowAE_Batch100_lr2e-4_Region10_affine_scale0.50/smmnist64.yaml" \ 15 | --restore_from $pretrained_path/SMMNIST/smmnist64_FlowAE_Batch100_lr2e-4_Region10_affine_scale0.50/snapshots/RegionMM.pth \ 16 | --data_type "test" \ 17 | --save-video True \ 18 | --random-seed 1000 \ 19 | --gpu "0" -------------------------------------------------------------------------------- /scripts/AE/valid_AE_ucf.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/AE/valid_AE_ucf.sh 2 | 3 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 4 | pretrained_path=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 5 | 6 | python ./scripts/AE/valid.py \ 7 | --cond_frames 1 \ 8 | --pred_frames 10 \ 9 | --num_videos 256 \ 10 | --batch_size 2 \ 11 | --input_size 64 \ 12 | --log_dir "./logs_validation/AE/UCF/ucf101_test" \ 13 | --data_dir $data_path/UCF101_h5 \ 14 | --config_path "$pretrained_path/UCF101/ucf101_64_FlowAE_Batch100_lr2e-4_Region64_scale0.5/ucf101_64.yaml" \ 15 | --restore_from $pretrained_path/UCF101/ucf101_64_FlowAE_Batch100_lr2e-4_Region64_scale0.5/snapshots/RegionMM_0100_S120000_270.85.pth \ 16 | --data_type "test" \ 17 | --save-video True \ 18 | --random-seed 1000 \ 19 | --gpu "0" -------------------------------------------------------------------------------- /scripts/DM/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import imageio 4 | import torch 5 | from torch.utils import data 6 | import numpy as np 7 | import torch.backends.cudnn as cudnn 8 | import os 9 | import yaml 10 | from shutil import copy 11 | from train import train 12 | 13 | from utils.seed import setup_seed 14 | from utils.logger import Logger 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | import wandb 19 | import os.path as osp 20 | import timeit 21 | import math 22 | from PIL import Image 23 | import sys 24 | import random 25 | from einops import rearrange 26 | 27 | if __name__ == '__main__': 28 | cudnn.enabled = True 29 | cudnn.benchmark = True 30 | 31 | if sys.version_info[0] < 3: 32 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 33 | 34 | parser = argparse.ArgumentParser(description="Flow Diffusion") 35 | parser.add_argument("--postfix", default="") 36 | parser.add_argument("--fine-tune", default=False) 37 | parser.add_argument("--set-start", default=True) 38 | parser.add_argument("--start-step", default=0, type=int) 39 | parser.add_argument("--log_dir",default='./logs_training/diffusion', help="path to log into") 40 | parser.add_argument("--config",default="./config/smmnist64.yaml",help="path to config") 41 | parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))), 42 | help="Names of the devices comma separated.") 43 | parser.add_argument("--random-seed", type=int, default=1234, 44 | help="Random seed to have reproducible results.") 45 | parser.add_argument("--checkpoint", default="") 46 | parser.add_argument("--flowae_checkpoint", # use the flowae_checkpoint pretrained model provided by Snap 47 | default="/mnt/rhdd/zzc/data/video_prediction/flow_pretrained/better/smmnist64/snapshots/RegionMM.pth", 48 | help="path to flowae_checkpoint checkpoint") 49 | parser.add_argument("--verbose", default=False, help="Print model architecture") 50 | parser.add_argument("--fp16", default=False) 51 | 52 | args = parser.parse_args() 53 | 54 | setup_seed(int(args.random_seed)) 55 | 56 | with open(args.config) as f: 57 | config = yaml.safe_load(f) 58 | 59 | if args.postfix == '': 60 | postfix = '' 61 | else: 62 | postfix = '_' + args.postfix 63 | 64 | log_dir = os.path.join(args.log_dir, os.path.basename(args.config).split('.')[0]+postfix) 65 | if not os.path.exists(log_dir): 66 | os.makedirs(log_dir) 67 | if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))): 68 | copy(args.config, log_dir) 69 | 70 | # the directory to save checkpoints 71 | config["snapshots"] = os.path.join(log_dir, 'snapshots') 72 | os.makedirs(config["snapshots"], exist_ok=True) 73 | # the directory to save images of training results 74 | config["imgshots"] = os.path.join(log_dir, 'imgshots') 75 | os.makedirs(config["imgshots"], exist_ok=True) 76 | # vidshots 77 | config["vidshots"] = os.path.join(log_dir, 'vidshots') 78 | os.makedirs(config["vidshots"], exist_ok=True) 79 | # samples 80 | config["samples"] = os.path.join(log_dir, 'samples') 81 | os.makedirs(config["samples"], exist_ok=True) 82 | 83 | config["set_start"] = args.set_start 84 | 85 | train_params = config['diffusion_params']['train_params'] 86 | model_params = config['diffusion_params']['model_params'] 87 | dataset_params = config['dataset_params'] 88 | 89 | log_txt = os.path.join(log_dir, 90 | "B"+format(train_params['batch_size'], "04d")+ 91 | "E"+format(train_params['max_epochs'], "04d")+".log") 92 | sys.stdout = Logger(log_txt, sys.stdout) 93 | 94 | wandb.login() 95 | wandb.init( 96 | entity="nku428", 97 | project="EDM_v1", 98 | config=config, 99 | name=f"{config['experiment_name']}{postfix}", 100 | dir=log_dir, 101 | tags=["diffusion"] 102 | ) 103 | 104 | print("postfix:", postfix) 105 | print("checkpoint:", args.checkpoint) 106 | print("flowae checkpoint:", args.flowae_checkpoint) 107 | print("batch size:", train_params['batch_size']) 108 | 109 | config['flowae_checkpoint'] = args.flowae_checkpoint 110 | 111 | print("Training...") 112 | train( 113 | config, 114 | dataset_params, 115 | train_params, 116 | log_dir, 117 | args.checkpoint, 118 | args.device_ids 119 | ) 120 | -------------------------------------------------------------------------------- /scripts/DM/train_DM_bair.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/train_DM_bair.sh 2 | 3 | AE_CKPT_PATH=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | AE_NAME=bair64_scale0.50 5 | AE_STEP=RegionMM 6 | SEED=1234 7 | 8 | python ./scripts/DM/run.py \ 9 | --random-seed $SEED \ 10 | --flowae_checkpoint $AE_CKPT_PATH/BAIR/$AE_NAME/snapshots/$AE_STEP.pth \ 11 | --config ./config/DM/bair.yaml \ 12 | --log_dir ./logs_training/DM/BAIR \ 13 | --device_ids 0,1,2,3 \ 14 | --postfix test -------------------------------------------------------------------------------- /scripts/DM/train_DM_cityscapes.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/train_DM_cityscapes.sh 2 | 3 | AE_CKPT_PATH=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | AE_NAME=cityscapes128_perspective 5 | AE_STEP=RegionMM 6 | # AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region40_perspective_scale0.50 7 | # AE_STEP=RegionMM_0128_S100000 8 | # AE_NAME=cityscapes128_FlowAE_Batch64_lr1e-4_Region40_perspective_scale1.00 9 | # AE_STEP=RegionMM_best_123.506 10 | SEED=1234 11 | 12 | python ./scripts/DM/run.py \ 13 | --random-seed $SEED \ 14 | --flowae_checkpoint $AE_CKPT_PATH/Cityscapes/$AE_NAME/snapshots/$AE_STEP.pth \ 15 | --config ./config/DM/cityscapes.yaml \ 16 | --log_dir ./logs_training/DM/Cityscapes \ 17 | --device_ids 0,1,2,3,4,5,6,7 \ 18 | --postfix test -------------------------------------------------------------------------------- /scripts/DM/train_DM_kth.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/train_DM_kth.sh 2 | 3 | AE_CKPT_PATH=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | AE_NAME=kth64_FlowAE_Batch256_lr2e-4_Region20_affine_Max40_2 5 | AE_STEP=RegionMM_0256_S220000 6 | # AE_NAME=kth64_FlowAE_Batch256_lr2e-4_Region20_affine_Max40_2 7 | # AE_STEP=RegionMM_0256_S220000 8 | # AE_NAME=kth64_FlowAE_Batch128_lr1e-4_Region20_affine_scale1.00_resume 9 | # AE_STEP=RegionMM_best_157.143 10 | SEED=1234 11 | 12 | python ./scripts/DM/run.py \ 13 | --random-seed $SEED \ 14 | --flowae_checkpoint $AE_CKPT_PATH/KTH/$AE_NAME/snapshots/$AE_STEP.pth \ 15 | --config ./config/DM/kth.yaml \ 16 | --log_dir ./logs_training/DM/KTH \ 17 | --device_ids 0,1,2,3 \ 18 | --postfix test -------------------------------------------------------------------------------- /scripts/DM/train_DM_smmnist.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/train_DM_smmnist.sh 2 | 3 | AE_CKPT_PATH=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | AE_NAME=smmnist64_FlowAE_Batch100_lr2e-4_Region10_affine_scale0.50 5 | AE_STEP=RegionMM 6 | # AE_NAME=smmnist64_FlowAE_Batch128_lr1e-4_Region10_affine_scale1.00 7 | # AE_STEP=RegionMM_best_2.183 8 | SEED=1234 9 | 10 | python ./scripts/DM/run.py \ 11 | --random-seed $SEED \ 12 | --flowae_checkpoint $AE_CKPT_PATH/SMMNIST/$AE_NAME/snapshots/$AE_STEP.pth \ 13 | --config ./config/DM/smmnist.yaml \ 14 | --log_dir ./logs_training/DM/SMMNIST \ 15 | --device_ids 0,1,2,3 \ 16 | --postfix test -------------------------------------------------------------------------------- /scripts/DM/train_DM_ucf.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/train_DM_ucf.sh 2 | 3 | AE_CKPT_PATH=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | AE_NAME=ucf101_64_FlowAE_Batch100_lr2e-4_Region64_scale0.5 5 | AE_STEP=RegionMM_0100_S120000_270.85 6 | # AE_NAME=ucf101_64_FlowAE_Batch100_lr2e-4_Region128_scale0.5 7 | # AE_STEP=RegionMM_0100_S120000_236.946 8 | SEED=1234 9 | 10 | python ./scripts/DM/run.py \ 11 | --random-seed $SEED \ 12 | --flowae_checkpoint $AE_CKPT_PATH/UCF101/$AE_NAME/snapshots/$AE_STEP.pth \ 13 | --config ./config/DM/ucf.yaml \ 14 | --log_dir ./logs_training/DM/ucf \ 15 | --device_ids 0,1 \ 16 | --postfix test -------------------------------------------------------------------------------- /scripts/DM/valid_DM_bair.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/valid_DM_bair.sh 2 | 3 | AE_CKPT=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | DM_CKPT=/home/ubuntu/zzc/data/video_prediction/DM_pretrained 5 | 6 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 7 | 8 | AE_NAME=bair64_scale0.50 9 | AE_STEP=RegionMM 10 | 11 | ######################################## 12 | # - VideoFlowDiffusion_multi_w_ref 13 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u12 14 | # ------------------------------------- 15 | # pred4 - repeat1 - 256bs - 30966MB 16 | # pred4 - repeat2 - 128bs - 25846MB 17 | # pred4 - repeat4 - 64bs - 26376MB 18 | # ------------------------------------- 19 | DM_architecture=VideoFlowDiffusion_multi_w_ref 20 | Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u12 21 | DM_NAME=bair64_DM_Batch64_lr2e-4_c2p4_STW_adaptor_scale0.50_multi_traj 22 | DM_STEP=flowdiff_best_73000_315.362 23 | SEED=1000 24 | NUM_SAMPLE=100 25 | NUM_BATCH_SIZE=128 26 | ######################################## 27 | 28 | ######################################## 29 | # - VideoFlowDiffusion_multi_w_ref 30 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u12 31 | # ------------------------------------- 32 | # pred5 - repeat1 - 256bs - 32474MB 33 | # ------------------------------------- 34 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 35 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_u12 36 | # DM_NAME=bair64_DM_Batch64_lr2.e-4_c2p5_STW_adaptor_multi_traj_resume 37 | # DM_STEP=flowdiff_0064_S190000 38 | # SEED=1000 39 | # NUM_SAMPLE=1 40 | # NUM_BATCH_SIZE=128 41 | ######################################## 42 | 43 | ######################################## 44 | # - VideoFlowDiffusion_multi_w_ref 45 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 46 | # ------------------------------------- 47 | # pred7 - repeat1 - 256bs - 37520MB 48 | # ------------------------------------- 49 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 50 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 51 | # DM_NAME=bair64_DM_Batch66_lr2.e-4_c2p7_STW_adaptor_multi_traj_ada 52 | # DM_STEP=flowdiff_0066_S095000 53 | # SEED=1000 54 | # NUM_SAMPLE=1 55 | # NUM_BATCH_SIZE=128 56 | ######################################## 57 | 58 | ######################################## 59 | # - VideoFlowDiffusion_multi_w_ref 60 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 61 | # ------------------------------------- 62 | # pred10 - repeat1 - 256bs - 42684MB 63 | # ------------------------------------- 64 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 65 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 66 | # DM_NAME=bair64_DM_Batch64_lr2e-4_c2p10_STW_adaptor_scale0.50_multi_traj_ada 67 | # DM_STEP=flowdiff_best_239.058 68 | # SEED=1000 69 | # NUM_SAMPLE=1 70 | # NUM_BATCH_SIZE=64 71 | ######################################## 72 | 73 | CUDA_VISIBLE_DEVICES=0 \ 74 | python ./scripts/DM/valid.py \ 75 | --num_sample_video $NUM_SAMPLE \ 76 | --total_pred_frames 28 \ 77 | --num_videos 256 \ 78 | --valid_batch_size $NUM_BATCH_SIZE \ 79 | --random-seed $SEED \ 80 | --DM_arch $DM_architecture \ 81 | --Unet3D_arch $Unet3D_architecture \ 82 | --dataset_path $data_path/bair_h5 \ 83 | --flowae_checkpoint $AE_CKPT/BAIR/$AE_NAME/snapshots/$AE_STEP.pth \ 84 | --config $DM_CKPT/BAIR/$DM_NAME/bair64.yaml \ 85 | --checkpoint $DM_CKPT/BAIR/$DM_NAME/snapshots/$DM_STEP.pth \ 86 | --log_dir ./logs_validation/pretrained_DM/BAIR/${DM_NAME}_${DM_STEP}_${SEED}_${NUM_SAMPLE} 87 | 88 | -------------------------------------------------------------------------------- /scripts/DM/valid_DM_cityscapes.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/valid_DM_cityscapes.sh 2 | 3 | AE_CKPT=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | DM_CKPT=/home/ubuntu/zzc/data/video_prediction/DM_pretrained 5 | 6 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 7 | 8 | DM_architecture=VideoFlowDiffusion_multi_w_ref_u22 9 | Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 10 | 11 | ######################################## 12 | # - VideoFlowDiffusion_multi_w_ref_u22 13 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 14 | # ------------------------------------- 15 | # AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region20_perspective_scale0.25 16 | # AE_STEP=RegionMM_0128_S150000 17 | # DM_NAME=cityscapes128_DM_Batch32_lr1.5e-4_c2p4_STW_adaptor_scale0.25_multi_traj_ada 18 | # DM_STEP=flowdiff_best 19 | # SEED=1000 20 | # NUM_SAMPLE=4 21 | # NUM_BATCH_SIZE=32 22 | ######################################## 23 | 24 | ######################################## 25 | # - VideoFlowDiffusion_multi_w_ref_u22 26 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 27 | # ------------------------------------- 28 | AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region20_perspective_scale0.25 29 | AE_STEP=RegionMM_0128_S150000 30 | DM_NAME=cityscapes128_DM_Batch40_lr1.5e-4_c2p5_STW_adaptor_scale0.25_multi_traj_ada 31 | DM_STEP=flowdiff_best_33000_181.577 32 | SEED=1000 33 | NUM_SAMPLE=100 34 | NUM_BATCH_SIZE=128 35 | ######################################## 36 | 37 | ######################################## 38 | # - VideoFlowDiffusion_multi_w_ref_u22 39 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 40 | # ------------------------------------- 41 | # AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region20_perspective_scale0.25 42 | # AE_STEP=RegionMM_0128_S150000 43 | # DM_NAME=cityscapes128_DM_Batch40_lr1.5e-4_c2p7_STW_adaptor_scale0.25_multi_traj_ada 44 | # DM_STEP=flowdiff_best 45 | # SEED=1000 46 | # NUM_SAMPLE=4 47 | # NUM_BATCH_SIZE=32 48 | ######################################## 49 | 50 | ######################################## 51 | # - VideoFlowDiffusion_multi_w_ref_u22 52 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 53 | # ------------------------------------- 54 | # AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region20_perspective_scale0.25 55 | # AE_STEP=RegionMM_0128_S150000 56 | # DM_NAME=cityscapes128_DM_Batch40_lr1.5e-4_c2p10_STW_adaptor_scale0.25_multi_traj_ada 57 | # DM_STEP=flowdiff_0040_S220000 58 | # SEED=1000 59 | # NUM_SAMPLE=4 60 | # NUM_BATCH_SIZE=32 61 | ######################################## 62 | 63 | ######################################## 64 | # - VideoFlowDiffusion_multi_w_ref_u22 65 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada_u22 66 | # ------------------------------------- 67 | # AE_NAME=cityscapes128_FlowAE_Batch128_lr2e-4_Region40_perspective_scale0.50 68 | # AE_STEP=RegionMM_0128_S100000 69 | # DM_NAME=cityscapes128_DM_Batch28_lr1.5e-4_c2p7_STW_adaptor_scale0.5_multi_traj_ada 70 | # DM_STEP=flowdiff_best 71 | # SEED=1000 72 | # NUM_SAMPLE=4 73 | # NUM_BATCH_SIZE=32 74 | ######################################## 75 | 76 | CUDA_VISIBLE_DEVICES=0 \ 77 | python ./scripts/DM/valid.py \ 78 | --estimate_occlusion_map \ 79 | --num_sample_video $NUM_SAMPLE \ 80 | --total_pred_frames 28 \ 81 | --num_videos 256 \ 82 | --valid_batch_size $NUM_BATCH_SIZE \ 83 | --random-seed $SEED \ 84 | --DM_arch $DM_architecture \ 85 | --Unet3D_arch $Unet3D_architecture \ 86 | --dataset_path $data_path/cityscapes_h5 \ 87 | --flowae_checkpoint $AE_CKPT/Cityscapes/$AE_NAME/snapshots/$AE_STEP.pth \ 88 | --config $DM_CKPT/Cityscapes/$DM_NAME/cityscapes128.yaml \ 89 | --checkpoint $DM_CKPT/Cityscapes/$DM_NAME/snapshots/$DM_STEP.pth \ 90 | --log_dir ./logs_validation/pretrained_DM/Cityscapes/${DM_NAME}_${DM_STEP}_${SEED}_${NUM_SAMPLE} 91 | -------------------------------------------------------------------------------- /scripts/DM/valid_DM_kth.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/valid_DM_kth.sh 2 | 3 | AE_CKPT=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | DM_CKPT=/home/ubuntu/zzc/data/video_prediction/DM_pretrained 5 | 6 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 7 | 8 | AE_NAME=kth64_FlowAE_Batch256_lr2e-4_Region20_affine_Max40_2 9 | AE_STEP=RegionMM_0256_S220000 10 | 11 | ######################################## 12 | # - VideoFlowDiffusion_multi_w_ref 13 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 14 | # ------------------------------------- 15 | DM_architecture=VideoFlowDiffusion_multi_w_ref 16 | Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 17 | DM_NAME=kth64_DM_Batch32_lr2e-4_c10p4_STW_adaptor_scale0.50_multi_traj_ada 18 | DM_STEP=flowdiff_best_355.236 19 | SEED=1000 20 | NUM_SAMPLE=100 21 | NUM_BATCH_SIZE=100 22 | ######################################## 23 | 24 | ######################################## 25 | # - VideoFlowDiffusion_multi_w_ref 26 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 27 | # ------------------------------------- 28 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 29 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 30 | # DM_NAME=kth64_DM_Batch32_lr2e-4_c10p5_STW_adaptor_multi_traj_ada 31 | # DM_STEP=flowdiff_0032_S098000 32 | # SEED=7000 33 | # NUM_SAMPLE=1 34 | # NUM_BATCH_SIZE=100 35 | ######################################## 36 | 37 | ######################################## 38 | # - VideoFlowDiffusion_multi_w_ref 39 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 40 | # ------------------------------------- 41 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 42 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 43 | # DM_NAME=kth64_DM_Batch64_lr2e-4_c10p10_STW_adaptor_scale0.50_multi_traj_ada 44 | # DM_STEP=flowdiff_0064_S088000 45 | # SEED=3000 46 | # NUM_SAMPLE=1 47 | # NUM_BATCH_SIZE=100 48 | ######################################## 49 | 50 | ######################################## 51 | # - VideoFlowDiffusion_multi_w_ref 52 | # - DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 53 | # ------------------------------------- 54 | # DM_architecture=VideoFlowDiffusion_multi_w_ref 55 | # Unet3D_architecture=DenoiseNet_STWAtt_w_w_ref_adaptor_cross_multi_traj_ada 56 | # DM_NAME=kth64_DM_Batch32_lr2e-4_c10p20_STW_adaptor_scale0.50_multi_traj_ada 57 | # DM_STEP=flowdiff_0032_S075000 58 | # SEED=3000 59 | # NUM_SAMPLE=1 60 | # NUM_BATCH_SIZE=100 61 | ######################################## 62 | 63 | CUDA_VISIBLE_DEVICES=0 \ 64 | python ./scripts/DM/valid.py \ 65 | --num_sample_video $NUM_SAMPLE \ 66 | --total_pred_frames 40 \ 67 | --num_videos 256 \ 68 | --valid_batch_size $NUM_BATCH_SIZE \ 69 | --random-seed $SEED \ 70 | --DM_arch $DM_architecture \ 71 | --Unet3D_arch $Unet3D_architecture \ 72 | --dataset_path $data_path/kth_h5 \ 73 | --flowae_checkpoint $AE_CKPT/KTH/$AE_NAME/snapshots/$AE_STEP.pth \ 74 | --config $DM_CKPT/KTH/$DM_NAME/kth64.yaml \ 75 | --checkpoint $DM_CKPT/KTH/$DM_NAME/snapshots/$DM_STEP.pth \ 76 | --log_dir ./logs_validation/pretrained_DM/KTH/${DM_NAME}_${DM_STEP}_${SEED}_${NUM_SAMPLE} 77 | # --random_time \ 78 | -------------------------------------------------------------------------------- /scripts/DM/valid_DM_smmnist.sh: -------------------------------------------------------------------------------- 1 | # sh ./scripts/DM/valid_DM_smmnist.sh 2 | 3 | AE_CKPT=/home/ubuntu/zzc/data/video_prediction/AE_pretrained 4 | DM_CKPT=/home/ubuntu/zzc/data/video_prediction/DM_pretrained 5 | 6 | data_path=/home/ubuntu/zzc/data/video_prediction/dataset_h5 7 | 8 | AE_NAME=smmnist64_scale0.50 9 | AE_STEP=RegionMM 10 | DM_architecture=VideoFlowDiffusion_multi1248 11 | Unet3D_architecture=DenoiseNet_STWAtt_w_wo_ref_adaptor_cross_multi 12 | 13 | ######################################## 14 | # - VideoFlowDiffusion_multi1248 15 | # - DenoiseNet_STWAtt_w_wo_ref_adaptor_cross_multi 16 | # ------------------------------------- 17 | DM_NAME=smmnist64_DM_Batch32_lr2e-4_c10p4_STW_adaptor_scale0.50_multi_1248 18 | DM_STEP=flowdiff_best_23.160 19 | SEED=1000 20 | NUM_SAMPLE=100 21 | NUM_BATCH_SIZE=2 22 | ######################################## 23 | 24 | ######################################## 25 | # - VideoFlowDiffusion_multi1248 26 | # - DenoiseNet_STWAtt_w_wo_ref_adaptor_cross_multi 27 | # ------------------------------------- 28 | # DM_NAME=smmnist64_DM_Batch32_lr2.0e-4_c5p5_STW_adaptor_multi_124_resume 29 | # DM_STEP=flowdiff_0036_S265000 30 | # SEED=1000 31 | # NUM_SAMPLE=100 32 | # NUM_BATCH_SIZE=2 33 | ######################################## 34 | 35 | ######################################## 36 | # - VideoFlowDiffusion_multi1248 37 | # - DenoiseNet_STWAtt_w_wo_ref_adaptor_cross_multi 38 | # ------------------------------------- 39 | # DM_NAME=smmnist64_DM_Batch40_lr2e-4_c10p10_STW_adaptor_multi_1248 40 | # DM_STEP=flowdiff_0040_S195000 41 | # SEED=1000 42 | # NUM_SAMPLE=100 43 | # NUM_BATCH_SIZE=2 44 | ######################################## 45 | 46 | CUDA_VISIBLE_DEVICES=1 \ 47 | python ./scripts/DM/valid.py \ 48 | --num_sample_video $NUM_SAMPLE \ 49 | --total_pred_frames 10 \ 50 | --num_videos 256 \ 51 | --valid_batch_size $NUM_BATCH_SIZE \ 52 | --random-seed $SEED \ 53 | --DM_arch $DM_architecture \ 54 | --Unet3D_arch $Unet3D_architecture \ 55 | --dataset_path $data_path/smmnist_h5 \ 56 | --flowae_checkpoint $AE_CKPT/SMMNIST/$AE_NAME/snapshots/$AE_STEP.pth \ 57 | --config $DM_CKPT/SMMNIST/$DM_NAME/smmnist64.yaml \ 58 | --checkpoint $DM_CKPT/SMMNIST/$DM_NAME/snapshots/$DM_STEP.pth \ 59 | --log_dir ./logs_validation/pretrained_DM/SMMNIST/${DM_NAME}_${DM_STEP}_${SEED}_${NUM_SAMPLE} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='extdm', 5 | version='1.0.0', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | ], 12 | ) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Logger(object): 4 | def __init__(self, filename='default.log', stream=sys.stdout): 5 | self.terminal = stream 6 | self.log = open(filename, 'w') 7 | 8 | def write(self, message): 9 | self.terminal.write(message) 10 | self.log.write(message) 11 | 12 | def flush(self): 13 | pass -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths=[10000000], verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 81 | 82 | def schedule(self, n, **kwargs): 83 | cycle = self.find_in_interval(n) 84 | n = n - self.cum_cycles[cycle] 85 | if self.verbosity_interval > 0: 86 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 87 | f"current cycle {cycle}") 88 | 89 | if n < self.lr_warm_up_steps[cycle]: 90 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 91 | self.last_f = f 92 | return f 93 | else: 94 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 95 | self.last_f = f 96 | return f 97 | -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | def setup_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed_all(seed) 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /vis/save_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from utils.visualize import visualize 5 | 6 | from metrics.calculate_fvd import calculate_fvd,calculate_fvd1 7 | from metrics.calculate_psnr import calculate_psnr,calculate_psnr1 8 | from metrics.calculate_ssim import calculate_ssim,calculate_ssim1 9 | from metrics.calculate_lpips import calculate_lpips,calculate_lpips1 10 | import json 11 | 12 | test_name='/home/ubuntu/zzc/code/EDM/logs_validation/diffusion/cityscapes128_DM_Batch44_lr1.5e-4_c2p4_STW_adaptor' 13 | 14 | num_frames_cond = 2 15 | video_num = 256 16 | 17 | device = torch.device("cuda") 18 | 19 | metrics = {} 20 | 21 | videos1 = torch.load(f'{test_name}/origin.pt') 22 | videos2 = torch.load(f'{test_name}/result.pt') 23 | flows1 = torch.load(f'{test_name}/origin_flows.pt') 24 | flows2 = torch.load(f'{test_name}/result_flows.pt') 25 | 26 | abs_diff_videos = torch.sqrt(torch.sum((videos1 - videos2)**2, dim=2)/3).unsqueeze(2).repeat(1,1,3,1,1) 27 | abs_diff_flows = torch.sqrt(torch.sum((flows1 - flows2)**2, dim=2)/3).unsqueeze(2).repeat(1,1,3,1,1) 28 | 29 | from utils.visualize import visualize_ori_pre_flow_diff 30 | visualize_ori_pre_flow_diff( 31 | save_path=f"{test_name}/result", 32 | origin=videos1, 33 | result=videos2, 34 | origin_flow=flows1, 35 | result_flow=flows2, 36 | video_diff=abs_diff_videos, 37 | flow_diff=abs_diff_flows, 38 | epoch_or_step_num=0, 39 | cond_frame_num=num_frames_cond, 40 | skip_pic_num=1 41 | ) 42 | 43 | from metrics.calculate_fvd import calculate_fvd,calculate_fvd1 44 | from metrics.calculate_psnr import calculate_psnr,calculate_psnr1 45 | from metrics.calculate_ssim import calculate_ssim,calculate_ssim1 46 | from metrics.calculate_lpips import calculate_lpips,calculate_lpips1 47 | 48 | fvd = calculate_fvd1(videos1, videos2, torch.device("cuda"), mini_bs=16) 49 | videos1 = videos1[:, num_frames_cond:] 50 | videos2 = videos2[:, num_frames_cond:] 51 | ssim = calculate_ssim1(videos1, videos2)[0] 52 | psnr = calculate_psnr1(videos1, videos2)[0] 53 | lpips = calculate_lpips1(videos1, videos2, torch.device("cuda"))[0] 54 | 55 | print("[fvd ]", fvd) 56 | print("[ssim ]", ssim) 57 | print("[psnr ]", psnr) 58 | print("[lpips ]", lpips) -------------------------------------------------------------------------------- /vis/test_flowae_run_our_result.py: -------------------------------------------------------------------------------- 1 | # use LFAE to reconstruct testing videos and measure the loss in video domain 2 | # using RegionMM 3 | 4 | import argparse 5 | import cv2 6 | import einops 7 | import imageio 8 | import torch 9 | from torch.utils import data 10 | import numpy as np 11 | import torch.backends.cudnn as cudnn 12 | import os 13 | import timeit 14 | from PIL import Image 15 | from data.video_dataset import VideoDataset, dataset2videos 16 | import random 17 | from model.LFAE.flow_autoenc import FlowAE 18 | import torch.nn.functional as F 19 | from model.LFAE.util import Visualizer 20 | import json_tricks as json 21 | import matplotlib.pyplot as plt 22 | from matplotlib.collections import LineCollection 23 | import yaml 24 | 25 | from utils.misc import flow2fig 26 | import mediapy as media 27 | 28 | def fig2data(fig): 29 | """ 30 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 31 | @param fig a matplotlib figure 32 | @return a numpy 3D array of RGBA values 33 | """ 34 | # draw the renderer 35 | fig.canvas.draw() 36 | 37 | # Get the RGBA buffer from the figure 38 | w, h = fig.canvas.get_width_height() 39 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 40 | buf.shape = (w, h, 4) 41 | 42 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 43 | buf = np.roll(buf, 3, axis=2) 44 | return buf 45 | 46 | def plot_grid(x, y, ax=None, **kwargs): 47 | ax = ax or plt.gca() 48 | segs1 = np.stack((x, y), axis=2) 49 | segs2 = segs1.transpose(1, 0, 2) 50 | ax.add_collection(LineCollection(segs1, **kwargs)) 51 | ax.add_collection(LineCollection(segs2, **kwargs)) 52 | ax.autoscale() 53 | 54 | def grid2fig(warped_grid, grid_size=32, img_size=256): 55 | dpi = 1000 56 | # plt.ioff() 57 | h_range = torch.linspace(-1, 1, grid_size) 58 | w_range = torch.linspace(-1, 1, grid_size) 59 | grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2) 60 | flow_uv = grid.cpu().data.numpy() 61 | fig, ax = plt.subplots() 62 | grid_x, grid_y = warped_grid[..., 0], warped_grid[..., 1] 63 | plot_grid(flow_uv[..., 0], flow_uv[..., 1], ax=ax, color="lightgrey") 64 | plot_grid(grid_x, grid_y, ax=ax, color="C0") 65 | plt.axis("off") 66 | plt.tight_layout(pad=0) 67 | fig.set_size_inches(img_size/100, img_size/100) 68 | fig.set_dpi(100) 69 | out = fig2data(fig)[:, :, :3] 70 | plt.close() 71 | plt.cla() 72 | plt.clf() 73 | return out 74 | 75 | 76 | name="bair64_not_onlyflow" 77 | 78 | start = timeit.default_timer() 79 | BATCH_SIZE = 256 80 | data_dir = "/mnt/sda/hjy/bair/BAIR_h5" 81 | GPU = "0" 82 | postfix = "" 83 | INPUT_SIZE = 64 84 | COND_FRAMES = 2 # 10 85 | PRED_FRAMES = 28 # 40 86 | N_FRAMES = COND_FRAMES + PRED_FRAMES # 50 / 30 87 | NUM_VIDEOS = 256 # 16 #256 88 | SAVE_VIDEO = True 89 | NUM_ITER = NUM_VIDEOS // BATCH_SIZE 90 | RANDOM_SEED = 1234 # 1234 91 | MEAN = (0.0, 0.0, 0.0) 92 | # the path to trained LFAE model 93 | # RESTORE_FROM = "/mnt/sda/hjy/flow_pretrained/kth64/snapshots/RegionMM.pth" 94 | RESTORE_FROM = "/mnt/rhdd/zzc/data/video_prediction/flow_pretrained/better/bair64/snapshots/RegionMM.pth" 95 | config_pth = f"./logs_training/diffusion/{name}/bair64.yaml" 96 | 97 | visualizer = Visualizer() 98 | print(postfix) 99 | print("RESTORE_FROM:", RESTORE_FROM) 100 | print("config_path:", config_pth) 101 | print("save video:", SAVE_VIDEO) 102 | 103 | 104 | def sample_img(rec_img_batch, index): 105 | rec_img = rec_img_batch[index].permute(1, 2, 0).data.cpu().numpy().copy() 106 | rec_img += np.array(MEAN)/255.0 107 | rec_img[rec_img < 0] = 0 108 | rec_img[rec_img > 1] = 1 109 | rec_img *= 255 110 | return np.array(rec_img, np.uint8) 111 | 112 | 113 | def main(): 114 | """Create the model and start the training.""" 115 | 116 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 117 | 118 | cudnn.enabled = True 119 | cudnn.benchmark = True 120 | setup_seed(1234) 121 | 122 | model = FlowAE(is_train=False, config=config_pth) 123 | model.cuda() 124 | 125 | if os.path.isfile(RESTORE_FROM): 126 | print("=> loading checkpoint '{}'".format(RESTORE_FROM)) 127 | checkpoint = torch.load(RESTORE_FROM) 128 | model.generator.load_state_dict(checkpoint['generator']) 129 | model.region_predictor.load_state_dict(checkpoint['region_predictor']) 130 | model.bg_predictor.load_state_dict(checkpoint['bg_predictor']) 131 | print("=> loaded checkpoint '{}'".format(RESTORE_FROM)) 132 | else: 133 | print("=> no checkpoint found at '{}'".format(RESTORE_FROM)) 134 | exit(-1) 135 | 136 | model.eval() 137 | 138 | setup_seed(1234) 139 | 140 | with open(config_pth) as f: 141 | config = yaml.safe_load(f) 142 | 143 | valid_origin = torch.load(f'./logs_validation/diffusion/{name}/origin.pt') 144 | valid_result = torch.load(f'./logs_validation/diffusion/{name}/result.pt') 145 | 146 | batch_time = AverageMeter() 147 | data_time = AverageMeter() 148 | 149 | iter_end = timeit.default_timer() 150 | 151 | 152 | # index = [1,17,24,26,112,176,203,204,223,232] 153 | index = [2,5,27,30,52,61,71,84,134,222] 154 | 155 | for idx in index: 156 | origin_batch = valid_origin[idx] 157 | result_batch = valid_result[idx] 158 | 159 | os.makedirs(f'./flow_output/{name}/{idx}', exist_ok=True) 160 | 161 | data_time.update(timeit.default_timer() - iter_end) 162 | 163 | origin_batch = origin_batch.unsqueeze(0) 164 | result_batch = result_batch.unsqueeze(0) 165 | 166 | # real_vids b t c h w -> b c t h w 167 | # tensor(0.0431) tensor(0.9647) 168 | origin_batch = origin_batch.permute(0,2,1,3,4).contiguous() 169 | result_batch = result_batch.permute(0,2,1,3,4).contiguous() 170 | 171 | cond_vids = origin_batch[:, :, :COND_FRAMES, :, :] 172 | pred_vids = result_batch[:, :, COND_FRAMES:, :, :] 173 | 174 | # use first frame of each video as reference frame (vids: B C T H W) 175 | ref_imgs = cond_vids[:, :, -1, :, :].clone().detach() 176 | 177 | batch_time.update(timeit.default_timer() - iter_end) 178 | 179 | flow = [] 180 | 181 | for frame_idx in range(COND_FRAMES+PRED_FRAMES): 182 | dri_imgs = result_batch[:, :, frame_idx, :, :] 183 | with torch.no_grad(): 184 | model.set_train_input(ref_img=ref_imgs, dri_img=dri_imgs) 185 | model.forward() 186 | 187 | driven = einops.rearrange(dri_imgs[0]*255, "c h w -> h w c").numpy() 188 | warped = model.generated['optical_flow'][0].clone().detach().cpu().numpy() 189 | print('1', np.max(warped), np.min(warped), warped.shape) 190 | output = flow2fig(warped_grid=warped, grid_size=32, img_size=64) 191 | flow.append(output) 192 | print('2', np.max(output), np.min(output), output.shape) 193 | cv2.imwrite(f'./flow_output/{name}/{idx}/driven_{frame_idx}.png', driven[:,:,::-1]) # RGB -> BGR 194 | cv2.imwrite(f'./flow_output/{name}/{idx}/flow_{frame_idx}.png', output[:,:,::-1]) # RGB -> BGR 195 | 196 | video = np.array(einops.rearrange(result_batch[0],'c t h w -> t h w c')) 197 | flow = np.array(flow) 198 | media.show_videos([, flow], fps=20) 199 | 200 | iter_end = timeit.default_timer() 201 | 202 | end = timeit.default_timer() 203 | print(end - start, 'seconds') 204 | 205 | 206 | class AverageMeter(object): 207 | """Computes and stores the average and current value""" 208 | 209 | def __init__(self): 210 | self.reset() 211 | 212 | def reset(self): 213 | self.val = 0 214 | self.avg = 0 215 | self.sum = 0 216 | self.count = 0 217 | 218 | def update(self, val, n=1): 219 | self.val = val 220 | self.sum += val * n 221 | self.count += n 222 | self.avg = self.sum / self.count 223 | 224 | 225 | def setup_seed(seed): 226 | torch.manual_seed(seed) 227 | torch.cuda.manual_seed_all(seed) 228 | np.random.seed(seed) 229 | random.seed(seed) 230 | torch.backends.cudnn.deterministic = True 231 | 232 | if __name__ == '__main__': 233 | main() 234 | 235 | -------------------------------------------------------------------------------- /vis/test_flowae_run_video2video.py: -------------------------------------------------------------------------------- 1 | from data.SMMNIST.stochastic_moving_mnist_edited import StochasticMovingMNIST 2 | import mediapy as media 3 | 4 | import cv2 5 | import einops 6 | import torch 7 | from torch.utils import data 8 | import numpy as np 9 | import torch.backends.cudnn as cudnn 10 | import os 11 | import random 12 | from model.LFAE.flow_autoenc import FlowAE 13 | import yaml 14 | 15 | from utils.misc import flow2fig 16 | 17 | mnist_dir = "/home/u1120230288/zzc/data/video_prediction/dataset/SMMNIST_h5" 18 | name = "smmnist_video2video" 19 | INPUT_SIZE = 64 20 | COND_FRAMES = 10 21 | PRED_FRAMES = 10 22 | RESTORE_FROM = "" 23 | config_pth = "" 24 | 25 | # cond_seq_len: 前10帧轨迹一致 26 | # pred_seq_len: 后面10帧轨迹变化 27 | # same_samples: 运动数字相同的视频数量 28 | # diff_samples: 运动数字不同的视频数量 29 | # 七个视频,最后一个为rdst视频,前六个为src视频 30 | train_dataset = StochasticMovingMNIST( 31 | mnist_dir, train=True, num_digits=2, 32 | step_length=0.1, with_target=False, 33 | cond_seq_len=10, pred_seq_len=10, same_samples=6, diff_samples=1 34 | ) 35 | # 取一组视频 36 | a_video_samples = train_dataset[50] 37 | 38 | def setup_seed(seed): 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | torch.backends.cudnn.deterministic = True 44 | 45 | def main(): 46 | """Create the model and start the training.""" 47 | 48 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 49 | 50 | cudnn.enabled = True 51 | cudnn.benchmark = True 52 | setup_seed(1234) 53 | 54 | model = FlowAE(is_train=False, config=config_pth) 55 | model.cuda() 56 | 57 | if os.path.isfile(RESTORE_FROM): 58 | print("=> loading checkpoint '{}'".format(RESTORE_FROM)) 59 | checkpoint = torch.load(RESTORE_FROM) 60 | model.generator.load_state_dict(checkpoint['generator']) 61 | model.region_predictor.load_state_dict(checkpoint['region_predictor']) 62 | model.bg_predictor.load_state_dict(checkpoint['bg_predictor']) 63 | print("=> loaded checkpoint '{}'".format(RESTORE_FROM)) 64 | else: 65 | print("=> no checkpoint found at '{}'".format(RESTORE_FROM)) 66 | exit(-1) 67 | 68 | model.eval() 69 | 70 | setup_seed(1234) 71 | 72 | with open(config_pth) as f: 73 | config = yaml.safe_load(f) 74 | 75 | 76 | # 前面是不同的参考运动samples,最后一个是dist的sample 77 | ref_videos = a_video_samples 78 | dst_video = a_video_samples[-1] 79 | 80 | ref_videos = torch.stack(ref_videos).repeat(1,1,3,1,1) 81 | dst_videos = dst_video.unsqueeze(0).repeat(len(ref_videos),1,3,1,1) 82 | # print(ref_videos.shape, dst_videos.shape) 83 | # torch.Size([6, 20, 3, 64, 64]) 84 | 85 | origin_batch = ref_videos.permute(0,2,1,3,4).contiguous() 86 | result_batch = dst_videos.permute(0,2,1,3,4).contiguous() 87 | 88 | cond_vids = origin_batch[:, :, :COND_FRAMES, :, :] 89 | pred_vids = result_batch[:, :, COND_FRAMES:, :, :] 90 | 91 | # use first frame of each video as reference frame (vids: B C T H W) 92 | ref_imgs = cond_vids[:, :, -1, :, :].cuda() 93 | 94 | real_grid_list = [] 95 | real_conf_list = [] 96 | 97 | for frame_idx in range(COND_FRAMES+PRED_FRAMES): 98 | if frame_idx < COND_FRAMES: 99 | dri_imgs = result_batch[:, :, frame_idx, :, :] 100 | else: 101 | dri_imgs = origin_batch[:, :, frame_idx, :, :] 102 | with torch.no_grad(): 103 | model.set_train_input(ref_img=ref_imgs, dri_img=dri_imgs) 104 | model.forward() 105 | real_grid_list.append(model.generated["optical_flow"].permute(0, 3, 1, 2)) 106 | real_conf_list.append(model.generated["occlusion_map"]) 107 | 108 | # 输出生成的flow 109 | # driven = einops.rearrange(dri_imgs*255, "t c h w ->t h w c").numpy() 110 | # warped = model.generated['optical_flow'].clone().detach().cpu().numpy() 111 | # # print('1', np.max(warped), np.min(warped), warped.shape) 112 | # output = [flow2fig(warped_grid=warped[ii], grid_size=32, img_size=64) for ii in range(len(warped))] 113 | # output = np.stack(output) 114 | # # print('2', np.max(output), np.min(output), output.shape) 115 | # for ii in range(len(origin_batch)): 116 | # os.makedirs(f'./video2video/{name}/motion_{ii}', exist_ok=True) 117 | # print(driven[:,:,::-1].shape, output[:,:,::-1].shape) 118 | # cv2.imwrite(f'./video2video/{name}/motion_{ii}/driven_{frame_idx}.png', driven[ii,:,:,::-1]) # RGB -> BGR 119 | # cv2.imwrite(f'./video2video/{name}/motion_{ii}/flow_{frame_idx}.png', output[ii,:,:,::-1]) # RGB -> BGR 120 | # exit() 121 | 122 | real_vid_grid = torch.stack(real_grid_list, dim=2) 123 | real_vid_conf = torch.stack(real_conf_list, dim=2) 124 | 125 | # print(real_vid_grid.shape, real_vid_grid.min(), real_vid_grid.max()) 126 | # print(real_vid_conf.shape, real_vid_conf.min(), real_vid_conf.max()) 127 | 128 | sample_vid_grid = real_vid_grid 129 | sample_vid_conf = real_vid_conf 130 | 131 | sample_out_img_list = [] 132 | sample_warped_img_list = [] 133 | 134 | for idx in range(sample_vid_grid.size(2)): 135 | sample_grid = sample_vid_grid[:, :, idx, :, :].permute(0, 2, 3, 1) 136 | sample_conf = sample_vid_conf[:, :, idx, :, :] 137 | # predict fake out image and fake warped image 138 | with torch.no_grad(): 139 | generated_result = model.generator.forward_with_flow( 140 | source_image=ref_imgs, 141 | optical_flow=sample_grid, 142 | occlusion_map=sample_conf) 143 | sample_out_img_list.append(generated_result["prediction"]) 144 | 145 | sample_out_vid = torch.stack(sample_out_img_list, dim=2).cpu() 146 | # print(sample_out_vid.shape, sample_out_vid.min(), sample_out_vid.max()) 147 | 148 | origin_videos = origin_batch 149 | result_videos = sample_out_vid 150 | 151 | origin_videos = einops.rearrange(origin_videos, 'b c t h w -> b t h w c').cpu() 152 | result_videos = einops.rearrange(result_videos, 'b c t h w -> b t h w c').cpu() 153 | 154 | print(origin_videos.shape, result_videos.shape) 155 | for i in range(origin_videos.shape[0]): 156 | print(origin_videos[i].numpy()) 157 | media.write_video(f"./origin_{i}.gif", origin_videos[i].numpy(), fps=10) 158 | for i in range(result_videos.shape[0]): 159 | media.write_video(f"./result_{i}.gif", result_videos[i].numpy(), fps=10) 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /vis/vis copy.py: -------------------------------------------------------------------------------- 1 | import cv2, shutil 2 | import os, torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from einops import rearrange 6 | from natsort import natsorted 7 | from metrics.calculate_fvd import calculate_fvd,calculate_fvd1 8 | from metrics.calculate_psnr import calculate_psnr,calculate_psnr1 9 | from metrics.calculate_ssim import calculate_ssim,calculate_ssim1 10 | from metrics.calculate_lpips import calculate_lpips,calculate_lpips1 11 | 12 | def get_metrics(videos1, videos2): 13 | videos1 = np.array(videos1).astype(float)/255 14 | videos1 = torch.from_numpy(videos1) 15 | videos1 = rearrange(videos1, 't h w c-> 1 t c h w').float() 16 | 17 | videos2 = np.array(videos2).astype(float)/255 18 | videos2 = torch.from_numpy(videos2) 19 | videos2 = rearrange(videos2, 't h w c-> 1 t c h w').float() 20 | # fvd = calculate_fvd1(videos1, videos2, torch.device("cuda"), mini_bs=1) 21 | ssim = calculate_ssim1(videos1, videos2)[0] 22 | psnr = calculate_psnr1(videos1, videos2)[0] 23 | lpips = calculate_lpips1(videos1, videos2, torch.device("cuda"))[0] 24 | return psnr, ssim, lpips 25 | 26 | dataset = '/home/ubuntu/zzc/data/video_prediction/ExtDM_output/SMMNIST/smmnist64_DM_Batch32_lr2.0e-4_c5p5_STW_adaptor_multi_124_resume_flowdiff_0036_S265000_1000_100' 27 | samplenum=100 28 | c=10 29 | p=10 30 | saveroot = '/home/ubuntu/zzc/data/video_prediction/tile' 31 | savedata = 'SMMNIST' 32 | 33 | if os.path.exists(os.path.join(saveroot, savedata)): 34 | shutil.rmtree(os.path.join(saveroot, savedata)) 35 | os.makedirs(os.path.join(saveroot, savedata)) 36 | 37 | origin_fold = os.path.join(dataset, 'result_origin', '0', 'pic') 38 | 39 | for vid in tqdm(os.listdir(origin_fold)): 40 | gts = [] 41 | gt_vid = os.path.join(origin_fold, vid, 'result') 42 | ori_p = natsorted(os.listdir(gt_vid)) 43 | 44 | for i in range(c+p): 45 | gts.append(cv2.imread(os.path.join(gt_vid, ori_p[i]))) 46 | gt_imgs = np.concatenate(gts, axis=1) 47 | 48 | psnr_best = None 49 | ssim_best = None 50 | lpips_best = None 51 | fvd_best = None 52 | 53 | psnr,ssim = 0,0 54 | lpips,fvd = 1e8,1e8 55 | 56 | 57 | for sample in range(samplenum): 58 | # print(sample) 59 | ress = [] 60 | sample_vid = os.path.join(dataset, 'result_'+str(sample), '0', 'pic', vid, 'result') 61 | res_p = natsorted(os.listdir(sample_vid)) 62 | 63 | for i in range(c): 64 | ress.append(np.zeros_like(gts[-1])) 65 | for i in range(p): 66 | ress.append(cv2.imread(os.path.join(sample_vid, res_p[c+i]))) 67 | # print(os.path.join(sample_vid, res_p[i])) 68 | 69 | res_imgs = np.concatenate(ress, axis=1) 70 | 71 | psnr_cur, ssim_cur, lpips_cur = get_metrics(gts[c:], ress[c:]) 72 | # print(psnr_cur, 'psnr_cur') 73 | if psnr_cur > psnr: 74 | psnr = psnr_cur 75 | # print(psnr, 'psnr') 76 | psnr_best = res_imgs 77 | if ssim_cur > ssim: 78 | ssim = ssim_cur 79 | # print(ssim, 'ssim') 80 | ssim_best = res_imgs 81 | if lpips_cur < lpips: 82 | lpips = lpips_cur 83 | # print(lpips, 'lpips') 84 | lpips_best = res_imgs 85 | 86 | 87 | 88 | psnrimg = np.concatenate([gt_imgs, psnr_best], axis=0) 89 | cv2.imwrite(os.path.join(saveroot, savedata, vid + '_psnr' + '.png'), psnrimg) 90 | 91 | ssimimg = np.concatenate([gt_imgs, ssim_best], axis=0) 92 | cv2.imwrite(os.path.join(saveroot, savedata, vid + '_ssim' + '.png'), ssimimg) 93 | 94 | lpipsimg = np.concatenate([gt_imgs, lpips_best], axis=0) 95 | cv2.imwrite(os.path.join(saveroot, savedata, vid + '_lpips' + '.png'), lpipsimg) 96 | 97 | # fvdimg = np.concatenate([gt_imgs, fvd_best], axis=0) 98 | # cv2.imwrite(os.path.join(saveroot, savedata, vid + '_fvd' + '.png'), fvdimg) 99 | --------------------------------------------------------------------------------