├── modules ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── modeling.cpython-38.pyc │ ├── ast_models.cpython-38.pyc │ ├── file_utils.cpython-38.pyc │ ├── module_clip.cpython-38.pyc │ ├── module_cross.cpython-38.pyc │ ├── optimization.cpython-38.pyc │ ├── resnet_models.cpython-38.pyc │ ├── until_config.cpython-38.pyc │ ├── until_module.cpython-38.pyc │ └── tokenization_clip.cpython-38.pyc ├── cross-base │ └── cross_config.json ├── until_config.py ├── tokenization_clip.py ├── optimization.py ├── file_utils.py ├── module_cross.py ├── ast_models.py ├── until_module.py └── resnet_models.py ├── easy_kill_process.sh ├── dataloaders ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── decoder.cpython-38.pyc │ ├── transform.cpython-38.pyc │ ├── mel_features.cpython-38.pyc │ ├── rawvideo_util.cpython-38.pyc │ ├── vggish_params.cpython-38.pyc │ ├── video_container.cpython-38.pyc │ ├── data_dataloaders.cpython-38.pyc │ ├── dataloader_msvd_retrieval.cpython-38.pyc │ ├── dataloader_didemo_retrieval.cpython-38.pyc │ ├── dataloader_lsmdc_retrieval.cpython-38.pyc │ ├── dataloader_msrvtt_retrieval.cpython-38.pyc │ ├── dataloader_charades_retrieval.cpython-38.pyc │ ├── dataloader_youcook_retrieval.cpython-38.pyc │ ├── dataloader_activitynet_retrieval.cpython-38.pyc │ ├── dataloader_qvhighlight_retrieval.cpython-38.pyc │ ├── dataloader_youcook_short_retrieval.cpython-38.pyc │ └── dataloader_activitynet_retrieval_demo.cpython-38.pyc ├── video_container.py ├── vggish_params.py ├── rawvideo_util.py ├── dataloader_lsmdc_retrieval.py ├── mel_features.py ├── dataloader_msvd_retrieval.py ├── utils.py └── decoder.py ├── torchvggish ├── __pycache__ │ ├── vggish.cpython-38.pyc │ ├── mel_features.cpython-38.pyc │ ├── vggish_input.cpython-38.pyc │ └── vggish_params.cpython-38.pyc ├── vggish_params.py ├── vggish_input.py ├── vggish.py └── mel_features.py ├── preprocess.sh ├── make_vid.sh ├── requirements.txt ├── run_didemo.sh ├── run_audio-only.sh ├── run_qvh.sh ├── comp_flops.sh ├── run_cha.sh ├── debug.sh ├── run_yc2.sh ├── run_act-B16.sh ├── run_act-B32.sh ├── LICENSE ├── preproc.sh ├── feature_ext.sh ├── util.py ├── README.md ├── metrics.py └── preprocess └── compress_video.py /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /easy_kill_process.sh: -------------------------------------------------------------------------------- 1 | kill $(ps aux | grep "main_task_retrieval.py" | grep -v grep | awk '{print $2}') 2 | -------------------------------------------------------------------------------- /modules/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /dataloaders/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/modeling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/modeling.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/ast_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/ast_models.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/file_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/file_utils.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/module_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/module_clip.cpython-38.pyc -------------------------------------------------------------------------------- /torchvggish/__pycache__/vggish.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/module_cross.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/module_cross.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/optimization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/optimization.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/resnet_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/resnet_models.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/until_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/until_config.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/until_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/until_module.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/mel_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/mel_features.cpython-38.pyc -------------------------------------------------------------------------------- /torchvggish/__pycache__/mel_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/mel_features.cpython-38.pyc -------------------------------------------------------------------------------- /torchvggish/__pycache__/vggish_input.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish_input.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/rawvideo_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/rawvideo_util.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/vggish_params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/vggish_params.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/video_container.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/video_container.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/tokenization_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/tokenization_clip.cpython-38.pyc -------------------------------------------------------------------------------- /preprocess.sh: -------------------------------------------------------------------------------- 1 | python preprocess/compress_video.py --input_root /playpen-iop/yblin/AVE_Dataset/AVE \ 2 | --output_root /playpen-iop/yblin/AVE_Dataset/raw_audio -------------------------------------------------------------------------------- /torchvggish/__pycache__/vggish_params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish_params.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/data_dataloaders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/data_dataloaders.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_charades_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_charades_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_youcook_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_youcook_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_qvhighlight_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_qvhighlight_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_youcook_short_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_youcook_short_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataloader_activitynet_retrieval_demo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_activitynet_retrieval_demo.cpython-38.pyc -------------------------------------------------------------------------------- /make_vid.sh: -------------------------------------------------------------------------------- 1 | ffmpeg -r 30 -i /playpen-iop/yblin/act_val_att/v_iosb2TdQ7yY/%04d.jpg -i \ 2 | /playpen-storage/yblin/v1-2/audio_raw/v_iosb2TdQ7yY/0000.wav -strict -2 /playpen-iop/yblin/act_val_att/"v_iosb2TdQ7yY"_"att".mp4 3 | -------------------------------------------------------------------------------- /modules/cross-base/cross_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 512, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 300, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 4, 11 | "vocab_size": 512 12 | } 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==8.0.3 2 | boto3==1.18.51 3 | botocore==1.21.51 4 | deepspeed==0.5.10 5 | einops==0.3.2 6 | ffmpeg==1.4 7 | ftfy==6.0.3 8 | gpuinfo==1.0.0a7 9 | ipdb==0.13.9 10 | jsonlines==3.0.0 11 | numpy==1.20.3 12 | opencv_python==4.5.3.56 13 | pandas==1.3.3 14 | Pillow==9.2.0 15 | psutil==5.8.0 16 | regex==2021.9.24 17 | requests==2.26.0 18 | resampy==0.2.2 19 | scipy==1.7.1 20 | SoundFile==0.10.3.post1 21 | thop==0.0.31.post2005241907 22 | timm==0.4.5 23 | torch==1.7.1+cu110 24 | torchaudio==0.7.2 25 | torchvision==0.8.2+cu110 26 | tqdm==4.62.3 27 | wget==3.2 28 | -------------------------------------------------------------------------------- /run_didemo.sh: -------------------------------------------------------------------------------- 1 | # bash mykill.sh 2 | DATA_PATH=YOUR_PATH 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=30678 \ 6 | main_task_retrieval.py --do_train --num_thread_reader=6 \ 7 | --epochs=5 --batch_size=128 --n_display=50 \ 8 | --data_path ${DATA_PATH} \ 9 | --features_path ${DATA_PATH}/raw_video \ 10 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \ 11 | --lr 1e-4 --coef_lr 5e-3 --yb_factor_aduio 0.01 --yb_factor_audio_decay 0 --max_words 64 --max_frames 32 --max_audio_frames 32 \ 12 | --batch_size_val 16 \ 13 | --datatype didemo --feature_framerate 1 \ 14 | --freeze_layer_num 0 --slice_framepos 2 \ 15 | --loose_type --linear_patch 2d --sim_header meanP \ 16 | --yb_dual 0 --yb_av 1 --wandb 0 --model_name ECLIPSE_didemo_f32_B32 \ 17 | --pretrained_clip_name ViT-B/32 18 | -------------------------------------------------------------------------------- /run_audio-only.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/playpen-iop/yblin/didemo_data 2 | export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | 4 | 5 | 6 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=9864 \ 7 | main_task_retrieval.py --do_train --num_thread_reader=8 \ 8 | --epochs=20 --batch_size=64 --n_display=50 \ 9 | --data_path ${DATA_PATH} \ 10 | --features_path ${DATA_PATH}/raw_video \ 11 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \ 12 | --lr 1e-4 --coef_lr 1e-3 --yb_factor_aduio 0.5 --yb_factor_audio_decay 0.6 --max_words 64 --max_frames 8 --max_audio_frames 8 --batch_size_val 16 \ 13 | --datatype CUSTOM_DATASET --feature_framerate 1 \ 14 | --freeze_layer_num 0 --slice_framepos 2 \ 15 | --loose_type --linear_patch 2d --sim_header meanP \ 16 | --yb_dual 1 --yb_av 1 --wandb 1 --model_name Rebuttal_qvh_ours-8f \ 17 | --pretrained_clip_name ViT-B/32 -------------------------------------------------------------------------------- /run_qvh.sh: -------------------------------------------------------------------------------- 1 | # bash mykill.sh 2 | DATA_PATH=YOUR_PATH 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | 6 | 7 | 8 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=5278 \ 9 | main_task_retrieval.py --do_train --num_thread_reader=8 \ 10 | --epochs=20 --batch_size=64 --n_display=50 \ 11 | --data_path ${DATA_PATH} \ 12 | --features_path ${DATA_PATH}/raw_video \ 13 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \ 14 | --lr 1e-4 --coef_lr 1e-3 --yb_factor_aduio 0.5 --yb_factor_audio_decay 0.6 --max_words 128 --max_frames 32 \ 15 | --max_audio_frames 32 --batch_size_val 16 \ 16 | --datatype qvhighlight --feature_framerate 1 \ 17 | --freeze_layer_num 0 --slice_framepos 2 \ 18 | --loose_type --linear_patch 2d --sim_header meanP \ 19 | --yb_dual 0 --yb_av 1 --wandb 0 --model_name ECLIPSE_qvh_f32_B32 \ 20 | --pretrained_clip_name ViT-B/32 --audio_pt VGGSound_Audio_features_10s_aligned 21 | -------------------------------------------------------------------------------- /comp_flops.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/playpen-iop/yblin/v1-2 2 | export CUDA_VISIBLE_DEVICES=4 3 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=30678 \ 4 | compute_cost.py --do_train --num_thread_reader=0 \ 5 | --epochs=20 --batch_size=64 --n_display=50 \ 6 | --data_path ${DATA_PATH} \ 7 | --features_path ${DATA_PATH}/Activity_Videos \ 8 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 9 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 96 --max_audio_frames 96 \ 10 | --audio_cluster 4 --batch_size_val 1 --gradient_accumulation_steps 1 \ 11 | --datatype activity --feature_framerate 1 \ 12 | --freeze_layer_num 0 --slice_framepos 2 \ 13 | --loose_type --linear_patch 2d --sim_header meanP \ 14 | --yb_coff_dis 0 --yb_coff_loss 1 --yb_factor_1 0.001 --yb_audio_length 10 \ 15 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name VGGSound_10s_co-Res_-5block_baseline_16 \ 16 | --audio_pt VGGSound_Audio_features_10s_aligned 17 | -------------------------------------------------------------------------------- /run_cha.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=YOUR_PATH 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3 3 | 4 | python3 -m torch.distributed.launch --nproc_per_node=8 --master_port=30678 \ 5 | main_task_retrieval.py --do_train --num_thread_reader=8 \ 6 | --epochs=20 --batch_size=64 --n_display=50 \ 7 | --data_path ${DATA_PATH} \ 8 | --features_path ${DATA_PATH}/Activity_Videos \ 9 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 10 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 32 --max_audio_frames 32 \ 11 | --audio_cluster 4 --batch_size_val 16 --gradient_accumulation_steps 1 \ 12 | --datatype charades --feature_framerate 1 \ 13 | --freeze_layer_num 0 --slice_framepos 2 \ 14 | --loose_type --linear_patch 2d --sim_header meanP \ 15 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 50 --yb_factor_audio_decay 0.6 --yb_audio_length 10 \ 16 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_cha_f32_B32 \ 17 | --audio_pt VGGSound_Audio_features_10s_aligned 18 | 19 | 20 | -------------------------------------------------------------------------------- /debug.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/playpen-iop/yblin/v1-2 2 | export CUDA_VISIBLE_DEVICES=4 3 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=5277 \ 4 | main_task_retrieval.py --do_train --num_thread_reader=0 \ 5 | --epochs=1 --batch_size=64 --n_display=50 \ 6 | --data_path ${DATA_PATH} \ 7 | --features_path ${DATA_PATH}/Activity_Videos \ 8 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 9 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 2 --max_audio_frames 2 \ 10 | --audio_cluster 16 --batch_size_val 1 --gradient_accumulation_steps 1 \ 11 | --datatype qvhighlight --feature_framerate 1 \ 12 | --freeze_layer_num 0 --slice_framepos 2 \ 13 | --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 1 --yb_dual 0 --yb_av 0 \ 14 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \ 15 | --audio_pt VGGSound_Audio_features_20s_aligned 16 | # --audio_pt Audio_features_all_10s_aligned_stride_16 17 | 18 | -------------------------------------------------------------------------------- /run_yc2.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=YOUR_PATH 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=30668 \ 6 | main_task_retrieval.py --do_train --num_thread_reader=8 \ 7 | --epochs=20 --batch_size=64 --n_display=50 \ 8 | --data_path ${DATA_PATH} \ 9 | --features_path ${DATA_PATH}/Activity_Videos \ 10 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 11 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \ 12 | --audio_cluster 4 --batch_size_val 64 --gradient_accumulation_steps 1 \ 13 | --datatype youcook --feature_framerate 1 \ 14 | --freeze_layer_num 0 --slice_framepos 2 \ 15 | --loose_type --linear_patch 2d --sim_header meanP \ 16 | --yb_coff_dis 0 --yb_coff_loss 1 --yb_av 1 --yb_dual 1 --yb_factor_aduio 1 --yb_factor_audio_decay 0.1 --yb_audio_length 10 \ 17 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_yc2_f32_B32 \ 18 | --audio_pt VGGSound_Audio_features_10s_aligned 19 | -------------------------------------------------------------------------------- /run_act-B16.sh: -------------------------------------------------------------------------------- 1 | # bash mykill.sh 2 | DATA_PATH=YOUR_DATA_PATH 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 4 | 5 | 6 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=7414 \ 7 | main_task_retrieval.py --do_train --num_thread_reader=6 \ 8 | --epochs=40 --batch_size=64 --n_display=50 \ 9 | --data_path ${DATA_PATH} \ 10 | --features_path ${DATA_PATH}/Activity_Videos \ 11 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 12 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \ 13 | --audio_cluster 4 --batch_size_val 8 --gradient_accumulation_steps 1 \ 14 | --datatype activity --feature_framerate 1 \ 15 | --freeze_layer_num 0 --slice_framepos 2 \ 16 | --loose_type --linear_patch 2d --sim_header meanP \ 17 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 0.5 --yb_time_cross_audio 0 --yb_factor_audio_decay 0.2 --yb_reverse_norm 1 \ 18 | --yb_audio_length 10 \ 19 | --pretrained_clip_name ViT-B/16 --wandb 1 --fixed_length 8 --model_name ECLIPSE_act_f32_B16 \ 20 | --audio_pt VGGSound_Audio_features_10s_aligned 21 | 22 | 23 | -------------------------------------------------------------------------------- /run_act-B32.sh: -------------------------------------------------------------------------------- 1 | # bash mykill.sh 2 | # train on 4 gpus 3 | DATA_PATH=YOUR_DATA_PATH 4 | export CUDA_VISIBLE_DEVICES=0,1,2,3 5 | 6 | 7 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=7414 \ 8 | main_task_retrieval.py --do_train --num_thread_reader=6 \ 9 | --epochs=40 --batch_size=64 --n_display=50 \ 10 | --data_path ${DATA_PATH} \ 11 | --features_path ${DATA_PATH}/Activity_Videos \ 12 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 13 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \ 14 | --audio_cluster 4 --batch_size_val 8 --gradient_accumulation_steps 1 \ 15 | --datatype activity --feature_framerate 1 \ 16 | --freeze_layer_num 0 --slice_framepos 2 \ 17 | --loose_type --linear_patch 2d --sim_header meanP \ 18 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 0.5 --yb_time_cross_audio 0 --yb_factor_audio_decay 0.2 --yb_reverse_norm 1 \ 19 | --yb_audio_length 10 \ 20 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_act_f32_B32 \ 21 | --audio_pt VGGSound_Audio_features_10s_aligned 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yan-Bo Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /dataloaders/video_container.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import av 5 | 6 | 7 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 8 | """ 9 | Given the path to the video, return the pyav video container. 10 | Args: 11 | path_to_vid (str): path to the video. 12 | multi_thread_decode (bool): if True, perform multi-thread decoding. 13 | backend (str): decoder backend, options include `pyav` and 14 | `torchvision`, default is `pyav`. 15 | Returns: 16 | container (container): video container. 17 | """ 18 | if backend == "torchvision": 19 | with open(path_to_vid, "rb") as fp: 20 | container = fp.read() 21 | return container 22 | elif backend == "pyav": 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | return container 28 | else: 29 | raise NotImplementedError("Unknown backend {}".format(backend)) -------------------------------------------------------------------------------- /preproc.sh: -------------------------------------------------------------------------------- 1 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/training --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_train 2 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/validation --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_val 3 | 4 | 5 | 6 | python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train --output_root /playpen-iop/yblin/yk2/audio_raw_train 7 | python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val --output_root /playpen-iop/yblin/yk2/audio_raw_val 8 | 9 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_train --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train 10 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_val --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val 11 | 12 | 13 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train --output_root /playpen-iop/yblin/yk2/raw_videos_all/videos_frame_train 14 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val --output_root /playpen-iop/yblin/yk2/raw_videos_all/videos_frame_val -------------------------------------------------------------------------------- /feature_ext.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/playpen-iop/yblin/qvhighlight 2 | DATA_PATH=/playpen-storage/yblin/v1-2 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | # bash mykill.sh 6 | 7 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=7455 \ 8 | yb_feature_ext.py --do_train --num_thread_reader=32 \ 9 | --epochs=1 --batch_size=1 --batch_size_val 1 --n_display=50 \ 10 | --data_path ${DATA_PATH} \ 11 | --features_path ${DATA_PATH}/Activity_Videos \ 12 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 13 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 64 --max_audio_frames 32 \ 14 | --audio_cluster 16 --gradient_accumulation_steps 1 \ 15 | --datatype activity --feature_framerate 1 \ 16 | --freeze_layer_num 0 --slice_framepos 2 \ 17 | --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 10 \ 18 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \ 19 | --yb_av 1 --yb_dual 1 --yb_start_layer 6 \ 20 | --audio_pt VGGSound_Audio_features_10s_aligned 21 | 22 | 23 | 24 | # DATA_PATH=/playpen-iop/yblin/v1-2 25 | # python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=7445 \ 26 | # yb_feature_ext.py --do_train --num_thread_reader=0 \ 27 | # --epochs=1 --batch_size=1 --n_display=50 \ 28 | # --data_path ${DATA_PATH} \ 29 | # --features_path ${DATA_PATH}/Activity_Videos \ 30 | # --output_dir ckpts/ckpt_activity_retrieval_looseType \ 31 | # --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 2 --max_audio_frames 2 \ 32 | # --audio_cluster 16 --batch_size_val 1 --gradient_accumulation_steps 1 \ 33 | # --datatype youcook --feature_framerate 1 \ 34 | # --freeze_layer_num 0 --slice_framepos 2 \ 35 | # --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 10 \ 36 | # --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \ 37 | # --audio_pt VGGSound_Audio_features_10s_aligned -------------------------------------------------------------------------------- /dataloaders/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 16000 28 | STFT_WINDOW_LENGTH_SECONDS = 0.025 29 | STFT_HOP_LENGTH_SECONDS = 0.010 30 | NUM_MEL_BINS = NUM_BANDS 31 | MEL_MIN_HZ = 125 32 | MEL_MAX_HZ = 7500 33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 36 | 37 | # Parameters used for embedding postprocessing. 38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 39 | PCA_MEANS_NAME = 'pca_means' 40 | QUANTIZE_MIN_VAL = -2.0 41 | QUANTIZE_MAX_VAL = +2.0 42 | 43 | # Hyperparameters used in training. 44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 47 | 48 | # Names of ops, tensors, and features. 49 | INPUT_OP_NAME = 'vggish/input_features' 50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 51 | OUTPUT_OP_NAME = 'vggish/embedding' 52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 54 | -------------------------------------------------------------------------------- /torchvggish/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 16000 28 | STFT_WINDOW_LENGTH_SECONDS = 0.025 29 | STFT_HOP_LENGTH_SECONDS = 0.010 30 | NUM_MEL_BINS = NUM_BANDS 31 | MEL_MIN_HZ = 125 32 | MEL_MAX_HZ = 7500 33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 36 | 37 | # Parameters used for embedding postprocessing. 38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 39 | PCA_MEANS_NAME = 'pca_means' 40 | QUANTIZE_MIN_VAL = -2.0 41 | QUANTIZE_MAX_VAL = +2.0 42 | 43 | # Hyperparameters used in training. 44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 47 | 48 | # Names of ops, tensors, and features. 49 | INPUT_OP_NAME = 'vggish/input_features' 50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 51 | OUTPUT_OP_NAME = 'vggish/embedding' 52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 54 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | from torch._utils import ExceptionWrapper 5 | import logging 6 | 7 | def get_a_var(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj 10 | 11 | if isinstance(obj, list) or isinstance(obj, tuple): 12 | for result in map(get_a_var, obj): 13 | if isinstance(result, torch.Tensor): 14 | return result 15 | if isinstance(obj, dict): 16 | for result in map(get_a_var, obj.items()): 17 | if isinstance(result, torch.Tensor): 18 | return result 19 | return None 20 | 21 | def parallel_apply(fct, model, inputs, device_ids): 22 | modules = nn.parallel.replicate(model, device_ids) 23 | assert len(modules) == len(inputs) 24 | lock = threading.Lock() 25 | results = {} 26 | grad_enabled = torch.is_grad_enabled() 27 | 28 | def _worker(i, module, input): 29 | torch.set_grad_enabled(grad_enabled) 30 | device = get_a_var(input).get_device() 31 | try: 32 | with torch.cuda.device(device): 33 | # this also avoids accidental slicing of `input` if it is a Tensor 34 | if not isinstance(input, (list, tuple)): 35 | input = (input,) 36 | output = fct(module, *input) 37 | with lock: 38 | results[i] = output 39 | except Exception: 40 | with lock: 41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 42 | 43 | if len(modules) > 1: 44 | threads = [threading.Thread(target=_worker, args=(i, module, input)) 45 | for i, (module, input) in enumerate(zip(modules, inputs))] 46 | 47 | for thread in threads: 48 | thread.start() 49 | for thread in threads: 50 | thread.join() 51 | else: 52 | _worker(0, modules[0], inputs[0]) 53 | 54 | outputs = [] 55 | for i in range(len(inputs)): 56 | output = results[i] 57 | if isinstance(output, ExceptionWrapper): 58 | output.reraise() 59 | outputs.append(output) 60 | return outputs 61 | 62 | def get_logger(filename=None): 63 | logger = logging.getLogger('logger') 64 | logger.setLevel(logging.DEBUG) 65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 66 | datefmt='%m/%d/%Y %H:%M:%S', 67 | level=logging.INFO) 68 | if filename is not None: 69 | handler = logging.FileHandler(filename) 70 | handler.setLevel(logging.DEBUG) 71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 72 | logging.getLogger().addHandler(handler) 73 | return logger -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound 3 | 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | 6 | 7 | This is the PyTorch implementation of our paper:
8 | **ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound**
9 | [Yan-Bo Lin](https://genjib.github.io/), [Jie Lei](https://jayleicn.github.io/), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/), and [Gedas Bertasius](https://www.gedasbertasius.com/)
10 | In European Conference on Computer Vision, 2022.
11 | 12 | [paper](https://arxiv.org/abs/2204.02874) 13 | 14 | ### 📝 Preparation 15 | 1. `pip3 install requirements.txt` 16 | 2. Dataset: ActivityNet, QVHighlights, YouCook2, DiDeMo and Charades. 17 | 3. extract video frames in 3 fps. 18 | 4. extract audio features. 19 | 5. To load pretrained CLIP weight 20 | 21 | The download links are from official [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip) 22 | Download CLIP (ViT-B/32) weight, 23 | ```sh 24 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt 25 | ``` 26 | or, download CLIP (ViT-B/16) weight, 27 | ```sh 28 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt 29 | ``` 30 | 31 | 32 | ### 💿 Extract images and audio features. 33 | ```shell 34 | ActivityNet/ 35 | ├── raw_frames/ 36 | │ └── VIDEO_NAME/ 37 | │ ├── 0001.jpg 38 | │ ├── ... 39 | │ └── 00...jpg 40 | │ 41 | └── VGGSound_Audio_features_10s_aligned/ 42 | └── VIDEO_NAME/ 43 | ├── 0000.pt 44 | ├── ... 45 | └── 00...pt 46 | 47 | ``` 48 | 49 | 50 | 51 | ### 💿 Extracted audio features. 52 | VGGSound features on ActivityNet Captions: [Google Drive](https://drive.google.com/file/d/1PbZPrgO5HTuG_CORcS_zScQCUeFo1JOL/view?usp=sharing) 53 | 54 | ### 📚 Train and evaluate 55 | ActivityNet Captions: `bash run_act.sh` \ 56 | DiDemo: `bash run_didemo.sh` \ 57 | Charades: `bash run_cha.sh` \ 58 | QVHighlight:`bash run_qvh.sh` \ 59 | YouCook2: `bash run_yc2.sh` 60 | 61 | 62 | 63 | 64 | ### 🎓 Cite 65 | 66 | If you use this code in your research, please cite: 67 | 68 | ```bibtex 69 | @InProceedings{ECLIPSE_ECCV22, 70 | author = {Yan-Bo Lin and Jie Lei and Mohit Bansal and Gedas Bertasius}, 71 | title = {ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound}, 72 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, 73 | month = {October}, 74 | year = {2022} 75 | } 76 | ``` 77 | 78 | ### 👍 Acknowledgments 79 | Our code is based on [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip) and [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/) 80 | 81 | ### ✏ Future works 82 | * Preprocessed video frames and audio features 83 | 84 | 85 | ## License 86 | 87 | This project is licensed under [MIT License](LICENSE), as found in the LICENSE file. 88 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | 9 | def compute_metrics(x): 10 | sx = np.sort(-x, axis=1) 11 | d = np.diag(-x) 12 | d = d[:, np.newaxis] 13 | ind = sx - d 14 | ind = np.where(ind == 0) 15 | ind = ind[1] 16 | metrics = {} 17 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 18 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 19 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 20 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind) 21 | metrics['MR'] = np.median(ind) + 1 22 | metrics["MedianR"] = metrics['MR'] 23 | metrics["MeanR"] = np.mean(ind) + 1 24 | metrics["cols"] = [int(i) for i in list(ind)] 25 | return metrics 26 | 27 | def print_computed_metrics(metrics): 28 | r1 = metrics['R1'] 29 | r5 = metrics['R5'] 30 | r10 = metrics['R10'] 31 | mr = metrics['MR'] 32 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) 33 | 34 | # below two functions directly come from: https://github.com/Deferf/Experiments 35 | def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): 36 | if not torch.is_tensor(sim_tensor): 37 | sim_tensor = torch.tensor(sim_tensor) 38 | 39 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices. 40 | # Then obtain the double argsort to position the rank on the diagonal 41 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2) 42 | first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) 43 | second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) 44 | 45 | # Extracts ranks i.e diagonals 46 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) 47 | 48 | # Now we need to extract valid ranks, as some belong to inf padding values 49 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) 50 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) 51 | valid_ranks = ranks[mask] 52 | # A quick dimension check validates our results, there may be other correctness tests pending 53 | # Such as dot product localization, but that is for other time. 54 | #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) 55 | if not torch.is_tensor(valid_ranks): 56 | valid_ranks = torch.tensor(valid_ranks) 57 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} 58 | results["MedianR"] = float(torch.median(valid_ranks + 1)) 59 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) 60 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) 61 | results['MR'] = results["MedianR"] 62 | return results 63 | 64 | def tensor_video_to_text_sim(sim_tensor): 65 | if not torch.is_tensor(sim_tensor): 66 | sim_tensor = torch.tensor(sim_tensor) 67 | # Code to avoid nans 68 | sim_tensor[sim_tensor != sim_tensor] = float('-inf') 69 | # Forms a similarity matrix for use with rank at k 70 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True) 71 | return torch.squeeze(values).T 72 | -------------------------------------------------------------------------------- /torchvggish/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | # Modification: Return torch tensors rather than numpy arrays 19 | import torch 20 | 21 | import numpy as np 22 | import resampy 23 | 24 | from . import mel_features 25 | from . import vggish_params 26 | 27 | import soundfile as sf 28 | 29 | from ipdb import set_trace 30 | 31 | def waveform_to_examples(data, sample_rate, return_tensor=True): 32 | """Converts audio waveform into an array of examples for VGGish. 33 | 34 | Args: 35 | data: np.array of either one dimension (mono) or two dimensions 36 | (multi-channel, with the outer dimension representing channels). 37 | Each sample is generally expected to lie in the range [-1.0, +1.0], 38 | although this is not required. 39 | sample_rate: Sample rate of data. 40 | return_tensor: Return data as a Pytorch tensor ready for VGGish 41 | 42 | Returns: 43 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 44 | a sequence of examples, each of which contains a patch of log mel 45 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 46 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 47 | 48 | """ 49 | # Convert to mono. 50 | if len(data.shape) > 1: 51 | data = np.mean(data, axis=1) 52 | # Resample to the rate assumed by VGGish. 53 | if sample_rate != vggish_params.SAMPLE_RATE: 54 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 55 | 56 | set_trace() 57 | # Compute log mel spectrogram features. 58 | log_mel = mel_features.log_mel_spectrogram( 59 | data, 60 | audio_sample_rate=vggish_params.SAMPLE_RATE, 61 | log_offset=vggish_params.LOG_OFFSET, 62 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 63 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 64 | num_mel_bins=vggish_params.NUM_MEL_BINS, 65 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 66 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 67 | 68 | # Frame features into examples. 69 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 70 | example_window_length = int(round( 71 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 72 | example_hop_length = int(round( 73 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 74 | log_mel_examples = mel_features.frame( 75 | log_mel, 76 | window_length=example_window_length, 77 | hop_length=example_hop_length) 78 | 79 | if return_tensor: 80 | log_mel_examples = torch.tensor( 81 | log_mel_examples, requires_grad=True)[:, None, :, :].float() 82 | 83 | set_trace() 84 | return log_mel_examples 85 | 86 | 87 | def wavfile_to_examples(wav_file, return_tensor=True): 88 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 89 | 90 | Args: 91 | wav_file: String path to a file, or a file-like object. The file 92 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 93 | torch: Return data as a Pytorch tensor ready for VGGish 94 | 95 | Returns: 96 | See waveform_to_examples. 97 | """ 98 | wav_data, sr = sf.read(wav_file, dtype='int16') 99 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 100 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 101 | 102 | # sr = 16000 103 | return waveform_to_examples(samples, sr, return_tensor) 104 | -------------------------------------------------------------------------------- /dataloaders/rawvideo_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from PIL import Image 4 | # pytorch=1.7.1 5 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 6 | # pip install opencv-python 7 | import cv2 8 | from ipdb import set_trace 9 | 10 | class RawVideoExtractorCV2(): 11 | def __init__(self, centercrop=False, size=224, framerate=-1, ): 12 | self.centercrop = centercrop 13 | self.size = size 14 | # self.framerate = framerate 15 | self.framerate = 3 16 | self.transform = self._transform(self.size) 17 | 18 | def _transform(self, n_px): 19 | return Compose([ 20 | Resize(n_px, interpolation=Image.BICUBIC), 21 | CenterCrop(n_px), 22 | lambda image: image.convert("RGB"), 23 | ToTensor(), 24 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 25 | ]) 26 | 27 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): 28 | if start_time is not None or end_time is not None: 29 | assert isinstance(start_time, int) and isinstance(end_time, int) \ 30 | and start_time > -1 and end_time > start_time 31 | assert sample_fp > -1 32 | 33 | # Samples a frame sample_fp X frames. 34 | cap = cv2.VideoCapture(video_file) 35 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 36 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 37 | 38 | # try: 39 | total_duration = (frameCount + fps - 1) // fps 40 | # print('right: ',fps) 41 | # except: 42 | # print('fpssssssssssssssssssssssss:', fps) 43 | # set_trace() 44 | start_sec, end_sec = 0, total_duration 45 | 46 | if start_time is not None: 47 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration 48 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) 49 | 50 | interval = 1 51 | if sample_fp > 0: 52 | interval = fps // sample_fp 53 | else: 54 | sample_fp = fps 55 | if interval == 0: interval = 1 56 | 57 | inds = [ind for ind in np.arange(0, fps, interval)] 58 | assert len(inds) >= sample_fp 59 | inds = inds[:sample_fp] 60 | 61 | ret = True 62 | images, included = [], [] 63 | 64 | for sec in np.arange(start_sec, end_sec + 1): 65 | if not ret: break 66 | sec_base = int(sec * fps) 67 | for ind in inds: 68 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) 69 | ret, frame = cap.read() 70 | if not ret: break 71 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 72 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 73 | 74 | cap.release() 75 | 76 | if len(images) > 0: 77 | video_data = th.tensor(np.stack(images)) 78 | else: 79 | video_data = th.zeros(1) 80 | 81 | return {'video': video_data} 82 | 83 | def get_video_data(self, video_path, start_time=None, end_time=None): 84 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) 85 | return image_input 86 | 87 | def process_raw_data(self, raw_video_data): 88 | tensor_size = raw_video_data.size() 89 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) 90 | return tensor 91 | 92 | def process_frame_order(self, raw_video_data, frame_order=0): 93 | # 0: ordinary order; 1: reverse order; 2: random order. 94 | if frame_order == 0: 95 | pass 96 | elif frame_order == 1: 97 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) 98 | raw_video_data = raw_video_data[reverse_order, ...] 99 | elif frame_order == 2: 100 | random_order = np.arange(raw_video_data.size(0)) 101 | np.random.shuffle(random_order) 102 | raw_video_data = raw_video_data[random_order, ...] 103 | 104 | return raw_video_data 105 | 106 | # An ordinary video frame extractor based CV2 107 | RawVideoExtractor = RawVideoExtractorCV2 -------------------------------------------------------------------------------- /modules/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | # redirect to the cache, if necessary 50 | try: 51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 52 | except FileNotFoundError: 53 | if task_config is None or task_config.local_rank == 0: 54 | logger.error( 55 | "Model name '{}' was not found in model name list. " 56 | "We assumed '{}' was a path or url but couldn't find any file " 57 | "associated to this path or url.".format( 58 | pretrained_model_name, 59 | archive_file)) 60 | return None 61 | if resolved_archive_file == archive_file: 62 | if task_config is None or task_config.local_rank == 0: 63 | logger.info("loading archive file {}".format(archive_file)) 64 | else: 65 | if task_config is None or task_config.local_rank == 0: 66 | logger.info("loading archive file {} from cache at {}".format( 67 | archive_file, resolved_archive_file)) 68 | tempdir = None 69 | if os.path.isdir(resolved_archive_file): 70 | serialization_dir = resolved_archive_file 71 | else: 72 | # Extract archive to temp dir 73 | tempdir = tempfile.mkdtemp() 74 | if task_config is None or task_config.local_rank == 0: 75 | logger.info("extracting archive file {} to temp dir {}".format( 76 | resolved_archive_file, tempdir)) 77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 78 | archive.extractall(tempdir) 79 | serialization_dir = tempdir 80 | # Load config 81 | config_file = os.path.join(serialization_dir, cls.config_name) 82 | config = cls.from_json_file(config_file) 83 | config.type_vocab_size = type_vocab_size 84 | if task_config is None or task_config.local_rank == 0: 85 | logger.info("Model config {}".format(config)) 86 | 87 | if state_dict is None: 88 | weights_path = os.path.join(serialization_dir, cls.weights_name) 89 | if os.path.exists(weights_path): 90 | state_dict = torch.load(weights_path, map_location='cpu') 91 | else: 92 | if task_config is None or task_config.local_rank == 0: 93 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 94 | 95 | if tempdir: 96 | # Clean up temp dir 97 | shutil.rmtree(tempdir) 98 | 99 | return config, state_dict 100 | 101 | @classmethod 102 | def from_dict(cls, json_object): 103 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 104 | config = cls(vocab_size_or_config_json_file=-1) 105 | for key, value in json_object.items(): 106 | config.__dict__[key] = value 107 | return config 108 | 109 | @classmethod 110 | def from_json_file(cls, json_file): 111 | """Constructs a `BertConfig` from a json file of parameters.""" 112 | with open(json_file, "r", encoding='utf-8') as reader: 113 | text = reader.read() 114 | return cls.from_dict(json.loads(text)) 115 | 116 | def __repr__(self): 117 | return str(self.to_json_string()) 118 | 119 | def to_dict(self): 120 | """Serializes this instance to a Python dictionary.""" 121 | output = copy.deepcopy(self.__dict__) 122 | return output 123 | 124 | def to_json_string(self): 125 | """Serializes this instance to a JSON string.""" 126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /modules/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | from ipdb import set_trace 10 | 11 | @lru_cache() 12 | def default_bpe(): 13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 14 | 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 28 | cs = bs[:] 29 | n = 0 30 | for b in range(2**8): 31 | if b not in bs: 32 | bs.append(b) 33 | cs.append(2**8+n) 34 | n += 1 35 | cs = [chr(n) for n in cs] 36 | return dict(zip(bs, cs)) 37 | 38 | 39 | def get_pairs(word): 40 | """Return set of symbol pairs in a word. 41 | Word is represented as tuple of symbols (symbols being variable-length strings). 42 | """ 43 | pairs = set() 44 | prev_char = word[0] 45 | for char in word[1:]: 46 | pairs.add((prev_char, char)) 47 | prev_char = char 48 | return pairs 49 | 50 | 51 | def basic_clean(text): 52 | text = ftfy.fix_text(text) 53 | text = html.unescape(html.unescape(text)) 54 | return text.strip() 55 | 56 | 57 | def whitespace_clean(text): 58 | text = re.sub(r'\s+', ' ', text) 59 | text = text.strip() 60 | return text 61 | 62 | 63 | class SimpleTokenizer(object): 64 | def __init__(self, bpe_path: str = default_bpe()): 65 | self.byte_encoder = bytes_to_unicode() 66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 68 | merges = merges[1:49152-256-2+1] 69 | merges = [tuple(merge.split()) for merge in merges] 70 | vocab = list(bytes_to_unicode().values()) 71 | vocab = vocab + [v+'' for v in vocab] 72 | for merge in merges: 73 | vocab.append(''.join(merge)) 74 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 75 | self.encoder = dict(zip(vocab, range(len(vocab)))) 76 | self.decoder = {v: k for k, v in self.encoder.items()} 77 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 78 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 79 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 80 | 81 | self.vocab = self.encoder 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | 137 | def tokenize(self, text): 138 | tokens = [] 139 | text = whitespace_clean(basic_clean(text)).lower() 140 | for token in re.findall(self.pat, text): 141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 142 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 143 | return tokens 144 | 145 | def convert_tokens_to_ids(self, tokens): 146 | return [self.encoder[bpe_token] for bpe_token in tokens] -------------------------------------------------------------------------------- /preprocess/compress_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip 3 | Author: ArrowLuo 4 | """ 5 | import os 6 | import argparse 7 | import ffmpeg 8 | import subprocess 9 | import time 10 | import multiprocessing 11 | from multiprocessing import Pool 12 | import shutil 13 | try: 14 | from psutil import cpu_count 15 | except: 16 | from multiprocessing import cpu_count 17 | # multiprocessing.freeze_support() 18 | from ipdb import set_trace 19 | def compress(paras): 20 | input_video_path, output_video_path = paras 21 | output_video_path_ori = output_video_path #.split('.')[0]+'.mp4' 22 | 23 | 24 | output_video_path = os.path.splitext(output_video_path)[0] #+'.mp4' 25 | 26 | 27 | ## for audio/images extractation 28 | output_img_path = output_video_path + '/%04d.jpg' 29 | output_audio_path = output_video_path + '/%04d.wav' 30 | # output_audio2_path = output_video_path.split('.')[0]+'.wav' 31 | output_audio2_path = output_video_path + '.wav' 32 | try: 33 | # command = ['ffmpeg', 34 | # '-y', # (optional) overwrite output file if it exists 35 | # '-i', input_video_path, 36 | # '-c:v', 37 | # 'libx264', 38 | # '-c:a', 39 | # 'libmp3lame', 40 | # '-b:a', 41 | # '128K', 42 | # '-max_muxing_queue_size', '9999', 43 | # '-vf', 44 | # 'fps=3 ', # scale to 224 "scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'" 45 | # # '-max_muxing_queue_size', '9999', 46 | # # "scale=224:224", 47 | # # '-c:a', 'copy', 48 | # # 'fps=fps=30', # frames per second 49 | # output_video_path_ori, 50 | # ] 51 | 52 | ### ori compressed ----------> 53 | # command = ['ffmpeg', 54 | # '-y', # (optional) overwrite output file if it exists 55 | # '-i', input_video_path, 56 | # '-filter:v', 57 | # 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'', # scale to 224 58 | # '-map', '0:v', 59 | # '-r', '3', # frames per second 60 | # output_video_path_ori, 61 | # ] 62 | ########### <---------------- 63 | 64 | ############# for extract images ############## 65 | # command = ['ffmpeg', 66 | # '-y', # (optional) overwrite output file if it exists 67 | # '-i', input_video_path, 68 | # '-vf', 69 | # 'fps=3', 70 | # output_img_path, 71 | # ] 72 | ########## end extract images ################################### 73 | 74 | 75 | ######### for extract audio -----> 76 | # ffmpeg -i /playpen-iop/yblin/v1-2/train/v_XazKuBawFCM.mp4 -map 0:a -f segment -segment_time 10 -acodec pcm_s16le -ac 1 -ar 16000 /playpen-iop/yblin/v1-2/train_audio/output_%03d.wav 77 | # ffmpeg -y -i /playpen-iop/yblin/yk2/raw_videos_all/low_all_val/EpNUSTO2BI4.mp4 -map 0:a -f segment -segment_time 10000000 -acodec pcm_s16le -ac 1 -ar 16000 /playpen-iop/yblin/yk2/audio_raw_val/EpNUSTO2BI4.wav 78 | command = ['ffmpeg', 79 | '-y', # (optional) overwrite output file if it exists 80 | '-i', input_video_path, 81 | '-acodec', 'pcm_s16le', '-ac', '1', 82 | '-ar', '16000', # resample 83 | output_audio2_path, 84 | ] 85 | ### <------ 86 | 87 | 88 | ######### for extract audio 2 ########### 89 | # command = ['ffmpeg', 90 | # '-y', # (optional) overwrite output file if it exists 91 | # '-i', input_video_path, 92 | # '-map','0:a', '-f', 'segment', 93 | # '-segment_time', '10000000', # seconds here 94 | # '-acodec', 'pcm_s16le', '-ac', '1', 95 | # '-ar', '16000', # resample 96 | # output_audio_path, 97 | # ] 98 | ##################### 99 | 100 | # command= [ 101 | # 'mkdir', 102 | # output_video_path, # (optional) overwrite output file if it exists 103 | # ] 104 | 105 | print(command) 106 | 107 | # ffmpeg -y -i /playpen-iop/yblin/v1-2/val/v_6VT2jBflMAM.mp4 -vf fps=3 /playpen-iop/yblin/v1-2/val_low_scale/v_6VT2jBflMAM.mp4 108 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 109 | out, err = ffmpeg.communicate() 110 | retcode = ffmpeg.poll() 111 | # print something above for debug 112 | except Exception as e: 113 | raise e 114 | 115 | def prepare_input_output_pairs(input_root, output_root): 116 | input_video_path_list = [] 117 | output_video_path_list = [] 118 | for root, dirs, files in os.walk(input_root): 119 | 120 | for file_name in files: 121 | input_video_path = os.path.join(root, file_name) 122 | 123 | output_video_path = os.path.join(output_root, file_name) 124 | if os.path.exists(output_video_path): 125 | pass 126 | else: 127 | input_video_path_list.append(input_video_path) 128 | output_video_path_list.append(output_video_path) 129 | 130 | return input_video_path_list, output_video_path_list 131 | # ffmpeg -y -i /nas/longleaf/home/yanbo/dataset/msvd_data/MSVD_Videos/-4wsuPCjDBc_5_15.avi -vf "fps=3,scale=224:224" /nas/longleaf/home/yanbo/dataset/msvd_data/MSVD_Comp2/-4wsuPCjDBc_5_15.avi 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description='Compress video for speed-up') 136 | parser.add_argument('--input_root', type=str, help='input root') 137 | parser.add_argument('--output_root', type=str, help='output root') 138 | args = parser.parse_args() 139 | 140 | input_root = args.input_root 141 | output_root = args.output_root 142 | 143 | assert input_root != output_root 144 | 145 | if not os.path.exists(output_root): 146 | os.makedirs(output_root, exist_ok=True) 147 | 148 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root) 149 | 150 | 151 | print("Total video need to process: {}".format(len(input_video_path_list))) 152 | num_works = cpu_count() 153 | print("Begin with {}-core logical processor.".format(num_works)) 154 | 155 | 156 | # pool = Pool(num_works) 157 | pool = Pool(32) 158 | 159 | data_dict_list = pool.map(compress, 160 | [(input_video_path, output_video_path) for 161 | input_video_path, output_video_path in 162 | zip(input_video_path_list, output_video_path_list)]) 163 | pool.close() 164 | pool.join() 165 | 166 | -------------------------------------------------------------------------------- /torchvggish/vggish.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch import hub 5 | 6 | from . import vggish_input, vggish_params 7 | 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, features): 11 | super(VGG, self).__init__() 12 | self.features = features 13 | self.embeddings = nn.Sequential( 14 | nn.Linear(512 * 4 * 6, 4096), 15 | nn.ReLU(True), 16 | nn.Linear(4096, 4096), 17 | nn.ReLU(True), 18 | nn.Linear(4096, 128), 19 | nn.ReLU(True)) 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | 24 | # Transpose the output from features to 25 | # remain compatible with vggish embeddings 26 | x = torch.transpose(x, 1, 3) 27 | x = torch.transpose(x, 1, 2) 28 | x = x.contiguous() 29 | x = x.view(x.size(0), -1) 30 | 31 | return self.embeddings(x) 32 | 33 | 34 | class Postprocessor(nn.Module): 35 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a 36 | numpy array in order to preserve the gradient. 37 | 38 | "The initial release of AudioSet included 128-D VGGish embeddings for each 39 | segment of AudioSet. These released embeddings were produced by applying 40 | a PCA transformation (technically, a whitening transform is included as well) 41 | and 8-bit quantization to the raw embedding output from VGGish, in order to 42 | stay compatible with the YouTube-8M project which provides visual embeddings 43 | in the same format for a large set of YouTube videos. This class implements 44 | the same PCA (with whitening) and quantization transformations." 45 | """ 46 | 47 | def __init__(self): 48 | """Constructs a postprocessor.""" 49 | super(Postprocessor, self).__init__() 50 | # Create empty matrix, for user's state_dict to load 51 | self.pca_eigen_vectors = torch.empty( 52 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), 53 | dtype=torch.float, 54 | ) 55 | self.pca_means = torch.empty( 56 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float 57 | ) 58 | 59 | self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) 60 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) 61 | 62 | def postprocess(self, embeddings_batch): 63 | """Applies tensor postprocessing to a batch of embeddings. 64 | 65 | Args: 66 | embeddings_batch: An tensor of shape [batch_size, embedding_size] 67 | containing output from the embedding layer of VGGish. 68 | 69 | Returns: 70 | A tensor of the same shape as the input, containing the PCA-transformed, 71 | quantized, and clipped version of the input. 72 | """ 73 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( 74 | embeddings_batch.shape, 75 | ) 76 | assert ( 77 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE 78 | ), "Bad batch shape: %r" % (embeddings_batch.shape,) 79 | 80 | # Apply PCA. 81 | # - Embeddings come in as [batch_size, embedding_size]. 82 | # - Transpose to [embedding_size, batch_size]. 83 | # - Subtract pca_means column vector from each column. 84 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 85 | # where both are are equal to embedding_size in our case. 86 | # - Transpose result back to [batch_size, embedding_size]. 87 | pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() 88 | 89 | # Quantize by: 90 | # - clipping to [min, max] range 91 | clipped_embeddings = torch.clamp( 92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL 93 | ) 94 | # - convert to 8-bit in range [0.0, 255.0] 95 | quantized_embeddings = torch.round( 96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) 97 | * ( 98 | 255.0 99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) 100 | ) 101 | ) 102 | return torch.squeeze(quantized_embeddings) 103 | 104 | def forward(self, x): 105 | return self.postprocess(x) 106 | 107 | 108 | def make_layers(): 109 | layers = [] 110 | in_channels = 1 111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | layers += [conv2d, nn.ReLU(inplace=True)] 117 | in_channels = v 118 | return nn.Sequential(*layers) 119 | 120 | 121 | def _vgg(): 122 | return VGG(make_layers()) 123 | 124 | 125 | # def _spectrogram(): 126 | # config = dict( 127 | # sr=16000, 128 | # n_fft=400, 129 | # n_mels=64, 130 | # hop_length=160, 131 | # window="hann", 132 | # center=False, 133 | # pad_mode="reflect", 134 | # htk=True, 135 | # fmin=125, 136 | # fmax=7500, 137 | # output_format='Magnitude', 138 | # # device=device, 139 | # ) 140 | # return Spectrogram.MelSpectrogram(**config) 141 | 142 | 143 | class VGGish(VGG): 144 | def __init__(self, urls, device=None, pretrained=True, preprocess=True, postprocess=True, progress=True): 145 | super().__init__(make_layers()) 146 | if pretrained: 147 | state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress) 148 | super().load_state_dict(state_dict) 149 | 150 | if device is None: 151 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 152 | self.device = device 153 | self.preprocess = preprocess 154 | self.postprocess = postprocess 155 | if self.postprocess: 156 | self.pproc = Postprocessor() 157 | if pretrained: 158 | state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress) 159 | # TODO: Convert the state_dict to torch 160 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( 161 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float 162 | ) 163 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( 164 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float 165 | ) 166 | 167 | self.pproc.load_state_dict(state_dict) 168 | self.to(self.device) 169 | 170 | def forward(self, x, fs=None): 171 | # if self.preprocess: 172 | # x = self._preprocess(x, fs) 173 | x = x.to(self.device) 174 | x = VGG.forward(self, x) 175 | # if self.postprocess: 176 | # x = self._postprocess(x) 177 | return x 178 | 179 | def _preprocess(self, x, fs): 180 | if isinstance(x, np.ndarray): 181 | x = vggish_input.waveform_to_examples(x, fs) 182 | elif isinstance(x, str): 183 | x = vggish_input.wavfile_to_examples(x) 184 | else: 185 | raise AttributeError 186 | return x 187 | 188 | def _postprocess(self, x): 189 | return self.pproc(x) 190 | -------------------------------------------------------------------------------- /modules/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /dataloaders/dataloader_lsmdc_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | import math 11 | from dataloaders.rawvideo_util import RawVideoExtractor 12 | 13 | class LSMDC_DataLoader(Dataset): 14 | """LSMDC dataset loader.""" 15 | def __init__( 16 | self, 17 | subset, 18 | data_path, 19 | features_path, 20 | tokenizer, 21 | max_words=30, 22 | feature_framerate=1.0, 23 | max_frames=100, 24 | image_resolution=224, 25 | frame_order=0, 26 | slice_framepos=0, 27 | ): 28 | self.data_path = data_path 29 | self.features_path = features_path 30 | self.feature_framerate = feature_framerate 31 | self.max_words = max_words 32 | self.max_frames = max_frames 33 | self.tokenizer = tokenizer 34 | # 0: ordinary order; 1: reverse order; 2: random order. 35 | self.frame_order = frame_order 36 | assert self.frame_order in [0, 1, 2] 37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 38 | self.slice_framepos = slice_framepos 39 | assert self.slice_framepos in [0, 1, 2] 40 | 41 | self.subset = subset 42 | assert self.subset in ["train", "val", "test"] 43 | 44 | video_json_path_dict = {} 45 | video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv") 46 | video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv") 47 | video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv") 48 | 49 | # \t\t\t\t\t 50 | # is not a unique identifier, i.e. the same can be associated with multiple sentences. 51 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances 52 | video_id_list = [] 53 | caption_dict = {} 54 | with open(video_json_path_dict[self.subset], 'r') as fp: 55 | for line in fp: 56 | line = line.strip() 57 | line_split = line.split("\t") 58 | assert len(line_split) == 6 59 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split 60 | caption_dict[len(caption_dict)] = (clip_id, sentence) 61 | if clip_id not in video_id_list: video_id_list.append(clip_id) 62 | 63 | video_dict = {} 64 | for root, dub_dir, video_files in os.walk(self.features_path): 65 | for video_file in video_files: 66 | video_id_ = ".".join(video_file.split(".")[:-1]) 67 | if video_id_ not in video_id_list: 68 | continue 69 | file_path_ = os.path.join(root, video_file) 70 | video_dict[video_id_] = file_path_ 71 | 72 | self.video_dict = video_dict 73 | 74 | # Get all captions 75 | self.iter2video_pairs_dict = {} 76 | for clip_id, sentence in caption_dict.values(): 77 | if clip_id not in self.video_dict: 78 | continue 79 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence) 80 | 81 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 82 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 83 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 84 | 85 | def __len__(self): 86 | return len(self.iter2video_pairs_dict) 87 | 88 | def _get_video_id_from_pseduo(self, pseudo_video_id): 89 | video_id = pseudo_video_id[2:] 90 | return video_id 91 | 92 | def _get_video_id_single(self, path): 93 | pseudo_video_id_list = [] 94 | video_id_list = [] 95 | print('Loading json: {}'.format(path)) 96 | with open(path, 'r') as f: 97 | json_data = json.load(f) 98 | 99 | for pseudo_video_id in json_data: 100 | if pseudo_video_id in pseudo_video_id_list: 101 | print("reduplicate.") 102 | else: 103 | video_id = self._get_video_id_from_pseduo(pseudo_video_id) 104 | pseudo_video_id_list.append(pseudo_video_id) 105 | video_id_list.append(video_id) 106 | return pseudo_video_id_list, video_id_list 107 | 108 | def _get_captions_single(self, path): 109 | pseudo_caption_dict = {} 110 | with open(path, 'r') as f: 111 | json_data = json.load(f) 112 | 113 | for pseudo_video_id, v_ in json_data.items(): 114 | pseudo_caption_dict[pseudo_video_id] = {} 115 | timestamps = v_["timestamps"] 116 | pseudo_caption_dict[pseudo_video_id]["start"] = \ 117 | np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object) 118 | pseudo_caption_dict[pseudo_video_id]["end"] = \ 119 | np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object) 120 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object) 121 | return pseudo_caption_dict 122 | 123 | def _get_text(self, video_id, caption): 124 | k = 1 125 | choice_video_ids = [video_id] 126 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 127 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 128 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 129 | 130 | for i, video_id in enumerate(choice_video_ids): 131 | words = self.tokenizer.tokenize(caption) 132 | 133 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 134 | total_length_with_CLS = self.max_words - 1 135 | if len(words) > total_length_with_CLS: 136 | words = words[:total_length_with_CLS] 137 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 138 | 139 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 140 | input_mask = [1] * len(input_ids) 141 | segment_ids = [0] * len(input_ids) 142 | while len(input_ids) < self.max_words: 143 | input_ids.append(0) 144 | input_mask.append(0) 145 | segment_ids.append(0) 146 | assert len(input_ids) == self.max_words 147 | assert len(input_mask) == self.max_words 148 | assert len(segment_ids) == self.max_words 149 | 150 | pairs_text[i] = np.array(input_ids) 151 | pairs_mask[i] = np.array(input_mask) 152 | pairs_segment[i] = np.array(segment_ids) 153 | 154 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 155 | 156 | def _get_rawvideo(self, choice_video_ids): 157 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 158 | max_video_length = [0] * len(choice_video_ids) 159 | 160 | # Pair x L x T x 3 x H x W 161 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 162 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 163 | 164 | try: 165 | for i, video_id in enumerate(choice_video_ids): 166 | video_path = self.video_dict[video_id] 167 | 168 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 169 | raw_video_data = raw_video_data['video'] 170 | 171 | if len(raw_video_data.shape) > 3: 172 | raw_video_data_clip = raw_video_data 173 | # L x T x 3 x H x W 174 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 175 | if self.max_frames < raw_video_slice.shape[0]: 176 | if self.slice_framepos == 0: 177 | video_slice = raw_video_slice[:self.max_frames, ...] 178 | elif self.slice_framepos == 1: 179 | video_slice = raw_video_slice[-self.max_frames:, ...] 180 | else: 181 | sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int) 182 | video_slice = raw_video_slice[sample_indx, ...] 183 | else: 184 | video_slice = raw_video_slice 185 | 186 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 187 | 188 | slice_len = video_slice.shape[0] 189 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 190 | if slice_len < 1: 191 | pass 192 | else: 193 | video[i][:slice_len, ...] = video_slice 194 | else: 195 | print("video path: {} error. video id: {}".format(video_path, video_id)) 196 | except Exception as excep: 197 | print("Video ids: {}".format(choice_video_ids)) 198 | raise excep 199 | 200 | for i, v_length in enumerate(max_video_length): 201 | video_mask[i][:v_length] = [1] * v_length 202 | return video, video_mask 203 | 204 | def __getitem__(self, feature_idx): 205 | clip_id, sentence = self.iter2video_pairs_dict[feature_idx] 206 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence) 207 | video, video_mask = self._get_rawvideo(choice_video_ids) 208 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 209 | -------------------------------------------------------------------------------- /modules/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | import json 8 | import math 9 | import logging 10 | import tarfile 11 | import tempfile 12 | import shutil 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from .file_utils import cached_path 18 | from .until_config import PretrainedConfig 19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 20 | from collections import OrderedDict 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 25 | CONFIG_NAME = 'cross_config.json' 26 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 27 | 28 | 29 | class CrossConfig(PretrainedConfig): 30 | """Configuration class to store the configuration of a `CrossModel`. 31 | """ 32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 33 | config_name = CONFIG_NAME 34 | weights_name = WEIGHTS_NAME 35 | def __init__(self, 36 | vocab_size_or_config_json_file, 37 | hidden_size=768, 38 | num_hidden_layers=12, 39 | num_attention_heads=12, 40 | intermediate_size=3072, 41 | hidden_act="gelu", 42 | hidden_dropout_prob=0.1, 43 | attention_probs_dropout_prob=0.1, 44 | max_position_embeddings=512, 45 | type_vocab_size=2, 46 | initializer_range=0.02): 47 | """Constructs CrossConfig. 48 | 49 | Args: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `CrossModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | """ 71 | if isinstance(vocab_size_or_config_json_file, str): 72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 73 | json_config = json.loads(reader.read()) 74 | for key, value in json_config.items(): 75 | self.__dict__[key] = value 76 | elif isinstance(vocab_size_or_config_json_file, int): 77 | self.vocab_size = vocab_size_or_config_json_file 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | else: 89 | raise ValueError("First argument must be either a vocabulary size (int)" 90 | "or the path to a pretrained model config file (str)") 91 | 92 | class QuickGELU(nn.Module): 93 | def forward(self, x: torch.Tensor): 94 | return x * torch.sigmoid(1.702 * x) 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, d_model: int, n_head: int): 98 | super().__init__() 99 | 100 | self.attn = nn.MultiheadAttention(d_model, n_head) 101 | self.ln_1 = LayerNorm(d_model) 102 | self.mlp = nn.Sequential(OrderedDict([ 103 | ("c_fc", nn.Linear(d_model, d_model * 4)), 104 | ("gelu", QuickGELU()), 105 | ("c_proj", nn.Linear(d_model * 4, d_model)) 106 | ])) 107 | self.ln_2 = LayerNorm(d_model) 108 | self.n_head = n_head 109 | 110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 111 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) 112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 113 | 114 | def forward(self, para_tuple: tuple): 115 | # x: torch.Tensor, attn_mask: torch.Tensor 116 | # print(para_tuple) 117 | x, attn_mask = para_tuple 118 | x = x + self.attention(self.ln_1(x), attn_mask) 119 | x = x + self.mlp(self.ln_2(x)) 120 | return (x, attn_mask) 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, width: int, layers: int, heads: int): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 128 | 129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 130 | return self.resblocks((x, attn_mask))[0] 131 | 132 | class CrossEmbeddings(nn.Module): 133 | """Construct the embeddings from word, position and token_type embeddings. 134 | """ 135 | def __init__(self, config): 136 | super(CrossEmbeddings, self).__init__() 137 | 138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, concat_embeddings, concat_type=None): 144 | 145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 146 | # if concat_type is None: 147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 148 | 149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 151 | 152 | # token_type_embeddings = self.token_type_embeddings(concat_type) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | 155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 156 | # embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | class CrossPooler(nn.Module): 161 | def __init__(self, config): 162 | super(CrossPooler, self).__init__() 163 | self.ln_pool = LayerNorm(config.hidden_size) 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.activation = QuickGELU() 166 | 167 | def forward(self, hidden_states, hidden_mask): 168 | # We "pool" the model by simply taking the hidden state corresponding 169 | # to the first token. 170 | hidden_states = self.ln_pool(hidden_states) 171 | pooled_output = hidden_states[:, 0] 172 | pooled_output = self.dense(pooled_output) 173 | pooled_output = self.activation(pooled_output) 174 | return pooled_output 175 | 176 | class CrossModel(PreTrainedModel): 177 | 178 | def initialize_parameters(self): 179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 180 | attn_std = self.transformer.width ** -0.5 181 | fc_std = (2 * self.transformer.width) ** -0.5 182 | for block in self.transformer.resblocks: 183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 187 | 188 | def __init__(self, config): 189 | super(CrossModel, self).__init__(config) 190 | 191 | self.embeddings = CrossEmbeddings(config) 192 | 193 | transformer_width = config.hidden_size 194 | transformer_layers = config.num_hidden_layers 195 | transformer_heads = config.num_attention_heads 196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,) 197 | self.pooler = CrossPooler(config) 198 | self.apply(self.init_weights) 199 | 200 | def build_attention_mask(self, attention_mask): 201 | extended_attention_mask = attention_mask.unsqueeze(1) 202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 205 | return extended_attention_mask 206 | 207 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): 208 | 209 | if attention_mask is None: 210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 211 | if concat_type is None: 212 | concat_type = torch.zeros_like(attention_mask) 213 | 214 | extended_attention_mask = self.build_attention_mask(attention_mask) 215 | 216 | embedding_output = self.embeddings(concat_input, concat_type) 217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 218 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 220 | 221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 222 | 223 | return embedding_output, pooled_output 224 | -------------------------------------------------------------------------------- /dataloaders/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /torchvggish/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /dataloaders/dataloader_msvd_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle 10 | from dataloaders.rawvideo_util import RawVideoExtractor 11 | from . import video_container as container 12 | from . import decoder as decoder 13 | from . import utils as utils 14 | from ipdb import set_trace 15 | class MSVD_DataLoader(Dataset): 16 | """MSVD dataset loader.""" 17 | def __init__( 18 | self, 19 | subset, 20 | data_path, 21 | features_path, 22 | tokenizer, 23 | max_words=30, 24 | feature_framerate=1.0, 25 | max_frames=100, 26 | image_resolution=224, 27 | frame_order=0, 28 | slice_framepos=0, 29 | ): 30 | self.data_path = data_path 31 | self.features_path = features_path 32 | self.feature_framerate = feature_framerate 33 | self.max_words = max_words 34 | self.max_frames = max_frames 35 | self.tokenizer = tokenizer 36 | # 0: ordinary order; 1: reverse order; 2: random order. 37 | self.frame_order = frame_order 38 | assert self.frame_order in [0, 1, 2] 39 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 40 | self.slice_framepos = slice_framepos 41 | assert self.slice_framepos in [0, 1, 2] 42 | 43 | self.subset = subset 44 | assert self.subset in ["train", "val", "test"] 45 | video_id_path_dict = {} 46 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 47 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 48 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 49 | caption_file = os.path.join(self.data_path, "raw-captions.pkl") 50 | 51 | with open(video_id_path_dict[self.subset], 'r') as fp: 52 | video_ids = [itm.strip() for itm in fp.readlines()] 53 | 54 | with open(caption_file, 'rb') as f: 55 | captions = pickle.load(f) 56 | 57 | video_dict = {} 58 | for root, dub_dir, video_files in os.walk(self.features_path): 59 | for video_file in video_files: 60 | video_id_ = ".".join(video_file.split(".")[:-1]) 61 | if video_id_ not in video_ids: 62 | continue 63 | file_path_ = os.path.join(root, video_file) 64 | video_dict[video_id_] = file_path_ 65 | self.video_dict = video_dict 66 | 67 | self.sample_len = 0 68 | self.sentences_dict = {} 69 | self.cut_off_points = [] 70 | for video_id in video_ids: 71 | assert video_id in captions 72 | for cap in captions[video_id]: 73 | cap_txt = " ".join(cap) 74 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) 75 | self.cut_off_points.append(len(self.sentences_dict)) 76 | 77 | ## below variables are used to multi-sentences retrieval 78 | # self.cut_off_points: used to tag the label when calculate the metric 79 | # self.sentence_num: used to cut the sentence representation 80 | # self.video_num: used to cut the video representation 81 | self.multi_sentence_per_video = True # !!! important tag for eval 82 | if self.subset == "val" or self.subset == "test": 83 | self.sentence_num = len(self.sentences_dict) 84 | self.video_num = len(video_ids) 85 | assert len(self.cut_off_points) == self.video_num 86 | print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) 87 | print("For {}, video number: {}".format(self.subset, self.video_num)) 88 | 89 | print("Video number: {}".format(len(self.video_dict))) 90 | print("Total Paire: {}".format(len(self.sentences_dict))) 91 | 92 | self.sample_len = len(self.sentences_dict) 93 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 94 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 95 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 96 | 97 | def __len__(self): 98 | return self.sample_len 99 | 100 | def _get_text(self, video_id, caption): 101 | k = 1 102 | choice_video_ids = [video_id] 103 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 104 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 105 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 106 | 107 | for i, video_id in enumerate(choice_video_ids): 108 | words = self.tokenizer.tokenize(caption) 109 | 110 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 111 | total_length_with_CLS = self.max_words - 1 112 | if len(words) > total_length_with_CLS: 113 | words = words[:total_length_with_CLS] 114 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 115 | 116 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 117 | input_mask = [1] * len(input_ids) 118 | segment_ids = [0] * len(input_ids) 119 | while len(input_ids) < self.max_words: 120 | input_ids.append(0) 121 | input_mask.append(0) 122 | segment_ids.append(0) 123 | assert len(input_ids) == self.max_words 124 | assert len(input_mask) == self.max_words 125 | assert len(segment_ids) == self.max_words 126 | 127 | pairs_text[i] = np.array(input_ids) 128 | pairs_mask[i] = np.array(input_mask) 129 | pairs_segment[i] = np.array(segment_ids) 130 | 131 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 132 | 133 | def _get_rawvideo(self, choice_video_ids): 134 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 135 | max_video_length = [0] * len(choice_video_ids) 136 | 137 | # Pair x L x T x 3 x H x W 138 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 139 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 140 | 141 | for i, video_id in enumerate(choice_video_ids): 142 | video_path = self.video_dict[video_id] 143 | 144 | set_trace() 145 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 146 | raw_video_data = raw_video_data['video'] 147 | 148 | if len(raw_video_data.shape) > 3: 149 | raw_video_data_clip = raw_video_data 150 | # L x T x 3 x H x W 151 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 152 | if self.max_frames < raw_video_slice.shape[0]: 153 | if self.slice_framepos == 0: 154 | video_slice = raw_video_slice[:self.max_frames, ...] 155 | elif self.slice_framepos == 1: 156 | video_slice = raw_video_slice[-self.max_frames:, ...] 157 | else: 158 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 159 | video_slice = raw_video_slice[sample_indx, ...] 160 | else: 161 | video_slice = raw_video_slice 162 | 163 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 164 | 165 | slice_len = video_slice.shape[0] 166 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 167 | if slice_len < 1: 168 | pass 169 | else: 170 | video[i][:slice_len, ...] = video_slice 171 | else: 172 | print("video path: {} error. video id: {}".format(video_path, video_id)) 173 | 174 | for i, v_length in enumerate(max_video_length): 175 | video_mask[i][:v_length] = [1] * v_length 176 | 177 | return video, video_mask 178 | 179 | def _get_rawvideo_yb(self, choice_video_ids): 180 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 181 | max_video_length = [0] * len(choice_video_ids) 182 | 183 | # Pair x L x T x 3 x H x W 184 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 185 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 186 | 187 | for i, video_id in enumerate(choice_video_ids): 188 | video_path = self.video_dict[video_id] 189 | 190 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 191 | raw_video_data = raw_video_data['video'] 192 | 193 | 194 | 195 | video_container = container.get_video_container( 196 | path_to_vid=video_path, 197 | multi_thread_decode=False, 198 | backend="pyav", 199 | ) 200 | 201 | 202 | frames, vid_len = decoder.decode( 203 | video_container, 204 | sampling_rate=3, 205 | num_frames=12, 206 | clip_idx=1, 207 | num_clips=1, 208 | video_meta=None, 209 | target_fps=1, 210 | backend="pyav", 211 | max_spatial_scale=0, 212 | use_offset=False, 213 | ) 214 | 215 | 216 | frames = utils.tensor_normalize( 217 | frames, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711] 218 | ) 219 | 220 | # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 221 | # [0.485, 0.456, 0.406] 222 | # [0.229, 0.224, 0.225] 223 | # [0.45, 0.45, 0.45] 224 | # [0.225, 0.225, 0.225] 225 | 226 | 227 | # if len(raw_video_data.shape) > 3: 228 | # raw_video_data_clip = raw_video_data 229 | # # L x T x 3 x H x W 230 | # raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 231 | # if self.max_frames < raw_video_slice.shape[0]: 232 | # if self.slice_framepos == 0: 233 | # video_slice = raw_video_slice[:self.max_frames, ...] 234 | # elif self.slice_framepos == 1: 235 | # video_slice = raw_video_slice[-self.max_frames:, ...] 236 | # else: 237 | # sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 238 | # video_slice = raw_video_slice[sample_indx, ...] 239 | # else: 240 | # video_slice = raw_video_slice 241 | 242 | # video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 243 | 244 | # slice_len = video_slice.shape[0] 245 | # max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 246 | # if slice_len < 1: 247 | # pass 248 | # else: 249 | # video[i][:slice_len, ...] = video_slice 250 | # else: 251 | # print("video path: {} error. video id: {}".format(video_path, video_id)) 252 | 253 | for i, v_length in enumerate([int(vid_len)]): 254 | if vid_len > self.max_frames: 255 | video_mask[:] = 1 256 | else: 257 | video_mask[i][:v_length] = [1] * v_length 258 | 259 | 260 | if vid_len < self.max_frames: 261 | frames[int(vid_len):] = 0 262 | 263 | frames = frames.permute(0,3,1,2).unsqueeze(0).unsqueeze(2) 264 | 265 | return frames, video_mask 266 | 267 | def __getitem__(self, idx): 268 | video_id, caption = self.sentences_dict[idx] 269 | 270 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) 271 | video, video_mask = self._get_rawvideo(choice_video_ids) 272 | # video, video_mask = self._get_rawvideo_yb(choice_video_ids) 273 | 274 | 275 | 276 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 277 | -------------------------------------------------------------------------------- /modules/ast_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/10/21 5:04 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : ast_models.py 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.cuda.amp import autocast 11 | import os 12 | import wget 13 | import timm 14 | from timm.models.layers import to_2tuple,trunc_normal_ 15 | from ipdb import set_trace 16 | 17 | # override the timm package to relax the input shape constraint. 18 | class PatchEmbed(nn.Module): 19 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 20 | super().__init__() 21 | 22 | img_size = to_2tuple(img_size) 23 | patch_size = to_2tuple(patch_size) 24 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 25 | self.img_size = img_size 26 | self.patch_size = patch_size 27 | self.num_patches = num_patches 28 | 29 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 30 | # Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10)) for audioset 31 | def forward(self, x): 32 | # x = 1x1x128x1024 33 | x = self.proj(x).flatten(2).transpose(1, 2) # 768x12x101 34 | return x 35 | 36 | class ASTModel(nn.Module): 37 | """ 38 | The AST model. 39 | :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35 40 | :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6 41 | :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6 42 | :param input_fdim: the number of frequency bins of the input spectrogram 43 | :param input_tdim: the number of time frames of the input spectrogram 44 | :param imagenet_pretrain: if use ImageNet pretrained model 45 | :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model 46 | :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining. 47 | """ 48 | def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True): 49 | 50 | super(ASTModel, self).__init__() 51 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.' 52 | 53 | if verbose == True: 54 | print('---------------AST Model Summary---------------') 55 | print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain))) 56 | # override timm input shape restriction 57 | timm.models.vision_transformer.PatchEmbed = PatchEmbed 58 | 59 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply) 60 | if audioset_pretrain == False: 61 | if model_size == 'tiny224': 62 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain) 63 | elif model_size == 'small224': 64 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain) 65 | elif model_size == 'base224': 66 | self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain) 67 | elif model_size == 'base384': 68 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain) 69 | else: 70 | raise Exception('Model size must be one of tiny224, small224, base224, base384.') 71 | self.original_num_patches = self.v.patch_embed.num_patches 72 | self.oringal_hw = int(self.original_num_patches ** 0.5) 73 | self.original_embedding_dim = self.v.pos_embed.shape[2] 74 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 75 | 76 | # automatcially get the intermediate shape 77 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 78 | num_patches = f_dim * t_dim 79 | self.v.patch_embed.num_patches = num_patches 80 | if verbose == True: 81 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 82 | print('number of patches={:d}'.format(num_patches)) 83 | 84 | # the linear projection layer 85 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 86 | if imagenet_pretrain == True: 87 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1)) 88 | new_proj.bias = self.v.patch_embed.proj.bias 89 | self.v.patch_embed.proj = new_proj 90 | 91 | # the positional embedding 92 | if imagenet_pretrain == True: 93 | # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24). 94 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw) 95 | # cut (from middle) or interpolate the second dimension of the positional embedding 96 | if t_dim <= self.oringal_hw: 97 | new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim] 98 | else: 99 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear') 100 | # cut (from middle) or interpolate the first dimension of the positional embedding 101 | if f_dim <= self.oringal_hw: 102 | new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :] 103 | else: 104 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 105 | # flatten the positional embedding 106 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2) 107 | # concatenate the above positional embedding with the cls token and distillation token of the deit model. 108 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 109 | else: 110 | # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding 111 | # TODO can use sinusoidal positional embedding instead 112 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim)) 113 | self.v.pos_embed = new_pos_embed 114 | trunc_normal_(self.v.pos_embed, std=.02) 115 | 116 | # now load a model that is pretrained on both ImageNet and AudioSet 117 | elif audioset_pretrain == True: 118 | if audioset_pretrain == True and imagenet_pretrain == False: 119 | raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.') 120 | if model_size != 'base384': 121 | raise ValueError('currently only has base384 AudioSet pretrained model.') 122 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 123 | if os.path.exists('./pretrained_models/audioset_10_10_0.4593.pth') == False: 124 | # this model performs 0.4593 mAP on the audioset eval set 125 | audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1' 126 | wget.download(audioset_mdl_url, out='./pretrained_models/audioset_10_10_0.4593.pth') 127 | sd = torch.load('./pretrained_models/audioset_10_10_0.4593.pth', map_location=device) 128 | 129 | # for 10 s audioset 130 | audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False) 131 | audio_model = torch.nn.DataParallel(audio_model) 132 | audio_model.load_state_dict(sd, strict=False) 133 | self.v = audio_model.module.v 134 | self.original_embedding_dim = self.v.pos_embed.shape[2] 135 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 136 | 137 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 138 | num_patches = f_dim * t_dim 139 | self.v.patch_embed.num_patches = num_patches 140 | if verbose == True: 141 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 142 | print('number of patches={:d}'.format(num_patches)) 143 | 144 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101) 145 | # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding 146 | if t_dim < 101: 147 | # new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim] ori 148 | 149 | ## yb add: f_dim:8 t_dim:64 150 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 151 | self.v.patch_embed.proj.stride = (fstride,tstride) 152 | ## yb end ### 153 | # otherwise interpolate 154 | else: 155 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 156 | 157 | 158 | new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2) 159 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 160 | 161 | def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024): 162 | test_input = torch.randn(1, 1, input_fdim, input_tdim) 163 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 164 | test_out = test_proj(test_input) 165 | f_dim = test_out.shape[2] 166 | t_dim = test_out.shape[3] 167 | return f_dim, t_dim 168 | 169 | @autocast() 170 | def forward(self, x): 171 | """ 172 | :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 173 | :return: prediction 174 | """ 175 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 176 | x = x.unsqueeze(1) 177 | x = x.transpose(2, 3) 178 | 179 | B = x.shape[0] 180 | x = self.v.patch_embed(x) 181 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 182 | dist_token = self.v.dist_token.expand(B, -1, -1) 183 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 184 | 185 | x = x + self.v.pos_embed 186 | 187 | x = self.v.pos_drop(x) 188 | for blk in self.v.blocks: 189 | x = blk(x) 190 | x = self.v.norm(x) 191 | 192 | 193 | ## two cls here 194 | # x = (x[:, 0] + x[:, 1]) / 2 195 | 196 | # x = self.mlp_head(x) 197 | return x 198 | 199 | if __name__ == '__main__': 200 | input_tdim = 100 201 | ast_mdl = ASTModel(input_tdim=input_tdim) 202 | # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins 203 | test_input = torch.rand([10, input_tdim, 128]) 204 | test_output = ast_mdl(test_input) 205 | # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. 206 | print(test_output.shape) 207 | 208 | input_tdim = 256 209 | ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True) 210 | # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins 211 | test_input = torch.rand([10, input_tdim, 128]) 212 | test_output = ast_mdl(test_input) 213 | # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes. 214 | print(test_output.shape) -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | import time 8 | from collections import defaultdict 9 | import cv2 10 | import torch 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | # from slowfast.utils.env import pathmgr 14 | 15 | from . import transform as transform 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def retry_load_images(image_paths, retry=10, backend="pytorch"): 21 | """ 22 | This function is to load images with support of retrying for failed load. 23 | Args: 24 | image_paths (list): paths of images needed to be loaded. 25 | retry (int, optional): maximum time of loading retrying. Defaults to 10. 26 | backend (str): `pytorch` or `cv2`. 27 | Returns: 28 | imgs (list): list of loaded images. 29 | """ 30 | for i in range(retry): 31 | imgs = [] 32 | for image_path in image_paths: 33 | with pathmgr.open(image_path, "rb") as f: 34 | img_str = np.frombuffer(f.read(), np.uint8) 35 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) 36 | imgs.append(img) 37 | 38 | if all(img is not None for img in imgs): 39 | if backend == "pytorch": 40 | imgs = torch.as_tensor(np.stack(imgs)) 41 | return imgs 42 | else: 43 | logger.warn("Reading failed. Will retry.") 44 | time.sleep(1.0) 45 | if i == retry - 1: 46 | raise Exception("Failed to load images {}".format(image_paths)) 47 | 48 | 49 | def get_sequence(center_idx, half_len, sample_rate, num_frames): 50 | """ 51 | Sample frames among the corresponding clip. 52 | Args: 53 | center_idx (int): center frame idx for current clip 54 | half_len (int): half of the clip length 55 | sample_rate (int): sampling rate for sampling frames inside of the clip 56 | num_frames (int): number of expected sampled frames 57 | Returns: 58 | seq (list): list of indexes of sampled frames in this clip. 59 | """ 60 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate)) 61 | 62 | for seq_idx in range(len(seq)): 63 | if seq[seq_idx] < 0: 64 | seq[seq_idx] = 0 65 | elif seq[seq_idx] >= num_frames: 66 | seq[seq_idx] = num_frames - 1 67 | return seq 68 | 69 | 70 | def pack_pathway_output(cfg, frames): 71 | """ 72 | Prepare output as a list of tensors. Each tensor corresponding to a 73 | unique pathway. 74 | Args: 75 | frames (tensor): frames of images sampled from the video. The 76 | dimension is `channel` x `num frames` x `height` x `width`. 77 | Returns: 78 | frame_list (list): list of tensors with the dimension of 79 | `channel` x `num frames` x `height` x `width`. 80 | """ 81 | if cfg.DATA.REVERSE_INPUT_CHANNEL: 82 | frames = frames[[2, 1, 0], :, :, :] 83 | if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH: 84 | frame_list = [frames] 85 | elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH: 86 | fast_pathway = frames 87 | # Perform temporal sampling from the fast pathway. 88 | slow_pathway = torch.index_select( 89 | frames, 90 | 1, 91 | torch.linspace( 92 | 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA 93 | ).long(), 94 | ) 95 | frame_list = [slow_pathway, fast_pathway] 96 | else: 97 | raise NotImplementedError( 98 | "Model arch {} is not in {}".format( 99 | cfg.MODEL.ARCH, 100 | cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH, 101 | ) 102 | ) 103 | return frame_list 104 | 105 | 106 | def spatial_sampling( 107 | frames, 108 | spatial_idx=-1, 109 | min_scale=256, 110 | max_scale=320, 111 | crop_size=224, 112 | random_horizontal_flip=True, 113 | inverse_uniform_sampling=False, 114 | aspect_ratio=None, 115 | scale=None, 116 | motion_shift=False, 117 | ): 118 | """ 119 | Perform spatial sampling on the given video frames. If spatial_idx is 120 | -1, perform random scale, random crop, and random flip on the given 121 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 122 | with the given spatial_idx. 123 | Args: 124 | frames (tensor): frames of images sampled from the video. The 125 | dimension is `num frames` x `height` x `width` x `channel`. 126 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 127 | or 2, perform left, center, right crop if width is larger than 128 | height, and perform top, center, buttom crop if height is larger 129 | than width. 130 | min_scale (int): the minimal size of scaling. 131 | max_scale (int): the maximal size of scaling. 132 | crop_size (int): the size of height and width used to crop the 133 | frames. 134 | inverse_uniform_sampling (bool): if True, sample uniformly in 135 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 136 | scale. If False, take a uniform sample from [min_scale, 137 | max_scale]. 138 | aspect_ratio (list): Aspect ratio range for resizing. 139 | scale (list): Scale range for resizing. 140 | motion_shift (bool): Whether to apply motion shift for resizing. 141 | Returns: 142 | frames (tensor): spatially sampled frames. 143 | """ 144 | assert spatial_idx in [-1, 0, 1, 2] 145 | if spatial_idx == -1: 146 | if aspect_ratio is None and scale is None: 147 | frames, _ = transform.random_short_side_scale_jitter( 148 | images=frames, 149 | min_size=min_scale, 150 | max_size=max_scale, 151 | inverse_uniform_sampling=inverse_uniform_sampling, 152 | ) 153 | frames, _ = transform.random_crop(frames, crop_size) 154 | else: 155 | transform_func = ( 156 | transform.random_resized_crop_with_shift 157 | if motion_shift 158 | else transform.random_resized_crop 159 | ) 160 | frames = transform_func( 161 | images=frames, 162 | target_height=crop_size, 163 | target_width=crop_size, 164 | scale=scale, 165 | ratio=aspect_ratio, 166 | ) 167 | if random_horizontal_flip: 168 | frames, _ = transform.horizontal_flip(0.5, frames) 169 | else: 170 | # The testing is deterministic and no jitter should be performed. 171 | # min_scale, max_scale, and crop_size are expect to be the same. 172 | assert len({min_scale, max_scale}) == 1 173 | frames, _ = transform.random_short_side_scale_jitter( 174 | frames, min_scale, max_scale 175 | ) 176 | frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx) 177 | return frames 178 | 179 | 180 | def as_binary_vector(labels, num_classes): 181 | """ 182 | Construct binary label vector given a list of label indices. 183 | Args: 184 | labels (list): The input label list. 185 | num_classes (int): Number of classes of the label vector. 186 | Returns: 187 | labels (numpy array): the resulting binary vector. 188 | """ 189 | label_arr = np.zeros((num_classes,)) 190 | 191 | for lbl in set(labels): 192 | label_arr[lbl] = 1.0 193 | return label_arr 194 | 195 | 196 | def aggregate_labels(label_list): 197 | """ 198 | Join a list of label list. 199 | Args: 200 | labels (list): The input label list. 201 | Returns: 202 | labels (list): The joint list of all lists in input. 203 | """ 204 | all_labels = [] 205 | for labels in label_list: 206 | for l in labels: 207 | all_labels.append(l) 208 | return list(set(all_labels)) 209 | 210 | 211 | def convert_to_video_level_labels(labels): 212 | """ 213 | Aggregate annotations from all frames of a video to form video-level labels. 214 | Args: 215 | labels (list): The input label list. 216 | Returns: 217 | labels (list): Same as input, but with each label replaced by 218 | a video-level one. 219 | """ 220 | for video_id in range(len(labels)): 221 | video_level_labels = aggregate_labels(labels[video_id]) 222 | for i in range(len(labels[video_id])): 223 | labels[video_id][i] = video_level_labels 224 | return labels 225 | 226 | 227 | def load_image_lists(frame_list_file, prefix="", return_list=False): 228 | """ 229 | Load image paths and labels from a "frame list". 230 | Each line of the frame list contains: 231 | `original_vido_id video_id frame_id path labels` 232 | Args: 233 | frame_list_file (string): path to the frame list. 234 | prefix (str): the prefix for the path. 235 | return_list (bool): if True, return a list. If False, return a dict. 236 | Returns: 237 | image_paths (list or dict): list of list containing path to each frame. 238 | If return_list is False, then return in a dict form. 239 | labels (list or dict): list of list containing label of each frame. 240 | If return_list is False, then return in a dict form. 241 | """ 242 | image_paths = defaultdict(list) 243 | labels = defaultdict(list) 244 | with pathmgr.open(frame_list_file, "r") as f: 245 | assert f.readline().startswith("original_vido_id") 246 | for line in f: 247 | row = line.split() 248 | # original_vido_id video_id frame_id path labels 249 | assert len(row) == 5 250 | video_name = row[0] 251 | if prefix == "": 252 | path = row[3] 253 | else: 254 | path = os.path.join(prefix, row[3]) 255 | image_paths[video_name].append(path) 256 | frame_labels = row[-1].replace('"', "") 257 | if frame_labels != "": 258 | labels[video_name].append( 259 | [int(x) for x in frame_labels.split(",")] 260 | ) 261 | else: 262 | labels[video_name].append([]) 263 | 264 | if return_list: 265 | keys = image_paths.keys() 266 | image_paths = [image_paths[key] for key in keys] 267 | labels = [labels[key] for key in keys] 268 | return image_paths, labels 269 | return dict(image_paths), dict(labels) 270 | 271 | 272 | def tensor_normalize(tensor, mean, std): 273 | """ 274 | Normalize a given tensor by subtracting the mean and dividing the std. 275 | Args: 276 | tensor (tensor): tensor to normalize. 277 | mean (tensor or list): mean value to subtract. 278 | std (tensor or list): std to divide. 279 | """ 280 | if tensor.dtype == torch.uint8: 281 | tensor = tensor.float() 282 | tensor = tensor / 255.0 283 | if type(mean) == list: 284 | mean = torch.tensor(mean) 285 | if type(std) == list: 286 | std = torch.tensor(std) 287 | tensor = tensor - mean 288 | tensor = tensor / std 289 | return tensor 290 | 291 | 292 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate): 293 | """ 294 | When multigrid training uses a fewer number of frames, we randomly 295 | increase the sampling rate so that some clips cover the original span. 296 | """ 297 | if long_cycle_sampling_rate > 0: 298 | assert long_cycle_sampling_rate >= sampling_rate 299 | return random.randint(sampling_rate, long_cycle_sampling_rate) 300 | else: 301 | return sampling_rate 302 | 303 | 304 | def revert_tensor_normalize(tensor, mean, std): 305 | """ 306 | Revert normalization for a given tensor by multiplying by the std and adding the mean. 307 | Args: 308 | tensor (tensor): tensor to revert normalization. 309 | mean (tensor or list): mean value to add. 310 | std (tensor or list): std to multiply. 311 | """ 312 | if type(mean) == list: 313 | mean = torch.tensor(mean) 314 | if type(std) == list: 315 | std = torch.tensor(std) 316 | tensor = tensor * std 317 | tensor = tensor + mean 318 | return tensor 319 | 320 | 321 | def create_sampler(dataset, shuffle, cfg): 322 | """ 323 | Create sampler for the given dataset. 324 | Args: 325 | dataset (torch.utils.data.Dataset): the given dataset. 326 | shuffle (bool): set to ``True`` to have the data reshuffled 327 | at every epoch. 328 | cfg (CfgNode): configs. Details can be found in 329 | slowfast/config/defaults.py 330 | Returns: 331 | sampler (Sampler): the created sampler. 332 | """ 333 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 334 | 335 | return sampler 336 | 337 | 338 | def loader_worker_init_fn(dataset): 339 | """ 340 | Create init function passed to pytorch data loader. 341 | Args: 342 | dataset (torch.utils.data.Dataset): the given dataset. 343 | """ 344 | return None -------------------------------------------------------------------------------- /modules/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | import logging 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | import math 24 | from modules.until_config import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def gelu(x): 29 | """Implementation of the gelu activation function. 30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | """ 33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 39 | 40 | class LayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | class PreTrainedModel(nn.Module): 56 | """ An abstract class to handle weights initialization and 57 | a simple interface for dowloading and loading pretrained models. 58 | """ 59 | def __init__(self, config, *inputs, **kwargs): 60 | super(PreTrainedModel, self).__init__() 61 | if not isinstance(config, PretrainedConfig): 62 | raise ValueError( 63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 64 | "To create a model from a Google pretrained model use " 65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 66 | self.__class__.__name__, self.__class__.__name__ 67 | )) 68 | self.config = config 69 | 70 | def init_weights(self, module): 71 | """ Initialize the weights. 72 | """ 73 | if isinstance(module, (nn.Linear, nn.Embedding)): 74 | # Slightly different from the TF version which uses truncated_normal for initialization 75 | # cf https://github.com/pytorch/pytorch/pull/5617 76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 77 | elif isinstance(module, LayerNorm): 78 | if 'beta' in dir(module) and 'gamma' in dir(module): 79 | module.beta.data.zero_() 80 | module.gamma.data.fill_(1.0) 81 | else: 82 | module.bias.data.zero_() 83 | module.weight.data.fill_(1.0) 84 | if isinstance(module, nn.Linear) and module.bias is not None: 85 | module.bias.data.zero_() 86 | 87 | def resize_token_embeddings(self, new_num_tokens=None): 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): 92 | old_keys = [] 93 | new_keys = [] 94 | for key in state_dict.keys(): 95 | new_key = None 96 | if 'gamma' in key: 97 | new_key = key.replace('gamma', 'weight') 98 | if 'beta' in key: 99 | new_key = key.replace('beta', 'bias') 100 | if new_key: 101 | old_keys.append(key) 102 | new_keys.append(new_key) 103 | for old_key, new_key in zip(old_keys, new_keys): 104 | state_dict[new_key] = state_dict.pop(old_key) 105 | 106 | if prefix is not None: 107 | old_keys = [] 108 | new_keys = [] 109 | for key in state_dict.keys(): 110 | old_keys.append(key) 111 | new_keys.append(prefix + key) 112 | for old_key, new_key in zip(old_keys, new_keys): 113 | state_dict[new_key] = state_dict.pop(old_key) 114 | 115 | missing_keys = [] 116 | unexpected_keys = [] 117 | error_msgs = [] 118 | # copy state_dict so _load_from_state_dict can modify it 119 | metadata = getattr(state_dict, '_metadata', None) 120 | state_dict = state_dict.copy() 121 | if metadata is not None: 122 | state_dict._metadata = metadata 123 | 124 | def load(module, prefix=''): 125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 126 | module._load_from_state_dict( 127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 128 | for name, child in module._modules.items(): 129 | if child is not None: 130 | load(child, prefix + name + '.') 131 | 132 | load(model, prefix='') 133 | 134 | if prefix is None and (task_config is None or task_config.local_rank == 0): 135 | logger.info("-" * 20) 136 | if len(missing_keys) > 0: 137 | logger.info("Weights of {} not initialized from pretrained model: {}" 138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 139 | if len(unexpected_keys) > 0: 140 | logger.info("Weights from pretrained model not used in {}: {}" 141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 142 | if len(error_msgs) > 0: 143 | logger.error("Weights from pretrained model cause errors in {}: {}" 144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 145 | 146 | return model 147 | 148 | @property 149 | def dtype(self): 150 | """ 151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 152 | """ 153 | try: 154 | return next(self.parameters()).dtype 155 | except StopIteration: 156 | # For nn.DataParallel compatibility in PyTorch 1.5 157 | def find_tensor_attributes(module: nn.Module): 158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 159 | return tuples 160 | 161 | gen = self._named_members(get_members_fn=find_tensor_attributes) 162 | first_tuple = next(gen) 163 | return first_tuple[1].dtype 164 | 165 | @classmethod 166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): 167 | """ 168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. 169 | Download and cache the pre-trained model file if needed. 170 | """ 171 | # Instantiate model. 172 | model = cls(config, *inputs, **kwargs) 173 | if state_dict is None: 174 | return model 175 | model = cls.init_preweight(model, state_dict) 176 | 177 | return model 178 | 179 | ################################## 180 | ###### LOSS FUNCTION ############# 181 | ################################## 182 | class CrossEn(nn.Module): 183 | def __init__(self,): 184 | super(CrossEn, self).__init__() 185 | 186 | def forward(self, sim_matrix, guide_dis=None): 187 | logpt = F.log_softmax(sim_matrix, dim=-1) 188 | logpt = torch.diag(logpt) 189 | nce_loss = -logpt 190 | sim_loss = nce_loss.mean() 191 | return sim_loss 192 | 193 | class yb_CrossEn(nn.Module): 194 | def __init__(self,): 195 | super(yb_CrossEn, self).__init__() 196 | 197 | def forward(self, sim_matrix, id_dict, vid_id, opt=None): 198 | id_dict = id_dict.flatten() 199 | 200 | from ipdb import set_trace 201 | 202 | logpt = F.log_softmax(sim_matrix, dim=-1) 203 | sim_loss = [] 204 | for my_idx in range(len(vid_id)): 205 | vid_loc = (vid_id[my_idx] == id_dict).nonzero(as_tuple=True)[0] 206 | sim_loss.append(-logpt[my_idx, vid_loc]) 207 | 208 | return sum(sim_loss)/len(sim_loss) 209 | 210 | 211 | # gt_loc = [] 212 | # for my_idx in range(len(vid_id)): 213 | # vid_loc = (vid_id[my_idx] == id_dict).nonzero(as_tuple=True)[0] 214 | # gt_loc.append(vid_loc.item()) 215 | 216 | # while True: 217 | # perm = torch.randperm(sim_matrix.size(1)) 218 | # rand_idx = perm[:opt.yb_sample_num] 219 | 220 | # if not any(x in gt_loc for x in rand_idx.tolist()): 221 | # break 222 | 223 | # new_sim_matrix = torch.cat((sim_matrix[:,gt_loc], sim_matrix[:,rand_idx]), dim=-1) 224 | 225 | # logpt = F.log_softmax(new_sim_matrix, dim=-1) 226 | # logpt = torch.diag(logpt) 227 | # nce_loss = -logpt 228 | # sim_loss = nce_loss.mean() 229 | 230 | # return sim_loss 231 | 232 | class MILNCELoss(nn.Module): 233 | def __init__(self, batch_size=1, n_pair=1,): 234 | super(MILNCELoss, self).__init__() 235 | self.batch_size = batch_size 236 | self.n_pair = n_pair 237 | torch_v = float(".".join(torch.__version__.split(".")[:2])) 238 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 239 | 240 | def forward(self, sim_matrix): 241 | mm_mask = np.eye(self.batch_size) 242 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) 243 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) 244 | 245 | from_text_matrix = sim_matrix + mm_mask * -1e12 246 | from_video_matrix = sim_matrix.transpose(1, 0) 247 | 248 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) 249 | logpt = F.log_softmax(new_sim_matrix, dim=-1) 250 | 251 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) 252 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 253 | 254 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1) 255 | 256 | logpt_choice = torch.zeros_like(new_logpt) 257 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) 258 | logpt_choice[mark_ind] = 1 259 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() 260 | return sim_loss 261 | 262 | class MaxMarginRankingLoss(nn.Module): 263 | def __init__(self, 264 | margin=1.0, 265 | negative_weighting=False, 266 | batch_size=1, 267 | n_pair=1, 268 | hard_negative_rate=0.5, 269 | ): 270 | super(MaxMarginRankingLoss, self).__init__() 271 | self.margin = margin 272 | self.n_pair = n_pair 273 | self.batch_size = batch_size 274 | easy_negative_rate = 1 - hard_negative_rate 275 | self.easy_negative_rate = easy_negative_rate 276 | self.negative_weighting = negative_weighting 277 | if n_pair > 1 and batch_size > 1: 278 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) 279 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha 280 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) 281 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) 282 | self.mm_mask = mm_mask.float() 283 | 284 | def forward(self, x): 285 | d = torch.diag(x) 286 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ 287 | F.relu(self.margin + x - d.view(1, -1)) 288 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: 289 | max_margin = max_margin * self.mm_mask.to(max_margin.device) 290 | return max_margin.mean() 291 | 292 | class AllGather(torch.autograd.Function): 293 | """An autograd function that performs allgather on a tensor.""" 294 | 295 | @staticmethod 296 | def forward(ctx, tensor, args): 297 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 298 | torch.distributed.all_gather(output, tensor) 299 | ctx.rank = args.rank 300 | ctx.batch_size = tensor.shape[0] 301 | return torch.cat(output, dim=0) 302 | 303 | @staticmethod 304 | def backward(ctx, grad_output): 305 | return ( 306 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 307 | None, 308 | ) 309 | -------------------------------------------------------------------------------- /modules/resnet_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | class AVENet(nn.Module): 9 | 10 | def __init__(self,args): 11 | super(AVENet, self).__init__() 12 | self.audnet = resnet18(num_classes=309, pool='avgpool') 13 | 14 | def forward(self, audio): 15 | aud = self.audnet(audio) 16 | return aud 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, groups=groups, bias=False, dilation=dilation) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes ,stride=1, downsample=None, groups=1, 33 | base_width=64, dilation=1, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = norm_layer(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | class ResNet(nn.Module): 69 | 70 | def __init__(self, block, layers, num_classes=1000, pool = 'avgpool',zero_init_residual=False, 71 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 72 | norm_layer=None): 73 | super(ResNet, self).__init__() 74 | self.pool = pool 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | self._norm_layer = norm_layer 78 | 79 | self.inplanes = 64 80 | self.dilation = 1 81 | if replace_stride_with_dilation is None: 82 | # each element in the tuple indicates if we should replace 83 | # the 2x2 stride with a dilated convolution instead 84 | replace_stride_with_dilation = [False, False, False] 85 | if len(replace_stride_with_dilation) != 3: 86 | raise ValueError("replace_stride_with_dilation should be None " 87 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 88 | self.groups = groups 89 | self.base_width = width_per_group 90 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 91 | bias=False) 92 | self.bn1 = norm_layer(self.inplanes) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, 64, layers[0]) 96 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 97 | dilate=replace_stride_with_dilation[0]) 98 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 99 | dilate=replace_stride_with_dilation[1]) 100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 101 | dilate=replace_stride_with_dilation[2]) 102 | if self.pool == 'avgpool': 103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 104 | 105 | self.fc = nn.Linear(512 * block.expansion, num_classes) # 8192 106 | elif self.pool == 'vlad': 107 | self.avgpool = NetVLAD() 108 | self.fc_ = nn.Linear(8192 * block.expansion, num_classes) 109 | 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 114 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 115 | nn.init.normal_(m.weight, mean=1, std=0.02) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | # Zero-initialize the last BN in each residual branch, 119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 121 | if zero_init_residual: 122 | for m in self.modules(): 123 | if isinstance(m, Bottleneck): 124 | nn.init.constant_(m.bn3.weight, 0) 125 | elif isinstance(m, BasicBlock): 126 | nn.init.constant_(m.bn2.weight, 0) 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 129 | norm_layer = self._norm_layer 130 | downsample = None 131 | previous_dilation = self.dilation 132 | if dilate: 133 | self.dilation *= stride 134 | stride = 1 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | conv1x1(self.inplanes, planes * block.expansion, stride), 138 | norm_layer(planes * block.expansion), 139 | ) 140 | 141 | 142 | layers = [] 143 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 144 | self.base_width, previous_dilation, norm_layer)) 145 | self.inplanes = planes * block.expansion 146 | for _ in range(1, blocks): 147 | layers.append(block(self.inplanes, planes, groups=self.groups, 148 | base_width=self.base_width, dilation=self.dilation, 149 | norm_layer=norm_layer)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | x = self.avgpool(x) 164 | x = x.reshape(x.size(0), -1) 165 | # if self.pool == 'avgpool': 166 | # x = self.fc(x) 167 | # elif self.pool == 'vlad': 168 | # x = self.fc_(x) 169 | 170 | return x 171 | 172 | class NetVLAD(nn.Module): 173 | """NetVLAD layer implementation""" 174 | 175 | def __init__(self, num_clusters=16, dim=512, alpha=100.0, 176 | normalize_input=True): 177 | """ 178 | Args: 179 | num_clusters : int 180 | The number of clusters 181 | dim : int 182 | Dimension of descriptors 183 | alpha : float 184 | Parameter of initialization. Larger value is harder assignment. 185 | normalize_input : bool 186 | If true, descriptor-wise L2 normalization is applied to input. 187 | """ 188 | super(NetVLAD, self).__init__() 189 | self.num_clusters = num_clusters 190 | self.dim = dim 191 | self.alpha = alpha 192 | self.normalize_input = normalize_input 193 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) 194 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) 195 | self._init_params() 196 | 197 | def _init_params(self): 198 | self.conv.weight = nn.Parameter( 199 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) 200 | ) 201 | self.conv.bias = nn.Parameter( 202 | - self.alpha * self.centroids.norm(dim=1) 203 | ) 204 | 205 | def forward(self, x): 206 | N, C = x.shape[:2] 207 | 208 | if self.normalize_input: 209 | x = F.normalize(x, p=2, dim=1) # across descriptor dim 210 | 211 | # soft-assignment 212 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 213 | soft_assign = F.softmax(soft_assign, dim=1) 214 | 215 | x_flatten = x.view(N, C, -1) 216 | 217 | # calculate residuals to each clusters 218 | residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ 219 | self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 220 | residual *= soft_assign.unsqueeze(2) 221 | vlad = residual.sum(dim=-1) 222 | 223 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 224 | vlad = vlad.view(x.size(0), -1) # flatten 225 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 226 | 227 | return vlad 228 | 229 | class Bottleneck(nn.Module): 230 | expansion = 4 231 | 232 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 233 | base_width=64, dilation=1, norm_layer=None): 234 | super(Bottleneck, self).__init__() 235 | if norm_layer is None: 236 | norm_layer = nn.BatchNorm2d 237 | width = int(planes * (base_width / 64.)) * groups 238 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 239 | self.conv1 = conv1x1(inplanes, width) 240 | self.bn1 = norm_layer(width) 241 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 242 | self.bn2 = norm_layer(width) 243 | self.conv3 = conv1x1(width, planes * self.expansion) 244 | self.bn3 = norm_layer(planes * self.expansion) 245 | self.relu = nn.ReLU(inplace=True) 246 | self.downsample = downsample 247 | self.stride = stride 248 | 249 | def forward(self, x): 250 | identity = x 251 | 252 | out = self.conv1(x) 253 | out = self.bn1(out) 254 | out = self.relu(out) 255 | 256 | out = self.conv2(out) 257 | out = self.bn2(out) 258 | out = self.relu(out) 259 | 260 | out = self.conv3(out) 261 | out = self.bn3(out) 262 | 263 | if self.downsample is not None: 264 | identity = self.downsample(x) 265 | 266 | out += identity 267 | out = self.relu(out) 268 | 269 | return out 270 | 271 | 272 | 273 | 274 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 275 | model = ResNet(block, layers, **kwargs) 276 | if pretrained: 277 | state_dict = load_state_dict_from_url(model_urls[arch], 278 | progress=progress) 279 | model.load_state_dict(state_dict) 280 | return model 281 | 282 | 283 | def resnet18(pretrained=False, progress=True, **kwargs): 284 | """Constructs a ResNet-18 model. 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 291 | **kwargs) 292 | 293 | 294 | def resnet34(pretrained=False, progress=True, **kwargs): 295 | """Constructs a ResNet-34 model. 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 302 | **kwargs) 303 | 304 | 305 | def resnet50(pretrained=False, progress=True, **kwargs): 306 | """Constructs a ResNet-50 model. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 313 | **kwargs) 314 | 315 | 316 | def resnet101(pretrained=False, progress=True, **kwargs): 317 | """Constructs a ResNet-101 model. 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 324 | **kwargs) 325 | 326 | 327 | def resnet152(pretrained=False, progress=True, **kwargs): 328 | """Constructs a ResNet-152 model. 329 | 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 335 | **kwargs) 336 | 337 | 338 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 339 | """Constructs a ResNeXt-50 32x4d model. 340 | 341 | Args: 342 | pretrained (bool): If True, returns a model pre-trained on ImageNet 343 | progress (bool): If True, displays a progress bar of the download to stderr 344 | """ 345 | kwargs['groups'] = 32 346 | kwargs['width_per_group'] = 4 347 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 348 | pretrained, progress, **kwargs) 349 | 350 | 351 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 352 | """Constructs a ResNeXt-101 32x8d model. 353 | 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | progress (bool): If True, displays a progress bar of the download to stderr 357 | """ 358 | kwargs['groups'] = 32 359 | kwargs['width_per_group'] = 8 360 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 361 | pretrained, progress, **kwargs) 362 | 363 | 364 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 365 | """Constructs a Wide ResNet-50-2 model. 366 | 367 | The model is the same as ResNet except for the bottleneck number of channels 368 | which is twice larger in every block. The number of channels in outer 1x1 369 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 370 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 371 | 372 | Args: 373 | pretrained (bool): If True, returns a model pre-trained on ImageNet 374 | progress (bool): If True, displays a progress bar of the download to stderr 375 | """ 376 | kwargs['width_per_group'] = 64 * 2 377 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 378 | pretrained, progress, **kwargs) 379 | 380 | 381 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 382 | """Constructs a Wide ResNet-101-2 model. 383 | 384 | The model is the same as ResNet except for the bottleneck number of channels 385 | which is twice larger in every block. The number of channels in outer 1x1 386 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 387 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 388 | 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | kwargs['width_per_group'] = 64 * 2 394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 395 | pretrained, progress, **kwargs) 396 | -------------------------------------------------------------------------------- /dataloaders/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import math 5 | import numpy as np 6 | import random 7 | import torch 8 | import torchvision.io as io 9 | 10 | 11 | def temporal_sampling(frames, start_idx, end_idx, num_samples): 12 | """ 13 | Given the start and end frame index, sample num_samples frames between 14 | the start and end with equal interval. 15 | Args: 16 | frames (tensor): a tensor of video frames, dimension is 17 | `num video frames` x `channel` x `height` x `width`. 18 | start_idx (int): the index of the start frame. 19 | end_idx (int): the index of the end frame. 20 | num_samples (int): number of frames to sample. 21 | Returns: 22 | frames (tersor): a tensor of temporal sampled video frames, dimension is 23 | `num clip frames` x `channel` x `height` x `width`. 24 | """ 25 | index = torch.linspace(start_idx, end_idx, num_samples) 26 | index = torch.clamp(index, 0, frames.shape[0] - 1).long() 27 | frames = torch.index_select(frames, 0, index) 28 | return frames 29 | 30 | 31 | def get_start_end_idx( 32 | video_size, clip_size, clip_idx, num_clips, use_offset=False 33 | ): 34 | """ 35 | Sample a clip of size clip_size from a video of size video_size and 36 | return the indices of the first and last frame of the clip. If clip_idx is 37 | -1, the clip is randomly sampled, otherwise uniformly split the video to 38 | num_clips clips, and select the start and end index of clip_idx-th video 39 | clip. 40 | Args: 41 | video_size (int): number of overall frames. 42 | clip_size (int): size of the clip to sample from the frames. 43 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If 44 | clip_idx is larger than -1, uniformly split the video to num_clips 45 | clips, and select the start and end index of the clip_idx-th video 46 | clip. 47 | num_clips (int): overall number of clips to uniformly sample from the 48 | given video for testing. 49 | Returns: 50 | start_idx (int): the start frame index. 51 | end_idx (int): the end frame index. 52 | """ 53 | delta = max(video_size - clip_size, 0) 54 | if clip_idx == -1: 55 | # Random temporal sampling. 56 | start_idx = random.uniform(0, delta) 57 | else: 58 | if use_offset: 59 | if num_clips == 1: 60 | # Take the center clip if num_clips is 1. 61 | start_idx = math.floor(delta / 2) 62 | else: 63 | # Uniformly sample the clip with the given index. 64 | start_idx = clip_idx * math.floor(delta / (num_clips - 1)) 65 | else: 66 | # Uniformly sample the clip with the given index. 67 | start_idx = delta * clip_idx / num_clips 68 | end_idx = start_idx + clip_size - 1 69 | return start_idx, end_idx 70 | 71 | 72 | def pyav_decode_stream( 73 | container, start_pts, end_pts, stream, stream_name, buffer_size=0 74 | ): 75 | """ 76 | Decode the video with PyAV decoder. 77 | Args: 78 | container (container): PyAV container. 79 | start_pts (int): the starting Presentation TimeStamp to fetch the 80 | video frames. 81 | end_pts (int): the ending Presentation TimeStamp of the decoded frames. 82 | stream (stream): PyAV stream. 83 | stream_name (dict): a dictionary of streams. For example, {"video": 0} 84 | means video stream at stream index 0. 85 | buffer_size (int): number of additional frames to decode beyond end_pts. 86 | Returns: 87 | result (list): list of frames decoded. 88 | max_pts (int): max Presentation TimeStamp of the video sequence. 89 | """ 90 | # Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a 91 | # margin pts. 92 | margin = 1024 93 | seek_offset = max(start_pts - margin, 0) 94 | 95 | container.seek(seek_offset, any_frame=False, backward=True, stream=stream) 96 | frames = {} 97 | buffer_count = 0 98 | max_pts = 0 99 | for frame in container.decode(**stream_name): 100 | max_pts = max(max_pts, frame.pts) 101 | if frame.pts < start_pts: 102 | continue 103 | if frame.pts <= end_pts: 104 | frames[frame.pts] = frame 105 | else: 106 | buffer_count += 1 107 | frames[frame.pts] = frame 108 | if buffer_count >= buffer_size: 109 | break 110 | result = [frames[pts] for pts in sorted(frames)] 111 | return result, max_pts 112 | 113 | 114 | def torchvision_decode( 115 | video_handle, 116 | sampling_rate, 117 | num_frames, 118 | clip_idx, 119 | video_meta, 120 | num_clips=10, 121 | target_fps=30, 122 | modalities=("visual",), 123 | max_spatial_scale=0, 124 | use_offset=False, 125 | ): 126 | """ 127 | If video_meta is not empty, perform temporal selective decoding to sample a 128 | clip from the video with TorchVision decoder. If video_meta is empty, decode 129 | the entire video and update the video_meta. 130 | Args: 131 | video_handle (bytes): raw bytes of the video file. 132 | sampling_rate (int): frame sampling rate (interval between two sampled 133 | frames). 134 | num_frames (int): number of frames to sample. 135 | clip_idx (int): if clip_idx is -1, perform random temporal 136 | sampling. If clip_idx is larger than -1, uniformly split the 137 | video to num_clips clips, and select the clip_idx-th video clip. 138 | video_meta (dict): a dict contains VideoMetaData. Details can be found 139 | at `pytorch/vision/torchvision/io/_video_opt.py`. 140 | num_clips (int): overall number of clips to uniformly sample from the 141 | given video. 142 | target_fps (int): the input video may has different fps, convert it to 143 | the target video fps. 144 | modalities (tuple): tuple of modalities to decode. Currently only 145 | support `visual`, planning to support `acoustic` soon. 146 | max_spatial_scale (int): the maximal resolution of the spatial shorter 147 | edge size during decoding. 148 | Returns: 149 | frames (tensor): decoded frames from the video. 150 | fps (float): the number of frames per second of the video. 151 | decode_all_video (bool): if True, the entire video was decoded. 152 | """ 153 | # Convert the bytes to a tensor. 154 | video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8)) 155 | 156 | decode_all_video = True 157 | video_start_pts, video_end_pts = 0, -1 158 | # The video_meta is empty, fetch the meta data from the raw video. 159 | if len(video_meta) == 0: 160 | # Tracking the meta info for selective decoding in the future. 161 | meta = io._probe_video_from_memory(video_tensor) 162 | # Using the information from video_meta to perform selective decoding. 163 | video_meta["video_timebase"] = meta.video_timebase 164 | video_meta["video_numerator"] = meta.video_timebase.numerator 165 | video_meta["video_denominator"] = meta.video_timebase.denominator 166 | video_meta["has_video"] = meta.has_video 167 | video_meta["video_duration"] = meta.video_duration 168 | video_meta["video_fps"] = meta.video_fps 169 | video_meta["audio_timebas"] = meta.audio_timebase 170 | video_meta["audio_numerator"] = meta.audio_timebase.numerator 171 | video_meta["audio_denominator"] = meta.audio_timebase.denominator 172 | video_meta["has_audio"] = meta.has_audio 173 | video_meta["audio_duration"] = meta.audio_duration 174 | video_meta["audio_sample_rate"] = meta.audio_sample_rate 175 | 176 | fps = video_meta["video_fps"] 177 | if ( 178 | video_meta["has_video"] 179 | and video_meta["video_denominator"] > 0 180 | and video_meta["video_duration"] > 0 181 | ): 182 | # try selective decoding. 183 | decode_all_video = False 184 | clip_size = sampling_rate * num_frames / target_fps * fps 185 | start_idx, end_idx = get_start_end_idx( 186 | fps * video_meta["video_duration"], 187 | clip_size, 188 | clip_idx, 189 | num_clips, 190 | use_offset=use_offset, 191 | ) 192 | # Convert frame index to pts. 193 | pts_per_frame = video_meta["video_denominator"] / fps 194 | video_start_pts = int(start_idx * pts_per_frame) 195 | video_end_pts = int(end_idx * pts_per_frame) 196 | 197 | # Decode the raw video with the tv decoder. 198 | v_frames, _ = io._read_video_from_memory( 199 | video_tensor, 200 | seek_frame_margin=1.0, 201 | read_video_stream="visual" in modalities, 202 | video_width=0, 203 | video_height=0, 204 | video_min_dimension=max_spatial_scale, 205 | video_pts_range=(video_start_pts, video_end_pts), 206 | video_timebase_numerator=video_meta["video_numerator"], 207 | video_timebase_denominator=video_meta["video_denominator"], 208 | ) 209 | 210 | if v_frames.shape == torch.Size([0]): 211 | # failed selective decoding 212 | decode_all_video = True 213 | video_start_pts, video_end_pts = 0, -1 214 | v_frames, _ = io._read_video_from_memory( 215 | video_tensor, 216 | seek_frame_margin=1.0, 217 | read_video_stream="visual" in modalities, 218 | video_width=0, 219 | video_height=0, 220 | video_min_dimension=max_spatial_scale, 221 | video_pts_range=(video_start_pts, video_end_pts), 222 | video_timebase_numerator=video_meta["video_numerator"], 223 | video_timebase_denominator=video_meta["video_denominator"], 224 | ) 225 | 226 | return v_frames, fps, decode_all_video 227 | 228 | 229 | def pyav_decode( 230 | container, 231 | sampling_rate, 232 | num_frames, 233 | clip_idx, 234 | num_clips=10, 235 | target_fps=30, 236 | use_offset=False, 237 | ): 238 | """ 239 | Convert the video from its original fps to the target_fps. If the video 240 | support selective decoding (contain decoding information in the video head), 241 | the perform temporal selective decoding and sample a clip from the video 242 | with the PyAV decoder. If the video does not support selective decoding, 243 | decode the entire video. 244 | Args: 245 | container (container): pyav container. 246 | sampling_rate (int): frame sampling rate (interval between two sampled 247 | frames. 248 | num_frames (int): number of frames to sample. 249 | clip_idx (int): if clip_idx is -1, perform random temporal sampling. If 250 | clip_idx is larger than -1, uniformly split the video to num_clips 251 | clips, and select the clip_idx-th video clip. 252 | num_clips (int): overall number of clips to uniformly sample from the 253 | given video. 254 | target_fps (int): the input video may has different fps, convert it to 255 | the target video fps before frame sampling. 256 | Returns: 257 | frames (tensor): decoded frames from the video. Return None if the no 258 | video stream was found. 259 | fps (float): the number of frames per second of the video. 260 | decode_all_video (bool): If True, the entire video was decoded. 261 | """ 262 | # Try to fetch the decoding information from the video head. Some of the 263 | # videos does not support fetching the decoding information, for that case 264 | # it will get None duration. 265 | fps = float(container.streams.video[0].average_rate) 266 | frames_length = container.streams.video[0].frames 267 | duration = container.streams.video[0].duration 268 | 269 | if duration is None: 270 | # If failed to fetch the decoding information, decode the entire video. 271 | decode_all_video = True 272 | video_start_pts, video_end_pts = 0, math.inf 273 | else: 274 | # Perform selective decoding. 275 | decode_all_video = False 276 | start_idx, end_idx = get_start_end_idx( 277 | frames_length, 278 | sampling_rate * num_frames / target_fps * fps, 279 | clip_idx, 280 | num_clips, 281 | use_offset=use_offset, 282 | ) 283 | timebase = duration / frames_length 284 | video_start_pts = int(start_idx * timebase) 285 | video_end_pts = int(end_idx * timebase) 286 | 287 | frames = None 288 | # If video stream was found, fetch video frames from the video. 289 | if container.streams.video: 290 | video_frames, max_pts = pyav_decode_stream( 291 | container, 292 | video_start_pts, 293 | video_end_pts, 294 | container.streams.video[0], 295 | {"video": 0}, 296 | ) 297 | container.close() 298 | 299 | frames = [frame.to_rgb().to_ndarray() for frame in video_frames] 300 | 301 | 302 | frames = torch.as_tensor(np.stack(frames)) 303 | return frames, fps, decode_all_video 304 | 305 | 306 | def decode( 307 | container, 308 | sampling_rate, 309 | num_frames, 310 | clip_idx=-1, 311 | num_clips=10, 312 | video_meta=None, 313 | target_fps=30, 314 | backend="pyav", 315 | max_spatial_scale=0, 316 | use_offset=False, 317 | ): 318 | """ 319 | Decode the video and perform temporal sampling. 320 | Args: 321 | container (container): pyav container. 322 | sampling_rate (int): frame sampling rate (interval between two sampled 323 | frames). 324 | num_frames (int): number of frames to sample. 325 | clip_idx (int): if clip_idx is -1, perform random temporal 326 | sampling. If clip_idx is larger than -1, uniformly split the 327 | video to num_clips clips, and select the 328 | clip_idx-th video clip. 329 | num_clips (int): overall number of clips to uniformly 330 | sample from the given video. 331 | video_meta (dict): a dict contains VideoMetaData. Details can be find 332 | at `pytorch/vision/torchvision/io/_video_opt.py`. 333 | target_fps (int): the input video may have different fps, convert it to 334 | the target video fps before frame sampling. 335 | backend (str): decoding backend includes `pyav` and `torchvision`. The 336 | default one is `pyav`. 337 | max_spatial_scale (int): keep the aspect ratio and resize the frame so 338 | that shorter edge size is max_spatial_scale. Only used in 339 | `torchvision` backend. 340 | Returns: 341 | frames (tensor): decoded frames from the video. 342 | """ 343 | # Currently support two decoders: 1) PyAV, and 2) TorchVision. 344 | assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx) 345 | try: 346 | 347 | if backend == "pyav": 348 | frames, fps, decode_all_video = pyav_decode( 349 | container, 350 | sampling_rate, 351 | num_frames, 352 | clip_idx, 353 | num_clips, 354 | target_fps, 355 | use_offset=use_offset, 356 | ) 357 | from ipdb import set_trace 358 | total_len = len(frames) / fps 359 | elif backend == "torchvision": 360 | frames, fps, decode_all_video = torchvision_decode( 361 | container, 362 | sampling_rate, 363 | num_frames, 364 | clip_idx, 365 | video_meta, 366 | num_clips, 367 | target_fps, 368 | ("visual",), 369 | max_spatial_scale, 370 | use_offset=use_offset, 371 | ) 372 | 373 | else: 374 | raise NotImplementedError( 375 | "Unknown decoding backend {}".format(backend) 376 | ) 377 | except Exception as e: 378 | print("Failed to decode by {} with exception: {}".format(backend, e)) 379 | return None 380 | 381 | # Return None if the frames was not decoded successfully. 382 | if frames is None or frames.size(0) == 0: 383 | return None 384 | 385 | clip_sz = sampling_rate * num_frames / target_fps * fps 386 | start_idx, end_idx = get_start_end_idx( 387 | frames.shape[0], 388 | clip_sz, 389 | clip_idx if decode_all_video else 0, 390 | num_clips if decode_all_video else 1, 391 | use_offset=use_offset, 392 | ) 393 | # Perform temporal sampling from the decoded video. 394 | 395 | frames = temporal_sampling(frames, start_idx, end_idx, num_frames) 396 | return frames,total_len --------------------------------------------------------------------------------