├── modules
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── modeling.cpython-38.pyc
│ ├── ast_models.cpython-38.pyc
│ ├── file_utils.cpython-38.pyc
│ ├── module_clip.cpython-38.pyc
│ ├── module_cross.cpython-38.pyc
│ ├── optimization.cpython-38.pyc
│ ├── resnet_models.cpython-38.pyc
│ ├── until_config.cpython-38.pyc
│ ├── until_module.cpython-38.pyc
│ └── tokenization_clip.cpython-38.pyc
├── cross-base
│ └── cross_config.json
├── until_config.py
├── tokenization_clip.py
├── optimization.py
├── file_utils.py
├── module_cross.py
├── ast_models.py
├── until_module.py
└── resnet_models.py
├── easy_kill_process.sh
├── dataloaders
├── __pycache__
│ ├── utils.cpython-38.pyc
│ ├── decoder.cpython-38.pyc
│ ├── transform.cpython-38.pyc
│ ├── mel_features.cpython-38.pyc
│ ├── rawvideo_util.cpython-38.pyc
│ ├── vggish_params.cpython-38.pyc
│ ├── video_container.cpython-38.pyc
│ ├── data_dataloaders.cpython-38.pyc
│ ├── dataloader_msvd_retrieval.cpython-38.pyc
│ ├── dataloader_didemo_retrieval.cpython-38.pyc
│ ├── dataloader_lsmdc_retrieval.cpython-38.pyc
│ ├── dataloader_msrvtt_retrieval.cpython-38.pyc
│ ├── dataloader_charades_retrieval.cpython-38.pyc
│ ├── dataloader_youcook_retrieval.cpython-38.pyc
│ ├── dataloader_activitynet_retrieval.cpython-38.pyc
│ ├── dataloader_qvhighlight_retrieval.cpython-38.pyc
│ ├── dataloader_youcook_short_retrieval.cpython-38.pyc
│ └── dataloader_activitynet_retrieval_demo.cpython-38.pyc
├── video_container.py
├── vggish_params.py
├── rawvideo_util.py
├── dataloader_lsmdc_retrieval.py
├── mel_features.py
├── dataloader_msvd_retrieval.py
├── utils.py
└── decoder.py
├── torchvggish
├── __pycache__
│ ├── vggish.cpython-38.pyc
│ ├── mel_features.cpython-38.pyc
│ ├── vggish_input.cpython-38.pyc
│ └── vggish_params.cpython-38.pyc
├── vggish_params.py
├── vggish_input.py
├── vggish.py
└── mel_features.py
├── preprocess.sh
├── make_vid.sh
├── requirements.txt
├── run_didemo.sh
├── run_audio-only.sh
├── run_qvh.sh
├── comp_flops.sh
├── run_cha.sh
├── debug.sh
├── run_yc2.sh
├── run_act-B16.sh
├── run_act-B32.sh
├── LICENSE
├── preproc.sh
├── feature_ext.sh
├── util.py
├── README.md
├── metrics.py
└── preprocess
└── compress_video.py
/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/easy_kill_process.sh:
--------------------------------------------------------------------------------
1 | kill $(ps aux | grep "main_task_retrieval.py" | grep -v grep | awk '{print $2}')
2 |
--------------------------------------------------------------------------------
/modules/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/dataloaders/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/modeling.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/modeling.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/ast_models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/ast_models.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/file_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/file_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/module_clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/module_clip.cpython-38.pyc
--------------------------------------------------------------------------------
/torchvggish/__pycache__/vggish.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/transform.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/module_cross.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/module_cross.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/optimization.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/optimization.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/resnet_models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/resnet_models.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/until_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/until_config.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/until_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/until_module.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/mel_features.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/mel_features.cpython-38.pyc
--------------------------------------------------------------------------------
/torchvggish/__pycache__/mel_features.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/mel_features.cpython-38.pyc
--------------------------------------------------------------------------------
/torchvggish/__pycache__/vggish_input.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish_input.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/rawvideo_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/rawvideo_util.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/vggish_params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/vggish_params.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/video_container.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/video_container.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/tokenization_clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/modules/__pycache__/tokenization_clip.cpython-38.pyc
--------------------------------------------------------------------------------
/preprocess.sh:
--------------------------------------------------------------------------------
1 | python preprocess/compress_video.py --input_root /playpen-iop/yblin/AVE_Dataset/AVE \
2 | --output_root /playpen-iop/yblin/AVE_Dataset/raw_audio
--------------------------------------------------------------------------------
/torchvggish/__pycache__/vggish_params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/torchvggish/__pycache__/vggish_params.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/data_dataloaders.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/data_dataloaders.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_charades_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_charades_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_youcook_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_youcook_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_qvhighlight_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_qvhighlight_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_youcook_short_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_youcook_short_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/dataloader_activitynet_retrieval_demo.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GenjiB/ECLIPSE/HEAD/dataloaders/__pycache__/dataloader_activitynet_retrieval_demo.cpython-38.pyc
--------------------------------------------------------------------------------
/make_vid.sh:
--------------------------------------------------------------------------------
1 | ffmpeg -r 30 -i /playpen-iop/yblin/act_val_att/v_iosb2TdQ7yY/%04d.jpg -i \
2 | /playpen-storage/yblin/v1-2/audio_raw/v_iosb2TdQ7yY/0000.wav -strict -2 /playpen-iop/yblin/act_val_att/"v_iosb2TdQ7yY"_"att".mp4
3 |
--------------------------------------------------------------------------------
/modules/cross-base/cross_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 512,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 2048,
8 | "max_position_embeddings": 300,
9 | "num_attention_heads": 8,
10 | "num_hidden_layers": 4,
11 | "vocab_size": 512
12 | }
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | av==8.0.3
2 | boto3==1.18.51
3 | botocore==1.21.51
4 | deepspeed==0.5.10
5 | einops==0.3.2
6 | ffmpeg==1.4
7 | ftfy==6.0.3
8 | gpuinfo==1.0.0a7
9 | ipdb==0.13.9
10 | jsonlines==3.0.0
11 | numpy==1.20.3
12 | opencv_python==4.5.3.56
13 | pandas==1.3.3
14 | Pillow==9.2.0
15 | psutil==5.8.0
16 | regex==2021.9.24
17 | requests==2.26.0
18 | resampy==0.2.2
19 | scipy==1.7.1
20 | SoundFile==0.10.3.post1
21 | thop==0.0.31.post2005241907
22 | timm==0.4.5
23 | torch==1.7.1+cu110
24 | torchaudio==0.7.2
25 | torchvision==0.8.2+cu110
26 | tqdm==4.62.3
27 | wget==3.2
28 |
--------------------------------------------------------------------------------
/run_didemo.sh:
--------------------------------------------------------------------------------
1 | # bash mykill.sh
2 | DATA_PATH=YOUR_PATH
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 |
5 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=30678 \
6 | main_task_retrieval.py --do_train --num_thread_reader=6 \
7 | --epochs=5 --batch_size=128 --n_display=50 \
8 | --data_path ${DATA_PATH} \
9 | --features_path ${DATA_PATH}/raw_video \
10 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \
11 | --lr 1e-4 --coef_lr 5e-3 --yb_factor_aduio 0.01 --yb_factor_audio_decay 0 --max_words 64 --max_frames 32 --max_audio_frames 32 \
12 | --batch_size_val 16 \
13 | --datatype didemo --feature_framerate 1 \
14 | --freeze_layer_num 0 --slice_framepos 2 \
15 | --loose_type --linear_patch 2d --sim_header meanP \
16 | --yb_dual 0 --yb_av 1 --wandb 0 --model_name ECLIPSE_didemo_f32_B32 \
17 | --pretrained_clip_name ViT-B/32
18 |
--------------------------------------------------------------------------------
/run_audio-only.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=/playpen-iop/yblin/didemo_data
2 | export CUDA_VISIBLE_DEVICES=4,5,6,7
3 |
4 |
5 |
6 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=9864 \
7 | main_task_retrieval.py --do_train --num_thread_reader=8 \
8 | --epochs=20 --batch_size=64 --n_display=50 \
9 | --data_path ${DATA_PATH} \
10 | --features_path ${DATA_PATH}/raw_video \
11 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \
12 | --lr 1e-4 --coef_lr 1e-3 --yb_factor_aduio 0.5 --yb_factor_audio_decay 0.6 --max_words 64 --max_frames 8 --max_audio_frames 8 --batch_size_val 16 \
13 | --datatype CUSTOM_DATASET --feature_framerate 1 \
14 | --freeze_layer_num 0 --slice_framepos 2 \
15 | --loose_type --linear_patch 2d --sim_header meanP \
16 | --yb_dual 1 --yb_av 1 --wandb 1 --model_name Rebuttal_qvh_ours-8f \
17 | --pretrained_clip_name ViT-B/32
--------------------------------------------------------------------------------
/run_qvh.sh:
--------------------------------------------------------------------------------
1 | # bash mykill.sh
2 | DATA_PATH=YOUR_PATH
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 |
5 |
6 |
7 |
8 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=5278 \
9 | main_task_retrieval.py --do_train --num_thread_reader=8 \
10 | --epochs=20 --batch_size=64 --n_display=50 \
11 | --data_path ${DATA_PATH} \
12 | --features_path ${DATA_PATH}/raw_video \
13 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \
14 | --lr 1e-4 --coef_lr 1e-3 --yb_factor_aduio 0.5 --yb_factor_audio_decay 0.6 --max_words 128 --max_frames 32 \
15 | --max_audio_frames 32 --batch_size_val 16 \
16 | --datatype qvhighlight --feature_framerate 1 \
17 | --freeze_layer_num 0 --slice_framepos 2 \
18 | --loose_type --linear_patch 2d --sim_header meanP \
19 | --yb_dual 0 --yb_av 1 --wandb 0 --model_name ECLIPSE_qvh_f32_B32 \
20 | --pretrained_clip_name ViT-B/32 --audio_pt VGGSound_Audio_features_10s_aligned
21 |
--------------------------------------------------------------------------------
/comp_flops.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=/playpen-iop/yblin/v1-2
2 | export CUDA_VISIBLE_DEVICES=4
3 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=30678 \
4 | compute_cost.py --do_train --num_thread_reader=0 \
5 | --epochs=20 --batch_size=64 --n_display=50 \
6 | --data_path ${DATA_PATH} \
7 | --features_path ${DATA_PATH}/Activity_Videos \
8 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
9 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 96 --max_audio_frames 96 \
10 | --audio_cluster 4 --batch_size_val 1 --gradient_accumulation_steps 1 \
11 | --datatype activity --feature_framerate 1 \
12 | --freeze_layer_num 0 --slice_framepos 2 \
13 | --loose_type --linear_patch 2d --sim_header meanP \
14 | --yb_coff_dis 0 --yb_coff_loss 1 --yb_factor_1 0.001 --yb_audio_length 10 \
15 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name VGGSound_10s_co-Res_-5block_baseline_16 \
16 | --audio_pt VGGSound_Audio_features_10s_aligned
17 |
--------------------------------------------------------------------------------
/run_cha.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=YOUR_PATH
2 | export CUDA_VISIBLE_DEVICES=0,1,2,3
3 |
4 | python3 -m torch.distributed.launch --nproc_per_node=8 --master_port=30678 \
5 | main_task_retrieval.py --do_train --num_thread_reader=8 \
6 | --epochs=20 --batch_size=64 --n_display=50 \
7 | --data_path ${DATA_PATH} \
8 | --features_path ${DATA_PATH}/Activity_Videos \
9 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
10 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 32 --max_audio_frames 32 \
11 | --audio_cluster 4 --batch_size_val 16 --gradient_accumulation_steps 1 \
12 | --datatype charades --feature_framerate 1 \
13 | --freeze_layer_num 0 --slice_framepos 2 \
14 | --loose_type --linear_patch 2d --sim_header meanP \
15 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 50 --yb_factor_audio_decay 0.6 --yb_audio_length 10 \
16 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_cha_f32_B32 \
17 | --audio_pt VGGSound_Audio_features_10s_aligned
18 |
19 |
20 |
--------------------------------------------------------------------------------
/debug.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=/playpen-iop/yblin/v1-2
2 | export CUDA_VISIBLE_DEVICES=4
3 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=5277 \
4 | main_task_retrieval.py --do_train --num_thread_reader=0 \
5 | --epochs=1 --batch_size=64 --n_display=50 \
6 | --data_path ${DATA_PATH} \
7 | --features_path ${DATA_PATH}/Activity_Videos \
8 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
9 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 2 --max_audio_frames 2 \
10 | --audio_cluster 16 --batch_size_val 1 --gradient_accumulation_steps 1 \
11 | --datatype qvhighlight --feature_framerate 1 \
12 | --freeze_layer_num 0 --slice_framepos 2 \
13 | --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 1 --yb_dual 0 --yb_av 0 \
14 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \
15 | --audio_pt VGGSound_Audio_features_20s_aligned
16 | # --audio_pt Audio_features_all_10s_aligned_stride_16
17 |
18 |
--------------------------------------------------------------------------------
/run_yc2.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=YOUR_PATH
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 |
5 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=30668 \
6 | main_task_retrieval.py --do_train --num_thread_reader=8 \
7 | --epochs=20 --batch_size=64 --n_display=50 \
8 | --data_path ${DATA_PATH} \
9 | --features_path ${DATA_PATH}/Activity_Videos \
10 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
11 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \
12 | --audio_cluster 4 --batch_size_val 64 --gradient_accumulation_steps 1 \
13 | --datatype youcook --feature_framerate 1 \
14 | --freeze_layer_num 0 --slice_framepos 2 \
15 | --loose_type --linear_patch 2d --sim_header meanP \
16 | --yb_coff_dis 0 --yb_coff_loss 1 --yb_av 1 --yb_dual 1 --yb_factor_aduio 1 --yb_factor_audio_decay 0.1 --yb_audio_length 10 \
17 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_yc2_f32_B32 \
18 | --audio_pt VGGSound_Audio_features_10s_aligned
19 |
--------------------------------------------------------------------------------
/run_act-B16.sh:
--------------------------------------------------------------------------------
1 | # bash mykill.sh
2 | DATA_PATH=YOUR_DATA_PATH
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
4 |
5 |
6 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=7414 \
7 | main_task_retrieval.py --do_train --num_thread_reader=6 \
8 | --epochs=40 --batch_size=64 --n_display=50 \
9 | --data_path ${DATA_PATH} \
10 | --features_path ${DATA_PATH}/Activity_Videos \
11 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
12 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \
13 | --audio_cluster 4 --batch_size_val 8 --gradient_accumulation_steps 1 \
14 | --datatype activity --feature_framerate 1 \
15 | --freeze_layer_num 0 --slice_framepos 2 \
16 | --loose_type --linear_patch 2d --sim_header meanP \
17 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 0.5 --yb_time_cross_audio 0 --yb_factor_audio_decay 0.2 --yb_reverse_norm 1 \
18 | --yb_audio_length 10 \
19 | --pretrained_clip_name ViT-B/16 --wandb 1 --fixed_length 8 --model_name ECLIPSE_act_f32_B16 \
20 | --audio_pt VGGSound_Audio_features_10s_aligned
21 |
22 |
23 |
--------------------------------------------------------------------------------
/run_act-B32.sh:
--------------------------------------------------------------------------------
1 | # bash mykill.sh
2 | # train on 4 gpus
3 | DATA_PATH=YOUR_DATA_PATH
4 | export CUDA_VISIBLE_DEVICES=0,1,2,3
5 |
6 |
7 | python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=7414 \
8 | main_task_retrieval.py --do_train --num_thread_reader=6 \
9 | --epochs=40 --batch_size=64 --n_display=50 \
10 | --data_path ${DATA_PATH} \
11 | --features_path ${DATA_PATH}/Activity_Videos \
12 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
13 | --lr 1e-4 --coef_lr 1e-3 --max_words 128 --max_frames 32 --max_audio_frames 32 \
14 | --audio_cluster 4 --batch_size_val 8 --gradient_accumulation_steps 1 \
15 | --datatype activity --feature_framerate 1 \
16 | --freeze_layer_num 0 --slice_framepos 2 \
17 | --loose_type --linear_patch 2d --sim_header meanP \
18 | --yb_av 1 --yb_dual 1 --yb_factor_aduio 0.5 --yb_time_cross_audio 0 --yb_factor_audio_decay 0.2 --yb_reverse_norm 1 \
19 | --yb_audio_length 10 \
20 | --pretrained_clip_name ViT-B/32 --wandb 1 --fixed_length 8 --model_name ECLIPSE_act_f32_B32 \
21 | --audio_pt VGGSound_Audio_features_10s_aligned
22 |
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Yan-Bo Lin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/dataloaders/video_container.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | import av
5 |
6 |
7 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
8 | """
9 | Given the path to the video, return the pyav video container.
10 | Args:
11 | path_to_vid (str): path to the video.
12 | multi_thread_decode (bool): if True, perform multi-thread decoding.
13 | backend (str): decoder backend, options include `pyav` and
14 | `torchvision`, default is `pyav`.
15 | Returns:
16 | container (container): video container.
17 | """
18 | if backend == "torchvision":
19 | with open(path_to_vid, "rb") as fp:
20 | container = fp.read()
21 | return container
22 | elif backend == "pyav":
23 | container = av.open(path_to_vid)
24 | if multi_thread_decode:
25 | # Enable multiple threads for decoding.
26 | container.streams.video[0].thread_type = "AUTO"
27 | return container
28 | else:
29 | raise NotImplementedError("Unknown backend {}".format(backend))
--------------------------------------------------------------------------------
/preproc.sh:
--------------------------------------------------------------------------------
1 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/training --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_train
2 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/validation --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_val
3 |
4 |
5 |
6 | python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train --output_root /playpen-iop/yblin/yk2/audio_raw_train
7 | python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val --output_root /playpen-iop/yblin/yk2/audio_raw_val
8 |
9 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_train --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train
10 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_scale_val --output_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val
11 |
12 |
13 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_train --output_root /playpen-iop/yblin/yk2/raw_videos_all/videos_frame_train
14 | # python3 preprocess/compress_video.py --input_root /playpen-iop/yblin/yk2/raw_videos_all/low_all_val --output_root /playpen-iop/yblin/yk2/raw_videos_all/videos_frame_val
--------------------------------------------------------------------------------
/feature_ext.sh:
--------------------------------------------------------------------------------
1 | DATA_PATH=/playpen-iop/yblin/qvhighlight
2 | DATA_PATH=/playpen-storage/yblin/v1-2
3 | export CUDA_VISIBLE_DEVICES=0
4 |
5 | # bash mykill.sh
6 |
7 | python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=7455 \
8 | yb_feature_ext.py --do_train --num_thread_reader=32 \
9 | --epochs=1 --batch_size=1 --batch_size_val 1 --n_display=50 \
10 | --data_path ${DATA_PATH} \
11 | --features_path ${DATA_PATH}/Activity_Videos \
12 | --output_dir ckpts/ckpt_activity_retrieval_looseType \
13 | --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 64 --max_audio_frames 32 \
14 | --audio_cluster 16 --gradient_accumulation_steps 1 \
15 | --datatype activity --feature_framerate 1 \
16 | --freeze_layer_num 0 --slice_framepos 2 \
17 | --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 10 \
18 | --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \
19 | --yb_av 1 --yb_dual 1 --yb_start_layer 6 \
20 | --audio_pt VGGSound_Audio_features_10s_aligned
21 |
22 |
23 |
24 | # DATA_PATH=/playpen-iop/yblin/v1-2
25 | # python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=7445 \
26 | # yb_feature_ext.py --do_train --num_thread_reader=0 \
27 | # --epochs=1 --batch_size=1 --n_display=50 \
28 | # --data_path ${DATA_PATH} \
29 | # --features_path ${DATA_PATH}/Activity_Videos \
30 | # --output_dir ckpts/ckpt_activity_retrieval_looseType \
31 | # --lr 1e-4 --coef_lr 1e-3 --max_words 64 --max_frames 2 --max_audio_frames 2 \
32 | # --audio_cluster 16 --batch_size_val 1 --gradient_accumulation_steps 1 \
33 | # --datatype youcook --feature_framerate 1 \
34 | # --freeze_layer_num 0 --slice_framepos 2 \
35 | # --loose_type --linear_patch 2d --sim_header meanP --yb_audio_length 10 \
36 | # --pretrained_clip_name ViT-B/32 --wandb 0 --fixed_length 8 --model_name DistKL_avblock_ACLS+baseline_8 \
37 | # --audio_pt VGGSound_Audio_features_10s_aligned
--------------------------------------------------------------------------------
/dataloaders/vggish_params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Global parameters for the VGGish model.
17 |
18 | See vggish_slim.py for more information.
19 | """
20 |
21 | # Architectural constants.
22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24 | EMBEDDING_SIZE = 128 # Size of embedding layer.
25 |
26 | # Hyperparameters used in feature and example generation.
27 | SAMPLE_RATE = 16000
28 | STFT_WINDOW_LENGTH_SECONDS = 0.025
29 | STFT_HOP_LENGTH_SECONDS = 0.010
30 | NUM_MEL_BINS = NUM_BANDS
31 | MEL_MIN_HZ = 125
32 | MEL_MAX_HZ = 7500
33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36 |
37 | # Parameters used for embedding postprocessing.
38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39 | PCA_MEANS_NAME = 'pca_means'
40 | QUANTIZE_MIN_VAL = -2.0
41 | QUANTIZE_MAX_VAL = +2.0
42 |
43 | # Hyperparameters used in training.
44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47 |
48 | # Names of ops, tensors, and features.
49 | INPUT_OP_NAME = 'vggish/input_features'
50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51 | OUTPUT_OP_NAME = 'vggish/embedding'
52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
54 |
--------------------------------------------------------------------------------
/torchvggish/vggish_params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Global parameters for the VGGish model.
17 |
18 | See vggish_slim.py for more information.
19 | """
20 |
21 | # Architectural constants.
22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24 | EMBEDDING_SIZE = 128 # Size of embedding layer.
25 |
26 | # Hyperparameters used in feature and example generation.
27 | SAMPLE_RATE = 16000
28 | STFT_WINDOW_LENGTH_SECONDS = 0.025
29 | STFT_HOP_LENGTH_SECONDS = 0.010
30 | NUM_MEL_BINS = NUM_BANDS
31 | MEL_MIN_HZ = 125
32 | MEL_MAX_HZ = 7500
33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36 |
37 | # Parameters used for embedding postprocessing.
38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39 | PCA_MEANS_NAME = 'pca_means'
40 | QUANTIZE_MIN_VAL = -2.0
41 | QUANTIZE_MAX_VAL = +2.0
42 |
43 | # Hyperparameters used in training.
44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47 |
48 | # Names of ops, tensors, and features.
49 | INPUT_OP_NAME = 'vggish/input_features'
50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51 | OUTPUT_OP_NAME = 'vggish/embedding'
52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
54 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import threading
4 | from torch._utils import ExceptionWrapper
5 | import logging
6 |
7 | def get_a_var(obj):
8 | if isinstance(obj, torch.Tensor):
9 | return obj
10 |
11 | if isinstance(obj, list) or isinstance(obj, tuple):
12 | for result in map(get_a_var, obj):
13 | if isinstance(result, torch.Tensor):
14 | return result
15 | if isinstance(obj, dict):
16 | for result in map(get_a_var, obj.items()):
17 | if isinstance(result, torch.Tensor):
18 | return result
19 | return None
20 |
21 | def parallel_apply(fct, model, inputs, device_ids):
22 | modules = nn.parallel.replicate(model, device_ids)
23 | assert len(modules) == len(inputs)
24 | lock = threading.Lock()
25 | results = {}
26 | grad_enabled = torch.is_grad_enabled()
27 |
28 | def _worker(i, module, input):
29 | torch.set_grad_enabled(grad_enabled)
30 | device = get_a_var(input).get_device()
31 | try:
32 | with torch.cuda.device(device):
33 | # this also avoids accidental slicing of `input` if it is a Tensor
34 | if not isinstance(input, (list, tuple)):
35 | input = (input,)
36 | output = fct(module, *input)
37 | with lock:
38 | results[i] = output
39 | except Exception:
40 | with lock:
41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
42 |
43 | if len(modules) > 1:
44 | threads = [threading.Thread(target=_worker, args=(i, module, input))
45 | for i, (module, input) in enumerate(zip(modules, inputs))]
46 |
47 | for thread in threads:
48 | thread.start()
49 | for thread in threads:
50 | thread.join()
51 | else:
52 | _worker(0, modules[0], inputs[0])
53 |
54 | outputs = []
55 | for i in range(len(inputs)):
56 | output = results[i]
57 | if isinstance(output, ExceptionWrapper):
58 | output.reraise()
59 | outputs.append(output)
60 | return outputs
61 |
62 | def get_logger(filename=None):
63 | logger = logging.getLogger('logger')
64 | logger.setLevel(logging.DEBUG)
65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
66 | datefmt='%m/%d/%Y %H:%M:%S',
67 | level=logging.INFO)
68 | if filename is not None:
69 | handler = logging.FileHandler(filename)
70 | handler.setLevel(logging.DEBUG)
71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
72 | logging.getLogger().addHandler(handler)
73 | return logger
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound
3 |
4 | [](https://opensource.org/licenses/MIT)
5 |
6 |
7 | This is the PyTorch implementation of our paper:
8 | **ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound**
9 | [Yan-Bo Lin](https://genjib.github.io/), [Jie Lei](https://jayleicn.github.io/), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/), and [Gedas Bertasius](https://www.gedasbertasius.com/)
10 | In European Conference on Computer Vision, 2022.
11 |
12 | [paper](https://arxiv.org/abs/2204.02874)
13 |
14 | ### 📝 Preparation
15 | 1. `pip3 install requirements.txt`
16 | 2. Dataset: ActivityNet, QVHighlights, YouCook2, DiDeMo and Charades.
17 | 3. extract video frames in 3 fps.
18 | 4. extract audio features.
19 | 5. To load pretrained CLIP weight
20 |
21 | The download links are from official [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip)
22 | Download CLIP (ViT-B/32) weight,
23 | ```sh
24 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
25 | ```
26 | or, download CLIP (ViT-B/16) weight,
27 | ```sh
28 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
29 | ```
30 |
31 |
32 | ### 💿 Extract images and audio features.
33 | ```shell
34 | ActivityNet/
35 | ├── raw_frames/
36 | │ └── VIDEO_NAME/
37 | │ ├── 0001.jpg
38 | │ ├── ...
39 | │ └── 00...jpg
40 | │
41 | └── VGGSound_Audio_features_10s_aligned/
42 | └── VIDEO_NAME/
43 | ├── 0000.pt
44 | ├── ...
45 | └── 00...pt
46 |
47 | ```
48 |
49 |
50 |
51 | ### 💿 Extracted audio features.
52 | VGGSound features on ActivityNet Captions: [Google Drive](https://drive.google.com/file/d/1PbZPrgO5HTuG_CORcS_zScQCUeFo1JOL/view?usp=sharing)
53 |
54 | ### 📚 Train and evaluate
55 | ActivityNet Captions: `bash run_act.sh` \
56 | DiDemo: `bash run_didemo.sh` \
57 | Charades: `bash run_cha.sh` \
58 | QVHighlight:`bash run_qvh.sh` \
59 | YouCook2: `bash run_yc2.sh`
60 |
61 |
62 |
63 |
64 | ### 🎓 Cite
65 |
66 | If you use this code in your research, please cite:
67 |
68 | ```bibtex
69 | @InProceedings{ECLIPSE_ECCV22,
70 | author = {Yan-Bo Lin and Jie Lei and Mohit Bansal and Gedas Bertasius},
71 | title = {ECLIPSE: Efficient Long-range Video Retrieval using Sight and Sound},
72 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
73 | month = {October},
74 | year = {2022}
75 | }
76 | ```
77 |
78 | ### 👍 Acknowledgments
79 | Our code is based on [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip) and [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/)
80 |
81 | ### ✏ Future works
82 | * Preprocessed video frames and audio features
83 |
84 |
85 | ## License
86 |
87 | This project is licensed under [MIT License](LICENSE), as found in the LICENSE file.
88 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import torch
8 |
9 | def compute_metrics(x):
10 | sx = np.sort(-x, axis=1)
11 | d = np.diag(-x)
12 | d = d[:, np.newaxis]
13 | ind = sx - d
14 | ind = np.where(ind == 0)
15 | ind = ind[1]
16 | metrics = {}
17 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
18 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
19 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
20 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind)
21 | metrics['MR'] = np.median(ind) + 1
22 | metrics["MedianR"] = metrics['MR']
23 | metrics["MeanR"] = np.mean(ind) + 1
24 | metrics["cols"] = [int(i) for i in list(ind)]
25 | return metrics
26 |
27 | def print_computed_metrics(metrics):
28 | r1 = metrics['R1']
29 | r5 = metrics['R5']
30 | r10 = metrics['R10']
31 | mr = metrics['MR']
32 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))
33 |
34 | # below two functions directly come from: https://github.com/Deferf/Experiments
35 | def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]):
36 | if not torch.is_tensor(sim_tensor):
37 | sim_tensor = torch.tensor(sim_tensor)
38 |
39 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices.
40 | # Then obtain the double argsort to position the rank on the diagonal
41 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2)
42 | first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True)
43 | second_argsort = torch.argsort(first_argsort, dim = -1, descending= False)
44 |
45 | # Extracts ranks i.e diagonals
46 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2))
47 |
48 | # Now we need to extract valid ranks, as some belong to inf padding values
49 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2))
50 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data))
51 | valid_ranks = ranks[mask]
52 | # A quick dimension check validates our results, there may be other correctness tests pending
53 | # Such as dot product localization, but that is for other time.
54 | #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict])
55 | if not torch.is_tensor(valid_ranks):
56 | valid_ranks = torch.tensor(valid_ranks)
57 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k}
58 | results["MedianR"] = float(torch.median(valid_ranks + 1))
59 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1))
60 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1))
61 | results['MR'] = results["MedianR"]
62 | return results
63 |
64 | def tensor_video_to_text_sim(sim_tensor):
65 | if not torch.is_tensor(sim_tensor):
66 | sim_tensor = torch.tensor(sim_tensor)
67 | # Code to avoid nans
68 | sim_tensor[sim_tensor != sim_tensor] = float('-inf')
69 | # Forms a similarity matrix for use with rank at k
70 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True)
71 | return torch.squeeze(values).T
72 |
--------------------------------------------------------------------------------
/torchvggish/vggish_input.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute input examples for VGGish from audio waveform."""
17 |
18 | # Modification: Return torch tensors rather than numpy arrays
19 | import torch
20 |
21 | import numpy as np
22 | import resampy
23 |
24 | from . import mel_features
25 | from . import vggish_params
26 |
27 | import soundfile as sf
28 |
29 | from ipdb import set_trace
30 |
31 | def waveform_to_examples(data, sample_rate, return_tensor=True):
32 | """Converts audio waveform into an array of examples for VGGish.
33 |
34 | Args:
35 | data: np.array of either one dimension (mono) or two dimensions
36 | (multi-channel, with the outer dimension representing channels).
37 | Each sample is generally expected to lie in the range [-1.0, +1.0],
38 | although this is not required.
39 | sample_rate: Sample rate of data.
40 | return_tensor: Return data as a Pytorch tensor ready for VGGish
41 |
42 | Returns:
43 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
44 | a sequence of examples, each of which contains a patch of log mel
45 | spectrogram, covering num_frames frames of audio and num_bands mel frequency
46 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
47 |
48 | """
49 | # Convert to mono.
50 | if len(data.shape) > 1:
51 | data = np.mean(data, axis=1)
52 | # Resample to the rate assumed by VGGish.
53 | if sample_rate != vggish_params.SAMPLE_RATE:
54 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
55 |
56 | set_trace()
57 | # Compute log mel spectrogram features.
58 | log_mel = mel_features.log_mel_spectrogram(
59 | data,
60 | audio_sample_rate=vggish_params.SAMPLE_RATE,
61 | log_offset=vggish_params.LOG_OFFSET,
62 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
63 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
64 | num_mel_bins=vggish_params.NUM_MEL_BINS,
65 | lower_edge_hertz=vggish_params.MEL_MIN_HZ,
66 | upper_edge_hertz=vggish_params.MEL_MAX_HZ)
67 |
68 | # Frame features into examples.
69 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
70 | example_window_length = int(round(
71 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
72 | example_hop_length = int(round(
73 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
74 | log_mel_examples = mel_features.frame(
75 | log_mel,
76 | window_length=example_window_length,
77 | hop_length=example_hop_length)
78 |
79 | if return_tensor:
80 | log_mel_examples = torch.tensor(
81 | log_mel_examples, requires_grad=True)[:, None, :, :].float()
82 |
83 | set_trace()
84 | return log_mel_examples
85 |
86 |
87 | def wavfile_to_examples(wav_file, return_tensor=True):
88 | """Convenience wrapper around waveform_to_examples() for a common WAV format.
89 |
90 | Args:
91 | wav_file: String path to a file, or a file-like object. The file
92 | is assumed to contain WAV audio data with signed 16-bit PCM samples.
93 | torch: Return data as a Pytorch tensor ready for VGGish
94 |
95 | Returns:
96 | See waveform_to_examples.
97 | """
98 | wav_data, sr = sf.read(wav_file, dtype='int16')
99 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
100 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
101 |
102 | # sr = 16000
103 | return waveform_to_examples(samples, sr, return_tensor)
104 |
--------------------------------------------------------------------------------
/dataloaders/rawvideo_util.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | from PIL import Image
4 | # pytorch=1.7.1
5 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
6 | # pip install opencv-python
7 | import cv2
8 | from ipdb import set_trace
9 |
10 | class RawVideoExtractorCV2():
11 | def __init__(self, centercrop=False, size=224, framerate=-1, ):
12 | self.centercrop = centercrop
13 | self.size = size
14 | # self.framerate = framerate
15 | self.framerate = 3
16 | self.transform = self._transform(self.size)
17 |
18 | def _transform(self, n_px):
19 | return Compose([
20 | Resize(n_px, interpolation=Image.BICUBIC),
21 | CenterCrop(n_px),
22 | lambda image: image.convert("RGB"),
23 | ToTensor(),
24 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
25 | ])
26 |
27 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None):
28 | if start_time is not None or end_time is not None:
29 | assert isinstance(start_time, int) and isinstance(end_time, int) \
30 | and start_time > -1 and end_time > start_time
31 | assert sample_fp > -1
32 |
33 | # Samples a frame sample_fp X frames.
34 | cap = cv2.VideoCapture(video_file)
35 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36 | fps = int(cap.get(cv2.CAP_PROP_FPS))
37 |
38 | # try:
39 | total_duration = (frameCount + fps - 1) // fps
40 | # print('right: ',fps)
41 | # except:
42 | # print('fpssssssssssssssssssssssss:', fps)
43 | # set_trace()
44 | start_sec, end_sec = 0, total_duration
45 |
46 | if start_time is not None:
47 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration
48 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps))
49 |
50 | interval = 1
51 | if sample_fp > 0:
52 | interval = fps // sample_fp
53 | else:
54 | sample_fp = fps
55 | if interval == 0: interval = 1
56 |
57 | inds = [ind for ind in np.arange(0, fps, interval)]
58 | assert len(inds) >= sample_fp
59 | inds = inds[:sample_fp]
60 |
61 | ret = True
62 | images, included = [], []
63 |
64 | for sec in np.arange(start_sec, end_sec + 1):
65 | if not ret: break
66 | sec_base = int(sec * fps)
67 | for ind in inds:
68 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind)
69 | ret, frame = cap.read()
70 | if not ret: break
71 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB")))
73 |
74 | cap.release()
75 |
76 | if len(images) > 0:
77 | video_data = th.tensor(np.stack(images))
78 | else:
79 | video_data = th.zeros(1)
80 |
81 | return {'video': video_data}
82 |
83 | def get_video_data(self, video_path, start_time=None, end_time=None):
84 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time)
85 | return image_input
86 |
87 | def process_raw_data(self, raw_video_data):
88 | tensor_size = raw_video_data.size()
89 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1])
90 | return tensor
91 |
92 | def process_frame_order(self, raw_video_data, frame_order=0):
93 | # 0: ordinary order; 1: reverse order; 2: random order.
94 | if frame_order == 0:
95 | pass
96 | elif frame_order == 1:
97 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1)
98 | raw_video_data = raw_video_data[reverse_order, ...]
99 | elif frame_order == 2:
100 | random_order = np.arange(raw_video_data.size(0))
101 | np.random.shuffle(random_order)
102 | raw_video_data = raw_video_data[random_order, ...]
103 |
104 | return raw_video_data
105 |
106 | # An ordinary video frame extractor based CV2
107 | RawVideoExtractor = RawVideoExtractorCV2
--------------------------------------------------------------------------------
/modules/until_config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import copy
24 | import json
25 | import logging
26 | import tarfile
27 | import tempfile
28 | import shutil
29 | import torch
30 | from .file_utils import cached_path
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 | class PretrainedConfig(object):
35 |
36 | pretrained_model_archive_map = {}
37 | config_name = ""
38 | weights_name = ""
39 |
40 | @classmethod
41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
43 | if os.path.exists(archive_file) is False:
44 | if pretrained_model_name in cls.pretrained_model_archive_map:
45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
46 | else:
47 | archive_file = pretrained_model_name
48 |
49 | # redirect to the cache, if necessary
50 | try:
51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
52 | except FileNotFoundError:
53 | if task_config is None or task_config.local_rank == 0:
54 | logger.error(
55 | "Model name '{}' was not found in model name list. "
56 | "We assumed '{}' was a path or url but couldn't find any file "
57 | "associated to this path or url.".format(
58 | pretrained_model_name,
59 | archive_file))
60 | return None
61 | if resolved_archive_file == archive_file:
62 | if task_config is None or task_config.local_rank == 0:
63 | logger.info("loading archive file {}".format(archive_file))
64 | else:
65 | if task_config is None or task_config.local_rank == 0:
66 | logger.info("loading archive file {} from cache at {}".format(
67 | archive_file, resolved_archive_file))
68 | tempdir = None
69 | if os.path.isdir(resolved_archive_file):
70 | serialization_dir = resolved_archive_file
71 | else:
72 | # Extract archive to temp dir
73 | tempdir = tempfile.mkdtemp()
74 | if task_config is None or task_config.local_rank == 0:
75 | logger.info("extracting archive file {} to temp dir {}".format(
76 | resolved_archive_file, tempdir))
77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive:
78 | archive.extractall(tempdir)
79 | serialization_dir = tempdir
80 | # Load config
81 | config_file = os.path.join(serialization_dir, cls.config_name)
82 | config = cls.from_json_file(config_file)
83 | config.type_vocab_size = type_vocab_size
84 | if task_config is None or task_config.local_rank == 0:
85 | logger.info("Model config {}".format(config))
86 |
87 | if state_dict is None:
88 | weights_path = os.path.join(serialization_dir, cls.weights_name)
89 | if os.path.exists(weights_path):
90 | state_dict = torch.load(weights_path, map_location='cpu')
91 | else:
92 | if task_config is None or task_config.local_rank == 0:
93 | logger.info("Weight doesn't exsits. {}".format(weights_path))
94 |
95 | if tempdir:
96 | # Clean up temp dir
97 | shutil.rmtree(tempdir)
98 |
99 | return config, state_dict
100 |
101 | @classmethod
102 | def from_dict(cls, json_object):
103 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
104 | config = cls(vocab_size_or_config_json_file=-1)
105 | for key, value in json_object.items():
106 | config.__dict__[key] = value
107 | return config
108 |
109 | @classmethod
110 | def from_json_file(cls, json_file):
111 | """Constructs a `BertConfig` from a json file of parameters."""
112 | with open(json_file, "r", encoding='utf-8') as reader:
113 | text = reader.read()
114 | return cls.from_dict(json.loads(text))
115 |
116 | def __repr__(self):
117 | return str(self.to_json_string())
118 |
119 | def to_dict(self):
120 | """Serializes this instance to a Python dictionary."""
121 | output = copy.deepcopy(self.__dict__)
122 | return output
123 |
124 | def to_json_string(self):
125 | """Serializes this instance to a JSON string."""
126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
--------------------------------------------------------------------------------
/modules/tokenization_clip.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 | from ipdb import set_trace
10 |
11 | @lru_cache()
12 | def default_bpe():
13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
14 |
15 |
16 | @lru_cache()
17 | def bytes_to_unicode():
18 | """
19 | Returns list of utf-8 byte and a corresponding list of unicode strings.
20 | The reversible bpe codes work on unicode strings.
21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
23 | This is a signficant percentage of your normal, say, 32K bpe vocab.
24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
25 | And avoids mapping to whitespace/control characters the bpe code barfs on.
26 | """
27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
28 | cs = bs[:]
29 | n = 0
30 | for b in range(2**8):
31 | if b not in bs:
32 | bs.append(b)
33 | cs.append(2**8+n)
34 | n += 1
35 | cs = [chr(n) for n in cs]
36 | return dict(zip(bs, cs))
37 |
38 |
39 | def get_pairs(word):
40 | """Return set of symbol pairs in a word.
41 | Word is represented as tuple of symbols (symbols being variable-length strings).
42 | """
43 | pairs = set()
44 | prev_char = word[0]
45 | for char in word[1:]:
46 | pairs.add((prev_char, char))
47 | prev_char = char
48 | return pairs
49 |
50 |
51 | def basic_clean(text):
52 | text = ftfy.fix_text(text)
53 | text = html.unescape(html.unescape(text))
54 | return text.strip()
55 |
56 |
57 | def whitespace_clean(text):
58 | text = re.sub(r'\s+', ' ', text)
59 | text = text.strip()
60 | return text
61 |
62 |
63 | class SimpleTokenizer(object):
64 | def __init__(self, bpe_path: str = default_bpe()):
65 | self.byte_encoder = bytes_to_unicode()
66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
68 | merges = merges[1:49152-256-2+1]
69 | merges = [tuple(merge.split()) for merge in merges]
70 | vocab = list(bytes_to_unicode().values())
71 | vocab = vocab + [v+'' for v in vocab]
72 | for merge in merges:
73 | vocab.append(''.join(merge))
74 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
75 | self.encoder = dict(zip(vocab, range(len(vocab))))
76 | self.decoder = {v: k for k, v in self.encoder.items()}
77 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
78 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
79 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
80 |
81 | self.vocab = self.encoder
82 |
83 | def bpe(self, token):
84 | if token in self.cache:
85 | return self.cache[token]
86 | word = tuple(token[:-1]) + ( token[-1] + '',)
87 | pairs = get_pairs(word)
88 |
89 | if not pairs:
90 | return token+''
91 |
92 | while True:
93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
94 | if bigram not in self.bpe_ranks:
95 | break
96 | first, second = bigram
97 | new_word = []
98 | i = 0
99 | while i < len(word):
100 | try:
101 | j = word.index(first, i)
102 | new_word.extend(word[i:j])
103 | i = j
104 | except:
105 | new_word.extend(word[i:])
106 | break
107 |
108 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
109 | new_word.append(first+second)
110 | i += 2
111 | else:
112 | new_word.append(word[i])
113 | i += 1
114 | new_word = tuple(new_word)
115 | word = new_word
116 | if len(word) == 1:
117 | break
118 | else:
119 | pairs = get_pairs(word)
120 | word = ' '.join(word)
121 | self.cache[token] = word
122 | return word
123 |
124 | def encode(self, text):
125 | bpe_tokens = []
126 | text = whitespace_clean(basic_clean(text)).lower()
127 | for token in re.findall(self.pat, text):
128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
130 | return bpe_tokens
131 |
132 | def decode(self, tokens):
133 | text = ''.join([self.decoder[token] for token in tokens])
134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
135 | return text
136 |
137 | def tokenize(self, text):
138 | tokens = []
139 | text = whitespace_clean(basic_clean(text)).lower()
140 | for token in re.findall(self.pat, text):
141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
142 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
143 | return tokens
144 |
145 | def convert_tokens_to_ids(self, tokens):
146 | return [self.encoder[bpe_token] for bpe_token in tokens]
--------------------------------------------------------------------------------
/preprocess/compress_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip
3 | Author: ArrowLuo
4 | """
5 | import os
6 | import argparse
7 | import ffmpeg
8 | import subprocess
9 | import time
10 | import multiprocessing
11 | from multiprocessing import Pool
12 | import shutil
13 | try:
14 | from psutil import cpu_count
15 | except:
16 | from multiprocessing import cpu_count
17 | # multiprocessing.freeze_support()
18 | from ipdb import set_trace
19 | def compress(paras):
20 | input_video_path, output_video_path = paras
21 | output_video_path_ori = output_video_path #.split('.')[0]+'.mp4'
22 |
23 |
24 | output_video_path = os.path.splitext(output_video_path)[0] #+'.mp4'
25 |
26 |
27 | ## for audio/images extractation
28 | output_img_path = output_video_path + '/%04d.jpg'
29 | output_audio_path = output_video_path + '/%04d.wav'
30 | # output_audio2_path = output_video_path.split('.')[0]+'.wav'
31 | output_audio2_path = output_video_path + '.wav'
32 | try:
33 | # command = ['ffmpeg',
34 | # '-y', # (optional) overwrite output file if it exists
35 | # '-i', input_video_path,
36 | # '-c:v',
37 | # 'libx264',
38 | # '-c:a',
39 | # 'libmp3lame',
40 | # '-b:a',
41 | # '128K',
42 | # '-max_muxing_queue_size', '9999',
43 | # '-vf',
44 | # 'fps=3 ', # scale to 224 "scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'"
45 | # # '-max_muxing_queue_size', '9999',
46 | # # "scale=224:224",
47 | # # '-c:a', 'copy',
48 | # # 'fps=fps=30', # frames per second
49 | # output_video_path_ori,
50 | # ]
51 |
52 | ### ori compressed ---------->
53 | # command = ['ffmpeg',
54 | # '-y', # (optional) overwrite output file if it exists
55 | # '-i', input_video_path,
56 | # '-filter:v',
57 | # 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'', # scale to 224
58 | # '-map', '0:v',
59 | # '-r', '3', # frames per second
60 | # output_video_path_ori,
61 | # ]
62 | ########### <----------------
63 |
64 | ############# for extract images ##############
65 | # command = ['ffmpeg',
66 | # '-y', # (optional) overwrite output file if it exists
67 | # '-i', input_video_path,
68 | # '-vf',
69 | # 'fps=3',
70 | # output_img_path,
71 | # ]
72 | ########## end extract images ###################################
73 |
74 |
75 | ######### for extract audio ----->
76 | # ffmpeg -i /playpen-iop/yblin/v1-2/train/v_XazKuBawFCM.mp4 -map 0:a -f segment -segment_time 10 -acodec pcm_s16le -ac 1 -ar 16000 /playpen-iop/yblin/v1-2/train_audio/output_%03d.wav
77 | # ffmpeg -y -i /playpen-iop/yblin/yk2/raw_videos_all/low_all_val/EpNUSTO2BI4.mp4 -map 0:a -f segment -segment_time 10000000 -acodec pcm_s16le -ac 1 -ar 16000 /playpen-iop/yblin/yk2/audio_raw_val/EpNUSTO2BI4.wav
78 | command = ['ffmpeg',
79 | '-y', # (optional) overwrite output file if it exists
80 | '-i', input_video_path,
81 | '-acodec', 'pcm_s16le', '-ac', '1',
82 | '-ar', '16000', # resample
83 | output_audio2_path,
84 | ]
85 | ### <------
86 |
87 |
88 | ######### for extract audio 2 ###########
89 | # command = ['ffmpeg',
90 | # '-y', # (optional) overwrite output file if it exists
91 | # '-i', input_video_path,
92 | # '-map','0:a', '-f', 'segment',
93 | # '-segment_time', '10000000', # seconds here
94 | # '-acodec', 'pcm_s16le', '-ac', '1',
95 | # '-ar', '16000', # resample
96 | # output_audio_path,
97 | # ]
98 | #####################
99 |
100 | # command= [
101 | # 'mkdir',
102 | # output_video_path, # (optional) overwrite output file if it exists
103 | # ]
104 |
105 | print(command)
106 |
107 | # ffmpeg -y -i /playpen-iop/yblin/v1-2/val/v_6VT2jBflMAM.mp4 -vf fps=3 /playpen-iop/yblin/v1-2/val_low_scale/v_6VT2jBflMAM.mp4
108 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
109 | out, err = ffmpeg.communicate()
110 | retcode = ffmpeg.poll()
111 | # print something above for debug
112 | except Exception as e:
113 | raise e
114 |
115 | def prepare_input_output_pairs(input_root, output_root):
116 | input_video_path_list = []
117 | output_video_path_list = []
118 | for root, dirs, files in os.walk(input_root):
119 |
120 | for file_name in files:
121 | input_video_path = os.path.join(root, file_name)
122 |
123 | output_video_path = os.path.join(output_root, file_name)
124 | if os.path.exists(output_video_path):
125 | pass
126 | else:
127 | input_video_path_list.append(input_video_path)
128 | output_video_path_list.append(output_video_path)
129 |
130 | return input_video_path_list, output_video_path_list
131 | # ffmpeg -y -i /nas/longleaf/home/yanbo/dataset/msvd_data/MSVD_Videos/-4wsuPCjDBc_5_15.avi -vf "fps=3,scale=224:224" /nas/longleaf/home/yanbo/dataset/msvd_data/MSVD_Comp2/-4wsuPCjDBc_5_15.avi
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser(description='Compress video for speed-up')
136 | parser.add_argument('--input_root', type=str, help='input root')
137 | parser.add_argument('--output_root', type=str, help='output root')
138 | args = parser.parse_args()
139 |
140 | input_root = args.input_root
141 | output_root = args.output_root
142 |
143 | assert input_root != output_root
144 |
145 | if not os.path.exists(output_root):
146 | os.makedirs(output_root, exist_ok=True)
147 |
148 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root)
149 |
150 |
151 | print("Total video need to process: {}".format(len(input_video_path_list)))
152 | num_works = cpu_count()
153 | print("Begin with {}-core logical processor.".format(num_works))
154 |
155 |
156 | # pool = Pool(num_works)
157 | pool = Pool(32)
158 |
159 | data_dict_list = pool.map(compress,
160 | [(input_video_path, output_video_path) for
161 | input_video_path, output_video_path in
162 | zip(input_video_path_list, output_video_path_list)])
163 | pool.close()
164 | pool.join()
165 |
166 |
--------------------------------------------------------------------------------
/torchvggish/vggish.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch import hub
5 |
6 | from . import vggish_input, vggish_params
7 |
8 |
9 | class VGG(nn.Module):
10 | def __init__(self, features):
11 | super(VGG, self).__init__()
12 | self.features = features
13 | self.embeddings = nn.Sequential(
14 | nn.Linear(512 * 4 * 6, 4096),
15 | nn.ReLU(True),
16 | nn.Linear(4096, 4096),
17 | nn.ReLU(True),
18 | nn.Linear(4096, 128),
19 | nn.ReLU(True))
20 |
21 | def forward(self, x):
22 | x = self.features(x)
23 |
24 | # Transpose the output from features to
25 | # remain compatible with vggish embeddings
26 | x = torch.transpose(x, 1, 3)
27 | x = torch.transpose(x, 1, 2)
28 | x = x.contiguous()
29 | x = x.view(x.size(0), -1)
30 |
31 | return self.embeddings(x)
32 |
33 |
34 | class Postprocessor(nn.Module):
35 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a
36 | numpy array in order to preserve the gradient.
37 |
38 | "The initial release of AudioSet included 128-D VGGish embeddings for each
39 | segment of AudioSet. These released embeddings were produced by applying
40 | a PCA transformation (technically, a whitening transform is included as well)
41 | and 8-bit quantization to the raw embedding output from VGGish, in order to
42 | stay compatible with the YouTube-8M project which provides visual embeddings
43 | in the same format for a large set of YouTube videos. This class implements
44 | the same PCA (with whitening) and quantization transformations."
45 | """
46 |
47 | def __init__(self):
48 | """Constructs a postprocessor."""
49 | super(Postprocessor, self).__init__()
50 | # Create empty matrix, for user's state_dict to load
51 | self.pca_eigen_vectors = torch.empty(
52 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
53 | dtype=torch.float,
54 | )
55 | self.pca_means = torch.empty(
56 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
57 | )
58 |
59 | self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False)
60 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)
61 |
62 | def postprocess(self, embeddings_batch):
63 | """Applies tensor postprocessing to a batch of embeddings.
64 |
65 | Args:
66 | embeddings_batch: An tensor of shape [batch_size, embedding_size]
67 | containing output from the embedding layer of VGGish.
68 |
69 | Returns:
70 | A tensor of the same shape as the input, containing the PCA-transformed,
71 | quantized, and clipped version of the input.
72 | """
73 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % (
74 | embeddings_batch.shape,
75 | )
76 | assert (
77 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE
78 | ), "Bad batch shape: %r" % (embeddings_batch.shape,)
79 |
80 | # Apply PCA.
81 | # - Embeddings come in as [batch_size, embedding_size].
82 | # - Transpose to [embedding_size, batch_size].
83 | # - Subtract pca_means column vector from each column.
84 | # - Premultiply by PCA matrix of shape [output_dims, input_dims]
85 | # where both are are equal to embedding_size in our case.
86 | # - Transpose result back to [batch_size, embedding_size].
87 | pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t()
88 |
89 | # Quantize by:
90 | # - clipping to [min, max] range
91 | clipped_embeddings = torch.clamp(
92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL
93 | )
94 | # - convert to 8-bit in range [0.0, 255.0]
95 | quantized_embeddings = torch.round(
96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL)
97 | * (
98 | 255.0
99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)
100 | )
101 | )
102 | return torch.squeeze(quantized_embeddings)
103 |
104 | def forward(self, x):
105 | return self.postprocess(x)
106 |
107 |
108 | def make_layers():
109 | layers = []
110 | in_channels = 1
111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]:
112 | if v == "M":
113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
114 | else:
115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
116 | layers += [conv2d, nn.ReLU(inplace=True)]
117 | in_channels = v
118 | return nn.Sequential(*layers)
119 |
120 |
121 | def _vgg():
122 | return VGG(make_layers())
123 |
124 |
125 | # def _spectrogram():
126 | # config = dict(
127 | # sr=16000,
128 | # n_fft=400,
129 | # n_mels=64,
130 | # hop_length=160,
131 | # window="hann",
132 | # center=False,
133 | # pad_mode="reflect",
134 | # htk=True,
135 | # fmin=125,
136 | # fmax=7500,
137 | # output_format='Magnitude',
138 | # # device=device,
139 | # )
140 | # return Spectrogram.MelSpectrogram(**config)
141 |
142 |
143 | class VGGish(VGG):
144 | def __init__(self, urls, device=None, pretrained=True, preprocess=True, postprocess=True, progress=True):
145 | super().__init__(make_layers())
146 | if pretrained:
147 | state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress)
148 | super().load_state_dict(state_dict)
149 |
150 | if device is None:
151 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152 | self.device = device
153 | self.preprocess = preprocess
154 | self.postprocess = postprocess
155 | if self.postprocess:
156 | self.pproc = Postprocessor()
157 | if pretrained:
158 | state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress)
159 | # TODO: Convert the state_dict to torch
160 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
161 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
162 | )
163 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
164 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
165 | )
166 |
167 | self.pproc.load_state_dict(state_dict)
168 | self.to(self.device)
169 |
170 | def forward(self, x, fs=None):
171 | # if self.preprocess:
172 | # x = self._preprocess(x, fs)
173 | x = x.to(self.device)
174 | x = VGG.forward(self, x)
175 | # if self.postprocess:
176 | # x = self._postprocess(x)
177 | return x
178 |
179 | def _preprocess(self, x, fs):
180 | if isinstance(x, np.ndarray):
181 | x = vggish_input.waveform_to_examples(x, fs)
182 | elif isinstance(x, str):
183 | x = vggish_input.wavfile_to_examples(x)
184 | else:
185 | raise AttributeError
186 | return x
187 |
188 | def _postprocess(self, x):
189 | return self.pproc(x)
190 |
--------------------------------------------------------------------------------
/modules/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | return 0.5 * (1.0 + math.cos(math.pi * x))
30 |
31 | def warmup_constant(x, warmup=0.002):
32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
33 | Learning rate is 1. afterwards. """
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0
37 |
38 | def warmup_linear(x, warmup=0.002):
39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
40 | After `t_total`-th training step, learning rate is zero. """
41 | if x < warmup:
42 | return x/warmup
43 | return max((x-1.)/(warmup-1.), 0)
44 |
45 | SCHEDULES = {
46 | 'warmup_cosine': warmup_cosine,
47 | 'warmup_constant': warmup_constant,
48 | 'warmup_linear': warmup_linear,
49 | }
50 |
51 |
52 | class BertAdam(Optimizer):
53 | """Implements BERT version of Adam algorithm with weight decay fix.
54 | Params:
55 | lr: learning rate
56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
57 | t_total: total number of training steps for the learning
58 | rate schedule, -1 means constant learning rate. Default: -1
59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
60 | b1: Adams b1. Default: 0.9
61 | b2: Adams b2. Default: 0.999
62 | e: Adams epsilon. Default: 1e-6
63 | weight_decay: Weight decay. Default: 0.01
64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
65 | """
66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
68 | max_grad_norm=1.0):
69 | if lr is not required and lr < 0.0:
70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
71 | if schedule not in SCHEDULES:
72 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
73 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
75 | if not 0.0 <= b1 < 1.0:
76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
77 | if not 0.0 <= b2 < 1.0:
78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
79 | if not e >= 0.0:
80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
83 | max_grad_norm=max_grad_norm)
84 | super(BertAdam, self).__init__(params, defaults)
85 |
86 | def get_lr(self):
87 | lr = []
88 | for group in self.param_groups:
89 | for p in group['params']:
90 | if p.grad is None:
91 | continue
92 | state = self.state[p]
93 | if len(state) == 0:
94 | return [0]
95 | if group['t_total'] != -1:
96 | schedule_fct = SCHEDULES[group['schedule']]
97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
98 | else:
99 | lr_scheduled = group['lr']
100 | lr.append(lr_scheduled)
101 | return lr
102 |
103 | def step(self, closure=None):
104 | """Performs a single optimization step.
105 | Arguments:
106 | closure (callable, optional): A closure that reevaluates the model
107 | and returns the loss.
108 | """
109 | loss = None
110 | if closure is not None:
111 | loss = closure()
112 |
113 | for group in self.param_groups:
114 | for p in group['params']:
115 | if p.grad is None:
116 | continue
117 | grad = p.grad.data
118 | if grad.is_sparse:
119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
120 |
121 | state = self.state[p]
122 |
123 | # State initialization
124 | if len(state) == 0:
125 | state['step'] = 0
126 | # Exponential moving average of gradient values
127 | state['next_m'] = torch.zeros_like(p.data)
128 | # Exponential moving average of squared gradient values
129 | state['next_v'] = torch.zeros_like(p.data)
130 |
131 | next_m, next_v = state['next_m'], state['next_v']
132 | beta1, beta2 = group['b1'], group['b2']
133 |
134 | # Add grad clipping
135 | if group['max_grad_norm'] > 0:
136 | clip_grad_norm_(p, group['max_grad_norm'])
137 |
138 | # Decay the first and second moment running average coefficient
139 | # In-place operations to update the averages at the same time
140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
144 | update = next_m / (next_v.sqrt() + group['e'])
145 |
146 | # Just adding the square of the weights to the loss function is *not*
147 | # the correct way of using L2 regularization/weight decay with Adam,
148 | # since that will interact with the m and v parameters in strange ways.
149 | #
150 | # Instead we want to decay the weights in a manner that doesn't interact
151 | # with the m/v parameters. This is equivalent to adding the square
152 | # of the weights to the loss with plain (non-momentum) SGD.
153 | if group['weight_decay'] > 0.0:
154 | update += group['weight_decay'] * p.data
155 |
156 | if group['t_total'] != -1:
157 | schedule_fct = SCHEDULES[group['schedule']]
158 | progress = state['step']/group['t_total']
159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
160 | else:
161 | lr_scheduled = group['lr']
162 |
163 | update_with_lr = lr_scheduled * update
164 | p.data.add_(-update_with_lr)
165 |
166 | state['step'] += 1
167 |
168 | return loss
--------------------------------------------------------------------------------
/modules/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import os
8 | import logging
9 | import shutil
10 | import tempfile
11 | import json
12 | from urllib.parse import urlparse
13 | from pathlib import Path
14 | from typing import Optional, Tuple, Union, IO, Callable, Set
15 | from hashlib import sha256
16 | from functools import wraps
17 |
18 | from tqdm import tqdm
19 |
20 | import boto3
21 | from botocore.exceptions import ClientError
22 | import requests
23 |
24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
25 |
26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
27 | Path.home() / '.pytorch_pretrained_bert'))
28 |
29 |
30 | def url_to_filename(url: str, etag: str = None) -> str:
31 | """
32 | Convert `url` into a hashed filename in a repeatable way.
33 | If `etag` is specified, append its hash to the url's, delimited
34 | by a period.
35 | """
36 | url_bytes = url.encode('utf-8')
37 | url_hash = sha256(url_bytes)
38 | filename = url_hash.hexdigest()
39 |
40 | if etag:
41 | etag_bytes = etag.encode('utf-8')
42 | etag_hash = sha256(etag_bytes)
43 | filename += '.' + etag_hash.hexdigest()
44 |
45 | return filename
46 |
47 |
48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
49 | """
50 | Return the url and etag (which may be ``None``) stored for `filename`.
51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
52 | """
53 | if cache_dir is None:
54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55 | if isinstance(cache_dir, Path):
56 | cache_dir = str(cache_dir)
57 |
58 | cache_path = os.path.join(cache_dir, filename)
59 | if not os.path.exists(cache_path):
60 | raise FileNotFoundError("file {} not found".format(cache_path))
61 |
62 | meta_path = cache_path + '.json'
63 | if not os.path.exists(meta_path):
64 | raise FileNotFoundError("file {} not found".format(meta_path))
65 |
66 | with open(meta_path) as meta_file:
67 | metadata = json.load(meta_file)
68 | url = metadata['url']
69 | etag = metadata['etag']
70 |
71 | return url, etag
72 |
73 |
74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
75 | """
76 | Given something that might be a URL (or might be a local path),
77 | determine which. If it's a URL, download the file and cache it, and
78 | return the path to the cached file. If it's already a local path,
79 | make sure the file exists and then return the path.
80 | """
81 | if cache_dir is None:
82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
83 | if isinstance(url_or_filename, Path):
84 | url_or_filename = str(url_or_filename)
85 | if isinstance(cache_dir, Path):
86 | cache_dir = str(cache_dir)
87 |
88 | parsed = urlparse(url_or_filename)
89 |
90 | if parsed.scheme in ('http', 'https', 's3'):
91 | # URL, so get it from the cache (downloading if necessary)
92 | return get_from_cache(url_or_filename, cache_dir)
93 | elif os.path.exists(url_or_filename):
94 | # File, and it exists.
95 | return url_or_filename
96 | elif parsed.scheme == '':
97 | # File, but it doesn't exist.
98 | raise FileNotFoundError("file {} not found".format(url_or_filename))
99 | else:
100 | # Something unknown
101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
102 |
103 |
104 | def split_s3_path(url: str) -> Tuple[str, str]:
105 | """Split a full s3 path into the bucket name and path."""
106 | parsed = urlparse(url)
107 | if not parsed.netloc or not parsed.path:
108 | raise ValueError("bad s3 path {}".format(url))
109 | bucket_name = parsed.netloc
110 | s3_path = parsed.path
111 | # Remove '/' at beginning of path.
112 | if s3_path.startswith("/"):
113 | s3_path = s3_path[1:]
114 | return bucket_name, s3_path
115 |
116 |
117 | def s3_request(func: Callable):
118 | """
119 | Wrapper function for s3 requests in order to create more helpful error
120 | messages.
121 | """
122 |
123 | @wraps(func)
124 | def wrapper(url: str, *args, **kwargs):
125 | try:
126 | return func(url, *args, **kwargs)
127 | except ClientError as exc:
128 | if int(exc.response["Error"]["Code"]) == 404:
129 | raise FileNotFoundError("file {} not found".format(url))
130 | else:
131 | raise
132 |
133 | return wrapper
134 |
135 |
136 | @s3_request
137 | def s3_etag(url: str) -> Optional[str]:
138 | """Check ETag on S3 object."""
139 | s3_resource = boto3.resource("s3")
140 | bucket_name, s3_path = split_s3_path(url)
141 | s3_object = s3_resource.Object(bucket_name, s3_path)
142 | return s3_object.e_tag
143 |
144 |
145 | @s3_request
146 | def s3_get(url: str, temp_file: IO) -> None:
147 | """Pull a file directly from S3."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
151 |
152 |
153 | def http_get(url: str, temp_file: IO) -> None:
154 | req = requests.get(url, stream=True)
155 | content_length = req.headers.get('Content-Length')
156 | total = int(content_length) if content_length is not None else None
157 | progress = tqdm(unit="B", total=total)
158 | for chunk in req.iter_content(chunk_size=1024):
159 | if chunk: # filter out keep-alive new chunks
160 | progress.update(len(chunk))
161 | temp_file.write(chunk)
162 | progress.close()
163 |
164 |
165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
166 | """
167 | Given a URL, look for the corresponding dataset in the local cache.
168 | If it's not there, download it. Then return the path to the cached file.
169 | """
170 | if cache_dir is None:
171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172 | if isinstance(cache_dir, Path):
173 | cache_dir = str(cache_dir)
174 |
175 | os.makedirs(cache_dir, exist_ok=True)
176 |
177 | # Get eTag to add to filename, if it exists.
178 | if url.startswith("s3://"):
179 | etag = s3_etag(url)
180 | else:
181 | response = requests.head(url, allow_redirects=True)
182 | if response.status_code != 200:
183 | raise IOError("HEAD request failed for url {} with status code {}"
184 | .format(url, response.status_code))
185 | etag = response.headers.get("ETag")
186 |
187 | filename = url_to_filename(url, etag)
188 |
189 | # get cache path to put the file
190 | cache_path = os.path.join(cache_dir, filename)
191 |
192 | if not os.path.exists(cache_path):
193 | # Download to temporary file, then copy to cache dir once finished.
194 | # Otherwise you get corrupt cache entries if the download gets interrupted.
195 | with tempfile.NamedTemporaryFile() as temp_file:
196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
197 |
198 | # GET file object
199 | if url.startswith("s3://"):
200 | s3_get(url, temp_file)
201 | else:
202 | http_get(url, temp_file)
203 |
204 | # we are copying the file before closing it, so flush to avoid truncation
205 | temp_file.flush()
206 | # shutil.copyfileobj() starts at the current position, so go to the start
207 | temp_file.seek(0)
208 |
209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
210 | with open(cache_path, 'wb') as cache_file:
211 | shutil.copyfileobj(temp_file, cache_file)
212 |
213 | logger.info("creating metadata file for %s", cache_path)
214 | meta = {'url': url, 'etag': etag}
215 | meta_path = cache_path + '.json'
216 | with open(meta_path, 'w') as meta_file:
217 | json.dump(meta, meta_file)
218 |
219 | logger.info("removing temp file %s", temp_file.name)
220 |
221 | return cache_path
222 |
223 |
224 | def read_set_from_file(filename: str) -> Set[str]:
225 | '''
226 | Extract a de-duped collection (set) of text from a file.
227 | Expected file format is one item per line.
228 | '''
229 | collection = set()
230 | with open(filename, 'r', encoding='utf-8') as file_:
231 | for line in file_:
232 | collection.add(line.rstrip())
233 | return collection
234 |
235 |
236 | def get_file_extension(path: str, dot=True, lower: bool = True):
237 | ext = os.path.splitext(path)[1]
238 | ext = ext if dot else ext[1:]
239 | return ext.lower() if lower else ext
240 |
--------------------------------------------------------------------------------
/dataloaders/dataloader_lsmdc_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import os
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | import json
10 | import math
11 | from dataloaders.rawvideo_util import RawVideoExtractor
12 |
13 | class LSMDC_DataLoader(Dataset):
14 | """LSMDC dataset loader."""
15 | def __init__(
16 | self,
17 | subset,
18 | data_path,
19 | features_path,
20 | tokenizer,
21 | max_words=30,
22 | feature_framerate=1.0,
23 | max_frames=100,
24 | image_resolution=224,
25 | frame_order=0,
26 | slice_framepos=0,
27 | ):
28 | self.data_path = data_path
29 | self.features_path = features_path
30 | self.feature_framerate = feature_framerate
31 | self.max_words = max_words
32 | self.max_frames = max_frames
33 | self.tokenizer = tokenizer
34 | # 0: ordinary order; 1: reverse order; 2: random order.
35 | self.frame_order = frame_order
36 | assert self.frame_order in [0, 1, 2]
37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
38 | self.slice_framepos = slice_framepos
39 | assert self.slice_framepos in [0, 1, 2]
40 |
41 | self.subset = subset
42 | assert self.subset in ["train", "val", "test"]
43 |
44 | video_json_path_dict = {}
45 | video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv")
46 | video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv")
47 | video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv")
48 |
49 | # \t\t\t\t\t
50 | # is not a unique identifier, i.e. the same can be associated with multiple sentences.
51 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances
52 | video_id_list = []
53 | caption_dict = {}
54 | with open(video_json_path_dict[self.subset], 'r') as fp:
55 | for line in fp:
56 | line = line.strip()
57 | line_split = line.split("\t")
58 | assert len(line_split) == 6
59 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split
60 | caption_dict[len(caption_dict)] = (clip_id, sentence)
61 | if clip_id not in video_id_list: video_id_list.append(clip_id)
62 |
63 | video_dict = {}
64 | for root, dub_dir, video_files in os.walk(self.features_path):
65 | for video_file in video_files:
66 | video_id_ = ".".join(video_file.split(".")[:-1])
67 | if video_id_ not in video_id_list:
68 | continue
69 | file_path_ = os.path.join(root, video_file)
70 | video_dict[video_id_] = file_path_
71 |
72 | self.video_dict = video_dict
73 |
74 | # Get all captions
75 | self.iter2video_pairs_dict = {}
76 | for clip_id, sentence in caption_dict.values():
77 | if clip_id not in self.video_dict:
78 | continue
79 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence)
80 |
81 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
82 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
83 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
84 |
85 | def __len__(self):
86 | return len(self.iter2video_pairs_dict)
87 |
88 | def _get_video_id_from_pseduo(self, pseudo_video_id):
89 | video_id = pseudo_video_id[2:]
90 | return video_id
91 |
92 | def _get_video_id_single(self, path):
93 | pseudo_video_id_list = []
94 | video_id_list = []
95 | print('Loading json: {}'.format(path))
96 | with open(path, 'r') as f:
97 | json_data = json.load(f)
98 |
99 | for pseudo_video_id in json_data:
100 | if pseudo_video_id in pseudo_video_id_list:
101 | print("reduplicate.")
102 | else:
103 | video_id = self._get_video_id_from_pseduo(pseudo_video_id)
104 | pseudo_video_id_list.append(pseudo_video_id)
105 | video_id_list.append(video_id)
106 | return pseudo_video_id_list, video_id_list
107 |
108 | def _get_captions_single(self, path):
109 | pseudo_caption_dict = {}
110 | with open(path, 'r') as f:
111 | json_data = json.load(f)
112 |
113 | for pseudo_video_id, v_ in json_data.items():
114 | pseudo_caption_dict[pseudo_video_id] = {}
115 | timestamps = v_["timestamps"]
116 | pseudo_caption_dict[pseudo_video_id]["start"] = \
117 | np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object)
118 | pseudo_caption_dict[pseudo_video_id]["end"] = \
119 | np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object)
120 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object)
121 | return pseudo_caption_dict
122 |
123 | def _get_text(self, video_id, caption):
124 | k = 1
125 | choice_video_ids = [video_id]
126 | pairs_text = np.zeros((k, self.max_words), dtype=np.long)
127 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
128 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
129 |
130 | for i, video_id in enumerate(choice_video_ids):
131 | words = self.tokenizer.tokenize(caption)
132 |
133 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
134 | total_length_with_CLS = self.max_words - 1
135 | if len(words) > total_length_with_CLS:
136 | words = words[:total_length_with_CLS]
137 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
138 |
139 | input_ids = self.tokenizer.convert_tokens_to_ids(words)
140 | input_mask = [1] * len(input_ids)
141 | segment_ids = [0] * len(input_ids)
142 | while len(input_ids) < self.max_words:
143 | input_ids.append(0)
144 | input_mask.append(0)
145 | segment_ids.append(0)
146 | assert len(input_ids) == self.max_words
147 | assert len(input_mask) == self.max_words
148 | assert len(segment_ids) == self.max_words
149 |
150 | pairs_text[i] = np.array(input_ids)
151 | pairs_mask[i] = np.array(input_mask)
152 | pairs_segment[i] = np.array(segment_ids)
153 |
154 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids
155 |
156 | def _get_rawvideo(self, choice_video_ids):
157 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
158 | max_video_length = [0] * len(choice_video_ids)
159 |
160 | # Pair x L x T x 3 x H x W
161 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3,
162 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
163 |
164 | try:
165 | for i, video_id in enumerate(choice_video_ids):
166 | video_path = self.video_dict[video_id]
167 |
168 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path)
169 | raw_video_data = raw_video_data['video']
170 |
171 | if len(raw_video_data.shape) > 3:
172 | raw_video_data_clip = raw_video_data
173 | # L x T x 3 x H x W
174 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
175 | if self.max_frames < raw_video_slice.shape[0]:
176 | if self.slice_framepos == 0:
177 | video_slice = raw_video_slice[:self.max_frames, ...]
178 | elif self.slice_framepos == 1:
179 | video_slice = raw_video_slice[-self.max_frames:, ...]
180 | else:
181 | sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int)
182 | video_slice = raw_video_slice[sample_indx, ...]
183 | else:
184 | video_slice = raw_video_slice
185 |
186 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
187 |
188 | slice_len = video_slice.shape[0]
189 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
190 | if slice_len < 1:
191 | pass
192 | else:
193 | video[i][:slice_len, ...] = video_slice
194 | else:
195 | print("video path: {} error. video id: {}".format(video_path, video_id))
196 | except Exception as excep:
197 | print("Video ids: {}".format(choice_video_ids))
198 | raise excep
199 |
200 | for i, v_length in enumerate(max_video_length):
201 | video_mask[i][:v_length] = [1] * v_length
202 | return video, video_mask
203 |
204 | def __getitem__(self, feature_idx):
205 | clip_id, sentence = self.iter2video_pairs_dict[feature_idx]
206 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence)
207 | video, video_mask = self._get_rawvideo(choice_video_ids)
208 | return pairs_text, pairs_mask, pairs_segment, video, video_mask
209 |
--------------------------------------------------------------------------------
/modules/module_cross.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import copy
7 | import json
8 | import math
9 | import logging
10 | import tarfile
11 | import tempfile
12 | import shutil
13 |
14 | import torch
15 | from torch import nn
16 | import torch.nn.functional as F
17 | from .file_utils import cached_path
18 | from .until_config import PretrainedConfig
19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN
20 | from collections import OrderedDict
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | PRETRAINED_MODEL_ARCHIVE_MAP = {}
25 | CONFIG_NAME = 'cross_config.json'
26 | WEIGHTS_NAME = 'cross_pytorch_model.bin'
27 |
28 |
29 | class CrossConfig(PretrainedConfig):
30 | """Configuration class to store the configuration of a `CrossModel`.
31 | """
32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
33 | config_name = CONFIG_NAME
34 | weights_name = WEIGHTS_NAME
35 | def __init__(self,
36 | vocab_size_or_config_json_file,
37 | hidden_size=768,
38 | num_hidden_layers=12,
39 | num_attention_heads=12,
40 | intermediate_size=3072,
41 | hidden_act="gelu",
42 | hidden_dropout_prob=0.1,
43 | attention_probs_dropout_prob=0.1,
44 | max_position_embeddings=512,
45 | type_vocab_size=2,
46 | initializer_range=0.02):
47 | """Constructs CrossConfig.
48 |
49 | Args:
50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
51 | hidden_size: Size of the encoder layers and the pooler layer.
52 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
53 | num_attention_heads: Number of attention heads for each attention layer in
54 | the Transformer encoder.
55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
56 | layer in the Transformer encoder.
57 | hidden_act: The non-linear activation function (function or string) in the
58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
60 | layers in the embeddings, encoder, and pooler.
61 | attention_probs_dropout_prob: The dropout ratio for the attention
62 | probabilities.
63 | max_position_embeddings: The maximum sequence length that this model might
64 | ever be used with. Typically set this to something large just in case
65 | (e.g., 512 or 1024 or 2048).
66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
67 | `CrossModel`.
68 | initializer_range: The sttdev of the truncated_normal_initializer for
69 | initializing all weight matrices.
70 | """
71 | if isinstance(vocab_size_or_config_json_file, str):
72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
73 | json_config = json.loads(reader.read())
74 | for key, value in json_config.items():
75 | self.__dict__[key] = value
76 | elif isinstance(vocab_size_or_config_json_file, int):
77 | self.vocab_size = vocab_size_or_config_json_file
78 | self.hidden_size = hidden_size
79 | self.num_hidden_layers = num_hidden_layers
80 | self.num_attention_heads = num_attention_heads
81 | self.hidden_act = hidden_act
82 | self.intermediate_size = intermediate_size
83 | self.hidden_dropout_prob = hidden_dropout_prob
84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
85 | self.max_position_embeddings = max_position_embeddings
86 | self.type_vocab_size = type_vocab_size
87 | self.initializer_range = initializer_range
88 | else:
89 | raise ValueError("First argument must be either a vocabulary size (int)"
90 | "or the path to a pretrained model config file (str)")
91 |
92 | class QuickGELU(nn.Module):
93 | def forward(self, x: torch.Tensor):
94 | return x * torch.sigmoid(1.702 * x)
95 |
96 | class ResidualAttentionBlock(nn.Module):
97 | def __init__(self, d_model: int, n_head: int):
98 | super().__init__()
99 |
100 | self.attn = nn.MultiheadAttention(d_model, n_head)
101 | self.ln_1 = LayerNorm(d_model)
102 | self.mlp = nn.Sequential(OrderedDict([
103 | ("c_fc", nn.Linear(d_model, d_model * 4)),
104 | ("gelu", QuickGELU()),
105 | ("c_proj", nn.Linear(d_model * 4, d_model))
106 | ]))
107 | self.ln_2 = LayerNorm(d_model)
108 | self.n_head = n_head
109 |
110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor):
111 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0)
112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
113 |
114 | def forward(self, para_tuple: tuple):
115 | # x: torch.Tensor, attn_mask: torch.Tensor
116 | # print(para_tuple)
117 | x, attn_mask = para_tuple
118 | x = x + self.attention(self.ln_1(x), attn_mask)
119 | x = x + self.mlp(self.ln_2(x))
120 | return (x, attn_mask)
121 |
122 | class Transformer(nn.Module):
123 | def __init__(self, width: int, layers: int, heads: int):
124 | super().__init__()
125 | self.width = width
126 | self.layers = layers
127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
128 |
129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
130 | return self.resblocks((x, attn_mask))[0]
131 |
132 | class CrossEmbeddings(nn.Module):
133 | """Construct the embeddings from word, position and token_type embeddings.
134 | """
135 | def __init__(self, config):
136 | super(CrossEmbeddings, self).__init__()
137 |
138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
141 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
142 |
143 | def forward(self, concat_embeddings, concat_type=None):
144 |
145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
146 | # if concat_type is None:
147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
148 |
149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
151 |
152 | # token_type_embeddings = self.token_type_embeddings(concat_type)
153 | position_embeddings = self.position_embeddings(position_ids)
154 |
155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings
156 | # embeddings = self.LayerNorm(embeddings)
157 | embeddings = self.dropout(embeddings)
158 | return embeddings
159 |
160 | class CrossPooler(nn.Module):
161 | def __init__(self, config):
162 | super(CrossPooler, self).__init__()
163 | self.ln_pool = LayerNorm(config.hidden_size)
164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
165 | self.activation = QuickGELU()
166 |
167 | def forward(self, hidden_states, hidden_mask):
168 | # We "pool" the model by simply taking the hidden state corresponding
169 | # to the first token.
170 | hidden_states = self.ln_pool(hidden_states)
171 | pooled_output = hidden_states[:, 0]
172 | pooled_output = self.dense(pooled_output)
173 | pooled_output = self.activation(pooled_output)
174 | return pooled_output
175 |
176 | class CrossModel(PreTrainedModel):
177 |
178 | def initialize_parameters(self):
179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
180 | attn_std = self.transformer.width ** -0.5
181 | fc_std = (2 * self.transformer.width) ** -0.5
182 | for block in self.transformer.resblocks:
183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
187 |
188 | def __init__(self, config):
189 | super(CrossModel, self).__init__(config)
190 |
191 | self.embeddings = CrossEmbeddings(config)
192 |
193 | transformer_width = config.hidden_size
194 | transformer_layers = config.num_hidden_layers
195 | transformer_heads = config.num_attention_heads
196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,)
197 | self.pooler = CrossPooler(config)
198 | self.apply(self.init_weights)
199 |
200 | def build_attention_mask(self, attention_mask):
201 | extended_attention_mask = attention_mask.unsqueeze(1)
202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0
204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1)
205 | return extended_attention_mask
206 |
207 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
208 |
209 | if attention_mask is None:
210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
211 | if concat_type is None:
212 | concat_type = torch.zeros_like(attention_mask)
213 |
214 | extended_attention_mask = self.build_attention_mask(attention_mask)
215 |
216 | embedding_output = self.embeddings(concat_input, concat_type)
217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND
218 | embedding_output = self.transformer(embedding_output, extended_attention_mask)
219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD
220 |
221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask)
222 |
223 | return embedding_output, pooled_output
224 |
--------------------------------------------------------------------------------
/dataloaders/mel_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Defines routines to compute mel spectrogram features from audio waveform."""
17 |
18 | import numpy as np
19 |
20 |
21 | def frame(data, window_length, hop_length):
22 | """Convert array into a sequence of successive possibly overlapping frames.
23 |
24 | An n-dimensional array of shape (num_samples, ...) is converted into an
25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26 | starts hop_length points after the preceding one.
27 |
28 | This is accomplished using stride_tricks, so the original data is not
29 | copied. However, there is no zero-padding, so any incomplete frames at the
30 | end are not included.
31 |
32 | Args:
33 | data: np.array of dimension N >= 1.
34 | window_length: Number of samples in each frame.
35 | hop_length: Advance (in samples) between each window.
36 |
37 | Returns:
38 | (N+1)-D np.array with as many rows as there are complete frames that can be
39 | extracted.
40 | """
41 | num_samples = data.shape[0]
42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
43 | shape = (num_frames, window_length) + data.shape[1:]
44 | strides = (data.strides[0] * hop_length,) + data.strides
45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
46 |
47 |
48 | def periodic_hann(window_length):
49 | """Calculate a "periodic" Hann window.
50 |
51 | The classic Hann window is defined as a raised cosine that starts and
52 | ends on zero, and where every value appears twice, except the middle
53 | point for an odd-length window. Matlab calls this a "symmetric" window
54 | and np.hanning() returns it. However, for Fourier analysis, this
55 | actually represents just over one cycle of a period N-1 cosine, and
56 | thus is not compactly expressed on a length-N Fourier basis. Instead,
57 | it's better to use a raised cosine that ends just before the final
58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab
59 | calls this a "periodic" window. This routine calculates it.
60 |
61 | Args:
62 | window_length: The number of points in the returned window.
63 |
64 | Returns:
65 | A 1D np.array containing the periodic hann window.
66 | """
67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
68 | np.arange(window_length)))
69 |
70 |
71 | def stft_magnitude(signal, fft_length,
72 | hop_length=None,
73 | window_length=None):
74 | """Calculate the short-time Fourier transform magnitude.
75 |
76 | Args:
77 | signal: 1D np.array of the input time-domain signal.
78 | fft_length: Size of the FFT to apply.
79 | hop_length: Advance (in samples) between each frame passed to FFT.
80 | window_length: Length of each block of samples to pass to FFT.
81 |
82 | Returns:
83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1
84 | unique values of the FFT for the corresponding frame of input samples.
85 | """
86 | frames = frame(signal, window_length, hop_length)
87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period
88 | # window_length) instead of the symmetric Hann of np.hanning (period
89 | # window_length-1).
90 | window = periodic_hann(window_length)
91 | windowed_frames = frames * window
92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
93 |
94 |
95 | # Mel spectrum constants and functions.
96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0
97 | _MEL_HIGH_FREQUENCY_Q = 1127.0
98 |
99 |
100 | def hertz_to_mel(frequencies_hertz):
101 | """Convert frequencies to mel scale using HTK formula.
102 |
103 | Args:
104 | frequencies_hertz: Scalar or np.array of frequencies in hertz.
105 |
106 | Returns:
107 | Object of same size as frequencies_hertz containing corresponding values
108 | on the mel scale.
109 | """
110 | return _MEL_HIGH_FREQUENCY_Q * np.log(
111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
112 |
113 |
114 | def spectrogram_to_mel_matrix(num_mel_bins=20,
115 | num_spectrogram_bins=129,
116 | audio_sample_rate=8000,
117 | lower_edge_hertz=125.0,
118 | upper_edge_hertz=3800.0):
119 | """Return a matrix that can post-multiply spectrogram rows to make mel.
120 |
121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of
122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
123 | "mel spectrogram" M of frames x num_mel_bins. M = S A.
124 |
125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands
126 | to multiply each FFT bin by only one mel weight, then add it, with positive
127 | and negative signs, to the two adjacent mel bands to which that bin
128 | contributes. Here, by expressing this operation as a matrix multiply, we go
129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around
130 | num_fft^2 multiplies and adds. However, because these are all presumably
131 | accomplished in a single call to np.dot(), it's not clear which approach is
132 | faster in Python. The matrix multiplication has the attraction of being more
133 | general and flexible, and much easier to read.
134 |
135 | Args:
136 | num_mel_bins: How many bands in the resulting mel spectrum. This is
137 | the number of columns in the output matrix.
138 | num_spectrogram_bins: How many bins there are in the source spectrogram
139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
140 | only contains the nonredundant FFT bins.
141 | audio_sample_rate: Samples per second of the audio at the input to the
142 | spectrogram. We need this to figure out the actual frequencies for
143 | each spectrogram bin, which dictates how they are mapped into mel.
144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel
145 | spectrum. This corresponds to the lower edge of the lowest triangular
146 | band.
147 | upper_edge_hertz: The desired top edge of the highest frequency band.
148 |
149 | Returns:
150 | An np.array with shape (num_spectrogram_bins, num_mel_bins).
151 |
152 | Raises:
153 | ValueError: if frequency edges are incorrectly ordered or out of range.
154 | """
155 | nyquist_hertz = audio_sample_rate / 2.
156 | if lower_edge_hertz < 0.0:
157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
158 | if lower_edge_hertz >= upper_edge_hertz:
159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
160 | (lower_edge_hertz, upper_edge_hertz))
161 | if upper_edge_hertz > nyquist_hertz:
162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
163 | (upper_edge_hertz, nyquist_hertz))
164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
166 | # The i'th mel band (starting from i=1) has center frequency
167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
169 | # the band_edges_mel arrays.
170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
173 | # of spectrogram values.
174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
175 | for i in range(num_mel_bins):
176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
177 | # Calculate lower and upper slopes for every spectrogram bin.
178 | # Line segments are linear in the *mel* domain, not hertz.
179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
180 | (center_mel - lower_edge_mel))
181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
182 | (upper_edge_mel - center_mel))
183 | # .. then intersect them with each other and zero.
184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
185 | upper_slope))
186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero
187 | # coefficient.
188 | mel_weights_matrix[0, :] = 0.0
189 | return mel_weights_matrix
190 |
191 |
192 | def log_mel_spectrogram(data,
193 | audio_sample_rate=8000,
194 | log_offset=0.0,
195 | window_length_secs=0.025,
196 | hop_length_secs=0.010,
197 | **kwargs):
198 | """Convert waveform to a log magnitude mel-frequency spectrogram.
199 |
200 | Args:
201 | data: 1D np.array of waveform data.
202 | audio_sample_rate: The sampling rate of data.
203 | log_offset: Add this to values when taking log to avoid -Infs.
204 | window_length_secs: Duration of each window to analyze.
205 | hop_length_secs: Advance between successive analysis windows.
206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
207 |
208 | Returns:
209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
210 | magnitudes for successive frames.
211 | """
212 | window_length_samples = int(round(audio_sample_rate * window_length_secs))
213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
215 | spectrogram = stft_magnitude(
216 | data,
217 | fft_length=fft_length,
218 | hop_length=hop_length_samples,
219 | window_length=window_length_samples)
220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
221 | num_spectrogram_bins=spectrogram.shape[1],
222 | audio_sample_rate=audio_sample_rate, **kwargs))
223 | return np.log(mel_spectrogram + log_offset)
224 |
--------------------------------------------------------------------------------
/torchvggish/mel_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Defines routines to compute mel spectrogram features from audio waveform."""
17 |
18 | import numpy as np
19 |
20 |
21 | def frame(data, window_length, hop_length):
22 | """Convert array into a sequence of successive possibly overlapping frames.
23 |
24 | An n-dimensional array of shape (num_samples, ...) is converted into an
25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26 | starts hop_length points after the preceding one.
27 |
28 | This is accomplished using stride_tricks, so the original data is not
29 | copied. However, there is no zero-padding, so any incomplete frames at the
30 | end are not included.
31 |
32 | Args:
33 | data: np.array of dimension N >= 1.
34 | window_length: Number of samples in each frame.
35 | hop_length: Advance (in samples) between each window.
36 |
37 | Returns:
38 | (N+1)-D np.array with as many rows as there are complete frames that can be
39 | extracted.
40 | """
41 | num_samples = data.shape[0]
42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
43 | shape = (num_frames, window_length) + data.shape[1:]
44 | strides = (data.strides[0] * hop_length,) + data.strides
45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
46 |
47 |
48 | def periodic_hann(window_length):
49 | """Calculate a "periodic" Hann window.
50 |
51 | The classic Hann window is defined as a raised cosine that starts and
52 | ends on zero, and where every value appears twice, except the middle
53 | point for an odd-length window. Matlab calls this a "symmetric" window
54 | and np.hanning() returns it. However, for Fourier analysis, this
55 | actually represents just over one cycle of a period N-1 cosine, and
56 | thus is not compactly expressed on a length-N Fourier basis. Instead,
57 | it's better to use a raised cosine that ends just before the final
58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab
59 | calls this a "periodic" window. This routine calculates it.
60 |
61 | Args:
62 | window_length: The number of points in the returned window.
63 |
64 | Returns:
65 | A 1D np.array containing the periodic hann window.
66 | """
67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
68 | np.arange(window_length)))
69 |
70 |
71 | def stft_magnitude(signal, fft_length,
72 | hop_length=None,
73 | window_length=None):
74 | """Calculate the short-time Fourier transform magnitude.
75 |
76 | Args:
77 | signal: 1D np.array of the input time-domain signal.
78 | fft_length: Size of the FFT to apply.
79 | hop_length: Advance (in samples) between each frame passed to FFT.
80 | window_length: Length of each block of samples to pass to FFT.
81 |
82 | Returns:
83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1
84 | unique values of the FFT for the corresponding frame of input samples.
85 | """
86 | frames = frame(signal, window_length, hop_length)
87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period
88 | # window_length) instead of the symmetric Hann of np.hanning (period
89 | # window_length-1).
90 | window = periodic_hann(window_length)
91 | windowed_frames = frames * window
92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
93 |
94 |
95 | # Mel spectrum constants and functions.
96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0
97 | _MEL_HIGH_FREQUENCY_Q = 1127.0
98 |
99 |
100 | def hertz_to_mel(frequencies_hertz):
101 | """Convert frequencies to mel scale using HTK formula.
102 |
103 | Args:
104 | frequencies_hertz: Scalar or np.array of frequencies in hertz.
105 |
106 | Returns:
107 | Object of same size as frequencies_hertz containing corresponding values
108 | on the mel scale.
109 | """
110 | return _MEL_HIGH_FREQUENCY_Q * np.log(
111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
112 |
113 |
114 | def spectrogram_to_mel_matrix(num_mel_bins=20,
115 | num_spectrogram_bins=129,
116 | audio_sample_rate=8000,
117 | lower_edge_hertz=125.0,
118 | upper_edge_hertz=3800.0):
119 | """Return a matrix that can post-multiply spectrogram rows to make mel.
120 |
121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of
122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
123 | "mel spectrogram" M of frames x num_mel_bins. M = S A.
124 |
125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands
126 | to multiply each FFT bin by only one mel weight, then add it, with positive
127 | and negative signs, to the two adjacent mel bands to which that bin
128 | contributes. Here, by expressing this operation as a matrix multiply, we go
129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around
130 | num_fft^2 multiplies and adds. However, because these are all presumably
131 | accomplished in a single call to np.dot(), it's not clear which approach is
132 | faster in Python. The matrix multiplication has the attraction of being more
133 | general and flexible, and much easier to read.
134 |
135 | Args:
136 | num_mel_bins: How many bands in the resulting mel spectrum. This is
137 | the number of columns in the output matrix.
138 | num_spectrogram_bins: How many bins there are in the source spectrogram
139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
140 | only contains the nonredundant FFT bins.
141 | audio_sample_rate: Samples per second of the audio at the input to the
142 | spectrogram. We need this to figure out the actual frequencies for
143 | each spectrogram bin, which dictates how they are mapped into mel.
144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel
145 | spectrum. This corresponds to the lower edge of the lowest triangular
146 | band.
147 | upper_edge_hertz: The desired top edge of the highest frequency band.
148 |
149 | Returns:
150 | An np.array with shape (num_spectrogram_bins, num_mel_bins).
151 |
152 | Raises:
153 | ValueError: if frequency edges are incorrectly ordered or out of range.
154 | """
155 | nyquist_hertz = audio_sample_rate / 2.
156 | if lower_edge_hertz < 0.0:
157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
158 | if lower_edge_hertz >= upper_edge_hertz:
159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
160 | (lower_edge_hertz, upper_edge_hertz))
161 | if upper_edge_hertz > nyquist_hertz:
162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
163 | (upper_edge_hertz, nyquist_hertz))
164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
166 | # The i'th mel band (starting from i=1) has center frequency
167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
169 | # the band_edges_mel arrays.
170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
173 | # of spectrogram values.
174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
175 | for i in range(num_mel_bins):
176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
177 | # Calculate lower and upper slopes for every spectrogram bin.
178 | # Line segments are linear in the *mel* domain, not hertz.
179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
180 | (center_mel - lower_edge_mel))
181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
182 | (upper_edge_mel - center_mel))
183 | # .. then intersect them with each other and zero.
184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
185 | upper_slope))
186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero
187 | # coefficient.
188 | mel_weights_matrix[0, :] = 0.0
189 | return mel_weights_matrix
190 |
191 |
192 | def log_mel_spectrogram(data,
193 | audio_sample_rate=8000,
194 | log_offset=0.0,
195 | window_length_secs=0.025,
196 | hop_length_secs=0.010,
197 | **kwargs):
198 | """Convert waveform to a log magnitude mel-frequency spectrogram.
199 |
200 | Args:
201 | data: 1D np.array of waveform data.
202 | audio_sample_rate: The sampling rate of data.
203 | log_offset: Add this to values when taking log to avoid -Infs.
204 | window_length_secs: Duration of each window to analyze.
205 | hop_length_secs: Advance between successive analysis windows.
206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
207 |
208 | Returns:
209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
210 | magnitudes for successive frames.
211 | """
212 | window_length_samples = int(round(audio_sample_rate * window_length_secs))
213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
215 | spectrogram = stft_magnitude(
216 | data,
217 | fft_length=fft_length,
218 | hop_length=hop_length_samples,
219 | window_length=window_length_samples)
220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
221 | num_spectrogram_bins=spectrogram.shape[1],
222 | audio_sample_rate=audio_sample_rate, **kwargs))
223 | return np.log(mel_spectrogram + log_offset)
224 |
--------------------------------------------------------------------------------
/dataloaders/dataloader_msvd_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import os
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | import pickle
10 | from dataloaders.rawvideo_util import RawVideoExtractor
11 | from . import video_container as container
12 | from . import decoder as decoder
13 | from . import utils as utils
14 | from ipdb import set_trace
15 | class MSVD_DataLoader(Dataset):
16 | """MSVD dataset loader."""
17 | def __init__(
18 | self,
19 | subset,
20 | data_path,
21 | features_path,
22 | tokenizer,
23 | max_words=30,
24 | feature_framerate=1.0,
25 | max_frames=100,
26 | image_resolution=224,
27 | frame_order=0,
28 | slice_framepos=0,
29 | ):
30 | self.data_path = data_path
31 | self.features_path = features_path
32 | self.feature_framerate = feature_framerate
33 | self.max_words = max_words
34 | self.max_frames = max_frames
35 | self.tokenizer = tokenizer
36 | # 0: ordinary order; 1: reverse order; 2: random order.
37 | self.frame_order = frame_order
38 | assert self.frame_order in [0, 1, 2]
39 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
40 | self.slice_framepos = slice_framepos
41 | assert self.slice_framepos in [0, 1, 2]
42 |
43 | self.subset = subset
44 | assert self.subset in ["train", "val", "test"]
45 | video_id_path_dict = {}
46 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt")
47 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt")
48 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt")
49 | caption_file = os.path.join(self.data_path, "raw-captions.pkl")
50 |
51 | with open(video_id_path_dict[self.subset], 'r') as fp:
52 | video_ids = [itm.strip() for itm in fp.readlines()]
53 |
54 | with open(caption_file, 'rb') as f:
55 | captions = pickle.load(f)
56 |
57 | video_dict = {}
58 | for root, dub_dir, video_files in os.walk(self.features_path):
59 | for video_file in video_files:
60 | video_id_ = ".".join(video_file.split(".")[:-1])
61 | if video_id_ not in video_ids:
62 | continue
63 | file_path_ = os.path.join(root, video_file)
64 | video_dict[video_id_] = file_path_
65 | self.video_dict = video_dict
66 |
67 | self.sample_len = 0
68 | self.sentences_dict = {}
69 | self.cut_off_points = []
70 | for video_id in video_ids:
71 | assert video_id in captions
72 | for cap in captions[video_id]:
73 | cap_txt = " ".join(cap)
74 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt)
75 | self.cut_off_points.append(len(self.sentences_dict))
76 |
77 | ## below variables are used to multi-sentences retrieval
78 | # self.cut_off_points: used to tag the label when calculate the metric
79 | # self.sentence_num: used to cut the sentence representation
80 | # self.video_num: used to cut the video representation
81 | self.multi_sentence_per_video = True # !!! important tag for eval
82 | if self.subset == "val" or self.subset == "test":
83 | self.sentence_num = len(self.sentences_dict)
84 | self.video_num = len(video_ids)
85 | assert len(self.cut_off_points) == self.video_num
86 | print("For {}, sentence number: {}".format(self.subset, self.sentence_num))
87 | print("For {}, video number: {}".format(self.subset, self.video_num))
88 |
89 | print("Video number: {}".format(len(self.video_dict)))
90 | print("Total Paire: {}".format(len(self.sentences_dict)))
91 |
92 | self.sample_len = len(self.sentences_dict)
93 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
94 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
95 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
96 |
97 | def __len__(self):
98 | return self.sample_len
99 |
100 | def _get_text(self, video_id, caption):
101 | k = 1
102 | choice_video_ids = [video_id]
103 | pairs_text = np.zeros((k, self.max_words), dtype=np.long)
104 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
105 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
106 |
107 | for i, video_id in enumerate(choice_video_ids):
108 | words = self.tokenizer.tokenize(caption)
109 |
110 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
111 | total_length_with_CLS = self.max_words - 1
112 | if len(words) > total_length_with_CLS:
113 | words = words[:total_length_with_CLS]
114 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
115 |
116 | input_ids = self.tokenizer.convert_tokens_to_ids(words)
117 | input_mask = [1] * len(input_ids)
118 | segment_ids = [0] * len(input_ids)
119 | while len(input_ids) < self.max_words:
120 | input_ids.append(0)
121 | input_mask.append(0)
122 | segment_ids.append(0)
123 | assert len(input_ids) == self.max_words
124 | assert len(input_mask) == self.max_words
125 | assert len(segment_ids) == self.max_words
126 |
127 | pairs_text[i] = np.array(input_ids)
128 | pairs_mask[i] = np.array(input_mask)
129 | pairs_segment[i] = np.array(segment_ids)
130 |
131 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids
132 |
133 | def _get_rawvideo(self, choice_video_ids):
134 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
135 | max_video_length = [0] * len(choice_video_ids)
136 |
137 | # Pair x L x T x 3 x H x W
138 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3,
139 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
140 |
141 | for i, video_id in enumerate(choice_video_ids):
142 | video_path = self.video_dict[video_id]
143 |
144 | set_trace()
145 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path)
146 | raw_video_data = raw_video_data['video']
147 |
148 | if len(raw_video_data.shape) > 3:
149 | raw_video_data_clip = raw_video_data
150 | # L x T x 3 x H x W
151 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
152 | if self.max_frames < raw_video_slice.shape[0]:
153 | if self.slice_framepos == 0:
154 | video_slice = raw_video_slice[:self.max_frames, ...]
155 | elif self.slice_framepos == 1:
156 | video_slice = raw_video_slice[-self.max_frames:, ...]
157 | else:
158 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
159 | video_slice = raw_video_slice[sample_indx, ...]
160 | else:
161 | video_slice = raw_video_slice
162 |
163 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
164 |
165 | slice_len = video_slice.shape[0]
166 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
167 | if slice_len < 1:
168 | pass
169 | else:
170 | video[i][:slice_len, ...] = video_slice
171 | else:
172 | print("video path: {} error. video id: {}".format(video_path, video_id))
173 |
174 | for i, v_length in enumerate(max_video_length):
175 | video_mask[i][:v_length] = [1] * v_length
176 |
177 | return video, video_mask
178 |
179 | def _get_rawvideo_yb(self, choice_video_ids):
180 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
181 | max_video_length = [0] * len(choice_video_ids)
182 |
183 | # Pair x L x T x 3 x H x W
184 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3,
185 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
186 |
187 | for i, video_id in enumerate(choice_video_ids):
188 | video_path = self.video_dict[video_id]
189 |
190 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path)
191 | raw_video_data = raw_video_data['video']
192 |
193 |
194 |
195 | video_container = container.get_video_container(
196 | path_to_vid=video_path,
197 | multi_thread_decode=False,
198 | backend="pyav",
199 | )
200 |
201 |
202 | frames, vid_len = decoder.decode(
203 | video_container,
204 | sampling_rate=3,
205 | num_frames=12,
206 | clip_idx=1,
207 | num_clips=1,
208 | video_meta=None,
209 | target_fps=1,
210 | backend="pyav",
211 | max_spatial_scale=0,
212 | use_offset=False,
213 | )
214 |
215 |
216 | frames = utils.tensor_normalize(
217 | frames, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]
218 | )
219 |
220 | # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
221 | # [0.485, 0.456, 0.406]
222 | # [0.229, 0.224, 0.225]
223 | # [0.45, 0.45, 0.45]
224 | # [0.225, 0.225, 0.225]
225 |
226 |
227 | # if len(raw_video_data.shape) > 3:
228 | # raw_video_data_clip = raw_video_data
229 | # # L x T x 3 x H x W
230 | # raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
231 | # if self.max_frames < raw_video_slice.shape[0]:
232 | # if self.slice_framepos == 0:
233 | # video_slice = raw_video_slice[:self.max_frames, ...]
234 | # elif self.slice_framepos == 1:
235 | # video_slice = raw_video_slice[-self.max_frames:, ...]
236 | # else:
237 | # sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
238 | # video_slice = raw_video_slice[sample_indx, ...]
239 | # else:
240 | # video_slice = raw_video_slice
241 |
242 | # video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
243 |
244 | # slice_len = video_slice.shape[0]
245 | # max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
246 | # if slice_len < 1:
247 | # pass
248 | # else:
249 | # video[i][:slice_len, ...] = video_slice
250 | # else:
251 | # print("video path: {} error. video id: {}".format(video_path, video_id))
252 |
253 | for i, v_length in enumerate([int(vid_len)]):
254 | if vid_len > self.max_frames:
255 | video_mask[:] = 1
256 | else:
257 | video_mask[i][:v_length] = [1] * v_length
258 |
259 |
260 | if vid_len < self.max_frames:
261 | frames[int(vid_len):] = 0
262 |
263 | frames = frames.permute(0,3,1,2).unsqueeze(0).unsqueeze(2)
264 |
265 | return frames, video_mask
266 |
267 | def __getitem__(self, idx):
268 | video_id, caption = self.sentences_dict[idx]
269 |
270 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption)
271 | video, video_mask = self._get_rawvideo(choice_video_ids)
272 | # video, video_mask = self._get_rawvideo_yb(choice_video_ids)
273 |
274 |
275 |
276 | return pairs_text, pairs_mask, pairs_segment, video, video_mask
277 |
--------------------------------------------------------------------------------
/modules/ast_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/10/21 5:04 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : ast_models.py
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.cuda.amp import autocast
11 | import os
12 | import wget
13 | import timm
14 | from timm.models.layers import to_2tuple,trunc_normal_
15 | from ipdb import set_trace
16 |
17 | # override the timm package to relax the input shape constraint.
18 | class PatchEmbed(nn.Module):
19 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
20 | super().__init__()
21 |
22 | img_size = to_2tuple(img_size)
23 | patch_size = to_2tuple(patch_size)
24 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
25 | self.img_size = img_size
26 | self.patch_size = patch_size
27 | self.num_patches = num_patches
28 |
29 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
30 | # Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10)) for audioset
31 | def forward(self, x):
32 | # x = 1x1x128x1024
33 | x = self.proj(x).flatten(2).transpose(1, 2) # 768x12x101
34 | return x
35 |
36 | class ASTModel(nn.Module):
37 | """
38 | The AST model.
39 | :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
40 | :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
41 | :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
42 | :param input_fdim: the number of frequency bins of the input spectrogram
43 | :param input_tdim: the number of time frames of the input spectrogram
44 | :param imagenet_pretrain: if use ImageNet pretrained model
45 | :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
46 | :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
47 | """
48 | def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):
49 |
50 | super(ASTModel, self).__init__()
51 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
52 |
53 | if verbose == True:
54 | print('---------------AST Model Summary---------------')
55 | print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
56 | # override timm input shape restriction
57 | timm.models.vision_transformer.PatchEmbed = PatchEmbed
58 |
59 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
60 | if audioset_pretrain == False:
61 | if model_size == 'tiny224':
62 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
63 | elif model_size == 'small224':
64 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
65 | elif model_size == 'base224':
66 | self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
67 | elif model_size == 'base384':
68 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
69 | else:
70 | raise Exception('Model size must be one of tiny224, small224, base224, base384.')
71 | self.original_num_patches = self.v.patch_embed.num_patches
72 | self.oringal_hw = int(self.original_num_patches ** 0.5)
73 | self.original_embedding_dim = self.v.pos_embed.shape[2]
74 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
75 |
76 | # automatcially get the intermediate shape
77 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
78 | num_patches = f_dim * t_dim
79 | self.v.patch_embed.num_patches = num_patches
80 | if verbose == True:
81 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
82 | print('number of patches={:d}'.format(num_patches))
83 |
84 | # the linear projection layer
85 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
86 | if imagenet_pretrain == True:
87 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
88 | new_proj.bias = self.v.patch_embed.proj.bias
89 | self.v.patch_embed.proj = new_proj
90 |
91 | # the positional embedding
92 | if imagenet_pretrain == True:
93 | # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
94 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
95 | # cut (from middle) or interpolate the second dimension of the positional embedding
96 | if t_dim <= self.oringal_hw:
97 | new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
98 | else:
99 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
100 | # cut (from middle) or interpolate the first dimension of the positional embedding
101 | if f_dim <= self.oringal_hw:
102 | new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
103 | else:
104 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
105 | # flatten the positional embedding
106 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
107 | # concatenate the above positional embedding with the cls token and distillation token of the deit model.
108 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
109 | else:
110 | # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
111 | # TODO can use sinusoidal positional embedding instead
112 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
113 | self.v.pos_embed = new_pos_embed
114 | trunc_normal_(self.v.pos_embed, std=.02)
115 |
116 | # now load a model that is pretrained on both ImageNet and AudioSet
117 | elif audioset_pretrain == True:
118 | if audioset_pretrain == True and imagenet_pretrain == False:
119 | raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
120 | if model_size != 'base384':
121 | raise ValueError('currently only has base384 AudioSet pretrained model.')
122 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123 | if os.path.exists('./pretrained_models/audioset_10_10_0.4593.pth') == False:
124 | # this model performs 0.4593 mAP on the audioset eval set
125 | audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
126 | wget.download(audioset_mdl_url, out='./pretrained_models/audioset_10_10_0.4593.pth')
127 | sd = torch.load('./pretrained_models/audioset_10_10_0.4593.pth', map_location=device)
128 |
129 | # for 10 s audioset
130 | audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
131 | audio_model = torch.nn.DataParallel(audio_model)
132 | audio_model.load_state_dict(sd, strict=False)
133 | self.v = audio_model.module.v
134 | self.original_embedding_dim = self.v.pos_embed.shape[2]
135 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
136 |
137 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
138 | num_patches = f_dim * t_dim
139 | self.v.patch_embed.num_patches = num_patches
140 | if verbose == True:
141 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
142 | print('number of patches={:d}'.format(num_patches))
143 |
144 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101)
145 | # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
146 | if t_dim < 101:
147 | # new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim] ori
148 |
149 | ## yb add: f_dim:8 t_dim:64
150 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
151 | self.v.patch_embed.proj.stride = (fstride,tstride)
152 | ## yb end ###
153 | # otherwise interpolate
154 | else:
155 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
156 |
157 |
158 | new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
159 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
160 |
161 | def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
162 | test_input = torch.randn(1, 1, input_fdim, input_tdim)
163 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
164 | test_out = test_proj(test_input)
165 | f_dim = test_out.shape[2]
166 | t_dim = test_out.shape[3]
167 | return f_dim, t_dim
168 |
169 | @autocast()
170 | def forward(self, x):
171 | """
172 | :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
173 | :return: prediction
174 | """
175 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
176 | x = x.unsqueeze(1)
177 | x = x.transpose(2, 3)
178 |
179 | B = x.shape[0]
180 | x = self.v.patch_embed(x)
181 | cls_tokens = self.v.cls_token.expand(B, -1, -1)
182 | dist_token = self.v.dist_token.expand(B, -1, -1)
183 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
184 |
185 | x = x + self.v.pos_embed
186 |
187 | x = self.v.pos_drop(x)
188 | for blk in self.v.blocks:
189 | x = blk(x)
190 | x = self.v.norm(x)
191 |
192 |
193 | ## two cls here
194 | # x = (x[:, 0] + x[:, 1]) / 2
195 |
196 | # x = self.mlp_head(x)
197 | return x
198 |
199 | if __name__ == '__main__':
200 | input_tdim = 100
201 | ast_mdl = ASTModel(input_tdim=input_tdim)
202 | # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
203 | test_input = torch.rand([10, input_tdim, 128])
204 | test_output = ast_mdl(test_input)
205 | # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
206 | print(test_output.shape)
207 |
208 | input_tdim = 256
209 | ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
210 | # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
211 | test_input = torch.rand([10, input_tdim, 128])
212 | test_output = ast_mdl(test_input)
213 | # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
214 | print(test_output.shape)
--------------------------------------------------------------------------------
/dataloaders/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import logging
4 | import numpy as np
5 | import os
6 | import random
7 | import time
8 | from collections import defaultdict
9 | import cv2
10 | import torch
11 | from torch.utils.data.distributed import DistributedSampler
12 |
13 | # from slowfast.utils.env import pathmgr
14 |
15 | from . import transform as transform
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def retry_load_images(image_paths, retry=10, backend="pytorch"):
21 | """
22 | This function is to load images with support of retrying for failed load.
23 | Args:
24 | image_paths (list): paths of images needed to be loaded.
25 | retry (int, optional): maximum time of loading retrying. Defaults to 10.
26 | backend (str): `pytorch` or `cv2`.
27 | Returns:
28 | imgs (list): list of loaded images.
29 | """
30 | for i in range(retry):
31 | imgs = []
32 | for image_path in image_paths:
33 | with pathmgr.open(image_path, "rb") as f:
34 | img_str = np.frombuffer(f.read(), np.uint8)
35 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
36 | imgs.append(img)
37 |
38 | if all(img is not None for img in imgs):
39 | if backend == "pytorch":
40 | imgs = torch.as_tensor(np.stack(imgs))
41 | return imgs
42 | else:
43 | logger.warn("Reading failed. Will retry.")
44 | time.sleep(1.0)
45 | if i == retry - 1:
46 | raise Exception("Failed to load images {}".format(image_paths))
47 |
48 |
49 | def get_sequence(center_idx, half_len, sample_rate, num_frames):
50 | """
51 | Sample frames among the corresponding clip.
52 | Args:
53 | center_idx (int): center frame idx for current clip
54 | half_len (int): half of the clip length
55 | sample_rate (int): sampling rate for sampling frames inside of the clip
56 | num_frames (int): number of expected sampled frames
57 | Returns:
58 | seq (list): list of indexes of sampled frames in this clip.
59 | """
60 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate))
61 |
62 | for seq_idx in range(len(seq)):
63 | if seq[seq_idx] < 0:
64 | seq[seq_idx] = 0
65 | elif seq[seq_idx] >= num_frames:
66 | seq[seq_idx] = num_frames - 1
67 | return seq
68 |
69 |
70 | def pack_pathway_output(cfg, frames):
71 | """
72 | Prepare output as a list of tensors. Each tensor corresponding to a
73 | unique pathway.
74 | Args:
75 | frames (tensor): frames of images sampled from the video. The
76 | dimension is `channel` x `num frames` x `height` x `width`.
77 | Returns:
78 | frame_list (list): list of tensors with the dimension of
79 | `channel` x `num frames` x `height` x `width`.
80 | """
81 | if cfg.DATA.REVERSE_INPUT_CHANNEL:
82 | frames = frames[[2, 1, 0], :, :, :]
83 | if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
84 | frame_list = [frames]
85 | elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH:
86 | fast_pathway = frames
87 | # Perform temporal sampling from the fast pathway.
88 | slow_pathway = torch.index_select(
89 | frames,
90 | 1,
91 | torch.linspace(
92 | 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA
93 | ).long(),
94 | )
95 | frame_list = [slow_pathway, fast_pathway]
96 | else:
97 | raise NotImplementedError(
98 | "Model arch {} is not in {}".format(
99 | cfg.MODEL.ARCH,
100 | cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH,
101 | )
102 | )
103 | return frame_list
104 |
105 |
106 | def spatial_sampling(
107 | frames,
108 | spatial_idx=-1,
109 | min_scale=256,
110 | max_scale=320,
111 | crop_size=224,
112 | random_horizontal_flip=True,
113 | inverse_uniform_sampling=False,
114 | aspect_ratio=None,
115 | scale=None,
116 | motion_shift=False,
117 | ):
118 | """
119 | Perform spatial sampling on the given video frames. If spatial_idx is
120 | -1, perform random scale, random crop, and random flip on the given
121 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
122 | with the given spatial_idx.
123 | Args:
124 | frames (tensor): frames of images sampled from the video. The
125 | dimension is `num frames` x `height` x `width` x `channel`.
126 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
127 | or 2, perform left, center, right crop if width is larger than
128 | height, and perform top, center, buttom crop if height is larger
129 | than width.
130 | min_scale (int): the minimal size of scaling.
131 | max_scale (int): the maximal size of scaling.
132 | crop_size (int): the size of height and width used to crop the
133 | frames.
134 | inverse_uniform_sampling (bool): if True, sample uniformly in
135 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
136 | scale. If False, take a uniform sample from [min_scale,
137 | max_scale].
138 | aspect_ratio (list): Aspect ratio range for resizing.
139 | scale (list): Scale range for resizing.
140 | motion_shift (bool): Whether to apply motion shift for resizing.
141 | Returns:
142 | frames (tensor): spatially sampled frames.
143 | """
144 | assert spatial_idx in [-1, 0, 1, 2]
145 | if spatial_idx == -1:
146 | if aspect_ratio is None and scale is None:
147 | frames, _ = transform.random_short_side_scale_jitter(
148 | images=frames,
149 | min_size=min_scale,
150 | max_size=max_scale,
151 | inverse_uniform_sampling=inverse_uniform_sampling,
152 | )
153 | frames, _ = transform.random_crop(frames, crop_size)
154 | else:
155 | transform_func = (
156 | transform.random_resized_crop_with_shift
157 | if motion_shift
158 | else transform.random_resized_crop
159 | )
160 | frames = transform_func(
161 | images=frames,
162 | target_height=crop_size,
163 | target_width=crop_size,
164 | scale=scale,
165 | ratio=aspect_ratio,
166 | )
167 | if random_horizontal_flip:
168 | frames, _ = transform.horizontal_flip(0.5, frames)
169 | else:
170 | # The testing is deterministic and no jitter should be performed.
171 | # min_scale, max_scale, and crop_size are expect to be the same.
172 | assert len({min_scale, max_scale}) == 1
173 | frames, _ = transform.random_short_side_scale_jitter(
174 | frames, min_scale, max_scale
175 | )
176 | frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx)
177 | return frames
178 |
179 |
180 | def as_binary_vector(labels, num_classes):
181 | """
182 | Construct binary label vector given a list of label indices.
183 | Args:
184 | labels (list): The input label list.
185 | num_classes (int): Number of classes of the label vector.
186 | Returns:
187 | labels (numpy array): the resulting binary vector.
188 | """
189 | label_arr = np.zeros((num_classes,))
190 |
191 | for lbl in set(labels):
192 | label_arr[lbl] = 1.0
193 | return label_arr
194 |
195 |
196 | def aggregate_labels(label_list):
197 | """
198 | Join a list of label list.
199 | Args:
200 | labels (list): The input label list.
201 | Returns:
202 | labels (list): The joint list of all lists in input.
203 | """
204 | all_labels = []
205 | for labels in label_list:
206 | for l in labels:
207 | all_labels.append(l)
208 | return list(set(all_labels))
209 |
210 |
211 | def convert_to_video_level_labels(labels):
212 | """
213 | Aggregate annotations from all frames of a video to form video-level labels.
214 | Args:
215 | labels (list): The input label list.
216 | Returns:
217 | labels (list): Same as input, but with each label replaced by
218 | a video-level one.
219 | """
220 | for video_id in range(len(labels)):
221 | video_level_labels = aggregate_labels(labels[video_id])
222 | for i in range(len(labels[video_id])):
223 | labels[video_id][i] = video_level_labels
224 | return labels
225 |
226 |
227 | def load_image_lists(frame_list_file, prefix="", return_list=False):
228 | """
229 | Load image paths and labels from a "frame list".
230 | Each line of the frame list contains:
231 | `original_vido_id video_id frame_id path labels`
232 | Args:
233 | frame_list_file (string): path to the frame list.
234 | prefix (str): the prefix for the path.
235 | return_list (bool): if True, return a list. If False, return a dict.
236 | Returns:
237 | image_paths (list or dict): list of list containing path to each frame.
238 | If return_list is False, then return in a dict form.
239 | labels (list or dict): list of list containing label of each frame.
240 | If return_list is False, then return in a dict form.
241 | """
242 | image_paths = defaultdict(list)
243 | labels = defaultdict(list)
244 | with pathmgr.open(frame_list_file, "r") as f:
245 | assert f.readline().startswith("original_vido_id")
246 | for line in f:
247 | row = line.split()
248 | # original_vido_id video_id frame_id path labels
249 | assert len(row) == 5
250 | video_name = row[0]
251 | if prefix == "":
252 | path = row[3]
253 | else:
254 | path = os.path.join(prefix, row[3])
255 | image_paths[video_name].append(path)
256 | frame_labels = row[-1].replace('"', "")
257 | if frame_labels != "":
258 | labels[video_name].append(
259 | [int(x) for x in frame_labels.split(",")]
260 | )
261 | else:
262 | labels[video_name].append([])
263 |
264 | if return_list:
265 | keys = image_paths.keys()
266 | image_paths = [image_paths[key] for key in keys]
267 | labels = [labels[key] for key in keys]
268 | return image_paths, labels
269 | return dict(image_paths), dict(labels)
270 |
271 |
272 | def tensor_normalize(tensor, mean, std):
273 | """
274 | Normalize a given tensor by subtracting the mean and dividing the std.
275 | Args:
276 | tensor (tensor): tensor to normalize.
277 | mean (tensor or list): mean value to subtract.
278 | std (tensor or list): std to divide.
279 | """
280 | if tensor.dtype == torch.uint8:
281 | tensor = tensor.float()
282 | tensor = tensor / 255.0
283 | if type(mean) == list:
284 | mean = torch.tensor(mean)
285 | if type(std) == list:
286 | std = torch.tensor(std)
287 | tensor = tensor - mean
288 | tensor = tensor / std
289 | return tensor
290 |
291 |
292 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate):
293 | """
294 | When multigrid training uses a fewer number of frames, we randomly
295 | increase the sampling rate so that some clips cover the original span.
296 | """
297 | if long_cycle_sampling_rate > 0:
298 | assert long_cycle_sampling_rate >= sampling_rate
299 | return random.randint(sampling_rate, long_cycle_sampling_rate)
300 | else:
301 | return sampling_rate
302 |
303 |
304 | def revert_tensor_normalize(tensor, mean, std):
305 | """
306 | Revert normalization for a given tensor by multiplying by the std and adding the mean.
307 | Args:
308 | tensor (tensor): tensor to revert normalization.
309 | mean (tensor or list): mean value to add.
310 | std (tensor or list): std to multiply.
311 | """
312 | if type(mean) == list:
313 | mean = torch.tensor(mean)
314 | if type(std) == list:
315 | std = torch.tensor(std)
316 | tensor = tensor * std
317 | tensor = tensor + mean
318 | return tensor
319 |
320 |
321 | def create_sampler(dataset, shuffle, cfg):
322 | """
323 | Create sampler for the given dataset.
324 | Args:
325 | dataset (torch.utils.data.Dataset): the given dataset.
326 | shuffle (bool): set to ``True`` to have the data reshuffled
327 | at every epoch.
328 | cfg (CfgNode): configs. Details can be found in
329 | slowfast/config/defaults.py
330 | Returns:
331 | sampler (Sampler): the created sampler.
332 | """
333 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
334 |
335 | return sampler
336 |
337 |
338 | def loader_worker_init_fn(dataset):
339 | """
340 | Create init function passed to pytorch data loader.
341 | Args:
342 | dataset (torch.utils.data.Dataset): the given dataset.
343 | """
344 | return None
--------------------------------------------------------------------------------
/modules/until_module.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | import logging
19 | import numpy as np
20 | import torch
21 | from torch import nn
22 | import torch.nn.functional as F
23 | import math
24 | from modules.until_config import PretrainedConfig
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | def gelu(x):
29 | """Implementation of the gelu activation function.
30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
32 | """
33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
34 |
35 | def swish(x):
36 | return x * torch.sigmoid(x)
37 |
38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
39 |
40 | class LayerNorm(nn.Module):
41 | def __init__(self, hidden_size, eps=1e-12):
42 | """Construct a layernorm module in the TF style (epsilon inside the square root).
43 | """
44 | super(LayerNorm, self).__init__()
45 | self.weight = nn.Parameter(torch.ones(hidden_size))
46 | self.bias = nn.Parameter(torch.zeros(hidden_size))
47 | self.variance_epsilon = eps
48 |
49 | def forward(self, x):
50 | u = x.mean(-1, keepdim=True)
51 | s = (x - u).pow(2).mean(-1, keepdim=True)
52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
53 | return self.weight * x + self.bias
54 |
55 | class PreTrainedModel(nn.Module):
56 | """ An abstract class to handle weights initialization and
57 | a simple interface for dowloading and loading pretrained models.
58 | """
59 | def __init__(self, config, *inputs, **kwargs):
60 | super(PreTrainedModel, self).__init__()
61 | if not isinstance(config, PretrainedConfig):
62 | raise ValueError(
63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
64 | "To create a model from a Google pretrained model use "
65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
66 | self.__class__.__name__, self.__class__.__name__
67 | ))
68 | self.config = config
69 |
70 | def init_weights(self, module):
71 | """ Initialize the weights.
72 | """
73 | if isinstance(module, (nn.Linear, nn.Embedding)):
74 | # Slightly different from the TF version which uses truncated_normal for initialization
75 | # cf https://github.com/pytorch/pytorch/pull/5617
76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
77 | elif isinstance(module, LayerNorm):
78 | if 'beta' in dir(module) and 'gamma' in dir(module):
79 | module.beta.data.zero_()
80 | module.gamma.data.fill_(1.0)
81 | else:
82 | module.bias.data.zero_()
83 | module.weight.data.fill_(1.0)
84 | if isinstance(module, nn.Linear) and module.bias is not None:
85 | module.bias.data.zero_()
86 |
87 | def resize_token_embeddings(self, new_num_tokens=None):
88 | raise NotImplementedError
89 |
90 | @classmethod
91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
92 | old_keys = []
93 | new_keys = []
94 | for key in state_dict.keys():
95 | new_key = None
96 | if 'gamma' in key:
97 | new_key = key.replace('gamma', 'weight')
98 | if 'beta' in key:
99 | new_key = key.replace('beta', 'bias')
100 | if new_key:
101 | old_keys.append(key)
102 | new_keys.append(new_key)
103 | for old_key, new_key in zip(old_keys, new_keys):
104 | state_dict[new_key] = state_dict.pop(old_key)
105 |
106 | if prefix is not None:
107 | old_keys = []
108 | new_keys = []
109 | for key in state_dict.keys():
110 | old_keys.append(key)
111 | new_keys.append(prefix + key)
112 | for old_key, new_key in zip(old_keys, new_keys):
113 | state_dict[new_key] = state_dict.pop(old_key)
114 |
115 | missing_keys = []
116 | unexpected_keys = []
117 | error_msgs = []
118 | # copy state_dict so _load_from_state_dict can modify it
119 | metadata = getattr(state_dict, '_metadata', None)
120 | state_dict = state_dict.copy()
121 | if metadata is not None:
122 | state_dict._metadata = metadata
123 |
124 | def load(module, prefix=''):
125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
126 | module._load_from_state_dict(
127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
128 | for name, child in module._modules.items():
129 | if child is not None:
130 | load(child, prefix + name + '.')
131 |
132 | load(model, prefix='')
133 |
134 | if prefix is None and (task_config is None or task_config.local_rank == 0):
135 | logger.info("-" * 20)
136 | if len(missing_keys) > 0:
137 | logger.info("Weights of {} not initialized from pretrained model: {}"
138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
139 | if len(unexpected_keys) > 0:
140 | logger.info("Weights from pretrained model not used in {}: {}"
141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
142 | if len(error_msgs) > 0:
143 | logger.error("Weights from pretrained model cause errors in {}: {}"
144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
145 |
146 | return model
147 |
148 | @property
149 | def dtype(self):
150 | """
151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
152 | """
153 | try:
154 | return next(self.parameters()).dtype
155 | except StopIteration:
156 | # For nn.DataParallel compatibility in PyTorch 1.5
157 | def find_tensor_attributes(module: nn.Module):
158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
159 | return tuples
160 |
161 | gen = self._named_members(get_members_fn=find_tensor_attributes)
162 | first_tuple = next(gen)
163 | return first_tuple[1].dtype
164 |
165 | @classmethod
166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
167 | """
168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
169 | Download and cache the pre-trained model file if needed.
170 | """
171 | # Instantiate model.
172 | model = cls(config, *inputs, **kwargs)
173 | if state_dict is None:
174 | return model
175 | model = cls.init_preweight(model, state_dict)
176 |
177 | return model
178 |
179 | ##################################
180 | ###### LOSS FUNCTION #############
181 | ##################################
182 | class CrossEn(nn.Module):
183 | def __init__(self,):
184 | super(CrossEn, self).__init__()
185 |
186 | def forward(self, sim_matrix, guide_dis=None):
187 | logpt = F.log_softmax(sim_matrix, dim=-1)
188 | logpt = torch.diag(logpt)
189 | nce_loss = -logpt
190 | sim_loss = nce_loss.mean()
191 | return sim_loss
192 |
193 | class yb_CrossEn(nn.Module):
194 | def __init__(self,):
195 | super(yb_CrossEn, self).__init__()
196 |
197 | def forward(self, sim_matrix, id_dict, vid_id, opt=None):
198 | id_dict = id_dict.flatten()
199 |
200 | from ipdb import set_trace
201 |
202 | logpt = F.log_softmax(sim_matrix, dim=-1)
203 | sim_loss = []
204 | for my_idx in range(len(vid_id)):
205 | vid_loc = (vid_id[my_idx] == id_dict).nonzero(as_tuple=True)[0]
206 | sim_loss.append(-logpt[my_idx, vid_loc])
207 |
208 | return sum(sim_loss)/len(sim_loss)
209 |
210 |
211 | # gt_loc = []
212 | # for my_idx in range(len(vid_id)):
213 | # vid_loc = (vid_id[my_idx] == id_dict).nonzero(as_tuple=True)[0]
214 | # gt_loc.append(vid_loc.item())
215 |
216 | # while True:
217 | # perm = torch.randperm(sim_matrix.size(1))
218 | # rand_idx = perm[:opt.yb_sample_num]
219 |
220 | # if not any(x in gt_loc for x in rand_idx.tolist()):
221 | # break
222 |
223 | # new_sim_matrix = torch.cat((sim_matrix[:,gt_loc], sim_matrix[:,rand_idx]), dim=-1)
224 |
225 | # logpt = F.log_softmax(new_sim_matrix, dim=-1)
226 | # logpt = torch.diag(logpt)
227 | # nce_loss = -logpt
228 | # sim_loss = nce_loss.mean()
229 |
230 | # return sim_loss
231 |
232 | class MILNCELoss(nn.Module):
233 | def __init__(self, batch_size=1, n_pair=1,):
234 | super(MILNCELoss, self).__init__()
235 | self.batch_size = batch_size
236 | self.n_pair = n_pair
237 | torch_v = float(".".join(torch.__version__.split(".")[:2]))
238 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
239 |
240 | def forward(self, sim_matrix):
241 | mm_mask = np.eye(self.batch_size)
242 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
243 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
244 |
245 | from_text_matrix = sim_matrix + mm_mask * -1e12
246 | from_video_matrix = sim_matrix.transpose(1, 0)
247 |
248 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
249 | logpt = F.log_softmax(new_sim_matrix, dim=-1)
250 |
251 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
252 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
253 |
254 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
255 |
256 | logpt_choice = torch.zeros_like(new_logpt)
257 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
258 | logpt_choice[mark_ind] = 1
259 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
260 | return sim_loss
261 |
262 | class MaxMarginRankingLoss(nn.Module):
263 | def __init__(self,
264 | margin=1.0,
265 | negative_weighting=False,
266 | batch_size=1,
267 | n_pair=1,
268 | hard_negative_rate=0.5,
269 | ):
270 | super(MaxMarginRankingLoss, self).__init__()
271 | self.margin = margin
272 | self.n_pair = n_pair
273 | self.batch_size = batch_size
274 | easy_negative_rate = 1 - hard_negative_rate
275 | self.easy_negative_rate = easy_negative_rate
276 | self.negative_weighting = negative_weighting
277 | if n_pair > 1 and batch_size > 1:
278 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
279 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
280 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
281 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
282 | self.mm_mask = mm_mask.float()
283 |
284 | def forward(self, x):
285 | d = torch.diag(x)
286 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
287 | F.relu(self.margin + x - d.view(1, -1))
288 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
289 | max_margin = max_margin * self.mm_mask.to(max_margin.device)
290 | return max_margin.mean()
291 |
292 | class AllGather(torch.autograd.Function):
293 | """An autograd function that performs allgather on a tensor."""
294 |
295 | @staticmethod
296 | def forward(ctx, tensor, args):
297 | output = [torch.empty_like(tensor) for _ in range(args.world_size)]
298 | torch.distributed.all_gather(output, tensor)
299 | ctx.rank = args.rank
300 | ctx.batch_size = tensor.shape[0]
301 | return torch.cat(output, dim=0)
302 |
303 | @staticmethod
304 | def backward(ctx, grad_output):
305 | return (
306 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
307 | None,
308 | )
309 |
--------------------------------------------------------------------------------
/modules/resnet_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 |
8 | class AVENet(nn.Module):
9 |
10 | def __init__(self,args):
11 | super(AVENet, self).__init__()
12 | self.audnet = resnet18(num_classes=309, pool='avgpool')
13 |
14 | def forward(self, audio):
15 | aud = self.audnet(audio)
16 | return aud
17 |
18 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=dilation, groups=groups, bias=False, dilation=dilation)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes ,stride=1, downsample=None, groups=1,
33 | base_width=64, dilation=1, norm_layer=None):
34 | super(BasicBlock, self).__init__()
35 | if norm_layer is None:
36 | norm_layer = nn.BatchNorm2d
37 | if groups != 1 or base_width != 64:
38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39 | if dilation > 1:
40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
42 | self.conv1 = conv3x3(inplanes, planes, stride)
43 | self.bn1 = norm_layer(planes)
44 | self.relu = nn.ReLU(inplace=True)
45 | self.conv2 = conv3x3(planes, planes)
46 | self.bn2 = norm_layer(planes)
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | identity = x
52 |
53 | out = self.conv1(x)
54 | out = self.bn1(out)
55 | out = self.relu(out)
56 |
57 | out = self.conv2(out)
58 | out = self.bn2(out)
59 |
60 | if self.downsample is not None:
61 | identity = self.downsample(x)
62 |
63 | out += identity
64 | out = self.relu(out)
65 |
66 | return out
67 |
68 | class ResNet(nn.Module):
69 |
70 | def __init__(self, block, layers, num_classes=1000, pool = 'avgpool',zero_init_residual=False,
71 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
72 | norm_layer=None):
73 | super(ResNet, self).__init__()
74 | self.pool = pool
75 | if norm_layer is None:
76 | norm_layer = nn.BatchNorm2d
77 | self._norm_layer = norm_layer
78 |
79 | self.inplanes = 64
80 | self.dilation = 1
81 | if replace_stride_with_dilation is None:
82 | # each element in the tuple indicates if we should replace
83 | # the 2x2 stride with a dilated convolution instead
84 | replace_stride_with_dilation = [False, False, False]
85 | if len(replace_stride_with_dilation) != 3:
86 | raise ValueError("replace_stride_with_dilation should be None "
87 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
88 | self.groups = groups
89 | self.base_width = width_per_group
90 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
91 | bias=False)
92 | self.bn1 = norm_layer(self.inplanes)
93 | self.relu = nn.ReLU(inplace=True)
94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
95 | self.layer1 = self._make_layer(block, 64, layers[0])
96 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
97 | dilate=replace_stride_with_dilation[0])
98 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
99 | dilate=replace_stride_with_dilation[1])
100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
101 | dilate=replace_stride_with_dilation[2])
102 | if self.pool == 'avgpool':
103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
104 |
105 | self.fc = nn.Linear(512 * block.expansion, num_classes) # 8192
106 | elif self.pool == 'vlad':
107 | self.avgpool = NetVLAD()
108 | self.fc_ = nn.Linear(8192 * block.expansion, num_classes)
109 |
110 |
111 | for m in self.modules():
112 | if isinstance(m, nn.Conv2d):
113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
114 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
115 | nn.init.normal_(m.weight, mean=1, std=0.02)
116 | nn.init.constant_(m.bias, 0)
117 |
118 | # Zero-initialize the last BN in each residual branch,
119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
121 | if zero_init_residual:
122 | for m in self.modules():
123 | if isinstance(m, Bottleneck):
124 | nn.init.constant_(m.bn3.weight, 0)
125 | elif isinstance(m, BasicBlock):
126 | nn.init.constant_(m.bn2.weight, 0)
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
129 | norm_layer = self._norm_layer
130 | downsample = None
131 | previous_dilation = self.dilation
132 | if dilate:
133 | self.dilation *= stride
134 | stride = 1
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | conv1x1(self.inplanes, planes * block.expansion, stride),
138 | norm_layer(planes * block.expansion),
139 | )
140 |
141 |
142 | layers = []
143 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
144 | self.base_width, previous_dilation, norm_layer))
145 | self.inplanes = planes * block.expansion
146 | for _ in range(1, blocks):
147 | layers.append(block(self.inplanes, planes, groups=self.groups,
148 | base_width=self.base_width, dilation=self.dilation,
149 | norm_layer=norm_layer))
150 |
151 | return nn.Sequential(*layers)
152 |
153 | def forward(self, x):
154 | x = self.conv1(x)
155 | x = self.bn1(x)
156 | x = self.relu(x)
157 | x = self.maxpool(x)
158 |
159 | x = self.layer1(x)
160 | x = self.layer2(x)
161 | x = self.layer3(x)
162 | x = self.layer4(x)
163 | x = self.avgpool(x)
164 | x = x.reshape(x.size(0), -1)
165 | # if self.pool == 'avgpool':
166 | # x = self.fc(x)
167 | # elif self.pool == 'vlad':
168 | # x = self.fc_(x)
169 |
170 | return x
171 |
172 | class NetVLAD(nn.Module):
173 | """NetVLAD layer implementation"""
174 |
175 | def __init__(self, num_clusters=16, dim=512, alpha=100.0,
176 | normalize_input=True):
177 | """
178 | Args:
179 | num_clusters : int
180 | The number of clusters
181 | dim : int
182 | Dimension of descriptors
183 | alpha : float
184 | Parameter of initialization. Larger value is harder assignment.
185 | normalize_input : bool
186 | If true, descriptor-wise L2 normalization is applied to input.
187 | """
188 | super(NetVLAD, self).__init__()
189 | self.num_clusters = num_clusters
190 | self.dim = dim
191 | self.alpha = alpha
192 | self.normalize_input = normalize_input
193 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True)
194 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim))
195 | self._init_params()
196 |
197 | def _init_params(self):
198 | self.conv.weight = nn.Parameter(
199 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
200 | )
201 | self.conv.bias = nn.Parameter(
202 | - self.alpha * self.centroids.norm(dim=1)
203 | )
204 |
205 | def forward(self, x):
206 | N, C = x.shape[:2]
207 |
208 | if self.normalize_input:
209 | x = F.normalize(x, p=2, dim=1) # across descriptor dim
210 |
211 | # soft-assignment
212 | soft_assign = self.conv(x).view(N, self.num_clusters, -1)
213 | soft_assign = F.softmax(soft_assign, dim=1)
214 |
215 | x_flatten = x.view(N, C, -1)
216 |
217 | # calculate residuals to each clusters
218 | residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \
219 | self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
220 | residual *= soft_assign.unsqueeze(2)
221 | vlad = residual.sum(dim=-1)
222 |
223 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
224 | vlad = vlad.view(x.size(0), -1) # flatten
225 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
226 |
227 | return vlad
228 |
229 | class Bottleneck(nn.Module):
230 | expansion = 4
231 |
232 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
233 | base_width=64, dilation=1, norm_layer=None):
234 | super(Bottleneck, self).__init__()
235 | if norm_layer is None:
236 | norm_layer = nn.BatchNorm2d
237 | width = int(planes * (base_width / 64.)) * groups
238 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
239 | self.conv1 = conv1x1(inplanes, width)
240 | self.bn1 = norm_layer(width)
241 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
242 | self.bn2 = norm_layer(width)
243 | self.conv3 = conv1x1(width, planes * self.expansion)
244 | self.bn3 = norm_layer(planes * self.expansion)
245 | self.relu = nn.ReLU(inplace=True)
246 | self.downsample = downsample
247 | self.stride = stride
248 |
249 | def forward(self, x):
250 | identity = x
251 |
252 | out = self.conv1(x)
253 | out = self.bn1(out)
254 | out = self.relu(out)
255 |
256 | out = self.conv2(out)
257 | out = self.bn2(out)
258 | out = self.relu(out)
259 |
260 | out = self.conv3(out)
261 | out = self.bn3(out)
262 |
263 | if self.downsample is not None:
264 | identity = self.downsample(x)
265 |
266 | out += identity
267 | out = self.relu(out)
268 |
269 | return out
270 |
271 |
272 |
273 |
274 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
275 | model = ResNet(block, layers, **kwargs)
276 | if pretrained:
277 | state_dict = load_state_dict_from_url(model_urls[arch],
278 | progress=progress)
279 | model.load_state_dict(state_dict)
280 | return model
281 |
282 |
283 | def resnet18(pretrained=False, progress=True, **kwargs):
284 | """Constructs a ResNet-18 model.
285 |
286 | Args:
287 | pretrained (bool): If True, returns a model pre-trained on ImageNet
288 | progress (bool): If True, displays a progress bar of the download to stderr
289 | """
290 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
291 | **kwargs)
292 |
293 |
294 | def resnet34(pretrained=False, progress=True, **kwargs):
295 | """Constructs a ResNet-34 model.
296 |
297 | Args:
298 | pretrained (bool): If True, returns a model pre-trained on ImageNet
299 | progress (bool): If True, displays a progress bar of the download to stderr
300 | """
301 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
302 | **kwargs)
303 |
304 |
305 | def resnet50(pretrained=False, progress=True, **kwargs):
306 | """Constructs a ResNet-50 model.
307 |
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | progress (bool): If True, displays a progress bar of the download to stderr
311 | """
312 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
313 | **kwargs)
314 |
315 |
316 | def resnet101(pretrained=False, progress=True, **kwargs):
317 | """Constructs a ResNet-101 model.
318 |
319 | Args:
320 | pretrained (bool): If True, returns a model pre-trained on ImageNet
321 | progress (bool): If True, displays a progress bar of the download to stderr
322 | """
323 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
324 | **kwargs)
325 |
326 |
327 | def resnet152(pretrained=False, progress=True, **kwargs):
328 | """Constructs a ResNet-152 model.
329 |
330 | Args:
331 | pretrained (bool): If True, returns a model pre-trained on ImageNet
332 | progress (bool): If True, displays a progress bar of the download to stderr
333 | """
334 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
335 | **kwargs)
336 |
337 |
338 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
339 | """Constructs a ResNeXt-50 32x4d model.
340 |
341 | Args:
342 | pretrained (bool): If True, returns a model pre-trained on ImageNet
343 | progress (bool): If True, displays a progress bar of the download to stderr
344 | """
345 | kwargs['groups'] = 32
346 | kwargs['width_per_group'] = 4
347 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
348 | pretrained, progress, **kwargs)
349 |
350 |
351 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
352 | """Constructs a ResNeXt-101 32x8d model.
353 |
354 | Args:
355 | pretrained (bool): If True, returns a model pre-trained on ImageNet
356 | progress (bool): If True, displays a progress bar of the download to stderr
357 | """
358 | kwargs['groups'] = 32
359 | kwargs['width_per_group'] = 8
360 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
361 | pretrained, progress, **kwargs)
362 |
363 |
364 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
365 | """Constructs a Wide ResNet-50-2 model.
366 |
367 | The model is the same as ResNet except for the bottleneck number of channels
368 | which is twice larger in every block. The number of channels in outer 1x1
369 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
370 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
371 |
372 | Args:
373 | pretrained (bool): If True, returns a model pre-trained on ImageNet
374 | progress (bool): If True, displays a progress bar of the download to stderr
375 | """
376 | kwargs['width_per_group'] = 64 * 2
377 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
378 | pretrained, progress, **kwargs)
379 |
380 |
381 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
382 | """Constructs a Wide ResNet-101-2 model.
383 |
384 | The model is the same as ResNet except for the bottleneck number of channels
385 | which is twice larger in every block. The number of channels in outer 1x1
386 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
387 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
388 |
389 | Args:
390 | pretrained (bool): If True, returns a model pre-trained on ImageNet
391 | progress (bool): If True, displays a progress bar of the download to stderr
392 | """
393 | kwargs['width_per_group'] = 64 * 2
394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
395 | pretrained, progress, **kwargs)
396 |
--------------------------------------------------------------------------------
/dataloaders/decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | import math
5 | import numpy as np
6 | import random
7 | import torch
8 | import torchvision.io as io
9 |
10 |
11 | def temporal_sampling(frames, start_idx, end_idx, num_samples):
12 | """
13 | Given the start and end frame index, sample num_samples frames between
14 | the start and end with equal interval.
15 | Args:
16 | frames (tensor): a tensor of video frames, dimension is
17 | `num video frames` x `channel` x `height` x `width`.
18 | start_idx (int): the index of the start frame.
19 | end_idx (int): the index of the end frame.
20 | num_samples (int): number of frames to sample.
21 | Returns:
22 | frames (tersor): a tensor of temporal sampled video frames, dimension is
23 | `num clip frames` x `channel` x `height` x `width`.
24 | """
25 | index = torch.linspace(start_idx, end_idx, num_samples)
26 | index = torch.clamp(index, 0, frames.shape[0] - 1).long()
27 | frames = torch.index_select(frames, 0, index)
28 | return frames
29 |
30 |
31 | def get_start_end_idx(
32 | video_size, clip_size, clip_idx, num_clips, use_offset=False
33 | ):
34 | """
35 | Sample a clip of size clip_size from a video of size video_size and
36 | return the indices of the first and last frame of the clip. If clip_idx is
37 | -1, the clip is randomly sampled, otherwise uniformly split the video to
38 | num_clips clips, and select the start and end index of clip_idx-th video
39 | clip.
40 | Args:
41 | video_size (int): number of overall frames.
42 | clip_size (int): size of the clip to sample from the frames.
43 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
44 | clip_idx is larger than -1, uniformly split the video to num_clips
45 | clips, and select the start and end index of the clip_idx-th video
46 | clip.
47 | num_clips (int): overall number of clips to uniformly sample from the
48 | given video for testing.
49 | Returns:
50 | start_idx (int): the start frame index.
51 | end_idx (int): the end frame index.
52 | """
53 | delta = max(video_size - clip_size, 0)
54 | if clip_idx == -1:
55 | # Random temporal sampling.
56 | start_idx = random.uniform(0, delta)
57 | else:
58 | if use_offset:
59 | if num_clips == 1:
60 | # Take the center clip if num_clips is 1.
61 | start_idx = math.floor(delta / 2)
62 | else:
63 | # Uniformly sample the clip with the given index.
64 | start_idx = clip_idx * math.floor(delta / (num_clips - 1))
65 | else:
66 | # Uniformly sample the clip with the given index.
67 | start_idx = delta * clip_idx / num_clips
68 | end_idx = start_idx + clip_size - 1
69 | return start_idx, end_idx
70 |
71 |
72 | def pyav_decode_stream(
73 | container, start_pts, end_pts, stream, stream_name, buffer_size=0
74 | ):
75 | """
76 | Decode the video with PyAV decoder.
77 | Args:
78 | container (container): PyAV container.
79 | start_pts (int): the starting Presentation TimeStamp to fetch the
80 | video frames.
81 | end_pts (int): the ending Presentation TimeStamp of the decoded frames.
82 | stream (stream): PyAV stream.
83 | stream_name (dict): a dictionary of streams. For example, {"video": 0}
84 | means video stream at stream index 0.
85 | buffer_size (int): number of additional frames to decode beyond end_pts.
86 | Returns:
87 | result (list): list of frames decoded.
88 | max_pts (int): max Presentation TimeStamp of the video sequence.
89 | """
90 | # Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
91 | # margin pts.
92 | margin = 1024
93 | seek_offset = max(start_pts - margin, 0)
94 |
95 | container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
96 | frames = {}
97 | buffer_count = 0
98 | max_pts = 0
99 | for frame in container.decode(**stream_name):
100 | max_pts = max(max_pts, frame.pts)
101 | if frame.pts < start_pts:
102 | continue
103 | if frame.pts <= end_pts:
104 | frames[frame.pts] = frame
105 | else:
106 | buffer_count += 1
107 | frames[frame.pts] = frame
108 | if buffer_count >= buffer_size:
109 | break
110 | result = [frames[pts] for pts in sorted(frames)]
111 | return result, max_pts
112 |
113 |
114 | def torchvision_decode(
115 | video_handle,
116 | sampling_rate,
117 | num_frames,
118 | clip_idx,
119 | video_meta,
120 | num_clips=10,
121 | target_fps=30,
122 | modalities=("visual",),
123 | max_spatial_scale=0,
124 | use_offset=False,
125 | ):
126 | """
127 | If video_meta is not empty, perform temporal selective decoding to sample a
128 | clip from the video with TorchVision decoder. If video_meta is empty, decode
129 | the entire video and update the video_meta.
130 | Args:
131 | video_handle (bytes): raw bytes of the video file.
132 | sampling_rate (int): frame sampling rate (interval between two sampled
133 | frames).
134 | num_frames (int): number of frames to sample.
135 | clip_idx (int): if clip_idx is -1, perform random temporal
136 | sampling. If clip_idx is larger than -1, uniformly split the
137 | video to num_clips clips, and select the clip_idx-th video clip.
138 | video_meta (dict): a dict contains VideoMetaData. Details can be found
139 | at `pytorch/vision/torchvision/io/_video_opt.py`.
140 | num_clips (int): overall number of clips to uniformly sample from the
141 | given video.
142 | target_fps (int): the input video may has different fps, convert it to
143 | the target video fps.
144 | modalities (tuple): tuple of modalities to decode. Currently only
145 | support `visual`, planning to support `acoustic` soon.
146 | max_spatial_scale (int): the maximal resolution of the spatial shorter
147 | edge size during decoding.
148 | Returns:
149 | frames (tensor): decoded frames from the video.
150 | fps (float): the number of frames per second of the video.
151 | decode_all_video (bool): if True, the entire video was decoded.
152 | """
153 | # Convert the bytes to a tensor.
154 | video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
155 |
156 | decode_all_video = True
157 | video_start_pts, video_end_pts = 0, -1
158 | # The video_meta is empty, fetch the meta data from the raw video.
159 | if len(video_meta) == 0:
160 | # Tracking the meta info for selective decoding in the future.
161 | meta = io._probe_video_from_memory(video_tensor)
162 | # Using the information from video_meta to perform selective decoding.
163 | video_meta["video_timebase"] = meta.video_timebase
164 | video_meta["video_numerator"] = meta.video_timebase.numerator
165 | video_meta["video_denominator"] = meta.video_timebase.denominator
166 | video_meta["has_video"] = meta.has_video
167 | video_meta["video_duration"] = meta.video_duration
168 | video_meta["video_fps"] = meta.video_fps
169 | video_meta["audio_timebas"] = meta.audio_timebase
170 | video_meta["audio_numerator"] = meta.audio_timebase.numerator
171 | video_meta["audio_denominator"] = meta.audio_timebase.denominator
172 | video_meta["has_audio"] = meta.has_audio
173 | video_meta["audio_duration"] = meta.audio_duration
174 | video_meta["audio_sample_rate"] = meta.audio_sample_rate
175 |
176 | fps = video_meta["video_fps"]
177 | if (
178 | video_meta["has_video"]
179 | and video_meta["video_denominator"] > 0
180 | and video_meta["video_duration"] > 0
181 | ):
182 | # try selective decoding.
183 | decode_all_video = False
184 | clip_size = sampling_rate * num_frames / target_fps * fps
185 | start_idx, end_idx = get_start_end_idx(
186 | fps * video_meta["video_duration"],
187 | clip_size,
188 | clip_idx,
189 | num_clips,
190 | use_offset=use_offset,
191 | )
192 | # Convert frame index to pts.
193 | pts_per_frame = video_meta["video_denominator"] / fps
194 | video_start_pts = int(start_idx * pts_per_frame)
195 | video_end_pts = int(end_idx * pts_per_frame)
196 |
197 | # Decode the raw video with the tv decoder.
198 | v_frames, _ = io._read_video_from_memory(
199 | video_tensor,
200 | seek_frame_margin=1.0,
201 | read_video_stream="visual" in modalities,
202 | video_width=0,
203 | video_height=0,
204 | video_min_dimension=max_spatial_scale,
205 | video_pts_range=(video_start_pts, video_end_pts),
206 | video_timebase_numerator=video_meta["video_numerator"],
207 | video_timebase_denominator=video_meta["video_denominator"],
208 | )
209 |
210 | if v_frames.shape == torch.Size([0]):
211 | # failed selective decoding
212 | decode_all_video = True
213 | video_start_pts, video_end_pts = 0, -1
214 | v_frames, _ = io._read_video_from_memory(
215 | video_tensor,
216 | seek_frame_margin=1.0,
217 | read_video_stream="visual" in modalities,
218 | video_width=0,
219 | video_height=0,
220 | video_min_dimension=max_spatial_scale,
221 | video_pts_range=(video_start_pts, video_end_pts),
222 | video_timebase_numerator=video_meta["video_numerator"],
223 | video_timebase_denominator=video_meta["video_denominator"],
224 | )
225 |
226 | return v_frames, fps, decode_all_video
227 |
228 |
229 | def pyav_decode(
230 | container,
231 | sampling_rate,
232 | num_frames,
233 | clip_idx,
234 | num_clips=10,
235 | target_fps=30,
236 | use_offset=False,
237 | ):
238 | """
239 | Convert the video from its original fps to the target_fps. If the video
240 | support selective decoding (contain decoding information in the video head),
241 | the perform temporal selective decoding and sample a clip from the video
242 | with the PyAV decoder. If the video does not support selective decoding,
243 | decode the entire video.
244 | Args:
245 | container (container): pyav container.
246 | sampling_rate (int): frame sampling rate (interval between two sampled
247 | frames.
248 | num_frames (int): number of frames to sample.
249 | clip_idx (int): if clip_idx is -1, perform random temporal sampling. If
250 | clip_idx is larger than -1, uniformly split the video to num_clips
251 | clips, and select the clip_idx-th video clip.
252 | num_clips (int): overall number of clips to uniformly sample from the
253 | given video.
254 | target_fps (int): the input video may has different fps, convert it to
255 | the target video fps before frame sampling.
256 | Returns:
257 | frames (tensor): decoded frames from the video. Return None if the no
258 | video stream was found.
259 | fps (float): the number of frames per second of the video.
260 | decode_all_video (bool): If True, the entire video was decoded.
261 | """
262 | # Try to fetch the decoding information from the video head. Some of the
263 | # videos does not support fetching the decoding information, for that case
264 | # it will get None duration.
265 | fps = float(container.streams.video[0].average_rate)
266 | frames_length = container.streams.video[0].frames
267 | duration = container.streams.video[0].duration
268 |
269 | if duration is None:
270 | # If failed to fetch the decoding information, decode the entire video.
271 | decode_all_video = True
272 | video_start_pts, video_end_pts = 0, math.inf
273 | else:
274 | # Perform selective decoding.
275 | decode_all_video = False
276 | start_idx, end_idx = get_start_end_idx(
277 | frames_length,
278 | sampling_rate * num_frames / target_fps * fps,
279 | clip_idx,
280 | num_clips,
281 | use_offset=use_offset,
282 | )
283 | timebase = duration / frames_length
284 | video_start_pts = int(start_idx * timebase)
285 | video_end_pts = int(end_idx * timebase)
286 |
287 | frames = None
288 | # If video stream was found, fetch video frames from the video.
289 | if container.streams.video:
290 | video_frames, max_pts = pyav_decode_stream(
291 | container,
292 | video_start_pts,
293 | video_end_pts,
294 | container.streams.video[0],
295 | {"video": 0},
296 | )
297 | container.close()
298 |
299 | frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
300 |
301 |
302 | frames = torch.as_tensor(np.stack(frames))
303 | return frames, fps, decode_all_video
304 |
305 |
306 | def decode(
307 | container,
308 | sampling_rate,
309 | num_frames,
310 | clip_idx=-1,
311 | num_clips=10,
312 | video_meta=None,
313 | target_fps=30,
314 | backend="pyav",
315 | max_spatial_scale=0,
316 | use_offset=False,
317 | ):
318 | """
319 | Decode the video and perform temporal sampling.
320 | Args:
321 | container (container): pyav container.
322 | sampling_rate (int): frame sampling rate (interval between two sampled
323 | frames).
324 | num_frames (int): number of frames to sample.
325 | clip_idx (int): if clip_idx is -1, perform random temporal
326 | sampling. If clip_idx is larger than -1, uniformly split the
327 | video to num_clips clips, and select the
328 | clip_idx-th video clip.
329 | num_clips (int): overall number of clips to uniformly
330 | sample from the given video.
331 | video_meta (dict): a dict contains VideoMetaData. Details can be find
332 | at `pytorch/vision/torchvision/io/_video_opt.py`.
333 | target_fps (int): the input video may have different fps, convert it to
334 | the target video fps before frame sampling.
335 | backend (str): decoding backend includes `pyav` and `torchvision`. The
336 | default one is `pyav`.
337 | max_spatial_scale (int): keep the aspect ratio and resize the frame so
338 | that shorter edge size is max_spatial_scale. Only used in
339 | `torchvision` backend.
340 | Returns:
341 | frames (tensor): decoded frames from the video.
342 | """
343 | # Currently support two decoders: 1) PyAV, and 2) TorchVision.
344 | assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
345 | try:
346 |
347 | if backend == "pyav":
348 | frames, fps, decode_all_video = pyav_decode(
349 | container,
350 | sampling_rate,
351 | num_frames,
352 | clip_idx,
353 | num_clips,
354 | target_fps,
355 | use_offset=use_offset,
356 | )
357 | from ipdb import set_trace
358 | total_len = len(frames) / fps
359 | elif backend == "torchvision":
360 | frames, fps, decode_all_video = torchvision_decode(
361 | container,
362 | sampling_rate,
363 | num_frames,
364 | clip_idx,
365 | video_meta,
366 | num_clips,
367 | target_fps,
368 | ("visual",),
369 | max_spatial_scale,
370 | use_offset=use_offset,
371 | )
372 |
373 | else:
374 | raise NotImplementedError(
375 | "Unknown decoding backend {}".format(backend)
376 | )
377 | except Exception as e:
378 | print("Failed to decode by {} with exception: {}".format(backend, e))
379 | return None
380 |
381 | # Return None if the frames was not decoded successfully.
382 | if frames is None or frames.size(0) == 0:
383 | return None
384 |
385 | clip_sz = sampling_rate * num_frames / target_fps * fps
386 | start_idx, end_idx = get_start_end_idx(
387 | frames.shape[0],
388 | clip_sz,
389 | clip_idx if decode_all_video else 0,
390 | num_clips if decode_all_video else 1,
391 | use_offset=use_offset,
392 | )
393 | # Perform temporal sampling from the decoded video.
394 |
395 | frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
396 | return frames,total_len
--------------------------------------------------------------------------------