├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── config ├── config.py ├── hero_finetune.json ├── hero_pretrain.json ├── hero_tvc.json ├── pretrain-tv-16gpu.json ├── train-didemo_video_only-4gpu.json ├── train-didemo_video_sub-8gpu.json ├── train-msrvtt_video_only-4gpu.json ├── train-msrvtt_video_sub-4gpu.json ├── train-tvc-8gpu.json ├── train-tvqa-8gpu.json ├── train-tvr-8gpu.json └── train-violin-8gpu.json ├── data ├── __init__.py ├── data.py ├── fom.py ├── loader.py ├── mfm.py ├── mlm.py ├── tvc.py ├── vcmr.py ├── vcmr_video_only.py ├── videoQA.py ├── violin.py ├── vr.py ├── vr_video_only.py └── vsm.py ├── eval ├── pycocoevalcap │ ├── README.md │ ├── __init__.py │ ├── bleu │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── license.txt │ ├── meteor │ │ ├── __init__.py │ │ ├── meteor.py │ │ └── tests │ │ │ └── test_meteor.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ └── tokenizer │ │ ├── __init__.py │ │ └── ptbtokenizer.py └── tvc.py ├── eval_vcmr.py ├── eval_videoQA.py ├── eval_violin.py ├── eval_vr.py ├── inf_tvc.py ├── launch_container.sh ├── load_data.py ├── model ├── __init__.py ├── embed.py ├── encoder.py ├── layers.py ├── model.py ├── modeling_utils.py ├── pretrain.py ├── tvc.py ├── vcmr.py ├── videoQA.py ├── violin.py └── vr.py ├── optim ├── __init__.py ├── adamw.py ├── misc.py └── sched.py ├── pretrain.py ├── scripts ├── collect_video_feature_paths.py ├── convert_videodb.py ├── create_txtdb.sh ├── download_didemo.sh ├── download_msrvtt.sh ├── download_pretrained.sh ├── download_tv_pretrain.sh ├── download_tvc.sh ├── download_tvqa.sh ├── download_tvr.sh ├── download_violin.sh ├── prepro_query.py ├── prepro_sub.py ├── prepro_tvc.py └── prepro_tvc.sh ├── train_tvc.py ├── train_vcmr.py ├── train_videoQA.py ├── train_violin.py ├── train_vr.py └── utils ├── __init__.py ├── basic_utils.py ├── const.py ├── distributed.py ├── logger.py ├── misc.py ├── save.py ├── tvr_eval_utils.py └── tvr_standalone_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ctags 2 | tags 3 | 4 | # compiled files # 5 | __pycache__ 6 | *.pyc 7 | 8 | # Packages # 9 | ############ 10 | # it's better to unpack these files and commit the raw source 11 | # git has its own built in compression methods 12 | *.7z 13 | *.dmg 14 | *.gz 15 | *.iso 16 | *.jar 17 | *.rar 18 | *.tar 19 | *.zip 20 | 21 | # Logs and databases # 22 | ###################### 23 | *.log 24 | *.sql 25 | *.sqlite 26 | .ipynb_checkpoints/ 27 | *.swp 28 | *.vscode/ 29 | *.idea/ 30 | 31 | # OS generated files # 32 | ###################### 33 | .DS_Store 34 | .DS_Store? 35 | ._* 36 | .Spotlight-V100 37 | .Trashes 38 | ehthumbs.db 39 | Thumbs.db 40 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.10-py3 2 | 3 | # basic python packages 4 | RUN pip install transformers==2.0.0 \ 5 | tensorboardX==1.7 ipdb==0.12 lz4==2.1.9 lmdb==0.97 6 | 7 | ####### horovod for multi-GPU (distributed) training ####### 8 | # horovod 9 | RUN HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \ 10 | pip install --no-cache-dir horovod==0.18.2 &&\ 11 | ldconfig 12 | 13 | # ssh 14 | RUN apt-get update &&\ 15 | apt-get install -y --no-install-recommends openssh-client openssh-server &&\ 16 | mkdir -p /var/run/sshd 17 | 18 | # Allow OpenSSH to talk to containers without asking for confirmation 19 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 20 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 21 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 22 | 23 | # captioning 24 | 25 | # captioning eval tool (java for PTBtokenizer and METEOR) 26 | RUN apt-get install -y --no-install-recommends openjdk-8-jdk && apt-get clean 27 | 28 | # binaries for cococap eval 29 | ARG PYCOCOEVALCAP=https://github.com/tylin/coco-caption/raw/master/pycocoevalcap 30 | RUN mkdir /workspace/cococap_bin/ && \ 31 | wget $PYCOCOEVALCAP/meteor/meteor-1.5.jar -P /workspace/cococap_bin/ && \ 32 | wget $PYCOCOEVALCAP/meteor/data/paraphrase-en.gz -P /workspace/cococap_bin/ && \ 33 | wget $PYCOCOEVALCAP/tokenizer/stanford-corenlp-3.4.1.jar -P /workspace/cococap_bin/ 34 | 35 | # add new command here 36 | 37 | WORKDIR /src 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/hero_finetune.json: -------------------------------------------------------------------------------- 1 | {"f_config":{ 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 514, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 6, 11 | "type_vocab_size": 2, 12 | "vocab_size": 50272 13 | }, 14 | "c_config": { 15 | "attention_probs_dropout_prob": 0.1, 16 | "hidden_act": "gelu", 17 | "hidden_dropout_prob": 0.1, 18 | "hidden_size": 768, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 3072, 21 | "max_position_embeddings": 514, 22 | "num_attention_heads": 12, 23 | "num_hidden_layers": 3, 24 | "type_vocab_size": 2 25 | }, 26 | "q_config": { 27 | "attention_probs_dropout_prob": 0.1, 28 | "hidden_act": "gelu", 29 | "hidden_dropout_prob": 0.1, 30 | "hidden_size": 768, 31 | "initializer_range": 0.02, 32 | "intermediate_size": 3072, 33 | "num_attention_heads": 12, 34 | "max_position_embeddings": 514, 35 | "num_hidden_layers": 0, 36 | "type_vocab_size": 1, 37 | "vocab_size": 50272 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /config/hero_pretrain.json: -------------------------------------------------------------------------------- 1 | {"f_config":{ 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 514, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 6, 11 | "type_vocab_size": 1, 12 | "vocab_size": 50265 13 | }, 14 | "c_config": { 15 | "attention_probs_dropout_prob": 0.1, 16 | "hidden_act": "gelu", 17 | "hidden_dropout_prob": 0.1, 18 | "hidden_size": 768, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 3072, 21 | "max_position_embeddings": 514, 22 | "num_attention_heads": 12, 23 | "num_hidden_layers": 3, 24 | "type_vocab_size": 2 25 | }, 26 | "q_config": { 27 | "attention_probs_dropout_prob": 0.1, 28 | "hidden_act": "gelu", 29 | "hidden_dropout_prob": 0.1, 30 | "hidden_size": 768, 31 | "initializer_range": 0.02, 32 | "intermediate_size": 3072, 33 | "num_attention_heads": 12, 34 | "max_position_embeddings": 514, 35 | "num_hidden_layers": 0, 36 | "type_vocab_size": 1, 37 | "vocab_size": 50265 38 | } 39 | } -------------------------------------------------------------------------------- /config/hero_tvc.json: -------------------------------------------------------------------------------- 1 | {"f_config":{ 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 514, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 6, 11 | "type_vocab_size": 2, 12 | "vocab_size": 50272 13 | }, 14 | "c_config": { 15 | "attention_probs_dropout_prob": 0.1, 16 | "hidden_act": "gelu", 17 | "hidden_dropout_prob": 0.1, 18 | "hidden_size": 768, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 3072, 21 | "max_position_embeddings": 514, 22 | "num_attention_heads": 12, 23 | "num_hidden_layers": 3, 24 | "type_vocab_size": 2 25 | }, 26 | "d_config": { 27 | "attention_probs_dropout_prob": 0.1, 28 | "hidden_act": "gelu", 29 | "hidden_dropout_prob": 0.1, 30 | "hidden_size": 768, 31 | "initializer_range": 0.02, 32 | "intermediate_size": 3072, 33 | "max_position_embeddings": 1024, 34 | "num_attention_heads": 12, 35 | "num_hidden_layers": 2, 36 | "type_vocab_size": 1, 37 | "vocab_size": 50272 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /config/pretrain-tv-16gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "txt_db": "/txt", 3 | "img_db": "/video", 4 | "targets": [ 5 | {"name": "tv", 6 | "sub_txt_db": "tv_subtitles.db", 7 | "vfeat_db": "tv", 8 | "vfeat_interval": 1.5, 9 | "splits": [ 10 | {"name": "all", 11 | "tasks": ["mlm", "mfm-nce", "fom", "vsm"], 12 | "train_idx": "pretrain_splits/tv_train.json", 13 | "val_idx": "pretrain_splits/tv_val.json", 14 | "ratio": [2, 2, 1, 2] 15 | } 16 | ] 17 | } 18 | ], 19 | "targets_ratio": [1], 20 | "mask_prob": 0.15, 21 | "compressed_db": false, 22 | "model_config": "config/hero_pretrain.json", 23 | "checkpoint": "/pretrain/pretrain-tv-init.bin", 24 | "load_partial_pretrained" : true, 25 | "skip_layer_loading" : true, 26 | "output_dir": "/storage/default_pretrain_tv", 27 | "max_clip_len": 100, 28 | "max_txt_len": 60, 29 | "vfeat_version": "resnet_slowfast", 30 | "drop_svmr_prob": 0.8, 31 | "train_batch_size": 32, 32 | "val_batch_size": 32, 33 | "gradient_accumulation_steps": 2, 34 | "learning_rate": 3e-05, 35 | "valid_steps": 500, 36 | "save_steps": 500, 37 | "num_train_steps": 100000, 38 | "optim": "adamw", 39 | "betas": [ 40 | 0.9, 41 | 0.98 42 | ], 43 | "dropout": 0.1, 44 | "weight_decay": 0.01, 45 | "grad_norm": 1.0, 46 | "warmup_steps": 10000, 47 | "lw_neg_q": 8.0, 48 | "lw_neg_ctx": 8.0, 49 | "lw_st_ed": 0.01, 50 | "ranking_loss_type": "hinge", 51 | "margin": 0.1, 52 | "hard_pool_size": [ 53 | 20 54 | ], 55 | "hard_neg_weights": [ 56 | 10 57 | ], 58 | "hard_negtiave_start_step": [ 59 | 20000 60 | ], 61 | "train_span_start_step": 0, 62 | "sub_ctx_len": 0, 63 | "use_all_neg": true, 64 | "seed": 77, 65 | "fp16": true, 66 | "n_workers": 4, 67 | "pin_mem": true, 68 | "rank": 0 69 | } 70 | -------------------------------------------------------------------------------- /config/train-didemo_video_only-4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "didemo_video_only", 3 | "sub_txt_db": null, 4 | "vfeat_db": "/video/didemo", 5 | "train_query_txt_db": "/txt/didemo_train.db", 6 | "val_query_txt_db": "/txt/didemo_val.db", 7 | "test_query_txt_db": "/txt/didemo_test.db", 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/didemo_video_only_default", 12 | "max_before_nms": 200, 13 | "max_after_nms": 100, 14 | "distributed_eval": true, 15 | "nms_thd": -1, 16 | "q2c_alpha": 20, 17 | "max_vcmr_video": 100, 18 | "full_eval_tasks": [ 19 | "VCMR", 20 | "SVMR", 21 | "VR" 22 | ], 23 | "max_clip_len": 20, 24 | "max_txt_len": 60, 25 | "vfeat_version": "resnet_slowfast", 26 | "vfeat_interval": 1.5, 27 | "min_pred_l": 3, 28 | "max_pred_l": 6, 29 | "drop_svmr_prob": 0.8, 30 | "train_batch_size": 32, 31 | "val_batch_size": 20, 32 | "vcmr_eval_video_batch_size": 50, 33 | "vcmr_eval_batch_size": 80, 34 | "gradient_accumulation_steps": 1, 35 | "learning_rate": 7e-05, 36 | "valid_steps": 200, 37 | "save_steps": 200, 38 | "num_train_steps": 5000, 39 | "optim": "adamw", 40 | "betas": [ 41 | 0.9, 42 | 0.98 43 | ], 44 | "dropout": 0.1, 45 | "weight_decay": 0.01, 46 | "grad_norm": 1.0, 47 | "warmup_steps": 500, 48 | "lw_neg_q": 10.0, 49 | "lw_neg_ctx": 10.0, 50 | "lw_st_ed": 0.01, 51 | "ranking_loss_type": "hinge", 52 | "margin": 0.1, 53 | "hard_pool_size": [ 54 | 80 55 | ], 56 | "hard_neg_weights": [ 57 | 10 58 | ], 59 | "hard_negtiave_start_step": [ 60 | 2000 61 | ], 62 | "train_span_start_step": 0, 63 | "use_all_neg": true, 64 | "seed": 77, 65 | "fp16": true, 66 | "n_workers": 4, 67 | "pin_mem": true, 68 | "rank": 0 69 | } 70 | -------------------------------------------------------------------------------- /config/train-didemo_video_sub-8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "didemo_video_sub", 3 | "sub_txt_db": "/txt/didemo_subtitles.db", 4 | "vfeat_db": "/video/didemo", 5 | "train_query_txt_db": "/txt/didemo_train.db", 6 | "val_query_txt_db": "/txt/didemo_val.db", 7 | "test_query_txt_db": "/txt/didemo_test.db", 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/didemo_video_sub_default", 12 | "eval_with_query_type": false, 13 | "max_before_nms": 200, 14 | "max_after_nms": 100, 15 | "distributed_eval": true, 16 | "nms_thd": -1, 17 | "q2c_alpha": 20, 18 | "max_vcmr_video": 100, 19 | "full_eval_tasks": [ 20 | "VCMR", 21 | "SVMR", 22 | "VR" 23 | ], 24 | "max_clip_len": 100, 25 | "max_txt_len": 60, 26 | "vfeat_version": "resnet_slowfast", 27 | "vfeat_interval": 1.5, 28 | "min_pred_l": 2, 29 | "max_pred_l": 16, 30 | "drop_svmr_prob": 0.8, 31 | "train_batch_size": 32, 32 | "val_batch_size": 20, 33 | "vcmr_eval_video_batch_size": 50, 34 | "vcmr_eval_batch_size": 80, 35 | "gradient_accumulation_steps":2, 36 | "learning_rate": 0.0001, 37 | "valid_steps": 200, 38 | "save_steps": 200, 39 | "num_train_steps": 5000, 40 | "optim": "adamw", 41 | "betas": [ 42 | 0.9, 43 | 0.98 44 | ], 45 | "dropout": 0.1, 46 | "weight_decay": 0.01, 47 | "grad_norm": 1.0, 48 | "warmup_steps": 500, 49 | "lw_neg_q": 10.0, 50 | "lw_neg_ctx": 10.0, 51 | "lw_st_ed": 0.01, 52 | "ranking_loss_type": "hinge", 53 | "margin": 0.1, 54 | "hard_pool_size": [ 55 | 80 56 | ], 57 | "hard_neg_weights": [ 58 | 10 59 | ], 60 | "hard_negtiave_start_step": [ 61 | 2000 62 | ], 63 | "train_span_start_step": 0, 64 | "sub_ctx_len": 0, 65 | "use_all_neg": true, 66 | "seed": 77, 67 | "fp16": true, 68 | "n_workers": 4, 69 | "pin_mem": true, 70 | "rank": 0 71 | } 72 | -------------------------------------------------------------------------------- /config/train-msrvtt_video_only-4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "msrvtt_video_only", 3 | "sub_txt_db": null, 4 | "vfeat_db": "/video/msrvtt", 5 | "train_query_txt_db": "/txt/msrvtt_train.db", 6 | "val_query_txt_db": "/txt/msrvtt_val.db", 7 | "test_query_txt_db": "/txt/msrvtt_test.db", 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/msrvtt_video_only_default", 12 | "distributed_eval": true, 13 | "max_vr_video": 100, 14 | "max_clip_len": 100, 15 | "max_txt_len": 60, 16 | "vfeat_version": "resnet_slowfast", 17 | "vfeat_interval": 2, 18 | "train_batch_size": 96, 19 | "val_batch_size": 20, 20 | "vr_eval_video_batch_size": 50, 21 | "vr_eval_q_batch_size": 80, 22 | "gradient_accumulation_steps": 2, 23 | "learning_rate": 7e-05, 24 | "valid_steps": 200, 25 | "save_steps": 200, 26 | "num_train_steps": 4000, 27 | "optim": "adamw", 28 | "betas": [ 29 | 0.9, 30 | 0.98 31 | ], 32 | "dropout": 0.1, 33 | "weight_decay": 0.01, 34 | "grad_norm": 1.0, 35 | "warmup_steps": 400, 36 | "lw_neg_q": 10.0, 37 | "lw_neg_ctx": 10.0, 38 | "ranking_loss_type": "hinge", 39 | "margin": 0.1, 40 | "hard_pool_size": [ 41 | 80 42 | ], 43 | "hard_neg_weights": [ 44 | 10 45 | ], 46 | "hard_negtiave_start_step": [ 47 | 2000 48 | ], 49 | "use_all_neg": true, 50 | "seed": 77, 51 | "fp16": true, 52 | "n_workers": 4, 53 | "pin_mem": true, 54 | "rank": 0 55 | } 56 | -------------------------------------------------------------------------------- /config/train-msrvtt_video_sub-4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "msrvtt_video_sub", 3 | "sub_txt_db": "/txt/msrvtt_subtitles.db", 4 | "vfeat_db": "/video/msrvtt", 5 | "train_query_txt_db": "/txt/msrvtt_train.db", 6 | "val_query_txt_db": "/txt/msrvtt_val.db", 7 | "test_query_txt_db": "/txt/msrvtt_test.db", 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/msrvtt_video_sub_default", 12 | "distributed_eval": true, 13 | "max_vr_video": 100, 14 | "max_clip_len": 100, 15 | "max_txt_len": 60, 16 | "vfeat_version": "resnet_slowfast", 17 | "vfeat_interval": 2, 18 | "train_batch_size": 96, 19 | "val_batch_size": 20, 20 | "vr_eval_video_batch_size": 50, 21 | "vr_eval_q_batch_size": 80, 22 | "gradient_accumulation_steps": 2, 23 | "learning_rate": 7e-05, 24 | "valid_steps": 200, 25 | "save_steps": 200, 26 | "num_train_steps": 4000, 27 | "optim": "adamw", 28 | "betas": [ 29 | 0.9, 30 | 0.98 31 | ], 32 | "dropout": 0.1, 33 | "weight_decay": 0.01, 34 | "grad_norm": 1.0, 35 | "warmup_steps": 400, 36 | "lw_neg_q": 10.0, 37 | "lw_neg_ctx": 10.0, 38 | "ranking_loss_type": "hinge", 39 | "margin": 0.1, 40 | "hard_pool_size": [ 41 | 80 42 | ], 43 | "hard_neg_weights": [ 44 | 10 45 | ], 46 | "hard_negtiave_start_step": [ 47 | 2000 48 | ], 49 | "use_all_neg": true, 50 | "sub_ctx_len": 1, 51 | "seed": 77, 52 | "fp16": true, 53 | "n_workers": 4, 54 | "pin_mem": true, 55 | "rank": 0 56 | } 57 | -------------------------------------------------------------------------------- /config/train-tvc-8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "sub_txt_db": "/txt/tv_subtitles.db", 3 | "vfeat_db": "/video/tv", 4 | "train_db": "/txt/tvc_train.db", 5 | "val_db": "/txt/tvc_val.db", 6 | "val_ref": "/txt/tvc_val_release.jsonl", 7 | "model_config": "/src/config/hero_tvc.json", 8 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 9 | "output_dir": "/storage/tvc_default", 10 | "max_clip_len": 100, 11 | "max_txt_len": 60, 12 | "max_cap_per_vid": -1, 13 | "max_gen_step": 30, 14 | "vfeat_version": "resnet_slowfast", 15 | "vfeat_interval": 1.5, 16 | "compressed_db": false, 17 | "train_batch_size": 4, 18 | "val_batch_size": 8, 19 | "gradient_accumulation_steps": 1, 20 | "learning_rate": 1e-4, 21 | "lr_mul": 10.0, 22 | "valid_steps": 500, 23 | "num_train_steps": 7000, 24 | "optim": "adamw", 25 | "betas": [0.9, 0.98], 26 | "lsr": 0.1, 27 | "dropout": 0.1, 28 | "weight_decay": 0.01, 29 | "grad_norm": 1.0, 30 | "warmup_steps": 700, 31 | "sub_ctx_len": 1, 32 | "seed": 77, 33 | "fp16": true, 34 | "n_workers": 4, 35 | "pin_mem": true 36 | } 37 | -------------------------------------------------------------------------------- /config/train-tvqa-8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "tvqa", 3 | "sub_txt_db": "/txt/tv_subtitles.db", 4 | "vfeat_db": "/video/tv", 5 | "train_query_txt_db": "/txt/tvqa_train.db", 6 | "val_query_txt_db": "/txt/tvqa_val.db", 7 | "test_query_txt_db": null, 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/tvqa_default", 12 | "max_clip_len": 100, 13 | "max_txt_len": 120, 14 | "vfeat_version": "resnet_slowfast", 15 | "vfeat_interval": 1.5, 16 | "train_batch_size": 4, 17 | "val_batch_size": 10, 18 | "gradient_accumulation_steps": 2, 19 | "learning_rate": 5e-05, 20 | "valid_steps": 200, 21 | "save_steps": 200, 22 | "num_train_steps": 10000, 23 | "optim": "adamw", 24 | "betas": [ 25 | 0.9, 26 | 0.98 27 | ], 28 | "dropout": 0.1, 29 | "weight_decay": 0.01, 30 | "lr_mul": 10.0, 31 | "grad_norm": 1.0, 32 | "warmup_steps": 1000, 33 | "lw_st_ed": 0.4, 34 | "sub_ctx_len": 0, 35 | "seed": 77, 36 | "fp16": true, 37 | "n_workers": 4, 38 | "pin_mem": true, 39 | "rank": 0 40 | } 41 | -------------------------------------------------------------------------------- /config/train-tvr-8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "tvr", 3 | "sub_txt_db": "/txt/tv_subtitles.db", 4 | "vfeat_db": "/video/tv", 5 | "train_query_txt_db": "/txt/tvr_train.db", 6 | "val_query_txt_db": "/txt/tvr_val.db", 7 | "test_query_txt_db": null, 8 | "compressed_db": false, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/tvr_default", 12 | "eval_with_query_type": true, 13 | "max_before_nms": 200, 14 | "max_after_nms": 100, 15 | "distributed_eval": true, 16 | "nms_thd": 0.5, 17 | "q2c_alpha": 20, 18 | "max_vcmr_video": 100, 19 | "full_eval_tasks": [ 20 | "VCMR", 21 | "SVMR", 22 | "VR" 23 | ], 24 | "max_clip_len": 100, 25 | "max_txt_len": 60, 26 | "vfeat_version": "resnet_slowfast", 27 | "vfeat_interval": 1.5, 28 | "min_pred_l": 2, 29 | "max_pred_l": 16, 30 | "drop_svmr_prob": 0.8, 31 | "train_batch_size": 32, 32 | "val_batch_size": 20, 33 | "vcmr_eval_video_batch_size": 50, 34 | "vcmr_eval_batch_size": 80, 35 | "gradient_accumulation_steps":2, 36 | "learning_rate": 1e-04, 37 | "valid_steps": 200, 38 | "save_steps": 200, 39 | "num_train_steps": 5000, 40 | "optim": "adamw", 41 | "betas": [ 42 | 0.9, 43 | 0.98 44 | ], 45 | "dropout": 0.1, 46 | "weight_decay": 0.01, 47 | "grad_norm": 1.0, 48 | "warmup_steps": 500, 49 | "lw_neg_q": 8.0, 50 | "lw_neg_ctx": 8.0, 51 | "lw_st_ed": 0.01, 52 | "ranking_loss_type": "hinge", 53 | "margin": 0.1, 54 | "hard_pool_size": [ 55 | 20 56 | ], 57 | "hard_neg_weights": [ 58 | 10 59 | ], 60 | "hard_negtiave_start_step": [ 61 | 2000 62 | ], 63 | "train_span_start_step": 0, 64 | "sub_ctx_len": 0, 65 | "use_all_neg": true, 66 | "seed": 77, 67 | "fp16": true, 68 | "n_workers": 4, 69 | "pin_mem": true, 70 | "rank": 0 71 | } 72 | -------------------------------------------------------------------------------- /config/train-violin-8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "violin", 3 | "sub_txt_db": "/txt/violin_subtitles.db", 4 | "vfeat_db": "/video/violin", 5 | "train_query_txt_db": "/txt/violin_train.db", 6 | "val_query_txt_db": "/txt/violin_val.db", 7 | "test_query_txt_db": "/txt/violin_test.db", 8 | "compressed_db": true, 9 | "model_config": "config/hero_finetune.json", 10 | "checkpoint": "/pretrain/hero-tv-ht100.pt", 11 | "output_dir": "/storage/violin_default", 12 | "max_clip_len": 100, 13 | "max_txt_len": 120, 14 | "vfeat_version": "resnet_slowfast", 15 | "vfeat_interval": 1.5, 16 | "train_batch_size": 4, 17 | "val_batch_size": 10, 18 | "gradient_accumulation_steps": 2, 19 | "learning_rate": 3e-05, 20 | "valid_steps": 200, 21 | "save_steps": 200, 22 | "num_train_steps": 6000, 23 | "optim": "adamw", 24 | "betas": [ 25 | 0.9, 26 | 0.98 27 | ], 28 | "dropout": 0.1, 29 | "weight_decay": 0.01, 30 | "lr_mul": 8.0, 31 | "grad_norm": 1.0, 32 | "warmup_steps": 600, 33 | "sub_ctx_len": 2, 34 | "seed": 77, 35 | "fp16": true, 36 | "n_workers": 4, 37 | "pin_mem": true, 38 | "rank": 0 39 | } 40 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .data import ( 7 | TxtTokLmdb, VideoFeatLmdb, SubTokLmdb, 8 | QueryTokLmdb, VideoFeatSubTokDataset, video_collate, 9 | QaQueryTokLmdb) 10 | from .loader import PrefetchLoader, MetaLoader 11 | from .vcmr import ( 12 | VcmrDataset, vcmr_collate, VcmrEvalDataset, vcmr_eval_collate, 13 | VcmrFullEvalDataset, vcmr_full_eval_collate) 14 | from .vcmr_video_only import ( 15 | VcmrVideoOnlyDataset, VcmrVideoOnlyEvalDataset, 16 | VcmrVideoOnlyFullEvalDataset) 17 | from .vr_video_only import ( 18 | VideoFeatDataset, 19 | VrVideoOnlyDataset, VrVideoOnlyEvalDataset, 20 | VrVideoOnlyFullEvalDataset) 21 | from .vr import ( 22 | VrDataset, VrEvalDataset, VrSubTokLmdb, VrQueryTokLmdb, 23 | MsrvttQueryTokLmdb, 24 | VrFullEvalDataset, vr_collate, vr_eval_collate, 25 | vr_full_eval_collate) 26 | from .videoQA import ( 27 | VideoQaDataset, video_qa_collate, 28 | VideoQaEvalDataset, video_qa_eval_collate) 29 | from .violin import ( 30 | ViolinDataset, violin_collate, 31 | ViolinEvalDataset, violin_eval_collate) 32 | from .fom import ( 33 | FomDataset, fom_collate, 34 | FomEvalDataset, fom_eval_collate) 35 | from .vsm import VsmDataset, vsm_collate 36 | from .mlm import ( 37 | VideoMlmDataset, mlm_collate) 38 | from .mfm import MfmDataset, mfm_collate 39 | from .tvc import TvcTrainDataset, TvcValDataset, CaptionTokLmdb 40 | -------------------------------------------------------------------------------- /data/fom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Pretrain FOM dataset 6 | """ 7 | import copy 8 | import random 9 | 10 | from torch.utils.data import Dataset 11 | import torch 12 | from toolz.sandbox import unzip 13 | import horovod.torch as hvd 14 | 15 | from .data import VideoFeatSubTokDataset, _check_ngpu, video_collate 16 | 17 | 18 | class FomDataset(Dataset): 19 | def __init__(self, video_ids, vid_sub_db, random_reorder_p=0.15): 20 | assert isinstance(vid_sub_db, VideoFeatSubTokDataset) 21 | self.vid_sub_db = vid_sub_db 22 | if _check_ngpu() > 1: 23 | self.ids = video_ids[hvd.rank()::hvd.size()] 24 | else: 25 | self.ids = video_ids 26 | self.random_reorder_p = random_reorder_p 27 | 28 | def __len__(self): 29 | return len(self.ids) 30 | 31 | def __getitem__(self, i): 32 | vid_ = self.ids[i] 33 | (f_sub_input_ids, f_v_feats, f_attn_masks, 34 | c_v_feats, c_attn_masks, 35 | num_subs, sub2frames) = self.vid_sub_db[vid_] 36 | c_pos_ids = [i for i in range(len(c_v_feats))] 37 | # Random shuffle 15% of pos_ids 38 | orders, targets = random_reorder( 39 | list(range(len(c_pos_ids))), self.random_reorder_p) 40 | orders = torch.tensor(orders, dtype=torch.long) 41 | targets = torch.tensor(targets, dtype=torch.long) 42 | video_inputs = ( 43 | f_sub_input_ids, f_v_feats, f_attn_masks, 44 | c_v_feats, c_attn_masks, 45 | num_subs, sub2frames) 46 | out = (video_inputs, orders, targets) 47 | return out 48 | 49 | 50 | def fom_collate(inputs): 51 | (video_inputs, orders, targets) = map(list, unzip(inputs)) 52 | batch = video_collate(video_inputs) 53 | 54 | clip_level_v_feats = batch["c_v_feats"] 55 | num_frames = [item.size(0) for item in orders] 56 | 57 | all_orders = torch.arange( 58 | 0, clip_level_v_feats.size(1), dtype=torch.long).unsqueeze(0).repeat( 59 | clip_level_v_feats.size(0), 1) 60 | all_targets = torch.ones_like(all_orders) * -1 61 | for i, nframe in enumerate(num_frames): 62 | all_orders[i, :nframe] = orders[i] 63 | all_targets[i, :nframe] = targets[i] 64 | reordered_frame_idx = [] 65 | binary_targets = [] 66 | bs, max_vl = all_orders.size() 67 | for clip_idx in range(bs): 68 | for i in range(num_frames[clip_idx]): 69 | if all_targets[clip_idx, i] == -1: 70 | continue 71 | for j in range(i+1, num_frames[clip_idx]): 72 | if all_targets[clip_idx, j] == -1: 73 | continue 74 | reordered_frame_idx.append(clip_idx*max_vl+i) 75 | reordered_frame_idx.append(clip_idx*max_vl+j) 76 | if all_targets[clip_idx, i] > all_targets[clip_idx, j]: 77 | binary_targets.append(0) 78 | else: 79 | binary_targets.append(1) 80 | 81 | reordered_frame_idx.append(clip_idx*max_vl+j) 82 | reordered_frame_idx.append(clip_idx*max_vl+i) 83 | if all_targets[clip_idx, j] > all_targets[clip_idx, i]: 84 | binary_targets.append(0) 85 | else: 86 | binary_targets.append(1) 87 | reordered_frame_idx = torch.tensor(reordered_frame_idx, dtype=torch.long) 88 | binary_targets = torch.tensor(binary_targets, dtype=torch.long) 89 | batch["shuffled_orders"] = all_orders 90 | batch["targets"] = all_targets 91 | batch['reordered_frame_idx'] = reordered_frame_idx 92 | batch['binary_targets'] = binary_targets 93 | return batch 94 | 95 | 96 | def random_reorder(pos_ids, random_reorder_p=0.15): 97 | """ 98 | random reorder frame positions 99 | """ 100 | selected_pos = [] 101 | target_pos = [] 102 | for i, pos_id in enumerate(pos_ids): 103 | prob = random.random() 104 | # mask token with 15% probability 105 | if prob < random_reorder_p: 106 | selected_pos.append(i) 107 | target_pos.append(pos_id) 108 | target_pos_shuffled = copy.deepcopy(target_pos) 109 | random.shuffle(target_pos_shuffled) 110 | output_order = copy.deepcopy(pos_ids) 111 | output_target = [-1] * len(output_order) 112 | for i, pos in enumerate(selected_pos): 113 | output_order[pos] = target_pos_shuffled[i] 114 | output_target[target_pos_shuffled[i]] = pos 115 | return output_order, output_target 116 | 117 | 118 | class FomEvalDataset(FomDataset): 119 | def __getitem__(self, i): 120 | vid = self.ids[i] 121 | tensors = super().__getitem__(i) 122 | return (vid, *tensors) 123 | 124 | 125 | def fom_eval_collate(inputs): 126 | vids, batch = [], [] 127 | for id_, *tensors in inputs: 128 | vids.append(id_) 129 | batch.append(tensors) 130 | batch = fom_collate(batch) 131 | batch['vids'] = vids 132 | return batch 133 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | A meta data loader for sampling from different datasets / training tasks 6 | 7 | A prefetch loader to speedup data loading 8 | Modified from Nvidia Deep Learning Examples 9 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 10 | """ 11 | import random 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | 16 | from utils.distributed import any_broadcast 17 | 18 | 19 | class MetaLoader(object): 20 | """ wraps multiple data loader """ 21 | def __init__(self, loaders, accum_steps=1, distributed=False): 22 | assert isinstance(loaders, dict) 23 | self.name2loader = {} 24 | self.name2iter = {} 25 | self.sampling_pools = [] 26 | for n, l in loaders.items(): 27 | if isinstance(l, tuple): 28 | l, r = l 29 | elif isinstance(l, DataLoader): 30 | r = 1 31 | else: 32 | raise ValueError() 33 | self.name2loader[n] = l 34 | self.name2iter[n] = iter(l) 35 | self.sampling_pools.extend([n]*r) 36 | 37 | self.accum_steps = accum_steps 38 | self.distributed = distributed 39 | self.step = 0 40 | 41 | def __iter__(self): 42 | """ this iterator will run indefinitely """ 43 | task = self.sampling_pools[0] 44 | while True: 45 | if self.step % self.accum_steps == 0: 46 | task = random.choice(self.sampling_pools) 47 | if self.distributed: 48 | # make sure all process is training same task 49 | task = any_broadcast(task, 0) 50 | self.step += 1 51 | iter_ = self.name2iter[task] 52 | try: 53 | batch = next(iter_) 54 | except StopIteration: 55 | iter_ = iter(self.name2loader[task]) 56 | batch = next(iter_) 57 | self.name2iter[task] = iter_ 58 | 59 | yield task, batch 60 | 61 | 62 | def move_to_cuda(batch): 63 | if isinstance(batch, torch.Tensor): 64 | return batch.cuda(non_blocking=True) 65 | elif isinstance(batch, list): 66 | new_batch = [move_to_cuda(t) for t in batch] 67 | elif isinstance(batch, tuple): 68 | new_batch = tuple(move_to_cuda(t) for t in batch) 69 | elif isinstance(batch, dict): 70 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 71 | else: 72 | return batch 73 | return new_batch 74 | 75 | 76 | def record_cuda_stream(batch): 77 | if isinstance(batch, torch.Tensor): 78 | batch.record_stream(torch.cuda.current_stream()) 79 | elif isinstance(batch, list) or isinstance(batch, tuple): 80 | for t in batch: 81 | record_cuda_stream(t) 82 | elif isinstance(batch, dict): 83 | for t in batch.values(): 84 | record_cuda_stream(t) 85 | else: 86 | pass 87 | 88 | 89 | class PrefetchLoader(object): 90 | """ 91 | overlap compute and cuda data transfer 92 | (copied and then modified from nvidia apex) 93 | """ 94 | def __init__(self, loader): 95 | self.loader = loader 96 | self.stream = torch.cuda.Stream() 97 | 98 | def __iter__(self): 99 | loader_it = iter(self.loader) 100 | self.preload(loader_it) 101 | batch = self.next(loader_it) 102 | while batch is not None: 103 | yield batch 104 | batch = self.next(loader_it) 105 | 106 | def __len__(self): 107 | return len(self.loader) 108 | 109 | def preload(self, it): 110 | try: 111 | self.batch = next(it) 112 | except StopIteration: 113 | self.batch = None 114 | return 115 | # if record_stream() doesn't work, another option is to make sure 116 | # device inputs are created on the main stream. 117 | # self.next_input_gpu = torch.empty_like(self.next_input, 118 | # device='cuda') 119 | # self.next_target_gpu = torch.empty_like(self.next_target, 120 | # device='cuda') 121 | # Need to make sure the memory allocated for next_* is not still in use 122 | # by the main stream at the time we start copying to next_*: 123 | # self.stream.wait_stream(torch.cuda.current_stream()) 124 | with torch.cuda.stream(self.stream): 125 | self.batch = move_to_cuda(self.batch) 126 | # more code for the alternative if record_stream() doesn't work: 127 | # copy_ will record the use of the pinned source tensor in this 128 | # side stream. 129 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 130 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 131 | # self.next_input = self.next_input_gpu 132 | # self.next_target = self.next_target_gpu 133 | 134 | def next(self, it): 135 | torch.cuda.current_stream().wait_stream(self.stream) 136 | batch = self.batch 137 | if batch is not None: 138 | record_cuda_stream(batch) 139 | self.preload(it) 140 | return batch 141 | 142 | def __getattr__(self, name): 143 | method = self.loader.__getattribute__(name) 144 | return method 145 | -------------------------------------------------------------------------------- /data/mfm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Pretrain MFM dataset 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import Dataset 12 | from toolz.sandbox import unzip 13 | from cytoolz import concat 14 | import horovod.torch as hvd 15 | 16 | from .data import VideoFeatSubTokDataset, video_collate, _check_ngpu 17 | 18 | 19 | def _get_img_mask(mask_prob, num_frame): 20 | img_mask = [random.random() < mask_prob for _ in range(num_frame)] 21 | if not any(img_mask): 22 | # at least mask 1 23 | img_mask[random.choice(range(num_frame))] = True 24 | img_mask = torch.tensor(img_mask) 25 | return img_mask 26 | 27 | 28 | def _get_feat_target(img_feat, img_masks): 29 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) 30 | feat_dim = img_feat.size(-1) 31 | feat_targets = img_feat[img_masks_ext].contiguous().view( 32 | -1, feat_dim) # (s, d) 33 | return feat_targets 34 | 35 | 36 | def _mask_img_feat(img_feat, img_masks): 37 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) 38 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) 39 | return img_feat_masked 40 | 41 | 42 | class MfmDataset(Dataset): 43 | def __init__(self, video_ids, vid_sub_db, mask_prob=0.15): 44 | assert isinstance(vid_sub_db, VideoFeatSubTokDataset) 45 | self.mask_prob = mask_prob 46 | self.vid_sub_db = vid_sub_db 47 | if _check_ngpu() > 1: 48 | self.ids = video_ids[hvd.rank()::hvd.size()] 49 | else: 50 | self.ids = video_ids 51 | 52 | def __len__(self): 53 | return len(self.ids) 54 | 55 | def __getitem__(self, i): 56 | vid = self.ids[i] 57 | (all_input_ids, f_v_feats, f_attn_masks, 58 | c_v_feats, c_attn_masks, 59 | num_subs, sub2frames) = self.vid_sub_db[vid] 60 | 61 | c_frame_mask = _get_img_mask(self.mask_prob, c_v_feats.size(0)) 62 | frame_masks = [] 63 | for i, frames in sub2frames: 64 | if len(frames): 65 | frame_masks.append( 66 | c_frame_mask.index_select(0, torch.tensor(frames))) 67 | else: 68 | frame_masks.append(torch.zeros(1, dtype=torch.bool)) 69 | c_pos_ids = torch.tensor(range(len(c_v_feats)), dtype=torch.long) 70 | c_frame_mask = c_frame_mask.index_select(0, c_pos_ids) 71 | return ((all_input_ids, f_v_feats, f_attn_masks, 72 | c_v_feats, c_attn_masks, 73 | num_subs, sub2frames), 74 | frame_masks, c_frame_mask) 75 | 76 | 77 | def mfm_collate(inputs): 78 | video_inputs, all_frame_masks, c_frame_masks = map(list, unzip(inputs)) 79 | batch = video_collate(video_inputs) 80 | 81 | # mask features 82 | frame_masks = pad_sequence(list(concat(all_frame_masks)), 83 | batch_first=True, padding_value=0) 84 | c_frame_masks = pad_sequence(c_frame_masks, 85 | batch_first=True, padding_value=0) 86 | f_v_feats = batch['f_v_feats'] 87 | f_v_feats = _mask_img_feat(f_v_feats, frame_masks) 88 | c_v_feats = batch['c_v_feats'] 89 | feat_targets = _get_feat_target(c_v_feats, c_frame_masks) 90 | c_v_feats = _mask_img_feat(c_v_feats, c_frame_masks) 91 | 92 | batch['f_v_feats'] = f_v_feats 93 | batch['f_v_masks'] = frame_masks 94 | batch['c_v_feats'] = c_v_feats 95 | batch['c_v_masks'] = c_frame_masks 96 | batch['feat_targets'] = feat_targets 97 | return batch 98 | -------------------------------------------------------------------------------- /data/mlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Pretrain MLM dataset 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import Dataset 12 | from toolz.sandbox import unzip 13 | from cytoolz import concat 14 | import horovod.torch as hvd 15 | import copy 16 | 17 | from .data import (VideoFeatSubTokDataset, 18 | pad_tensors, get_gather_index, _check_ngpu) 19 | 20 | 21 | def random_word(tokens, vocab_range, mask, mask_prob=0.15): 22 | """ 23 | Masking some random tokens for Language Model task with probabilities as in 24 | the original BERT paper. 25 | :param tokens: list of int, tokenized sentence. 26 | :param vocab_range: for choosing a random word 27 | :return: (list of int, list of int), masked tokens and related labels for 28 | LM prediction 29 | """ 30 | output_label = [] 31 | 32 | for i, token in enumerate(tokens): 33 | prob = random.random() 34 | # mask token with 15% probability 35 | if prob < mask_prob: 36 | prob /= mask_prob 37 | 38 | # 80% randomly change token to mask token 39 | if prob < 0.8: 40 | tokens[i] = mask 41 | 42 | # 10% randomly change token to random token 43 | elif prob < 0.9: 44 | tokens[i] = random.choice(list(range(*vocab_range))) 45 | 46 | # -> rest 10% randomly keep current token 47 | 48 | # append current token to output (we will predict these later) 49 | output_label.append(token) 50 | else: 51 | # no masking token (will be ignored by loss function later) 52 | output_label.append(-1) 53 | if all(o == -1 for o in output_label): 54 | # at least mask 1 55 | output_label[0] = tokens[0] 56 | tokens[0] = mask 57 | 58 | return tokens, output_label 59 | 60 | 61 | def _get_txt_tgt_mask(txt_mask, n_frame): 62 | z = torch.zeros(n_frame, dtype=torch.bool) 63 | txt_mask_tgt = torch.cat([z, txt_mask], dim=0) 64 | return txt_mask_tgt 65 | 66 | 67 | def create_mlm_io(input_ids, db, mask_prob, cls_tok=True): 68 | input_ids, txt_labels = random_word( 69 | input_ids, db.v_range, db.mask, mask_prob) 70 | if cls_tok: 71 | input_ids = [db.cls_] + input_ids 72 | else: 73 | input_ids = [db.sep] + input_ids 74 | txt_labels = torch.tensor([-1] + txt_labels) 75 | return input_ids, txt_labels 76 | 77 | 78 | class VideoMlmDataset(Dataset): 79 | def __init__(self, video_ids, vid_sub_db, mask_prob=0.15, 80 | sub_ctx_len=0): 81 | assert isinstance(vid_sub_db, VideoFeatSubTokDataset) 82 | self.mask_prob = mask_prob 83 | self.vid_sub_db = vid_sub_db 84 | if _check_ngpu() > 1: 85 | self.ids = video_ids[hvd.rank()::hvd.size()] 86 | else: 87 | self.ids = video_ids 88 | self.sub_ctx_len = sub_ctx_len 89 | 90 | def __len__(self): 91 | return len(self.ids) 92 | 93 | def __getitem__(self, i): 94 | vid = self.ids[i] 95 | example = self.vid_sub_db.txt_db[vid] 96 | v_feat, nframes = self.vid_sub_db._get_v_feat(vid) 97 | sub2frames = self.vid_sub_db.vid_sub2frame[vid] 98 | num_subs = len(sub2frames) 99 | outputs = [] 100 | for sub_idx, matched_frames in sub2frames: 101 | # text input 102 | orig_input_ids = [] 103 | for tmp_sub_idx in range(sub_idx-self.sub_ctx_len, 104 | sub_idx+1): 105 | if tmp_sub_idx >= 0 and tmp_sub_idx < num_subs: 106 | in_ids = example['input_ids'][tmp_sub_idx] 107 | if self.vid_sub_db.max_txt_len != -1: 108 | in_ids = in_ids[:self.vid_sub_db.max_txt_len] 109 | orig_input_ids.extend(copy.deepcopy(in_ids)) 110 | input_ids, txt_labels = create_mlm_io( 111 | orig_input_ids, self.vid_sub_db.txt_db, 112 | self.mask_prob) 113 | 114 | # video input 115 | n_frame = len(matched_frames) 116 | if n_frame: 117 | matched_v_feats = torch.index_select( 118 | v_feat, 0, torch.tensor(matched_frames)) 119 | attn_masks = torch.ones(len(input_ids) + n_frame, 120 | dtype=torch.long) 121 | txt_mask_tgt = _get_txt_tgt_mask(txt_labels != -1, n_frame) 122 | else: 123 | matched_v_feats = torch.zeros(1, v_feat.shape[1]) 124 | attn_masks = torch.ones(len(input_ids) + 1, dtype=torch.long) 125 | attn_masks.data[0].fill_(0) 126 | txt_mask_tgt = _get_txt_tgt_mask(txt_labels != -1, 1) 127 | input_ids = torch.tensor(input_ids) 128 | outputs.append((input_ids, matched_v_feats, attn_masks, 129 | txt_mask_tgt, txt_labels)) 130 | 131 | return outputs 132 | 133 | 134 | def mlm_collate(inputs): 135 | """ 136 | Return: 137 | :input_ids (n, max_L) padded with 0 138 | :position_ids (n, max_L) padded with 0 139 | :img_feat (n, max_num_bb, feat_dim) 140 | :img_pos_feat (n, max_num_bb, 7) 141 | :attn_masks (n, max_{L + num_bb}) padded with 0 142 | :txt_labels (n, max_L) padded with -1 143 | """ 144 | (input_ids, v_feats, attn_masks, txt_masks, txt_labels 145 | ) = map(list, unzip(concat(inputs))) 146 | 147 | # text batches 148 | txt_lens = [i.size(0) for i in input_ids] 149 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 150 | txt_mask_tgt = pad_sequence(txt_masks, batch_first=True, padding_value=0) 151 | txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) 152 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 153 | ).unsqueeze(0) 154 | 155 | # image batches 156 | num_fs = [f.size(0) for f in v_feats] 157 | v_feat = pad_tensors(v_feats, num_fs) 158 | 159 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 160 | 161 | bs, max_vl, _ = v_feat.size() 162 | out_size = attn_masks.size(1) 163 | if max_vl > 0: 164 | gather_index = get_gather_index(txt_lens, num_fs, bs, max_vl, out_size) 165 | else: 166 | gather_index = None 167 | v_feat = None 168 | 169 | batch = {'input_ids': input_ids, 170 | 'position_ids': position_ids, 171 | 'v_feat': v_feat, 172 | 'attn_masks': attn_masks, 173 | 'gather_index': gather_index, 174 | 'txt_mask_tgt': txt_mask_tgt, 175 | 'txt_labels': txt_labels[txt_labels != -1]} 176 | return batch 177 | -------------------------------------------------------------------------------- /data/vcmr_video_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | VCMR video-only dataset 6 | """ 7 | 8 | import torch 9 | import horovod.torch as hvd 10 | from .data import (QueryTokLmdb, get_ids_and_lens, _check_ngpu) 11 | from .vr_video_only import VideoFeatDataset 12 | from .vcmr import VcmrDataset 13 | 14 | 15 | class VcmrVideoOnlyDataset(VcmrDataset): 16 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 17 | sampled_by_q=True): 18 | assert isinstance(query_db, QueryTokLmdb) 19 | assert isinstance(video_db, VideoFeatDataset) 20 | self.video_db = video_db 21 | self.query_db = query_db 22 | self.vid2dur = self.video_db.vid2dur 23 | self.vids = video_ids 24 | self.global_vid2idx = video_db.vid2idx 25 | self.vid2idx = { 26 | vid_name: self.global_vid2idx[vid_name] 27 | for vid_name in video_ids} 28 | self.query_data = query_db.query_data 29 | self.frame_interval = video_db.img_db.frame_interval 30 | self.max_num_query = max_num_query 31 | self.sampled_by_q = sampled_by_q 32 | 33 | if sampled_by_q: 34 | self.lens, self.qids = get_ids_and_lens(query_db) 35 | # FIXME 36 | if _check_ngpu() > 1: 37 | # partition data by rank 38 | self.qids = self.qids[hvd.rank()::hvd.size()] 39 | self.lens = self.lens[hvd.rank()::hvd.size()] 40 | else: 41 | # FIXME 42 | if _check_ngpu() > 1: 43 | # partition data by rank 44 | self.vids = self.vids[hvd.rank()::hvd.size()] 45 | self.lens = [video_db.txt_db.id2len[vid] for vid in self.vids] 46 | 47 | 48 | class VcmrVideoOnlyEvalDataset(VcmrVideoOnlyDataset): 49 | def __getitem__(self, i): 50 | vid, qids = self.getids(i) 51 | outs = super().__getitem__(i) 52 | return qids, outs 53 | 54 | 55 | class VcmrVideoOnlyFullEvalDataset(VcmrVideoOnlyDataset): 56 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 57 | distributed=False): 58 | super().__init__([], video_db, query_db, sampled_by_q=True) 59 | qlens, qids = get_ids_and_lens(query_db) 60 | # this dataset does not support multi GPU 61 | del self.vids 62 | self.vid2idx = { 63 | vid_name: self.global_vid2idx[vid_name] 64 | for vid_name in video_ids} 65 | 66 | # FIXME 67 | if _check_ngpu() > 1 and distributed: 68 | # partition data by rank 69 | self.qids = qids[hvd.rank()::hvd.size()] 70 | self.lens = qlens[hvd.rank()::hvd.size()] 71 | else: 72 | self.qids = qids 73 | self.lens = qlens 74 | 75 | def __len__(self): 76 | return len(self.qids) 77 | 78 | def getids(self, i): 79 | qid = self.qids[i] 80 | if len(self.query_db.query2video): 81 | vid = self.query_db.query2video[qid] 82 | else: 83 | vid = -1 84 | return vid, [qid] 85 | 86 | def __getitem__(self, i): 87 | vid, qids = self.getids(i) 88 | if vid != -1: 89 | video_inputs = self.video_db.__getitem__(vid) 90 | (frame_level_input_ids, frame_level_v_feats, 91 | frame_level_attn_masks, 92 | clip_level_v_feats, clip_level_attn_masks, num_subs, 93 | sub_idx2frame_idx) = video_inputs 94 | nframes = len(clip_level_v_feats) 95 | query_and_targets = [] 96 | for qid in qids: 97 | example = self.query_db[qid] 98 | if example['target'] is not None: 99 | st_idx, ed_idx = self.get_st_ed_label( 100 | example['target'], max_idx=nframes-1) 101 | target = torch.LongTensor( 102 | [st_idx, ed_idx]) 103 | else: 104 | target = torch.LongTensor([-1, -1]) 105 | query_input_ids = example["input_ids"] 106 | query_input_ids = torch.tensor( 107 | [self.query_db.cls_] + query_input_ids) 108 | 109 | query_attn_mask = torch.tensor([1]*len(query_input_ids)) 110 | 111 | query_and_targets.append( 112 | (query_input_ids, query_attn_mask, vid, target)) 113 | return (qid, query_and_targets) 114 | -------------------------------------------------------------------------------- /data/videoQA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Video QA dataset 6 | """ 7 | import random 8 | 9 | from torch.utils.data import Dataset 10 | import torch 11 | from torch.nn.utils.rnn import pad_sequence 12 | from toolz.sandbox import unzip 13 | import horovod.torch as hvd 14 | 15 | from .data import (VideoFeatSubTokDataset, QaQueryTokLmdb, 16 | get_ids_and_lens, video_collate, _check_ngpu, 17 | txt_input_collate) 18 | import math 19 | 20 | 21 | class VideoQaDataset(Dataset): 22 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 23 | sampled_by_q=True): 24 | assert isinstance(query_db, QaQueryTokLmdb) 25 | assert isinstance(video_db, VideoFeatSubTokDataset) 26 | self.video_db = video_db 27 | self.query_db = query_db 28 | self.vid2dur = self.video_db.vid2dur 29 | self.vid2idx = self.video_db.vid2idx 30 | self.max_clip_len = video_db.txt_db.max_clip_len 31 | self.frame_interval = video_db.img_db.frame_interval 32 | self.max_num_query = max_num_query 33 | self.sampled_by_q = sampled_by_q 34 | self.vids = video_ids 35 | 36 | if sampled_by_q: 37 | self.lens, self.qids = get_ids_and_lens(query_db) 38 | # FIXME 39 | if _check_ngpu() > 1: 40 | # partition data by rank 41 | self.qids = self.qids[hvd.rank()::hvd.size()] 42 | self.lens = self.lens[hvd.rank()::hvd.size()] 43 | else: 44 | # FIXME 45 | if _check_ngpu() > 1: 46 | # partition data by rank 47 | self.vids = self.vids[hvd.rank()::hvd.size()] 48 | self.lens = [video_db.txt_db.id2len[vid] for vid in self.vids] 49 | 50 | def getids(self, i): 51 | if not self.sampled_by_q: 52 | vid = self.vids[i] 53 | # TVR video loss assumes fix number of queries 54 | qids = self.query_db.video2query[vid][:self.max_num_query] 55 | if len(qids) < self.max_num_query: 56 | qids += random.sample(qids, self.max_num_query - len(qids)) 57 | else: 58 | qids = [self.qids[i]] 59 | vid = self.query_db.query2video[qids[0]] 60 | return vid, qids 61 | 62 | def __getitem__(self, i): 63 | vid, qids = self.getids(i) 64 | video_inputs = self.video_db.__getitem__(vid) 65 | (frame_level_input_ids, frame_level_v_feats, 66 | frame_level_attn_masks, 67 | clip_level_v_feats, clip_level_attn_masks, num_subs, 68 | sub_idx2frame_idx) = video_inputs 69 | nframes = len(clip_level_v_feats) 70 | 71 | all_vids = [] 72 | all_targets = [] 73 | all_ts_targets = [] 74 | all_qa_input_ids = [] 75 | all_qa_attn_masks = [] 76 | all_video_qa_inputs = [] 77 | for qid in qids: 78 | example = self.query_db[qid] 79 | if example['target'] is not None: 80 | target = torch.LongTensor([example['target']]) 81 | else: 82 | target = torch.LongTensor([-1]) 83 | if example['ts'] is not None: 84 | st_idx, ed_idx = self.get_st_ed_label( 85 | example['ts'], max_idx=nframes-1) 86 | ts_target = torch.LongTensor( 87 | [st_idx, ed_idx]) 88 | else: 89 | ts_target = torch.LongTensor([-1, -1]) 90 | 91 | input_ids = example["input_ids"] 92 | q_input_ids = input_ids[0] 93 | for a_input_ids in input_ids[1:]: 94 | f_sub_qa_input_ids = [] 95 | f_sub_qa_attn_masks = [] 96 | curr_qa_input_id = torch.tensor( 97 | [self.query_db.sep] + q_input_ids + [ 98 | self.query_db.sep] + a_input_ids) 99 | curr_qa_attn_masks = torch.tensor([1]*len(curr_qa_input_id)) 100 | all_qa_input_ids.append(curr_qa_input_id) 101 | all_qa_attn_masks.append(curr_qa_attn_masks) 102 | for f_sub_input_ids, f_attn_masks in zip( 103 | frame_level_input_ids, frame_level_attn_masks): 104 | curr_f_sub_qa_input_ids = torch.cat(( 105 | f_sub_input_ids, curr_qa_input_id)) 106 | curr_f_sub_qa_attn_masks = torch.cat(( 107 | f_attn_masks, curr_qa_attn_masks)) 108 | f_sub_qa_input_ids.append(curr_f_sub_qa_input_ids) 109 | f_sub_qa_attn_masks.append(curr_f_sub_qa_attn_masks) 110 | curr_video_qa_inputs = ( 111 | f_sub_qa_input_ids, frame_level_v_feats, 112 | f_sub_qa_attn_masks, 113 | clip_level_v_feats, clip_level_attn_masks, num_subs, 114 | sub_idx2frame_idx) 115 | all_video_qa_inputs.append(curr_video_qa_inputs) 116 | all_vids.append(vid) 117 | all_targets.append(target) 118 | all_ts_targets.append(ts_target) 119 | out = (all_video_qa_inputs, all_qa_input_ids, all_qa_attn_masks, 120 | all_vids, all_targets, all_ts_targets) 121 | return out 122 | 123 | def __len__(self): 124 | if self.sampled_by_q: 125 | return len(self.qids) 126 | return len(self.vids) 127 | 128 | def get_st_ed_label(self, ts, max_idx): 129 | """ 130 | Args: 131 | ts: [st (float), ed (float)] in seconds, ed > st 132 | max_idx: length of the video 133 | 134 | Returns: 135 | [st_idx, ed_idx]: int, 136 | 137 | Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6, 138 | clips should be indexed as [2: 6), 139 | the translated back ts should be [3:9]. 140 | # TODO which one is better, [2: 5] or [2: 6) 141 | """ 142 | try: 143 | ts = ts.split("-") 144 | st = float(ts[0]) 145 | ed = float(ts[1]) 146 | st_idx = min(math.floor(st/self.frame_interval), max_idx) 147 | ed_idx = min(max(math.ceil(ed/self.frame_interval)-1, 148 | st_idx+1), max_idx) 149 | except Exception: 150 | st_idx, ed_idx = -1, -1 151 | 152 | return st_idx, ed_idx 153 | 154 | 155 | def video_qa_collate(inputs): 156 | (video_qa_inputs, qa_input_ids, qa_attn_masks, 157 | vids, target, ts_target) = map( 158 | list, unzip(inputs)) 159 | all_video_qa_inputs = [] 160 | all_target, all_ts_target = [], [] 161 | all_qa_input_ids, all_qa_attn_masks = [], [] 162 | for i in range(len(video_qa_inputs)): 163 | all_video_qa_inputs.extend(video_qa_inputs[i]) 164 | all_qa_input_ids.extend(qa_input_ids[i]) 165 | all_qa_attn_masks.extend(qa_attn_masks[i]) 166 | for j in range(len(vids)): 167 | all_target.extend(target[j]) 168 | all_ts_target.extend(ts_target[j]) 169 | batch = video_collate(all_video_qa_inputs) 170 | 171 | targets = pad_sequence( 172 | all_target, batch_first=True, padding_value=-1) 173 | ts_targets = pad_sequence( 174 | all_ts_target, batch_first=True, padding_value=-1) 175 | input_ids, pos_ids, attn_masks =\ 176 | txt_input_collate(all_qa_input_ids, all_qa_attn_masks) 177 | batch["targets"] = targets 178 | batch["ts_targets"] = ts_targets 179 | batch['qa_input_ids'] = input_ids 180 | batch['qa_pos_ids'] = pos_ids 181 | batch['qa_attn_masks'] = attn_masks 182 | return batch 183 | 184 | 185 | class VideoQaEvalDataset(VideoQaDataset): 186 | def __getitem__(self, i): 187 | vid, qids = self.getids(i) 188 | outs = super().__getitem__(i) 189 | return qids, outs 190 | 191 | 192 | def video_qa_eval_collate(inputs): 193 | qids, batch = [], [] 194 | for id_, tensors in inputs: 195 | qids.extend(id_) 196 | batch.append(tensors) 197 | batch = video_qa_collate(batch) 198 | batch['qids'] = qids 199 | return batch 200 | -------------------------------------------------------------------------------- /data/violin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Violin dataset 6 | """ 7 | import random 8 | 9 | from torch.utils.data import Dataset 10 | import torch 11 | from torch.nn.utils.rnn import pad_sequence 12 | from toolz.sandbox import unzip 13 | import horovod.torch as hvd 14 | 15 | from .data import (VideoFeatSubTokDataset, QaQueryTokLmdb, 16 | get_ids_and_lens, video_collate, _check_ngpu, 17 | txt_input_collate) 18 | 19 | 20 | def get_paired_statement_id(qid): 21 | parsed_qid = qid.split("-") 22 | label = int(parsed_qid[-1]) 23 | paired_qid = "-".join(parsed_qid[:-1]+[str(1 - label)]) 24 | return paired_qid 25 | 26 | 27 | class ViolinDataset(Dataset): 28 | def __init__(self, video_ids, video_db, query_db, max_num_query=6, 29 | sampled_by_q=True): 30 | assert isinstance(query_db, QaQueryTokLmdb) 31 | assert isinstance(video_db, VideoFeatSubTokDataset) 32 | self.video_db = video_db 33 | self.query_db = query_db 34 | self.vid2dur = self.video_db.vid2dur 35 | self.vid2idx = self.video_db.vid2idx 36 | self.max_clip_len = video_db.txt_db.max_clip_len 37 | self.frame_interval = video_db.img_db.frame_interval 38 | self.max_num_query = max_num_query 39 | self.sampled_by_q = sampled_by_q 40 | self.vids = video_ids 41 | if sampled_by_q: 42 | self.lens, self.qids = get_ids_and_lens(query_db) 43 | # FIXME 44 | if _check_ngpu() > 1: 45 | # partition data by rank 46 | self.qids = self.qids[hvd.rank()::hvd.size()] 47 | self.lens = self.lens[hvd.rank()::hvd.size()] 48 | else: 49 | # FIXME 50 | if _check_ngpu() > 1: 51 | # partition data by rank 52 | self.vids = self.vids[hvd.rank()::hvd.size()] 53 | self.lens = [video_db.txt_db.id2len[vid] for vid in self.vids] 54 | 55 | def getids(self, i): 56 | if not self.sampled_by_q: 57 | vid = self.vids[i] 58 | qids = self.query_db.video2query[vid][:self.max_num_query] 59 | if len(qids) < self.max_num_query: 60 | qids += random.sample(qids, self.max_num_query - len(qids)) 61 | else: 62 | qids = [self.qids[i], get_paired_statement_id(self.qids[i])] 63 | vid = self.query_db.query2video[qids[0]] 64 | return vid, qids 65 | 66 | def __getitem__(self, i): 67 | vid, qids = self.getids(i) 68 | video_inputs = self.video_db.__getitem__(vid) 69 | (frame_level_input_ids, frame_level_v_feats, 70 | frame_level_attn_masks, 71 | clip_level_v_feats, clip_level_attn_masks, num_subs, 72 | sub_idx2frame_idx) = video_inputs 73 | 74 | all_vids = [] 75 | all_targets = [] 76 | all_q_input_ids = [] 77 | all_q_attn_masks = [] 78 | all_video_q_inputs = [] 79 | for qid in qids: 80 | example = self.query_db[qid] 81 | if example['target']: 82 | target = torch.LongTensor([1]) 83 | else: 84 | target = torch.LongTensor([0]) 85 | 86 | curr_q_input_ids = torch.tensor( 87 | [self.query_db.sep] + example["input_ids"]) 88 | curr_q_attn_masks = torch.tensor([1]*len(curr_q_input_ids)) 89 | all_q_input_ids.append(curr_q_input_ids) 90 | all_q_attn_masks.append(curr_q_attn_masks) 91 | f_sub_q_input_ids, f_sub_q_attn_masks = [], [] 92 | for f_sub_input_ids, f_attn_masks in zip( 93 | frame_level_input_ids, frame_level_attn_masks): 94 | curr_f_sub_q_input_ids = torch.cat(( 95 | f_sub_input_ids, curr_q_input_ids)) 96 | curr_f_sub_q_attn_masks = torch.cat(( 97 | f_attn_masks, curr_q_attn_masks)) 98 | f_sub_q_input_ids.append(curr_f_sub_q_input_ids) 99 | f_sub_q_attn_masks.append(curr_f_sub_q_attn_masks) 100 | curr_video_q_inputs = ( 101 | f_sub_q_input_ids, frame_level_v_feats, 102 | f_sub_q_attn_masks, 103 | clip_level_v_feats, clip_level_attn_masks, num_subs, 104 | sub_idx2frame_idx) 105 | all_video_q_inputs.append(curr_video_q_inputs) 106 | all_vids.append(vid) 107 | all_targets.append(target) 108 | out = (all_video_q_inputs, all_q_input_ids, all_q_attn_masks, 109 | all_vids, all_targets) 110 | return out 111 | 112 | def __len__(self): 113 | if self.sampled_by_q: 114 | return len(self.qids) 115 | return len(self.vids) 116 | 117 | 118 | def violin_collate(inputs): 119 | (video_q_inputs, q_input_ids, q_attn_masks, 120 | vids, target) = map( 121 | list, unzip(inputs)) 122 | all_video_qa_inputs = [] 123 | all_target = [] 124 | all_q_input_ids, all_q_attn_masks = [], [] 125 | for i in range(len(video_q_inputs)): 126 | all_video_qa_inputs.extend(video_q_inputs[i]) 127 | all_q_input_ids.extend(q_input_ids[i]) 128 | all_q_attn_masks.extend(q_attn_masks[i]) 129 | for j in range(len(vids)): 130 | all_target.extend(target[j]) 131 | batch = video_collate(all_video_qa_inputs) 132 | 133 | targets = pad_sequence( 134 | all_target, batch_first=True, padding_value=-1) 135 | input_ids, pos_ids, attn_masks =\ 136 | txt_input_collate(all_q_input_ids, all_q_attn_masks) 137 | batch["targets"] = targets 138 | batch['q_input_ids'] = input_ids 139 | batch['q_pos_ids'] = pos_ids 140 | batch['q_attn_masks'] = attn_masks 141 | return batch 142 | 143 | 144 | class ViolinEvalDataset(ViolinDataset): 145 | def getids(self, i): 146 | if not self.sampled_by_q: 147 | vid = self.vids[i] 148 | # TVR video loss assumes fix number of queries 149 | qids = self.query_db.video2query[vid][:self.max_num_query] 150 | if len(qids) < self.max_num_query: 151 | qids += random.sample(qids, self.max_num_query - len(qids)) 152 | else: 153 | qids = [self.qids[i]] 154 | vid = self.query_db.query2video[qids[0]] 155 | return vid, qids 156 | 157 | def __getitem__(self, i): 158 | vid, qids = self.getids(i) 159 | outs = super().__getitem__(i) 160 | return qids, outs 161 | 162 | 163 | def violin_eval_collate(inputs): 164 | qids, batch = [], [] 165 | for id_, tensors in inputs: 166 | qids.extend(id_) 167 | batch.append(tensors) 168 | batch = violin_collate(batch) 169 | batch['qids'] = qids 170 | return batch 171 | -------------------------------------------------------------------------------- /data/vr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | VR dataset 6 | """ 7 | import torch 8 | import horovod.torch as hvd 9 | from utils.basic_utils import load_jsonl 10 | import os 11 | import json 12 | from .data import (VideoFeatSubTokDataset, TxtTokLmdb, SubTokLmdb, 13 | get_ids_and_lens, _check_ngpu) 14 | from .vcmr import VcmrDataset, vcmr_collate, vcmr_full_eval_collate 15 | 16 | 17 | class VrSubTokLmdb(SubTokLmdb): 18 | def __init__(self, db_dir, max_clip_len=-1): 19 | super().__init__(db_dir, max_clip_len=-1) 20 | self.max_clip_len = max_clip_len 21 | self.vid2max_len = json.load( 22 | open(f'{db_dir}/vid2max_frame_sub_len.json')) 23 | self.id2len = json.load( 24 | open(f'{db_dir}/vid2len.json')) 25 | self.vid2dur, self.vid2idx = {}, {} 26 | 27 | 28 | class VrQueryTokLmdb(TxtTokLmdb): 29 | def __init__(self, db_dir, max_txt_len=-1): 30 | super().__init__(db_dir, max_txt_len) 31 | if os.path.exists(f'{self.db_dir}/query2video.json'): 32 | self.query2video = json.load( 33 | open(f'{self.db_dir}/query2video.json')) 34 | self.video2query = {} 35 | for k, v in self.query2video.items(): 36 | if v not in self.video2query: 37 | self.video2query[v] = [k] 38 | else: 39 | self.video2query[v].append(k) 40 | else: 41 | self.query2video = {} 42 | self.video2query = {} 43 | self.query_data_f = load_jsonl(f'{self.db_dir}/query_data.jsonl') 44 | 45 | def __getitem__(self, id_): 46 | txt_dump = self.db[id_] 47 | return txt_dump 48 | 49 | 50 | class MsrvttQueryTokLmdb(VrQueryTokLmdb): 51 | @property 52 | def query_data(self): 53 | try: 54 | data = { 55 | str(item["sen_id"]): item 56 | for item in self.query_data_f} 57 | except Exception: 58 | data = { 59 | str(item["retrieval_key"]): item 60 | for item in self.query_data_f} 61 | return data 62 | 63 | 64 | class VrDataset(VcmrDataset): 65 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 66 | sampled_by_q=True): 67 | assert isinstance(query_db, VrQueryTokLmdb) 68 | assert isinstance(video_db, VideoFeatSubTokDataset) 69 | self.video_db = video_db 70 | self.query_db = query_db 71 | self.vid2dur = self.video_db.img_db.name2nframe 72 | self.query_data = query_db.query_data 73 | self.max_clip_len = video_db.txt_db.max_clip_len 74 | self.frame_interval = video_db.img_db.frame_interval 75 | self.max_num_query = max_num_query 76 | self.sampled_by_q = sampled_by_q 77 | self.vids = video_ids 78 | self.global_vid2idx = { 79 | vid_name: idx for idx, vid_name in 80 | enumerate(sorted(list(self.vid2dur.keys())))} 81 | self.vid2idx = { 82 | vid_name: self.global_vid2idx[vid_name] 83 | for vid_name in video_ids} 84 | if sampled_by_q: 85 | self.lens, self.qids = get_ids_and_lens(query_db) 86 | # FIXME 87 | if _check_ngpu() > 1: 88 | # partition data by rank 89 | self.qids = self.qids[hvd.rank()::hvd.size()] 90 | self.lens = self.lens[hvd.rank()::hvd.size()] 91 | else: 92 | # FIXME 93 | if _check_ngpu() > 1: 94 | # partition data by rank 95 | self.vids = self.vids[hvd.rank()::hvd.size()] 96 | self.lens = [video_db.vid2dur[vid] for vid in self.vids] 97 | 98 | def __getitem__(self, i): 99 | vid, qids = self.getids(i) 100 | 101 | video_inputs = self.video_db.__getitem__(vid) 102 | (frame_level_input_ids, frame_level_v_feats, 103 | frame_level_attn_masks, 104 | clip_level_v_feats, clip_level_attn_masks, num_subs, 105 | sub_idx2frame_idx) = video_inputs 106 | 107 | query_and_targets = [] 108 | for qid in qids: 109 | example = self.query_db[qid] 110 | target = torch.LongTensor([-1, -1]) 111 | query_input_ids = example["input_ids"] 112 | query_input_ids = torch.tensor( 113 | [self.query_db.cls_] + query_input_ids) 114 | 115 | query_attn_mask = torch.tensor([1]*len(query_input_ids)) 116 | 117 | query_and_targets.append( 118 | (query_input_ids, query_attn_mask, vid, target)) 119 | 120 | return (video_inputs, vid, tuple(query_and_targets)) 121 | 122 | 123 | def vr_collate(inputs): 124 | return vcmr_collate(inputs) 125 | 126 | 127 | class VrEvalDataset(VrDataset): 128 | def __getitem__(self, i): 129 | vid, qids = self.getids(i) 130 | outs = super().__getitem__(i) 131 | return qids, outs 132 | 133 | 134 | def vr_eval_collate(inputs): 135 | qids, batch = [], [] 136 | for id_, tensors in inputs: 137 | qids.extend(id_) 138 | batch.append(tensors) 139 | batch = vr_collate(batch) 140 | batch['qids'] = qids 141 | return batch 142 | 143 | 144 | class VrFullEvalDataset(VrDataset): 145 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 146 | distributed=False): 147 | super().__init__(video_ids, video_db, query_db, sampled_by_q=True) 148 | qlens, qids = get_ids_and_lens(query_db) 149 | # this dataset does not support multi GPU 150 | del self.vids 151 | self.vid2idx = { 152 | vid_name: self.global_vid2idx[vid_name] 153 | for vid_name in video_ids} 154 | 155 | # FIXME 156 | if _check_ngpu() > 1 and distributed: 157 | # partition data by rank 158 | self.qids = qids[hvd.rank()::hvd.size()] 159 | self.lens = qlens[hvd.rank()::hvd.size()] 160 | else: 161 | self.qids = qids 162 | self.lens = qlens 163 | 164 | def __len__(self): 165 | return len(self.qids) 166 | 167 | def getids(self, i): 168 | qid = self.qids[i] 169 | if len(self.query_db.query2video): 170 | vid = self.query_db.query2video[qid] 171 | else: 172 | vid = -1 173 | return vid, [qid] 174 | 175 | def __getitem__(self, i): 176 | vid, qids = self.getids(i) 177 | if vid != -1: 178 | video_inputs = self.video_db.__getitem__(vid) 179 | (frame_level_input_ids, frame_level_v_feats, 180 | frame_level_attn_masks, 181 | clip_level_v_feats, clip_level_attn_masks, num_subs, 182 | sub_idx2frame_idx) = video_inputs 183 | query_and_targets = [] 184 | for qid in qids: 185 | example = self.query_db[qid] 186 | target = torch.LongTensor([-1, -1]) 187 | query_input_ids = example["input_ids"] 188 | 189 | query_input_ids = torch.tensor( 190 | [self.query_db.cls_] + query_input_ids) 191 | 192 | query_attn_mask = torch.tensor([1]*len(query_input_ids)) 193 | 194 | query_and_targets.append( 195 | (query_input_ids, query_attn_mask, vid, target)) 196 | return (qid, query_and_targets) 197 | 198 | 199 | def vr_full_eval_collate(inputs): 200 | return vcmr_full_eval_collate(inputs) 201 | -------------------------------------------------------------------------------- /data/vr_video_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | VR video-only dataset 6 | """ 7 | from torch.utils.data import Dataset 8 | import torch 9 | import horovod.torch as hvd 10 | from .data import (VideoFeatLmdb, 11 | get_ids_and_lens, _check_ngpu) 12 | from .vr import VrQueryTokLmdb, VrDataset 13 | 14 | 15 | class VideoFeatDataset(Dataset): 16 | def __init__(self, meta, img_db): 17 | assert isinstance(img_db, VideoFeatLmdb) 18 | self.img_db = img_db 19 | self.max_clip_len = self.img_db.max_clip_len 20 | self.vid2dur = self.img_db.name2nframe 21 | self.vids = sorted(list(self.vid2dur.keys())) 22 | self.vid2idx = {vid: idx for idx, vid in enumerate(self.vids)} 23 | self.cls_ = meta['CLS'] 24 | self.sep = meta['SEP'] 25 | 26 | def __len__(self): 27 | return len(self.vids) 28 | 29 | def __getitem__(self, vid_): 30 | v_feat, nframes = self._get_v_feat(vid_) 31 | num_subs = 1 # fake an empty sub 32 | sub2frames = [(0, list(range(len(v_feat))))] 33 | frame_level_input_ids, frame_level_v_feats = ( 34 | [torch.tensor([self.cls_])], 35 | [v_feat]) 36 | frame_level_attn_masks = [ 37 | torch.tensor([1] * (1+len(v_feat)))] # [(fffwww)] 38 | 39 | clip_level_v_feats = v_feat 40 | clip_level_attn_masks = [1] * len(clip_level_v_feats) 41 | clip_level_attn_masks = torch.tensor(clip_level_attn_masks) 42 | 43 | out = (frame_level_input_ids, # num_subs list[tensor(sep,w0,w1,...)] 44 | frame_level_v_feats, # num_subs list[tensor(#sub_frames, d)] 45 | frame_level_attn_masks, # num_subs list[L_sub + #sub_frames] 46 | clip_level_v_feats, # tensor(num_frames, d) 47 | clip_level_attn_masks, # #frames list[1] 48 | num_subs, sub2frames) # num_subs, [(sub_ix, [frame_ix]) ] 49 | return out 50 | 51 | def _get_v_feat(self, fname): 52 | v_feat = self.img_db[fname] 53 | nframes = v_feat.size(0) 54 | return v_feat, nframes 55 | 56 | 57 | class VrVideoOnlyDataset(VrDataset): 58 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 59 | sampled_by_q=True): 60 | assert isinstance(query_db, VrQueryTokLmdb) 61 | assert isinstance(video_db, VideoFeatDataset) 62 | self.video_db = video_db 63 | self.query_db = query_db 64 | self.vid2dur = self.video_db.vid2dur 65 | self.query_data = query_db.query_data 66 | self.max_clip_len = video_db.max_clip_len 67 | self.frame_interval = video_db.img_db.frame_interval 68 | self.max_num_query = max_num_query 69 | self.sampled_by_q = sampled_by_q 70 | self.vids = video_ids 71 | self.global_vid2idx = video_db.vid2idx 72 | self.vid2idx = { 73 | vid_name: self.global_vid2idx[vid_name] 74 | for vid_name in video_ids} 75 | if sampled_by_q: 76 | self.lens, self.qids = get_ids_and_lens(query_db) 77 | # FIXME 78 | if _check_ngpu() > 1: 79 | # partition data by rank 80 | self.qids = self.qids[hvd.rank()::hvd.size()] 81 | self.lens = self.lens[hvd.rank()::hvd.size()] 82 | else: 83 | # FIXME 84 | if _check_ngpu() > 1: 85 | # partition data by rank 86 | self.vids = self.vids[hvd.rank()::hvd.size()] 87 | self.lens = [video_db.vid2dur[vid] for vid in self.vids] 88 | 89 | 90 | class VrVideoOnlyEvalDataset(VrVideoOnlyDataset): 91 | def __getitem__(self, i): 92 | vid, qids = self.getids(i) 93 | outs = super().__getitem__(i) 94 | return qids, outs 95 | 96 | 97 | class VrVideoOnlyFullEvalDataset(VrVideoOnlyDataset): 98 | def __init__(self, video_ids, video_db, query_db, max_num_query=5, 99 | distributed=False): 100 | super().__init__(video_ids, video_db, query_db, sampled_by_q=True) 101 | qlens, qids = get_ids_and_lens(query_db) 102 | # this dataset does not support multi GPU 103 | del self.vids 104 | self.vid2idx = { 105 | vid_name: self.global_vid2idx[vid_name] 106 | for vid_name in video_ids} 107 | 108 | # FIXME 109 | if _check_ngpu() > 1 and distributed: 110 | # partition data by rank 111 | self.qids = qids[hvd.rank()::hvd.size()] 112 | self.lens = qlens[hvd.rank()::hvd.size()] 113 | else: 114 | self.qids = qids 115 | self.lens = qlens 116 | 117 | def __len__(self): 118 | return len(self.qids) 119 | 120 | def getids(self, i): 121 | qid = self.qids[i] 122 | if len(self.query_db.query2video): 123 | vid = self.query_db.query2video[qid] 124 | else: 125 | vid = -1 126 | return vid, [qid] 127 | 128 | def __getitem__(self, i): 129 | vid, qids = self.getids(i) 130 | if vid != -1: 131 | video_inputs = self.video_db.__getitem__(vid) 132 | (frame_level_input_ids, frame_level_v_feats, 133 | frame_level_attn_masks, 134 | clip_level_v_feats, clip_level_attn_masks, num_subs, 135 | sub_idx2frame_idx) = video_inputs 136 | query_and_targets = [] 137 | for qid in qids: 138 | example = self.query_db[qid] 139 | target = torch.LongTensor([-1, -1]) 140 | query_input_ids = example["input_ids"] 141 | 142 | query_input_ids = torch.tensor( 143 | [self.query_db.cls_] + query_input_ids) 144 | 145 | query_attn_mask = torch.tensor([1]*len(query_input_ids)) 146 | 147 | query_and_targets.append( 148 | (query_input_ids, query_attn_mask, vid, target)) 149 | return (qid, query_and_targets) 150 | -------------------------------------------------------------------------------- /data/vsm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Pretrain VSM dataset 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import Dataset 12 | from toolz.sandbox import unzip 13 | from cytoolz import concat 14 | import horovod.torch as hvd 15 | import copy 16 | 17 | from .data import VideoFeatSubTokDataset, _check_ngpu, video_collate 18 | 19 | 20 | class VsmDataset(Dataset): 21 | def __init__(self, video_ids, vid_sub_db, query_per_video=5, 22 | sub_ctx_len=0): 23 | assert isinstance(vid_sub_db, VideoFeatSubTokDataset) 24 | self.query_per_video = query_per_video 25 | self.vid_sub_db = vid_sub_db 26 | if _check_ngpu() > 1: 27 | self.ids = video_ids[hvd.rank()::hvd.size()] 28 | else: 29 | self.ids = video_ids 30 | self.sub_ctx_len = sub_ctx_len 31 | 32 | def __len__(self): 33 | return len(self.ids) 34 | 35 | def __getitem__(self, i): 36 | vid = self.ids[i] 37 | example = self.vid_sub_db.txt_db[vid] 38 | v_feat, nframes = self.vid_sub_db._get_v_feat(vid) 39 | sub2frames = self.vid_sub_db.vid_sub2frame[vid] 40 | 41 | frame_level_input_ids, frame_level_v_feats = [], [] 42 | frame_level_attn_masks = [] 43 | num_subs = len(sub2frames) 44 | 45 | sub_queries_and_targets = [] 46 | matched_sub_idx = [sub_idx for sub_idx, matched_frames in sub2frames 47 | if matched_frames] 48 | n_samples = min(len(matched_sub_idx), self.query_per_video) 49 | query_sub_ids = set(random.sample(matched_sub_idx, n_samples)) 50 | for sub_idx, matched_frames in sub2frames: 51 | # text input 52 | if self.sub_ctx_len >= 0: 53 | curr_sub_ctx_input_ids = [] 54 | for tmp_sub_idx in range(sub_idx-self.sub_ctx_len, 55 | sub_idx+1): 56 | if tmp_sub_idx >= 0 and tmp_sub_idx < num_subs\ 57 | and tmp_sub_idx not in query_sub_ids: 58 | in_ids = example['input_ids'][tmp_sub_idx] 59 | if self.vid_sub_db.max_txt_len != -1: 60 | in_ids = in_ids[:self.vid_sub_db.max_txt_len] 61 | curr_sub_ctx_input_ids.extend(copy.deepcopy(in_ids)) 62 | curr_sub_ctx_input_ids = [ 63 | self.vid_sub_db.txt_db.sep] + curr_sub_ctx_input_ids 64 | 65 | n_frame = len(matched_frames) 66 | attn_masks_fill_0_pos = None 67 | if n_frame: 68 | matched_v_feats = torch.index_select( 69 | v_feat, 0, torch.tensor(matched_frames)) 70 | 71 | if sub_idx in query_sub_ids: 72 | in_ids = example['input_ids'][sub_idx] 73 | if self.vid_sub_db.max_txt_len != -1: 74 | in_ids = in_ids[:self.vid_sub_db.max_txt_len] 75 | sub_quries_input_ids = torch.tensor( 76 | [self.vid_sub_db.txt_db.cls_] + copy.deepcopy(in_ids)) 77 | sub_query_attn_masks = torch.ones( 78 | len(sub_quries_input_ids), dtype=torch.long) 79 | st, ed = matched_frames[0], min(max( 80 | matched_frames[0]+1, matched_frames[-1]), nframes-1) 81 | assert st <= ed, "st frame must <= ed frame" 82 | assert st >= 0, "st frame must >= 0" 83 | assert ed < nframes, f"ed frame must < frame_len {nframes}" 84 | targets = torch.tensor([st, ed], dtype=torch.long) 85 | sub_queries_and_targets.append( 86 | (sub_quries_input_ids, sub_query_attn_masks, 87 | vid, targets)) 88 | if len(curr_sub_ctx_input_ids) == 0: 89 | curr_sub_ctx_input_ids = [self.vid_sub_db.txt_db.mask] 90 | attn_masks_fill_0_pos = -1 91 | attn_masks = torch.ones( 92 | len(curr_sub_ctx_input_ids) + n_frame, 93 | dtype=torch.long) 94 | else: 95 | matched_v_feats = torch.zeros(1, v_feat.shape[1]) 96 | attn_masks = torch.ones( 97 | len(curr_sub_ctx_input_ids) + 1, dtype=torch.long) 98 | attn_masks_fill_0_pos = 0 99 | if attn_masks_fill_0_pos is not None: 100 | attn_masks.data[attn_masks_fill_0_pos].fill_(0) 101 | 102 | frame_level_input_ids.append(torch.tensor(curr_sub_ctx_input_ids)) 103 | frame_level_attn_masks.append(attn_masks) 104 | frame_level_v_feats.append(matched_v_feats) 105 | while len(sub_queries_and_targets) < self.query_per_video: 106 | sub_queries_and_targets.append( 107 | copy.deepcopy(sub_queries_and_targets[-1])) 108 | clip_level_v_feats = v_feat 109 | clip_level_attn_masks = [1] * len(clip_level_v_feats) 110 | clip_level_attn_masks = torch.tensor(clip_level_attn_masks) 111 | video_inputs = (frame_level_input_ids, frame_level_v_feats, 112 | frame_level_attn_masks, 113 | clip_level_v_feats, clip_level_attn_masks, 114 | num_subs, sub2frames) 115 | out = (video_inputs, vid, tuple(sub_queries_and_targets)) 116 | 117 | return out 118 | 119 | 120 | def vsm_collate(inputs): 121 | (video_inputs, vids, sub_queries_and_targets) = map(list, unzip(inputs)) 122 | (input_ids, attn_masks, sub_vids, targets) = map( 123 | list, unzip(concat(outs for outs in sub_queries_and_targets))) 124 | 125 | batch = video_collate(video_inputs) 126 | vid2idx = {vid: i for i, vid in enumerate(vids)} 127 | batch["q_vidx"] = torch.tensor([vid2idx[s_vid] for s_vid in sub_vids], 128 | dtype=torch.long) 129 | 130 | # text batches 131 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 132 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 133 | ).unsqueeze(0) 134 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 135 | 136 | vsm_targets = pad_sequence( 137 | targets, batch_first=True, padding_value=-1) 138 | batch.update({ 139 | 'query_input_ids': input_ids, 140 | 'query_pos_ids': position_ids, 141 | 'query_attn_masks': attn_masks, 142 | 'targets': vsm_targets, 143 | 'vids': vids}) 144 | 145 | return batch 146 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/README.md: -------------------------------------------------------------------------------- 1 | # coco-caption 2 | 3 | Original README can be found at [tylin/coco-caption](https://github.com/tylin/coco-caption/blob/3f0fe9b819c0ea881a56441e4de1146924a394eb/README.md). 4 | 5 | ## License 6 | 7 | All files in the pycocoevalcap directory are under 8 | [BSD 2-clause "Simplified" License](https://github.com/tylin/coco-caption/blob/3f0fe9b819c0ea881a56441e4de1146924a394eb/license.txt) 9 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" 55 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | from __future__ import division 6 | 7 | import atexit 8 | import logging 9 | import os 10 | from os.path import abspath 11 | import re 12 | import subprocess 13 | import sys 14 | import threading 15 | 16 | import psutil 17 | 18 | # binaries are built into docker image 19 | METEOR_JAR = '/workspace/cococap_bin/meteor-1.5.jar' 20 | PARAPHRASE = '/workspace/cococap_bin/paraphrase-en.gz' 21 | 22 | 23 | def enc(s): 24 | return s.encode('utf-8') 25 | 26 | 27 | def dec(s): 28 | return s.decode('utf-8') 29 | 30 | 31 | class Meteor: 32 | 33 | def __init__(self): 34 | # Used to guarantee thread safety 35 | self.lock = threading.Lock() 36 | 37 | mem = '2G' 38 | mem_available_G = psutil.virtual_memory().available / 1E9 39 | if mem_available_G < 2: 40 | logging.warning("There is less than 2GB of available memory.\n" 41 | "Will try with limiting Meteor to 1GB of memory but this might cause issues.\n" 42 | "If you have problems using Meteor, " 43 | "then you can try to lower the `mem` variable in meteor.py") 44 | mem = '1G' 45 | 46 | meteor_cmd = ['java', '-jar', '-Xmx{}'.format(mem), METEOR_JAR, 47 | '-', '-', '-stdio', '-l', 'en', '-norm', 48 | '-a', PARAPHRASE] 49 | env = os.environ.copy() 50 | env['LC_ALL'] = "C" 51 | self.meteor_p = subprocess.Popen(meteor_cmd, 52 | cwd=os.path.dirname(abspath(__file__)), 53 | env=env, 54 | stdin=subprocess.PIPE, 55 | stdout=subprocess.PIPE, 56 | stderr=subprocess.PIPE) 57 | 58 | atexit.register(self.close) 59 | 60 | def close(self): 61 | with self.lock: 62 | if self.meteor_p: 63 | self.meteor_p.kill() 64 | self.meteor_p.wait() 65 | self.meteor_p = None 66 | # if the user calls close() manually, remove the 67 | # reference from atexit so the object can be garbage-collected. 68 | if atexit is not None and atexit.unregister is not None: 69 | atexit.unregister(self.close) 70 | 71 | def compute_score(self, gts, res): 72 | assert (gts.keys() == res.keys()) 73 | imgIds = gts.keys() 74 | scores = [] 75 | 76 | eval_line = 'EVAL' 77 | with self.lock: 78 | for i in imgIds: 79 | assert (len(res[i]) == 1) 80 | stat = self._stat(res[i][0], gts[i]) 81 | eval_line += ' ||| {}'.format(stat) 82 | 83 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 84 | self.meteor_p.stdin.flush() 85 | for i in range(0, len(imgIds)): 86 | v = self.meteor_p.stdout.readline() 87 | try: 88 | scores.append(float(dec(v.strip()))) 89 | except: 90 | sys.stderr.write("Error handling value: {}\n".format(v)) 91 | sys.stderr.write("Decoded value: {}\n".format(dec(v.strip()))) 92 | sys.stderr.write("eval_line: {}\n".format(eval_line)) 93 | # You can try uncommenting the next code line to show stderr from the Meteor JAR. 94 | # If the Meteor JAR is not writing to stderr, then the line will just hang. 95 | # sys.stderr.write("Error from Meteor:\n{}".format(self.meteor_p.stderr.read())) 96 | raise 97 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 98 | 99 | return score, scores 100 | 101 | def method(self): 102 | return "METEOR" 103 | 104 | def _stat(self, hypothesis_str, reference_list): 105 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 106 | hypothesis_str = hypothesis_str.replace('|||', '') 107 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 108 | score_line = re.sub(r'\s+', ' ', score_line) 109 | self.meteor_p.stdin.write(enc(score_line)) 110 | self.meteor_p.stdin.write(enc('\n')) 111 | self.meteor_p.stdin.flush() 112 | return dec(self.meteor_p.stdout.readline()).strip() 113 | 114 | def _score(self, hypothesis_str, reference_list): 115 | with self.lock: 116 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 117 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 118 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 119 | self.meteor_p.stdin.write(enc('{}\n'.format(score_line))) 120 | self.meteor_p.stdin.flush() 121 | stats = dec(self.meteor_p.stdout.readline()).strip() 122 | eval_line = 'EVAL ||| {}'.format(stats) 123 | # EVAL ||| stats 124 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 125 | self.meteor_p.stdin.flush() 126 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 127 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 128 | # thanks for Andrej for pointing this out 129 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 130 | return score 131 | 132 | def __del__(self): 133 | self.close() 134 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/meteor/tests/test_meteor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | import unittest 5 | 6 | from nlgeval.pycocoevalcap.meteor.meteor import Meteor 7 | 8 | 9 | class TestMeteor(unittest.TestCase): 10 | def test_compute_score(self): 11 | m = Meteor() 12 | 13 | s = m.compute_score({0: ["test"]}, {0: ["test"]}) 14 | self.assertEqual(s, (1.0, [1.0])) 15 | 16 | s = m.compute_score({0: ["テスト"]}, {0: ["テスト"]}) 17 | self.assertEqual(s, (1.0, [1.0])) 18 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /eval/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | # path to the stanford corenlp jar 16 | STANFORD_CORENLP_3_4_1_JAR = ('/workspace/cococap_bin/' 17 | 'stanford-corenlp-3.4.1.jar') 18 | 19 | # punctuations to be removed from the sentences 20 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 21 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 22 | 23 | class PTBTokenizer: 24 | """Python wrapper of Stanford PTBTokenizer""" 25 | 26 | def tokenize(self, captions_for_image): 27 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 28 | 'edu.stanford.nlp.process.PTBTokenizer', \ 29 | '-preserveLines', '-lowerCase'] 30 | 31 | # ====================================================== 32 | # prepare data for PTB Tokenizer 33 | # ====================================================== 34 | final_tokenized_captions_for_image = {} 35 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 36 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 37 | 38 | # ====================================================== 39 | # save sentences to temporary file 40 | # ====================================================== 41 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 42 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 43 | tmp_file.write(sentences.encode()) 44 | tmp_file.close() 45 | 46 | # ====================================================== 47 | # tokenize sentence 48 | # ====================================================== 49 | cmd.append(os.path.basename(tmp_file.name)) 50 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 51 | stdout=subprocess.PIPE) 52 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 53 | token_lines = token_lines.decode() 54 | lines = token_lines.split('\n') 55 | # remove temp file 56 | os.remove(tmp_file.name) 57 | 58 | # ====================================================== 59 | # create dictionary for tokenized captions 60 | # ====================================================== 61 | for k, line in zip(image_id, lines): 62 | if not k in final_tokenized_captions_for_image: 63 | final_tokenized_captions_for_image[k] = [] 64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 65 | if w not in PUNCTUATIONS]) 66 | final_tokenized_captions_for_image[k].append(tokenized_caption) 67 | 68 | return final_tokenized_captions_for_image 69 | -------------------------------------------------------------------------------- /eval/tvc.py: -------------------------------------------------------------------------------- 1 | """ 2 | reproduce TVC evaluation using pycocoevalcap from Maluuba nlg-eval (Python 3) 3 | """ 4 | import json 5 | 6 | from .pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 7 | from .pycocoevalcap.bleu.bleu import Bleu 8 | from .pycocoevalcap.cider.cider import Cider 9 | from .pycocoevalcap.meteor.meteor import Meteor 10 | from .pycocoevalcap.rouge.rouge import Rouge 11 | 12 | 13 | def _remove_nonascii(text): 14 | return ''.join([i if ord(i) < 128 else ' ' for i in text]) 15 | 16 | 17 | class TVCEval(object): 18 | """ preload evaluation tools and references for repeated evaluation """ 19 | def __init__(self, ref_path): 20 | self.tokenizer = PTBTokenizer() 21 | id2refs = {ex['clip_id']: [_remove_nonascii(cap['desc'].strip()) 22 | for cap in ex['descs']] 23 | for ex in map(json.loads, open(ref_path))} 24 | self.id2refs = self.tokenizer.tokenize(id2refs) 25 | self.scorers = [] 26 | self.scorers.append((Bleu(4), 27 | ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"])) 28 | self.scorers.append((Meteor(), "METEOR")) 29 | self.scorers.append((Rouge(), "ROUGE_L")) 30 | self.scorers.append((Cider(), "CIDEr")) 31 | 32 | def __call__(self, json_res): 33 | """ corpus level metrics, take list of results """ 34 | id2hyps = { 35 | res['clip_id']: [_remove_nonascii(res['descs'][0]['desc'].strip())] 36 | for res in json_res 37 | } 38 | id2hyps = self.tokenizer.tokenize(id2hyps) 39 | assert len(id2hyps) == len(self.id2refs) 40 | 41 | ret_scores = {} 42 | for scorer, method in self.scorers: 43 | print(f"Computing {method} score...") 44 | score, scores = scorer.compute_score(self.id2refs, id2hyps) 45 | if isinstance(method, list): 46 | for sc, scs, m in zip(score, scores, method): 47 | ret_scores[m] = sc * 100 48 | else: 49 | ret_scores[method] = score * 100 50 | 51 | return ret_scores 52 | -------------------------------------------------------------------------------- /eval_violin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | run evaluation of VIOLIN 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from os.path import exists 11 | from time import time 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | from torch.nn import functional as F 16 | 17 | from apex import amp 18 | from horovod import torch as hvd 19 | 20 | from data import (ViolinEvalDataset, violin_eval_collate, 21 | QaQueryTokLmdb, PrefetchLoader) 22 | from load_data import get_video_ids, load_video_sub_dataset 23 | from model.violin import HeroForViolin 24 | 25 | from utils.basic_utils import save_json, save_pickle 26 | from utils.distributed import all_gather_list 27 | from utils.logger import LOGGER 28 | from utils.const import VFEAT_DIM 29 | from utils.misc import Struct 30 | 31 | 32 | def main(opts): 33 | hvd.init() 34 | n_gpu = hvd.size() 35 | device = torch.device("cuda", hvd.local_rank()) 36 | torch.cuda.set_device(hvd.local_rank()) 37 | LOGGER.info("device: {} n_gpu: {}, rank: {}, " 38 | "16-bits training: {}".format( 39 | device, n_gpu, hvd.rank(), opts.fp16)) 40 | if hvd.rank() != 0: 41 | LOGGER.disabled = True 42 | hps_file = f'{opts.output_dir}/log/hps.json' 43 | model_opts = Struct(json.load(open(hps_file))) 44 | model_config = f'{opts.output_dir}/log/model_config.json' 45 | 46 | # load DBs and image dirs 47 | video_ids = get_video_ids(opts.query_txt_db) 48 | video_db = load_video_sub_dataset( 49 | opts.vfeat_db, opts.sub_txt_db, model_opts.vfeat_interval, 50 | model_opts) 51 | assert opts.split in opts.query_txt_db 52 | q_txt_db = QaQueryTokLmdb(opts.query_txt_db, -1) 53 | eval_dataset = ViolinEvalDataset( 54 | video_ids, video_db, q_txt_db, 55 | sampled_by_q=model_opts.sampled_by_q) 56 | collate_fn = violin_eval_collate 57 | 58 | # Prepare model 59 | if exists(opts.checkpoint): 60 | ckpt_file = opts.checkpoint 61 | else: 62 | ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' 63 | checkpoint = torch.load(ckpt_file) 64 | img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\ 65 | ".position_embeddings.weight" 66 | assert img_pos_embed_weight_key in checkpoint 67 | max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key]) 68 | 69 | model = HeroForViolin.from_pretrained( 70 | model_config, 71 | state_dict=checkpoint, 72 | vfeat_dim=VFEAT_DIM, 73 | max_frm_seq_len=max_frm_seq_len 74 | ) 75 | model.to(device) 76 | if opts.fp16: 77 | model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') 78 | 79 | eval_dataloader = DataLoader(eval_dataset, batch_size=opts.batch_size, 80 | num_workers=opts.n_workers, 81 | pin_memory=opts.pin_mem, 82 | collate_fn=collate_fn) 83 | eval_dataloader = PrefetchLoader(eval_dataloader) 84 | 85 | _, results, logits = validate_violin( 86 | model, eval_dataloader, opts.split, opts.save_logits) 87 | result_dir = f'{opts.output_dir}/results_{opts.split}' 88 | if opts.save_logits: 89 | result_dir += '_w_logit' 90 | if not exists(result_dir) and hvd.rank() == 0: 91 | os.makedirs(result_dir) 92 | 93 | all_results = {} 94 | for id2res in all_gather_list(results): 95 | all_results.update(id2res) 96 | if opts.save_logits: 97 | all_logits = {} 98 | for id2logit in all_gather_list(logits): 99 | all_logits.update(id2logit) 100 | if hvd.rank() == 0: 101 | save_json( 102 | all_results, 103 | f'{result_dir}/results_{opts.checkpoint}_all.json') 104 | LOGGER.info('All results written......') 105 | if opts.save_logits: 106 | save_pickle( 107 | all_logits, 108 | f'{result_dir}/logits_{opts.checkpoint}_all.pkl') 109 | LOGGER.info('All logits written......') 110 | 111 | 112 | def compute_accuracies(predictions, labels): 113 | matched_qa = predictions.squeeze() == labels.squeeze() 114 | n_correct_qa = matched_qa.sum().item() 115 | return n_correct_qa 116 | 117 | 118 | @torch.no_grad() 119 | def validate_violin(model, val_loader, split, save_logits=False): 120 | LOGGER.info(f"start running validation on VIOLIN {split} split...") 121 | model.eval() 122 | val_loss = 0 123 | n_ex = 0 124 | tot_score = 0 125 | results = {} 126 | logits = {} 127 | st = time() 128 | 129 | for i, batch in enumerate(val_loader): 130 | targets = batch['targets'] 131 | if 'qids' in batch: 132 | qids = batch['qids'] 133 | del batch['qids'] 134 | 135 | scores = model(batch, "violin", compute_loss=False) 136 | predictions = (torch.sigmoid(scores) > 0.5).long() 137 | answers = predictions.squeeze().cpu().tolist() 138 | for qid, answer in zip(qids, answers): 139 | results[str(qid)] = answer 140 | if save_logits: 141 | scores = scores.cpu().tolist() 142 | for qid, logit in zip(qids, scores): 143 | logits[str(qid)] = logit 144 | 145 | loss = F.binary_cross_entropy( 146 | torch.sigmoid(scores), targets.to(dtype=scores.dtype), 147 | reduction='sum') 148 | val_loss += loss.item() 149 | tot_score += compute_accuracies(predictions, targets) 150 | n_ex += len(qids) 151 | 152 | val_loss = sum(all_gather_list(val_loss)) 153 | tot_score = sum(all_gather_list(tot_score)) 154 | n_ex = sum(all_gather_list(n_ex)) 155 | tot_time = time()-st 156 | val_loss /= n_ex 157 | val_acc = tot_score / n_ex 158 | val_log = { 159 | f'valid/{split}_loss': val_loss, 160 | f'valid/{split}_acc': val_acc, 161 | f'valid/{split}_ex_per_s': n_ex/tot_time} 162 | LOGGER.info(f"validation of {split} split finished in {int(tot_time)}s, " 163 | f"loss:{val_loss:.2f}, score: {val_acc*100:.2f}") 164 | return val_log, results, logits 165 | 166 | 167 | if __name__ == "__main__": 168 | parser = argparse.ArgumentParser() 169 | 170 | # Required parameters 171 | parser.add_argument("--sub_txt_db", 172 | default="/txt/violin_subtitles.db", 173 | type=str, 174 | help="The input video subtitle corpus. (LMDB)") 175 | parser.add_argument("--vfeat_db", 176 | default="/video/violin", type=str, 177 | help="The input video frame features.") 178 | parser.add_argument("--query_txt_db", 179 | default="/txt/violin_test.db", 180 | type=str, 181 | help="The input test query corpus. (LMDB)") 182 | parser.add_argument("--split", choices=["val", "test"], 183 | default="test", type=str, 184 | help="The input query split") 185 | parser.add_argument("--checkpoint", 186 | default=None, type=str, 187 | help="pretrained model checkpoint steps") 188 | parser.add_argument("--batch_size", 189 | default=10, type=int, 190 | help="number of queries in a batch") 191 | 192 | parser.add_argument( 193 | "--output_dir", default=None, type=str, 194 | help="The output directory where the model checkpoints will be " 195 | "written.") 196 | 197 | # Prepro parameters 198 | 199 | # device parameters 200 | parser.add_argument('--fp16', 201 | action='store_true', 202 | help="Whether to use 16-bit float precision instead " 203 | "of 32-bit") 204 | parser.add_argument('--n_workers', type=int, default=4, 205 | help="number of data workers") 206 | parser.add_argument('--pin_mem', action='store_true', 207 | help="pin memory") 208 | parser.add_argument("--save_logits", action='store_true', 209 | help="Whether to save logits") 210 | 211 | args = parser.parse_args() 212 | 213 | # options safe guard 214 | # TODO 215 | 216 | main(args) 217 | -------------------------------------------------------------------------------- /inf_tvc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | TVC inference 6 | generate prediction from JSON file 7 | """ 8 | import argparse 9 | import json 10 | from time import time 11 | 12 | import torch 13 | from horovod import torch as hvd 14 | from transformers import RobertaTokenizer 15 | from apex import amp 16 | from tqdm import tqdm 17 | 18 | from data.tvc import TvcEvalDataset 19 | from model.tvc import HeroForTvc, TvcGenerator 20 | from eval.tvc import TVCEval 21 | from utils.misc import Struct 22 | from utils.distributed import all_gather_list 23 | from utils.const import VFEAT_DIM, MAX_FRM_SEQ_LEN 24 | from utils.basic_utils import save_jsonl 25 | 26 | from load_data import load_video_sub_dataset 27 | from train_tvc import build_dataloader 28 | 29 | 30 | def main(opts): 31 | hvd.init() 32 | if hvd.rank() == 0: 33 | toker = RobertaTokenizer.from_pretrained('roberta-base') 34 | all_gather_list(None) 35 | else: 36 | all_gather_list(None) 37 | toker = RobertaTokenizer.from_pretrained('roberta-base') 38 | 39 | model_opts = Struct(json.load(open(f"{opts.model_dir}/log/hps.json"))) 40 | model_config = f"{opts.model_dir}/log/model_config.json" 41 | 42 | video_db = load_video_sub_dataset(model_opts.vfeat_db, 43 | model_opts.sub_txt_db, 44 | model_opts.vfeat_interval, 45 | model_opts) 46 | dset = TvcEvalDataset(video_db, opts.target_clip) 47 | loader = build_dataloader(dset, opts.batch_size, 48 | TvcEvalDataset.collate, False, opts) 49 | 50 | checkpoint = torch.load(f"{opts.model_dir}/ckpt/" 51 | f"model_step_{opts.ckpt_step}.pt") 52 | 53 | img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\ 54 | ".position_embeddings.weight" 55 | if img_pos_embed_weight_key in checkpoint: 56 | max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key]) 57 | else: 58 | max_frm_seq_len = MAX_FRM_SEQ_LEN 59 | 60 | model = HeroForTvc.from_pretrained(model_config, 61 | state_dict=checkpoint, 62 | vfeat_dim=VFEAT_DIM, 63 | max_frm_seq_len=max_frm_seq_len, 64 | lsr=model_opts.lsr) 65 | model.cuda() 66 | model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') 67 | 68 | bos = toker.convert_tokens_to_ids([''])[0] 69 | eos = toker.convert_tokens_to_ids([''])[0] 70 | model.eval() 71 | generator = TvcGenerator(model, opts.max_gen_step, bos, eos, opts.fp16) 72 | results = decode(loader, generator, toker) 73 | save_jsonl(results, opts.output) 74 | 75 | # evaluate score if possible 76 | if (hvd.rank() == 0 77 | and 'descs' in json.loads(next(iter(open(opts.target_clip))))): 78 | evaluator = TVCEval(opts.target_clip) 79 | score = evaluator(results) 80 | print(score) 81 | 82 | 83 | def decode(loader, generator, tokenizer): 84 | st = time() 85 | results = [] 86 | for batch in tqdm(loader, desc='decoding...'): 87 | vids = batch['vid_names'] 88 | cids = batch['clip_ids'] 89 | all_ts = batch['all_ts'] 90 | outputs = generator.greedy_decode(batch) 91 | for vid, cid, ts, out_ids in zip(vids, cids, all_ts, outputs): 92 | output = tokenizer.convert_tokens_to_string( 93 | tokenizer.convert_ids_to_tokens(out_ids)) 94 | results.append({'vid_name': vid, 'clip_id': cid, 'ts': ts, 95 | 'descs': [{'desc': output}]}) 96 | results = [r for rs in all_gather_list(results) for r in rs] 97 | print(f'decoding finished in {int(time() - st)} seconds') 98 | return results 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument("--model_dir", required=True, type=str, 104 | help="dir root to trained model") 105 | parser.add_argument("--ckpt_step", required=True, type=int, 106 | help="checkpoint step") 107 | parser.add_argument("--output", type=str, required=True, 108 | help="output file name") 109 | 110 | parser.add_argument("--batch_size", default=8, type=int, 111 | help="validation batch size (per GPU)") 112 | parser.add_argument("--max_gen_step", default=30, type=int, 113 | help="max generation steps") 114 | 115 | parser.add_argument('--n_workers', type=int, default=4, 116 | help="number of data workers") 117 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', 118 | help="disable pin memory") 119 | parser.add_argument("--no_fp16", action='store_false', dest='fp16', 120 | help="disable fp16") 121 | 122 | parser.add_argument("--target_clip", required=True, type=str, 123 | help="jsonl annotation") 124 | 125 | args = parser.parse_args() 126 | 127 | main(args) 128 | -------------------------------------------------------------------------------- /launch_container.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | TXT_DB=$1 7 | VID_DIR=$2 8 | OUTPUT=$3 9 | PRETRAIN_DIR=$4 10 | 11 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 12 | CUDA_VISIBLE_DEVICES='all' 13 | fi 14 | 15 | if [ "$5" = "--prepro" ]; then 16 | RO="" 17 | else 18 | RO=",readonly" 19 | fi 20 | 21 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --network=host --rm -it \ 22 | --mount src=$(pwd),dst=/src,type=bind \ 23 | --mount src=$OUTPUT,dst=/storage,type=bind \ 24 | --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ 25 | --mount src=$TXT_DB,dst=/txt,type=bind$RO \ 26 | --mount src=$VID_DIR,dst=/video,type=bind,readonly \ 27 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 28 | -w /src linjieli222/hero 29 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from utils.basic_utils import load_json 3 | from data import ( 4 | VideoFeatLmdb, SubTokLmdb, VrSubTokLmdb, 5 | VideoFeatSubTokDataset, VideoFeatDataset, 6 | VcmrDataset, vcmr_collate, VcmrEvalDataset, vcmr_eval_collate, 7 | VrVideoOnlyDataset, VrVideoOnlyEvalDataset, 8 | vr_collate, vr_eval_collate, 9 | VrDataset, VrEvalDataset, 10 | VcmrVideoOnlyDataset, VcmrVideoOnlyEvalDataset, 11 | VideoQaDataset, video_qa_collate, 12 | VideoQaEvalDataset, video_qa_eval_collate, 13 | ViolinDataset, violin_collate, 14 | ViolinEvalDataset, violin_eval_collate, 15 | PrefetchLoader) 16 | from utils.logger import LOGGER 17 | from utils.distributed import all_gather_list 18 | import os 19 | 20 | 21 | def get_video_ids(query_txt_db): 22 | if os.path.exists(f'{query_txt_db}/query2video.json'): 23 | q2v = load_json(f'{query_txt_db}/query2video.json') 24 | qids = load_json(f'{query_txt_db}/id2len.json').keys() 25 | video_ids = list(set([q2v[qid] for qid in qids])) 26 | else: 27 | video_ids = load_json(f'{query_txt_db}/video_ids.json') 28 | return video_ids 29 | 30 | 31 | def load_video_sub_dataset(v_feat_path, sub_txt_db, vfeat_interval, opts): 32 | vfeat_db = VideoFeatLmdb( 33 | v_feat_path, opts.vfeat_version, 34 | vfeat_interval, opts.compressed_db, 35 | opts.max_clip_len) 36 | if not isinstance(sub_txt_db, SubTokLmdb): 37 | if hasattr(opts, "task") and "msrvtt" in opts.task: 38 | sub_txt_db = VrSubTokLmdb(sub_txt_db, opts.max_clip_len) 39 | else: 40 | sub_txt_db = SubTokLmdb(sub_txt_db, opts.max_clip_len) 41 | video_db = VideoFeatSubTokDataset( 42 | sub_txt_db, vfeat_db, 43 | sub_ctx_len=opts.sub_ctx_len) 44 | return video_db 45 | 46 | 47 | def load_video_only_dataset(v_feat_path, txt_meta, vfeat_interval, opts): 48 | vfeat_db = VideoFeatLmdb( 49 | v_feat_path, opts.vfeat_version, 50 | vfeat_interval, opts.compressed_db, 51 | opts.max_clip_len) 52 | video_db = VideoFeatDataset( 53 | txt_meta, vfeat_db) 54 | return video_db 55 | 56 | 57 | def build_downstream_dataloaders( 58 | tasks, video_db, video_ids, is_train, opts, 59 | q_txt_db=None, shuffle=False): 60 | dataloaders = {} 61 | assert q_txt_db is not None 62 | for i, task in enumerate(tasks): 63 | if is_train: 64 | LOGGER.info(f"Loading {task} train dataset " 65 | f"{video_db.img_db.img_dir}") 66 | batch_size = opts.train_batch_size 67 | else: 68 | batch_size = opts.val_batch_size 69 | LOGGER.info(f"Loading {task} validation dataset" 70 | f"{video_db.img_db.img_dir}") 71 | if task in ["tvqa", "how2qa"]: 72 | if is_train: 73 | dataset = VideoQaDataset( 74 | video_ids, video_db, q_txt_db) 75 | collate_fn = video_qa_collate 76 | else: 77 | dataset = VideoQaEvalDataset( 78 | video_ids, video_db, q_txt_db) 79 | collate_fn = video_qa_eval_collate 80 | elif task in ["tvr", "how2r", "didemo_video_sub"]: 81 | if is_train: 82 | dataset = VcmrDataset( 83 | video_ids, video_db, q_txt_db) 84 | collate_fn = vcmr_collate 85 | else: 86 | dataset = VcmrEvalDataset( 87 | video_ids, video_db, q_txt_db) 88 | collate_fn = vcmr_eval_collate 89 | elif task == "didemo_video_only": 90 | if is_train: 91 | dataset = VcmrVideoOnlyDataset( 92 | video_ids, video_db, q_txt_db) 93 | collate_fn = vcmr_collate 94 | else: 95 | dataset = VcmrVideoOnlyEvalDataset( 96 | video_ids, video_db, q_txt_db) 97 | collate_fn = vcmr_eval_collate 98 | elif task == "msrvtt_video_only": 99 | if is_train: 100 | dataset = VrVideoOnlyDataset( 101 | video_ids, video_db, q_txt_db) 102 | collate_fn = vr_collate 103 | else: 104 | dataset = VrVideoOnlyEvalDataset( 105 | video_ids, video_db, q_txt_db) 106 | collate_fn = vr_eval_collate 107 | elif task == "msrvtt_video_sub": 108 | if is_train: 109 | dataset = VrDataset( 110 | video_ids, video_db, q_txt_db) 111 | collate_fn = vr_collate 112 | else: 113 | dataset = VrEvalDataset( 114 | video_ids, video_db, q_txt_db) 115 | collate_fn = vr_eval_collate 116 | elif task == "violin": 117 | if is_train: 118 | dataset = ViolinDataset( 119 | video_ids, video_db, q_txt_db) 120 | collate_fn = violin_collate 121 | else: 122 | dataset = ViolinEvalDataset( 123 | video_ids, video_db, q_txt_db) 124 | collate_fn = violin_eval_collate 125 | else: 126 | raise ValueError(f'Undefined task {task}') 127 | LOGGER.info(f"{sum(all_gather_list(len(dataset)))} samples loaded") 128 | loader = DataLoader(dataset, batch_size=batch_size, 129 | num_workers=opts.n_workers, 130 | pin_memory=opts.pin_mem, 131 | collate_fn=collate_fn, 132 | shuffle=shuffle) 133 | if is_train: 134 | ratio = 1 135 | dataloaders[task] = (loader, ratio) 136 | else: 137 | dataloaders[task] = PrefetchLoader(loader) 138 | return dataloaders 139 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjieli222/HERO/32c1c523c7a9f547a29f14c8e33dec24ebd14156/model/__init__.py -------------------------------------------------------------------------------- /model/modeling_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | some functions are modified from HuggingFace 6 | (https://github.com/huggingface/transformers) 7 | """ 8 | import torch 9 | from torch import nn 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def prune_linear_layer(layer, index, dim=0): 15 | """ Prune a linear layer (a model parameters) 16 | to keep only entries in index. 17 | Return the pruned layer as a new layer with requires_grad=True. 18 | Used to remove heads. 19 | """ 20 | index = index.to(layer.weight.device) 21 | W = layer.weight.index_select(dim, index).clone().detach() 22 | if layer.bias is not None: 23 | if dim == 1: 24 | b = layer.bias.clone().detach() 25 | else: 26 | b = layer.bias[index].clone().detach() 27 | new_size = list(layer.weight.size()) 28 | new_size[dim] = len(index) 29 | new_layer = nn.Linear( 30 | new_size[1], new_size[0], bias=layer.bias is not None).to( 31 | layer.weight.device) 32 | new_layer.weight.requires_grad = False 33 | new_layer.weight.copy_(W.contiguous()) 34 | new_layer.weight.requires_grad = True 35 | if layer.bias is not None: 36 | new_layer.bias.requires_grad = False 37 | new_layer.bias.copy_(b.contiguous()) 38 | new_layer.bias.requires_grad = True 39 | return new_layer 40 | 41 | 42 | def mask_logits(target, mask, eps=-1e4): 43 | return target * mask + (1 - mask) * eps 44 | 45 | 46 | def load_partial_checkpoint(checkpoint, n_layers, skip_layers=True): 47 | if skip_layers: 48 | new_checkpoint = {} 49 | gap = int(12/n_layers) 50 | prefix = "roberta.encoder.layer." 51 | layer_range = {str(l): str(i) for i, l in enumerate( 52 | list(range(gap-1, 12, gap)))} 53 | for k, v in checkpoint.items(): 54 | if prefix in k: 55 | layer_name = k.split(".") 56 | layer_num = layer_name[3] 57 | if layer_num in layer_range: 58 | layer_name[3] = layer_range[layer_num] 59 | new_layer_name = ".".join(layer_name) 60 | new_checkpoint[new_layer_name] = v 61 | else: 62 | new_checkpoint[k] = v 63 | else: 64 | new_checkpoint = checkpoint 65 | return new_checkpoint 66 | 67 | 68 | def load_pretrained_weight(model, state_dict): 69 | # Load from a PyTorch state_dict 70 | old_keys = [] 71 | new_keys = [] 72 | for key in state_dict.keys(): 73 | new_key = None 74 | if 'gamma' in key: 75 | new_key = key.replace('gamma', 'weight') 76 | if 'beta' in key: 77 | new_key = key.replace('beta', 'bias') 78 | if new_key: 79 | old_keys.append(key) 80 | new_keys.append(new_key) 81 | for old_key, new_key in zip(old_keys, new_keys): 82 | state_dict[new_key] = state_dict.pop(old_key) 83 | 84 | missing_keys = [] 85 | unexpected_keys = [] 86 | error_msgs = [] 87 | # copy state_dict so _load_from_state_dict can modify it 88 | metadata = getattr(state_dict, '_metadata', None) 89 | state_dict = state_dict.copy() 90 | if metadata is not None: 91 | state_dict._metadata = metadata 92 | 93 | def load(module, prefix=''): 94 | local_metadata = ({} if metadata is None 95 | else metadata.get(prefix[:-1], {})) 96 | module._load_from_state_dict( 97 | state_dict, prefix, local_metadata, True, missing_keys, 98 | unexpected_keys, error_msgs) 99 | for name, child in module._modules.items(): 100 | if child is not None: 101 | load(child, prefix + name + '.') 102 | start_prefix = '' 103 | if not hasattr(model, 'roberta') and\ 104 | any(s.startswith('roberta.') for s in state_dict.keys()): 105 | start_prefix = 'roberta.' 106 | 107 | load(model, prefix=start_prefix) 108 | if len(missing_keys) > 0: 109 | logger.info("Weights of {} not initialized from " 110 | "pretrained model: {}".format( 111 | model.__class__.__name__, missing_keys)) 112 | if len(unexpected_keys) > 0: 113 | logger.info("Weights from pretrained model not used in " 114 | "{}: {}".format( 115 | model.__class__.__name__, unexpected_keys)) 116 | if len(error_msgs) > 0: 117 | raise RuntimeError('Error(s) in loading state_dict for ' 118 | '{}:\n\t{}'.format( 119 | model.__class__.__name__, 120 | "\n\t".join(error_msgs))) 121 | return model 122 | 123 | 124 | def pad_tensor_to_mul(tensor, dim=0, mul=8): 125 | """ pad tensor to multiples (8 for tensor cores) """ 126 | t_size = list(tensor.size()) 127 | n_pad = mul - t_size[dim] % mul 128 | if n_pad == mul: 129 | n_pad = 0 130 | padded_tensor = tensor 131 | else: 132 | t_size[dim] = n_pad 133 | pad = torch.zeros(*t_size, dtype=tensor.dtype, device=tensor.device) 134 | padded_tensor = torch.cat([tensor, pad], dim=dim) 135 | return padded_tensor, n_pad 136 | -------------------------------------------------------------------------------- /model/vcmr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | HERO for Video Corpus Moment Retrieval Tasks, shared by: 6 | 1. TVR 7 | 2. How2R 8 | 3. DiDeMo with video and sub 9 | 4. DiDeMo with video only 10 | """ 11 | from .pretrain import HeroForPretraining 12 | 13 | 14 | class HeroForVcmr(HeroForPretraining): 15 | def __init__(self, config, vfeat_dim, max_frm_seq_len, 16 | conv_stride=1, conv_kernel_size=5, 17 | ranking_loss_type="hinge", margin=0.1, 18 | lw_neg_ctx=0, lw_neg_q=0, lw_st_ed=0.01, drop_svmr_prob=0, 19 | use_hard_negative=False, hard_pool_size=20, 20 | hard_neg_weight=10, use_all_neg=True): 21 | super().__init__( 22 | config, vfeat_dim, max_frm_seq_len, 23 | conv_stride, conv_kernel_size, 24 | ranking_loss_type, margin, 25 | lw_neg_ctx, lw_neg_q, lw_st_ed, drop_svmr_prob, 26 | use_hard_negative, hard_pool_size, 27 | hard_neg_weight, use_all_neg) 28 | 29 | def forward(self, batch, task='tvr', compute_loss=True): 30 | if task in ['tvr', 'how2r', 'didemo_video_sub', 31 | 'didemo_video_only']: 32 | return super().forward( 33 | batch, task='vsm', compute_loss=compute_loss) 34 | else: 35 | raise ValueError(f'Unrecognized task {task}') 36 | 37 | def get_pred_from_raw_query(self, frame_embeddings, c_attn_masks, 38 | query_input_ids, query_pos_ids, 39 | query_attn_masks, cross=False, 40 | val_gather_gpus=False): 41 | modularized_query = self.encode_txt_inputs( 42 | query_input_ids, query_pos_ids, 43 | query_attn_masks, attn_layer=self.q_feat_attn, 44 | normalized=False) 45 | 46 | st_prob, ed_prob = self.get_pred_from_mod_query( 47 | frame_embeddings, c_attn_masks, 48 | modularized_query, cross=cross) 49 | 50 | if self.lw_neg_ctx != 0 or self.lw_neg_q != 0: 51 | q2video_scores = self.get_video_level_scores( 52 | modularized_query, frame_embeddings, c_attn_masks, 53 | val_gather_gpus) 54 | else: 55 | q2video_scores = None 56 | return q2video_scores, st_prob, ed_prob 57 | -------------------------------------------------------------------------------- /model/videoQA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | HERO for Video Question Answering Tasks, shared by: 6 | 1. TVQA 7 | 2. How2QA 8 | """ 9 | from collections import defaultdict 10 | import copy 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | from .model import HeroModel 17 | from .layers import MLPLayer 18 | from .modeling_utils import mask_logits 19 | 20 | 21 | class HeroForVideoQA(HeroModel): 22 | def __init__(self, config, vfeat_dim, max_frm_seq_len): 23 | super().__init__( 24 | config, vfeat_dim, max_frm_seq_len) 25 | 26 | hsz = config.c_config.hidden_size 27 | 28 | self.qa_pool = nn.Linear( 29 | in_features=hsz, out_features=1, bias=False) 30 | self.qa_pred_head = MLPLayer(hsz, 1) 31 | 32 | # in tvqa/how2qa, we also have annotations for st and ed frame idx 33 | self.st_ed_pool = copy.deepcopy(self.qa_pool) 34 | self.st_ed_pred_head = MLPLayer(hsz, 2) 35 | 36 | def get_modularized_video(self, frame_embeddings, frame_mask): 37 | """ 38 | Args: 39 | frame_embeddings: (Nv, Nq, L, D) 40 | frame_mask: (Nv, Nq, L) 41 | """ 42 | st_ed_attn_scores = self.st_ed_pool( 43 | frame_embeddings) # (Nv, Nq, L, 1) 44 | qa_attn_scores = self.qa_pool(frame_embeddings) 45 | 46 | st_ed_attn_scores = F.softmax( 47 | mask_logits(st_ed_attn_scores, 48 | frame_mask.unsqueeze(-1)), dim=1) 49 | qa_attn_scores = F.softmax( 50 | mask_logits(qa_attn_scores, 51 | frame_mask.unsqueeze(-1)), dim=2) 52 | # TODO check whether it is the same 53 | st_ed_pooled_video = torch.einsum( 54 | "vqlm,vqld->vlmd", st_ed_attn_scores, 55 | frame_embeddings) # (Nv, L, 1, D) 56 | qa_pooled_video = torch.einsum( 57 | "vqlm,vqld->vqmd", qa_attn_scores, 58 | frame_embeddings) # (Nv, Nq, 1, D) 59 | return st_ed_pooled_video.squeeze(2), qa_pooled_video.squeeze(2) 60 | 61 | def forward(self, batch, task='tvqa', compute_loss=True): 62 | batch = defaultdict(lambda: None, batch) 63 | if task == 'tvqa' or task == 'how2qa': 64 | targets = batch['targets'].squeeze(-1) 65 | c_attn_masks = batch["c_attn_masks"] 66 | ts_targets = batch["ts_targets"] 67 | # (num_video * 5, num_frames, hid_size) 68 | frame_embeddings = self.v_encoder.forward_repr( 69 | batch, encode_clip=False) 70 | frame_embeddings = self.v_encoder.c_encoder.embeddings( 71 | frame_embeddings, 72 | position_ids=None) 73 | qa_embeddings = self.v_encoder.f_encoder._compute_txt_embeddings( 74 | batch["qa_input_ids"], batch["qa_pos_ids"], txt_type_ids=None) 75 | frame_qa_embeddings = torch.cat( 76 | (frame_embeddings, qa_embeddings), dim=1) 77 | frame_qa_attn_mask = torch.cat( 78 | (c_attn_masks, batch["qa_attn_masks"]), dim=1) 79 | fused_video_qa = self.v_encoder.c_encoder.forward_encoder( 80 | frame_qa_embeddings, frame_qa_attn_mask) 81 | num_frames = c_attn_masks.shape[1] 82 | video_embeddings = fused_video_qa[:, :num_frames, :] 83 | 84 | num_videos = len(targets) 85 | num_frames, hid_size = video_embeddings.shape[1:3] 86 | video_embeddings = video_embeddings.view( 87 | num_videos, -1, num_frames, hid_size) 88 | video_masks = c_attn_masks.view(num_videos, -1, num_frames) 89 | video_masks = video_masks.to(dtype=video_embeddings.dtype) 90 | st_ed_pooled_video, qa_pooled_video = self.get_modularized_video( 91 | video_embeddings, video_masks) 92 | pred_st_ed = self.st_ed_pred_head(st_ed_pooled_video) 93 | st_prob = mask_logits(pred_st_ed[:, :, 0], video_masks[:, 0]) 94 | ed_prob = mask_logits(pred_st_ed[:, :, 1], video_masks[:, 0]) 95 | logits = self.qa_pred_head(qa_pooled_video).squeeze(-1) 96 | 97 | if compute_loss: 98 | st_target, ed_target = ts_targets[:, 0], ts_targets[:, 1] 99 | st_loss = F.cross_entropy( 100 | st_prob, st_target, reduction="mean", 101 | ignore_index=-1) 102 | ed_loss = F.cross_entropy( 103 | ed_prob, ed_target, reduction="mean", 104 | ignore_index=-1) 105 | temporal_loss = (st_loss + ed_loss)/2. 106 | qa_loss = F.cross_entropy(logits, targets, reduction='mean', 107 | ignore_index=-1) 108 | return qa_loss, temporal_loss 109 | else: 110 | return logits 111 | else: 112 | raise ValueError(f'Unrecognized task: {task}') 113 | -------------------------------------------------------------------------------- /model/violin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | HERO for VIOLIN 6 | """ 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .model import HeroModel 14 | from .layers import MLPLayer 15 | from .modeling_utils import mask_logits 16 | 17 | 18 | class HeroForViolin(HeroModel): 19 | def __init__(self, config, vfeat_dim, max_frm_seq_len): 20 | super().__init__( 21 | config, vfeat_dim, max_frm_seq_len) 22 | hsz = config.c_config.hidden_size 23 | 24 | self.violin_pool = nn.Linear( 25 | in_features=hsz, 26 | out_features=1, 27 | bias=False) 28 | self.violin_pred_head = MLPLayer(hsz, 1) 29 | 30 | def get_modularized_video(self, frame_embeddings, frame_mask): 31 | """ 32 | Args: 33 | frame_embeddings: (Nv, L, D) 34 | frame_mask: (Nv, L) 35 | """ 36 | violin_attn_scores = self.violin_pool( 37 | frame_embeddings) # (Nv, L, 1) 38 | 39 | violin_attn_scores = F.softmax( 40 | mask_logits(violin_attn_scores, 41 | frame_mask.unsqueeze(-1)), dim=1) 42 | 43 | # TODO check whether it is the same 44 | violin_pooled_video = torch.einsum( 45 | "vlm,vld->vmd", violin_attn_scores, 46 | frame_embeddings) # (Nv, 1, D) 47 | return violin_pooled_video.squeeze(1) 48 | 49 | def forward(self, batch, task='violin', compute_loss=True): 50 | batch = defaultdict(lambda: None, batch) 51 | if task == 'violin': 52 | c_attn_masks = batch["c_attn_masks"] 53 | # (num_video * 5, num_frames, hid_size) 54 | frame_embeddings = self.v_encoder.forward_repr( 55 | batch, encode_clip=False) 56 | frame_embeddings = self.v_encoder.c_encoder.embeddings( 57 | frame_embeddings, 58 | position_ids=None) 59 | q_embeddings = self.v_encoder.f_encoder._compute_txt_embeddings( 60 | batch["q_input_ids"], batch["q_pos_ids"], txt_type_ids=None) 61 | frame_q_embeddings = torch.cat( 62 | (frame_embeddings, q_embeddings), dim=1) 63 | frame_q_attn_mask = torch.cat( 64 | (c_attn_masks, batch["q_attn_masks"]), dim=1) 65 | fused_video_q = self.v_encoder.c_encoder.forward_encoder( 66 | frame_q_embeddings, frame_q_attn_mask) 67 | num_frames = c_attn_masks.shape[1] 68 | video_embeddings = fused_video_q[:, :num_frames, :] 69 | 70 | video_masks = c_attn_masks.to(dtype=video_embeddings.dtype) 71 | violin_pooled_video = self.get_modularized_video( 72 | video_embeddings, video_masks) 73 | logits = self.violin_pred_head(violin_pooled_video) 74 | 75 | if compute_loss: 76 | targets = batch['targets'] 77 | scores = torch.sigmoid(logits).squeeze(-1) 78 | targets = targets.squeeze(-1).to(dtype=scores.dtype) 79 | violin_loss = F.binary_cross_entropy( 80 | scores, targets, reduction='mean') 81 | return violin_loss 82 | else: 83 | return logits 84 | raise ValueError(f'Unrecognized task: {task}') 85 | -------------------------------------------------------------------------------- /model/vr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | HERO for Video Retrieval Tasks, shared by: 6 | 1. MSR-VTT with video and sub 7 | 2. MSR-VTT with video only 8 | """ 9 | from .vcmr import HeroForVcmr 10 | 11 | 12 | class HeroForVr(HeroForVcmr): 13 | def __init__(self, config, vfeat_dim, max_frm_seq_len, 14 | ranking_loss_type="hinge", margin=0.1, 15 | lw_neg_ctx=1, lw_neg_q=1, 16 | use_hard_negative=False, hard_pool_size=20, 17 | hard_neg_weight=10, use_all_neg=True): 18 | assert lw_neg_ctx != 0 or lw_neg_q != 0,\ 19 | "Need to set lw_neg_ctx or lw_neg_q for VR training" 20 | super().__init__( 21 | config, vfeat_dim, max_frm_seq_len, 22 | ranking_loss_type=ranking_loss_type, margin=margin, 23 | lw_neg_ctx=lw_neg_ctx, lw_neg_q=lw_neg_q, 24 | lw_st_ed=0, drop_svmr_prob=1.0, 25 | use_hard_negative=use_hard_negative, 26 | hard_pool_size=hard_pool_size, 27 | hard_neg_weight=hard_neg_weight, 28 | use_all_neg=use_all_neg) 29 | assert self.lw_st_ed == 0, "For VR, lw_st_ed should be 0" 30 | 31 | def forward(self, batch, task='msrvtt_video_sub', compute_loss=True): 32 | if task in ['msrvtt_video_sub', 'msrvtt_video_only']: 33 | if compute_loss: 34 | _, loss_neg_ctx, loss_neg_q = super().forward( 35 | batch, task='tvr', compute_loss=True) 36 | return loss_neg_ctx, loss_neg_q 37 | else: 38 | q2video_scores, _, _ = super().forward( 39 | batch, task='tvr', compute_loss=False) 40 | return q2video_scores 41 | else: 42 | raise ValueError(f'Unrecognized task {task}') 43 | 44 | def get_pred_from_raw_query(self, frame_embeddings, c_attn_masks, 45 | query_input_ids, query_pos_ids, 46 | query_attn_masks, cross=False, 47 | val_gather_gpus=False): 48 | modularized_query = self.encode_txt_inputs( 49 | query_input_ids, query_pos_ids, 50 | query_attn_masks, attn_layer=self.q_feat_attn, 51 | normalized=False) 52 | 53 | q2video_scores = self.get_video_level_scores( 54 | modularized_query, frame_embeddings, c_attn_masks, 55 | val_gather_gpus) 56 | return q2video_scores 57 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Copied from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | """ 8 | from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched 9 | from .adamw import AdamW -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | originally from hugginface (https://github.com/huggingface/transformers). 4 | 5 | Copied from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | """ 8 | import math 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class AdamW(Optimizer): 15 | """ Implements Adam algorithm with weight decay fix. 16 | Parameters: 17 | lr (float): learning rate. Default 1e-3. 18 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). 19 | Default: (0.9, 0.999) 20 | eps (float): Adams epsilon. Default: 1e-6 21 | weight_decay (float): Weight decay. Default: 0.0 22 | correct_bias (bool): can be set to False to avoid correcting bias 23 | in Adam (e.g. like in Bert TF repository). Default True. 24 | """ 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 26 | weight_decay=0.0, correct_bias=True): 27 | if lr < 0.0: 28 | raise ValueError( 29 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 30 | if not 0.0 <= betas[0] < 1.0: 31 | raise ValueError("Invalid beta parameter: {} - " 32 | "should be in [0.0, 1.0[".format(betas[0])) 33 | if not 0.0 <= betas[1] < 1.0: 34 | raise ValueError("Invalid beta parameter: {} - " 35 | "should be in [0.0, 1.0[".format(betas[1])) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {} - " 38 | "should be >= 0.0".format(eps)) 39 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 40 | correct_bias=correct_bias) 41 | super(AdamW, self).__init__(params, defaults) 42 | 43 | def step(self, closure=None): 44 | """Performs a single optimization step. 45 | Arguments: 46 | closure (callable, optional): A closure that reevaluates the model 47 | and returns the loss. 48 | """ 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | 53 | for group in self.param_groups: 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | grad = p.grad.data 58 | if grad.is_sparse: 59 | raise RuntimeError( 60 | 'Adam does not support sparse ' 61 | 'gradients, please consider SparseAdam instead') 62 | 63 | state = self.state[p] 64 | 65 | # State initialization 66 | if len(state) == 0: 67 | state['step'] = 0 68 | # Exponential moving average of gradient values 69 | state['exp_avg'] = torch.zeros_like(p.data) 70 | # Exponential moving average of squared gradient values 71 | state['exp_avg_sq'] = torch.zeros_like(p.data) 72 | 73 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 74 | beta1, beta2 = group['betas'] 75 | 76 | state['step'] += 1 77 | 78 | # Decay the first and second moment running average coefficient 79 | # In-place operations to update the averages at the same time 80 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 82 | denom = exp_avg_sq.sqrt().add_(group['eps']) 83 | 84 | step_size = group['lr'] 85 | if group['correct_bias']: # No bias correction for Bert 86 | bias_correction1 = 1.0 - beta1 ** state['step'] 87 | bias_correction2 = 1.0 - beta2 ** state['step'] 88 | step_size = (step_size * math.sqrt(bias_correction2) 89 | / bias_correction1) 90 | 91 | p.data.addcdiv_(-step_size, exp_avg, denom) 92 | 93 | # Just adding the square of the weights to the loss function is 94 | # *not* the correct way of using L2 regularization/weight decay 95 | # with Adam, since that will interact with the m and v 96 | # parameters in strange ways. 97 | # 98 | # Instead we want to decay the weights in a manner that doesn't 99 | # interact with the m/v parameters. This is equivalent to 100 | # adding the square of the weights to the loss with plain 101 | # (non-momentum) SGD. 102 | # Add weight decay at the end (fixed version) 103 | if group['weight_decay'] > 0.0: 104 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Copied from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | 8 | Misc lr helper 9 | """ 10 | from torch.optim import Adam, Adamax 11 | from .adamw import AdamW 12 | 13 | 14 | def build_optimizer(model, opts): 15 | # Prepare optimizer 16 | param_optimizer = [(n, p) for n, p in model.named_parameters() 17 | if 'v_encoder' in n and p.requires_grad] 18 | # top layer has larger learning rate 19 | param_top = [(n, p) for n, p in model.named_parameters() 20 | if 'v_encoder' not in n and p.requires_grad] 21 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 22 | optimizer_grouped_parameters = [ 23 | {'params': [p for n, p in param_top 24 | if not any(nd in n for nd in no_decay)], 25 | 'lr': opts.lr_mul*opts.learning_rate, 26 | 'weight_decay': opts.weight_decay}, 27 | {'params': [p for n, p in param_top 28 | if any(nd in n for nd in no_decay)], 29 | 'lr': opts.lr_mul*opts.learning_rate, 30 | 'weight_decay': 0.0}, 31 | {'params': [p for n, p in param_optimizer 32 | if not any(nd in n for nd in no_decay)], 33 | 'weight_decay': opts.weight_decay}, 34 | {'params': [p for n, p in param_optimizer 35 | if any(nd in n for nd in no_decay)], 36 | 'weight_decay': 0.0} 37 | ] 38 | 39 | # currently Adam only 40 | if opts.optim == 'adam': 41 | OptimCls = Adam 42 | elif opts.optim == 'adamax': 43 | OptimCls = Adamax 44 | elif opts.optim == 'adamw': 45 | OptimCls = AdamW 46 | else: 47 | raise ValueError('invalid optimizer') 48 | optimizer = OptimCls(optimizer_grouped_parameters, 49 | lr=opts.learning_rate, betas=opts.betas) 50 | return optimizer 51 | -------------------------------------------------------------------------------- /optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Copied from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | 8 | optimizer learning rate scheduling helpers 9 | """ 10 | from math import ceil 11 | 12 | 13 | def noam_schedule(step, warmup_step=4000): 14 | """ original Transformer schedule""" 15 | if step <= warmup_step: 16 | return step / warmup_step 17 | return (warmup_step ** 0.5) * (step ** -0.5) 18 | 19 | 20 | def warmup_linear(step, warmup_step, tot_step): 21 | """ BERT schedule """ 22 | if step < warmup_step: 23 | return step / warmup_step 24 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 25 | 26 | 27 | def vqa_schedule(step, warmup_interval, decay_interval, 28 | decay_start, decay_rate): 29 | """ VQA schedule from MCAN """ 30 | if step < warmup_interval: 31 | return 1/4 32 | elif step < 2 * warmup_interval: 33 | return 2/4 34 | elif step < 3 * warmup_interval: 35 | return 3/4 36 | elif step >= decay_start: 37 | num_decay = ceil((step - decay_start) / decay_interval) 38 | return decay_rate ** num_decay 39 | else: 40 | return 1 41 | 42 | 43 | def get_lr_sched(global_step, opts): 44 | # learning rate scheduling 45 | lr_this_step = opts.learning_rate * warmup_linear( 46 | global_step, opts.warmup_steps, opts.num_train_steps) 47 | if lr_this_step <= 0: 48 | lr_this_step = 1e-8 49 | return lr_this_step -------------------------------------------------------------------------------- /scripts/collect_video_feature_paths.py: -------------------------------------------------------------------------------- 1 | """ 2 | gather slowfast/resnet feature paths 3 | """ 4 | import os 5 | import numpy as np 6 | import pickle as pkl 7 | import argparse 8 | from tqdm import tqdm 9 | from cytoolz import curry 10 | import multiprocessing as mp 11 | 12 | 13 | @curry 14 | def load_npz(slowfast_dir, resnet_dir, slowfast_f): 15 | vid = slowfast_f.split("/")[-1].split(".npz")[0] 16 | folder_name = slowfast_f.split("/")[-2] 17 | resnet_f = slowfast_f.replace(slowfast_dir, resnet_dir) 18 | try: 19 | slowfast_data = np.load(slowfast_f, allow_pickle=True) 20 | slowfast_frame_len = max(0, len(slowfast_data["features"])) 21 | except Exception: 22 | slowfast_frame_len = 0 23 | resnet_frame_len = 0 24 | if slowfast_frame_len == 0: 25 | slowfast_f = "" 26 | print(f"Corrupted slowfast files for {vid}") 27 | # print(resnet_f) 28 | if not os.path.exists(resnet_f): 29 | resnet_f = "" 30 | print(f"resnet files for {vid} does not exists") 31 | else: 32 | try: 33 | resnet_data = np.load(resnet_f, allow_pickle=True) 34 | resnet_frame_len = len(resnet_data["features"]) 35 | except Exception: 36 | resnet_frame_len = 0 37 | resnet_f = "" 38 | print(f"Corrupted resnet files for {vid}") 39 | frame_len = min(slowfast_frame_len, resnet_frame_len) 40 | return vid, frame_len, slowfast_f, resnet_f, folder_name 41 | 42 | 43 | def main(opts): 44 | slowfast_dir = os.path.join(opts.feature_dir, "slowfast_features/") 45 | resnet_dir = os.path.join(opts.feature_dir, "resnet_features/") 46 | failed_resnet_files = [] 47 | failed_slowfast_files = [] 48 | loaded_file = [] 49 | for root, dirs, curr_files in os.walk(f'{slowfast_dir}/'): 50 | for f in curr_files: 51 | if f.endswith('.npz'): 52 | slowfast_f = os.path.join(root, f) 53 | loaded_file.append(slowfast_f) 54 | print(f"Found {len(loaded_file)} slowfast files....") 55 | print(f"sample loaded_file: {loaded_file[:3]}") 56 | failed_resnet_files, failed_slowfast_files = [], [] 57 | files = {} 58 | load = load_npz(slowfast_dir, resnet_dir) 59 | with mp.Pool(opts.nproc) as pool, tqdm(total=len(loaded_file)) as pbar: 60 | for i, (vid, frame_len, slowfast_f, 61 | resnet_f, folder_name) in enumerate( 62 | pool.imap_unordered(load, loaded_file, chunksize=128)): 63 | files[vid] = (frame_len, slowfast_f, resnet_f, folder_name) 64 | if resnet_f == "": 65 | video_file = os.path.join(folder_name, vid) 66 | failed_resnet_files.append(video_file) 67 | if slowfast_f == "": 68 | video_file = os.path.join(folder_name, vid) 69 | failed_slowfast_files.append(video_file) 70 | pbar.update(1) 71 | output_dir = os.path.join(opts.output, opts.dataset) 72 | if not os.path.exists(output_dir): 73 | os.makedirs(output_dir, exist_ok=True) 74 | pkl.dump(files, open(os.path.join( 75 | output_dir, "video_feat_info.pkl"), "wb")) 76 | if len(failed_slowfast_files): 77 | pkl.dump(failed_slowfast_files, open(os.path.join( 78 | output_dir, "failed_slowfast_files.pkl"), "wb")) 79 | if len(failed_resnet_files): 80 | pkl.dump(failed_resnet_files, open(os.path.join( 81 | output_dir, "failed_resnet_files.pkl"), "wb")) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--feature_dir", 87 | default="", 88 | type=str, help="The input video feature dir.") 89 | parser.add_argument("--output", default=None, type=str, 90 | help="output dir") 91 | parser.add_argument('--dataset', type=str, 92 | default="") 93 | parser.add_argument('--nproc', type=int, default=10, 94 | help='number of cores used') 95 | args = parser.parse_args() 96 | main(args) 97 | -------------------------------------------------------------------------------- /scripts/convert_videodb.py: -------------------------------------------------------------------------------- 1 | """ 2 | convert feature npz file to lmdb 3 | """ 4 | import argparse 5 | import glob 6 | import io 7 | import json 8 | import multiprocessing as mp 9 | import os 10 | from os.path import exists 11 | 12 | from cytoolz import curry 13 | import numpy as np 14 | from tqdm import tqdm 15 | import lmdb 16 | import pickle as pkl 17 | 18 | import msgpack 19 | import msgpack_numpy 20 | msgpack_numpy.patch() 21 | 22 | 23 | @curry 24 | def load_npz(fname): 25 | try: 26 | vid, nframes, slowfast_fname, resnet_fname, _ = fname 27 | except Exception: 28 | vid, nframes, slowfast_fname, resnet_fname = fname 29 | try: 30 | if nframes == 0: 31 | raise ValueError('wrong ndim') 32 | slowfast_features = np.load( 33 | slowfast_fname, allow_pickle=True)["features"] 34 | if slowfast_features.dtype == np.float16: 35 | slowfast_features = slowfast_features.astype(np.float32) 36 | resnet_features = np.load( 37 | resnet_fname, allow_pickle=True)["features"] 38 | if resnet_features.dtype == np.float16: 39 | resnet_features = resnet_features.astype(np.float32) 40 | resnet_features = resnet_features[:nframes, :] 41 | slowfast_features = slowfast_features[:nframes, :] 42 | dump = {"features": np.concatenate( 43 | (resnet_features, slowfast_features), axis=1)} 44 | except Exception as e: 45 | # corrupted file 46 | print(f'corrupted file {vid}', e) 47 | dump = {} 48 | nframes = 0 49 | 50 | return vid, dump, nframes 51 | 52 | 53 | def dumps_npz(dump, compress=False): 54 | with io.BytesIO() as writer: 55 | if compress: 56 | np.savez_compressed(writer, **dump, allow_pickle=True) 57 | else: 58 | np.savez(writer, **dump, allow_pickle=True) 59 | return writer.getvalue() 60 | 61 | 62 | def dumps_msgpack(dump): 63 | return msgpack.dumps(dump, use_bin_type=True) 64 | 65 | 66 | def main(opts): 67 | db_name = f'{opts.feat_version}_{opts.frame_length}' 68 | if opts.compress: 69 | db_name += '_compressed' 70 | if not exists(f'{opts.output}/{opts.dataset}'): 71 | os.makedirs(f'{opts.output}/{opts.dataset}') 72 | env = lmdb.open(f'{opts.output}/{opts.dataset}/{db_name}', map_size=1024**4) 73 | txn = env.begin(write=True) 74 | clip_interval = int(opts.clip_interval/opts.frame_length) 75 | # files = glob.glob(f'{opts.img_dir}/*.npz') 76 | files_dict = pkl.load(open(opts.vfeat_info_file, "rb")) 77 | files = [[key]+list(val) for key, val in files_dict.items()] 78 | # for root, dirs, curr_files in os.walk(f'{opts.img_dir}/'): 79 | # for f in curr_files: 80 | # if f.endswith('.npz'): 81 | # files.append(os.path.join(root, f)) 82 | load = load_npz() 83 | name2nframes = {} 84 | corrupted_files = set() 85 | with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar: 86 | for i, (fname, features, nframes) in enumerate( 87 | pool.imap_unordered(load, files, chunksize=128)): 88 | if not features or nframes == 0: 89 | pbar.update(1) 90 | corrupted_files.add(fname) 91 | continue # corrupted feature 92 | if opts.clip_interval != -1: 93 | feature_values = features["features"] 94 | clip_id = 0 95 | for st_ind in range(0, nframes, clip_interval): 96 | clip_name = fname+f".{clip_id}" 97 | ed_ind = min(st_ind + clip_interval, nframes) 98 | clip_features = { 99 | "features": feature_values[st_ind: ed_ind]} 100 | clip_id += 1 101 | if opts.compress: 102 | clip_dump = dumps_npz(clip_features, compress=True) 103 | else: 104 | clip_dump = dumps_msgpack(clip_features) 105 | txn.put(key=clip_name.encode('utf-8'), value=clip_dump) 106 | name2nframes[clip_name] = ed_ind - st_ind 107 | else: 108 | if opts.compress: 109 | dump = dumps_npz(features, compress=True) 110 | else: 111 | dump = dumps_msgpack(features) 112 | txn.put(key=fname.encode('utf-8'), value=dump) 113 | name2nframes[fname] = nframes 114 | if i % 1000 == 0: 115 | txn.commit() 116 | txn = env.begin(write=True) 117 | pbar.update(1) 118 | txn.commit() 119 | env.close() 120 | id2frame_len_file = f'{opts.output}/{opts.dataset}/id2nframe.json' 121 | if os.path.exists(id2frame_len_file): 122 | id2frame = json.load(open(id2frame_len_file, "r")) 123 | for key, val in id2frame.items(): 124 | if val != name2nframes[key]: 125 | print(f"Mismatch: {val} vs. {name2nframes[key]}") 126 | id2frame[key] = min(val, name2nframes[key]) 127 | assert id2frame[key] > 0 128 | else: 129 | id2frame = name2nframes 130 | with open(id2frame_len_file, 'w') as f: 131 | json.dump(id2frame, f) 132 | corrupted_files = list(corrupted_files) 133 | if len(corrupted_files) > 0: 134 | corrupted_output_file = f'{opts.output}/{opts.dataset}/corrupted.json' 135 | with open(corrupted_output_file, 'w') as f: 136 | json.dump(corrupted_files, f) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--vfeat_info_file", default=None, type=str, 142 | help="The input feature paths stored in pkl file.") 143 | parser.add_argument("--output", default=None, type=str, 144 | help="output lmdb") 145 | parser.add_argument( 146 | '--frame_length', type=float, default=2, 147 | help='1 feature per "frame_length" seconds used in feature extraction,' 148 | 'in seconds (1.5/2)') 149 | parser.add_argument('--dataset', type=str, 150 | default="") 151 | parser.add_argument('--feat_version', type=str, 152 | default="resnet_slowfast") 153 | parser.add_argument('--nproc', type=int, default=4, 154 | help='number of cores used') 155 | parser.add_argument('--compress', action='store_true', 156 | help='compress the tensors') 157 | parser.add_argument( 158 | '--clip_interval', type=int, default=-1, 159 | help="cut the whole video into small clips, in seconds" 160 | "set to 60 for HowTo100M videos, set to -1 otherwise") 161 | args = parser.parse_args() 162 | main(args) 163 | -------------------------------------------------------------------------------- /scripts/create_txtdb.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # https://github.com/ChenRocks/UNITER 5 | 6 | OUT_DIR=$1 7 | ANN_DIR=$2 8 | 9 | set -e 10 | 11 | # annotations 12 | URL='https://raw.githubusercontent.com/jayleicn/TVRetrieval/master/data' 13 | BLOB='https://convaisharables.blob.core.windows.net/hero' 14 | 15 | if [ ! -d $OUT_DIR ]; then 16 | mkdir -p $OUT_DIR 17 | fi 18 | if [ ! -d $ANN_DIR ]; then 19 | mkdir -p $ANN_DIR 20 | fi 21 | 22 | for SPLIT in 'train' 'val' 'test_public'; do 23 | if [ ! -f $ANN_DIR/tvr_$SPLIT.jsonl ]; then 24 | echo "downloading ${SPLIT} annotations..." 25 | wget $URL/tvr_${SPLIT}_release.jsonl -O $ANN_DIR/tvr_$SPLIT.jsonl 26 | fi 27 | 28 | echo "preprocessing tvr ${SPLIT} annotations..." 29 | docker run --ipc=host --rm -it \ 30 | --mount src=$(pwd),dst=/src,type=bind \ 31 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 32 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 33 | -w /src linjieli222/hero \ 34 | python script/prepro_query.py --annotation /ann/tvr_$SPLIT.jsonl \ 35 | --output /txt_db/tvr_${SPLIT}.db \ 36 | --task tvr 37 | done 38 | 39 | wget $URL/tvqa_preprocessed_subtitles.jsonl -O $ANN_DIR/tv_subtitles.jsonl 40 | wget $BLOB/tv_vid2nframe.json -O $ANN_DIR/tv_vid2nframe.json 41 | wget $URL/tvr_video2dur_idx.json -O $ANN_DIR/vid2dur_idx.json 42 | echo "preprocessing tv subtitles..." 43 | docker run --ipc=host --rm -it \ 44 | --mount src=$(pwd),dst=/src,type=bind \ 45 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 46 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 47 | -w /src linjieli222/hero \ 48 | /bin/bash -c "python script/prepro_sub.py --annotation /ann/tv_subtitles.jsonl --output /txt_db/tv_subtitles.db --vid2nframe /ann/tv_vid2nframe.json --frame_length 1.5; cp /ann/vid2dur_idx.json /txt_db/tv_subtitles.db/" 49 | echo "done" -------------------------------------------------------------------------------- /scripts/download_didemo.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | 17 | 18 | # video dbs 19 | if [ ! -d $DOWNLOAD/video_db/didemo/ ] ; then 20 | wget $BLOB/video_db/didemo.tar -P $DOWNLOAD/video_db/ 21 | tar -xvf $DOWNLOAD/video_db/didemo.tar -C $DOWNLOAD/video_db --strip-components 1 22 | rm $DOWNLOAD/video_db/didemo.tar 23 | fi 24 | 25 | # text dbs 26 | for SPLIT in 'train' 'val' 'test'; do 27 | wget $BLOB/txt_db/didemo_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 28 | tar -xvf $DOWNLOAD/txt_db/didemo_$SPLIT.db.tar -C $DOWNLOAD/txt_db 29 | rm $DOWNLOAD/txt_db/didemo_$SPLIT.db.tar 30 | done 31 | if [ ! -d $DOWNLOAD/txt_db/didemo_subtitles.db/ ] ; then 32 | wget $BLOB/txt_db/didemo_subtitles.db.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/didemo_subtitles.db.tar -C $DOWNLOAD/txt_db 34 | rm $DOWNLOAD/txt_db/didemo_subtitles.db.tar 35 | fi 36 | 37 | # pretrainedsd 38 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 39 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 40 | fi 41 | -------------------------------------------------------------------------------- /scripts/download_msrvtt.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | 17 | 18 | # video dbs 19 | if [ ! -d $DOWNLOAD/video_db/msrvtt/ ] ; then 20 | wget $BLOB/video_db/msrvtt.tar -P $DOWNLOAD/video_db/ 21 | tar -xvf $DOWNLOAD/video_db/msrvtt.tar -C $DOWNLOAD/video_db --strip-components 1 22 | rm $DOWNLOAD/video_db/msrvtt.tar 23 | fi 24 | 25 | # text dbs 26 | for SPLIT in 'train' 'val' 'test'; do 27 | wget $BLOB/txt_db/msrvtt_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 28 | tar -xvf $DOWNLOAD/txt_db/msrvtt_$SPLIT.db.tar -C $DOWNLOAD/txt_db 29 | rm $DOWNLOAD/txt_db/msrvtt_$SPLIT.db.tar 30 | done 31 | if [ ! -d $DOWNLOAD/txt_db/msrvtt_subtitles.db/ ] ; then 32 | wget $BLOB/txt_db/msrvtt_subtitles.db.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/msrvtt_subtitles.db.tar -C $DOWNLOAD/txt_db 34 | rm $DOWNLOAD/txt_db/msrvtt_subtitles.db.tar 35 | fi 36 | 37 | # pretrained 38 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 39 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 40 | fi 41 | -------------------------------------------------------------------------------- /scripts/download_pretrained.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | if [ ! -d $DOWNLOAD/pretrained ] ; then 9 | mkdir -p $DOWNLOAD/pretrained 10 | fi 11 | 12 | BLOB='https://convaisharables.blob.core.windows.net/hero' 13 | 14 | # This will overwrite models 15 | wget $BLOB/pretrained/hero-tv-ht100.pt -O $DOWNLOAD/pretrained/hero-tv-ht100.pt 16 | -------------------------------------------------------------------------------- /scripts/download_tv_pretrain.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | 17 | # video dbs 18 | if [ ! -d $DOWNLOAD/video_db/tv/ ] ; then 19 | wget $BLOB/video_db/tv.tar -P $DOWNLOAD/video_db/ 20 | tar -xvf $DOWNLOAD/video_db/tv.tar -C $DOWNLOAD/video_db --strip-components 1 21 | rm $DOWNLOAD/video_db/tv.tar 22 | fi 23 | 24 | # text dbs 25 | if [ ! -d $DOWNLOAD/txt_db/tv_subtitles.db/ ] ; then 26 | wget $BLOB/txt_db/tv_subtitles.db.tar -P $DOWNLOAD/txt_db/ 27 | tar -xvf $DOWNLOAD/txt_db/tv_subtitles.db.tar -C $DOWNLOAD/txt_db 28 | rm $DOWNLOAD/txt_db/tv_subtitles.db.tar 29 | fi 30 | 31 | # pretrain splits 32 | wget $BLOB/txt_db/pretrain_splits.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/pretrain_splits.tar -C $DOWNLOAD/txt_db 34 | rm $DOWNLOAD/txt_db/pretrain_splits.tar 35 | 36 | # converted RoBERTa 37 | wget $BLOB/pretrained/pretrain-tv-init.bin -P $DOWNLOAD/pretrained/ 38 | -------------------------------------------------------------------------------- /scripts/download_tvc.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | # video dbs 17 | if [ ! -d $DOWNLOAD/video_db/tv/ ] ; then 18 | wget $BLOB/video_db/tv.tar -P $DOWNLOAD/video_db/ 19 | tar -xvf $DOWNLOAD/video_db/tv.tar -C $DOWNLOAD/video_db --strip-components 1 20 | rm $DOWNLOAD/video_db/tv.tar 21 | fi 22 | 23 | # text dbs 24 | for SPLIT in 'train' 'val' ; do 25 | wget $BLOB/txt_db/tvc_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 26 | tar -xvf $DOWNLOAD/txt_db/tvc_$SPLIT.db.tar -C $DOWNLOAD/txt_db 27 | rm $DOWNLOAD/txt_db/tvc_$SPLIT.db.tar 28 | done 29 | if [ ! -d $DOWNLOAD/txt_db/tv_subtitles.db/ ] ; then 30 | wget $BLOB/txt_db/tv_subtitles.db.tar -P $DOWNLOAD/txt_db/ 31 | tar -xvf $DOWNLOAD/txt_db/tv_subtitles.db.tar -C $DOWNLOAD/txt_db 32 | rm $DOWNLOAD/txt_db/tv_subtitles.db.tar 33 | fi 34 | 35 | # pretrained 36 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 37 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 38 | fi 39 | 40 | # raw data 41 | RAW_URL=https://raw.githubusercontent.com/jayleicn/TVCaption/66666ec08657d8963b165b18eafabd6427d44261/data/ 42 | for SPLIT in 'train' 'val' 'test_public'; do 43 | wget $RAW_URL/tvc_${SPLIT}_release.jsonl -P $DOWNLOAD/txt_db 44 | done 45 | wget $RAW_URL/tvqa_preprocessed_subtitles.jsonl -P $DOWNLOAD/txt_db 46 | -------------------------------------------------------------------------------- /scripts/download_tvqa.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | # video dbs 17 | if [ ! -d $DOWNLOAD/video_db/tv/ ] ; then 18 | wget $BLOB/video_db/tv.tar -P $DOWNLOAD/video_db/ 19 | tar -xvf $DOWNLOAD/video_db/tv.tar -C $DOWNLOAD/video_db --strip-components 1 20 | rm $DOWNLOAD/video_db/tv.tar 21 | fi 22 | 23 | # text dbs 24 | for SPLIT in 'train' 'val' 'test_public'; do 25 | wget $BLOB/txt_db/tvqa_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 26 | tar -xvf $DOWNLOAD/txt_db/tvqa_$SPLIT.db.tar -C $DOWNLOAD/txt_db 27 | rm $DOWNLOAD/txt_db/tvqa_$SPLIT.db.tar 28 | done 29 | if [ ! -d $DOWNLOAD/txt_db/tv_subtitles.db/ ] ; then 30 | wget $BLOB/txt_db/tv_subtitles.db.tar -P $DOWNLOAD/txt_db/ 31 | tar -xvf $DOWNLOAD/txt_db/tv_subtitles.db.tar -C $DOWNLOAD/txt_db 32 | rm $DOWNLOAD/txt_db/tv_subtitles.db.tar 33 | fi 34 | 35 | # pretrained 36 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 37 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 38 | fi 39 | -------------------------------------------------------------------------------- /scripts/download_tvr.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | # video dbs 17 | if [ ! -d $DOWNLOAD/video_db/tv/ ] ; then 18 | wget $BLOB/video_db/tv.tar -P $DOWNLOAD/video_db/ 19 | tar -xvf $DOWNLOAD/video_db/tv.tar -C $DOWNLOAD/video_db --strip-components 1 20 | rm $DOWNLOAD/video_db/tv.tar 21 | fi 22 | 23 | # text dbs 24 | for SPLIT in 'train' 'val' 'test_public'; do 25 | wget $BLOB/txt_db/tvr_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 26 | tar -xvf $DOWNLOAD/txt_db/tvr_$SPLIT.db.tar -C $DOWNLOAD/txt_db 27 | rm $DOWNLOAD/txt_db/tvr_$SPLIT.db.tar 28 | done 29 | if [ ! -d $DOWNLOAD/txt_db/tv_subtitles.db/ ] ; then 30 | wget $BLOB/txt_db/tv_subtitles.db.tar -P $DOWNLOAD/txt_db/ 31 | tar -xvf $DOWNLOAD/txt_db/tv_subtitles.db.tar -C $DOWNLOAD/txt_db 32 | rm $DOWNLOAD/txt_db/tv_subtitles.db.tar 33 | fi 34 | 35 | # pretrained 36 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 37 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 38 | fi 39 | 40 | # finetuned 41 | wget $BLOB/finetune/tvr_default.tar -P $DOWNLOAD/finetune/ 42 | tar -xvf $DOWNLOAD/finetune/tvr_default.tar -C $DOWNLOAD/finetune --strip-components 1 43 | rm $DOWNLOAD/finetune/tvr_default.tar 44 | -------------------------------------------------------------------------------- /scripts/download_violin.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # Modified from UNITER 4 | # (https://github.com/ChenRocks/UNITER) 5 | 6 | DOWNLOAD=$1 7 | 8 | for FOLDER in 'video_db' 'txt_db' 'pretrained' 'finetune'; do 9 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 10 | mkdir -p $DOWNLOAD/$FOLDER 11 | fi 12 | done 13 | 14 | BLOB='https://convaisharables.blob.core.windows.net/hero' 15 | 16 | # video dbs 17 | if [ ! -d $DOWNLOAD/video_db/violin/ ] ; then 18 | wget $BLOB/video_db/violin.tar -P $DOWNLOAD/video_db/ 19 | tar -xvf $DOWNLOAD/video_db/violin.tar -C $DOWNLOAD/video_db --strip-components 1 20 | rm $DOWNLOAD/video_db/violin.tar 21 | fi 22 | 23 | # text dbs 24 | for SPLIT in 'train' 'val' 'test'; do 25 | wget $BLOB/txt_db/violin_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 26 | tar -xvf $DOWNLOAD/txt_db/violin_$SPLIT.db.tar -C $DOWNLOAD/txt_db 27 | rm $DOWNLOAD/txt_db/violin_$SPLIT.db.tar 28 | done 29 | if [ ! -d $DOWNLOAD/txt_db/violin_subtitles.db/ ] ; then 30 | wget $BLOB/txt_db/violin_subtitles.db.tar -P $DOWNLOAD/txt_db/ 31 | tar -xvf $DOWNLOAD/txt_db/violin_subtitles.db.tar -C $DOWNLOAD/txt_db 32 | rm $DOWNLOAD/txt_db/violin_subtitles.db.tar 33 | fi 34 | 35 | # pretrained 36 | if [ ! -f $DOWNLOAD/pretrained/hero-tv-ht100.pt ] ; then 37 | wget $BLOB/pretrained/hero-tv-ht100.pt -P $DOWNLOAD/pretrained/ 38 | fi 39 | -------------------------------------------------------------------------------- /scripts/prepro_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | preprocess TVR/TVQA/VIOLIN annotations into LMDB 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from os.path import exists 11 | from cytoolz import curry 12 | from tqdm import tqdm 13 | from transformers import RobertaTokenizer 14 | import copy 15 | 16 | # quick hack for import 17 | import sys 18 | sys.path.insert(0, '/src') 19 | from utils.basic_utils import save_jsonl, save_json 20 | from data.data import open_lmdb 21 | 22 | 23 | @curry 24 | def roberta_tokenize(tokenizer, text): 25 | if text.isupper(): 26 | text = text.lower() 27 | words = tokenizer.tokenize(text) 28 | ids = tokenizer.convert_tokens_to_ids(words) 29 | return ids 30 | 31 | 32 | def process_tvr(jsonl, db, tokenizer): 33 | id2len = {} 34 | query2video = {} # not sure if useful 35 | query_data = [] 36 | for line in tqdm( 37 | jsonl, 38 | desc='processing TVR with raw query text'): 39 | example = json.loads(line) 40 | query_data.append(copy.copy(example)) 41 | id_ = example['desc_id'] 42 | input_ids = tokenizer(example["desc"]) 43 | if 'vid_name' in example: 44 | vid = example['vid_name'] 45 | else: 46 | vid = None 47 | if 'ts' in example: 48 | target = example['ts'] 49 | else: 50 | target = None 51 | if vid is not None: 52 | query2video[id_] = vid 53 | example['vid'] = vid 54 | id2len[id_] = len(input_ids) 55 | example['input_ids'] = input_ids 56 | example['target'] = target 57 | example['qid'] = str(id_) 58 | db[str(id_)] = example 59 | return id2len, query2video, query_data 60 | 61 | 62 | def process_tvqa(jsonl, db, tokenizer): 63 | id2len = {} 64 | query2video = {} # not sure if useful 65 | query_data = [] 66 | for line in tqdm(jsonl, desc='processing TVQA with raw QA text'): 67 | example = json.loads(line) 68 | query_data.append(copy.copy(example)) 69 | id_ = example['qid'] 70 | input_ids = [tokenizer(example["q"]), tokenizer(example["a0"]), 71 | tokenizer(example["a1"]), tokenizer(example["a2"]), 72 | tokenizer(example["a3"]), tokenizer(example["a4"])] 73 | vid = example['vid_name'] 74 | if 'ts' in example: 75 | ts = example['ts'] 76 | else: 77 | ts = None 78 | 79 | if 'answer_idx' in example: 80 | target = example['answer_idx'] 81 | else: 82 | target = None 83 | 84 | query2video[id_] = vid 85 | id2len[id_] = [len(input_ids[0]), len(input_ids[1]), len(input_ids[2]), 86 | len(input_ids[3]), len(input_ids[4]), len(input_ids[5])] 87 | example['input_ids'] = input_ids 88 | example['vid'] = vid 89 | example['ts'] = ts 90 | example['target'] = target 91 | example['qid'] = str(id_) 92 | db[str(id_)] = example 93 | return id2len, query2video, query_data 94 | 95 | 96 | def process_violin(jsonl, db, tokenizer): 97 | id2len = {} 98 | query2video = {} # not sure if useful 99 | query_data = [] 100 | for line in tqdm( 101 | jsonl, 102 | desc='processing Violin with raw statement text'): 103 | example = json.loads(line) 104 | query_data.append(copy.copy(example)) 105 | id_ = example['desc_id'] 106 | input_ids = tokenizer(example["desc"]) 107 | vid = example['vid_name'] 108 | target = example['label'] 109 | query2video[id_] = vid 110 | example['vid'] = vid 111 | id2len[id_] = len(input_ids) 112 | example['input_ids'] = input_ids 113 | example['target'] = target 114 | example['qid'] = str(id_) 115 | db[str(id_)] = example 116 | return id2len, query2video, query_data 117 | 118 | 119 | def main(opts): 120 | if not exists(opts.output): 121 | os.makedirs(opts.output) 122 | else: 123 | raise ValueError('Found existing DB. Please explicitly remove ' 124 | 'for re-processing') 125 | meta = vars(opts) 126 | meta['tokenizer'] = opts.toker 127 | toker = RobertaTokenizer.from_pretrained( 128 | opts.toker) 129 | tokenizer = roberta_tokenize(toker) 130 | meta['BOS'] = toker.convert_tokens_to_ids([''])[0] 131 | meta['EOS'] = toker.convert_tokens_to_ids([''])[0] 132 | meta['SEP'] = toker.convert_tokens_to_ids([''])[0] 133 | meta['CLS'] = toker.convert_tokens_to_ids([''])[0] 134 | meta['PAD'] = toker.convert_tokens_to_ids([''])[0] 135 | meta['MASK'] = toker.convert_tokens_to_ids([''])[0] 136 | meta['UNK'] = toker.convert_tokens_to_ids([''])[0] 137 | meta['v_range'] = (toker.convert_tokens_to_ids(['.'])[0], 138 | toker.convert_tokens_to_ids(['<|endoftext|>'])[0]+1) 139 | save_json(vars(opts), f'{opts.output}/meta.json', save_pretty=True) 140 | 141 | open_db = curry(open_lmdb, opts.output, readonly=False) 142 | with open_db() as db: 143 | with open(opts.annotation, "r") as ann: 144 | if opts.task == "tvr": 145 | id2lens, query2video, query_data = process_tvr( 146 | ann, db, tokenizer) 147 | elif opts.task == "tvqa": 148 | id2lens, query2video, query_data = process_tvqa( 149 | ann, db, tokenizer) 150 | elif opts.task == "violin": 151 | id2lens, query2video, query_data = process_violin( 152 | ann, db, tokenizer) 153 | else: 154 | raise NotImplementedError( 155 | f"prepro for {opts.task} not implemented") 156 | 157 | save_json(id2lens, f'{opts.output}/id2len.json') 158 | save_json(query2video, f'{opts.output}/query2video.json') 159 | save_jsonl(query_data, f'{opts.output}/query_data.jsonl') 160 | 161 | 162 | if __name__ == '__main__': 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument('--annotation', required=True, 165 | help='annotation JSON') 166 | parser.add_argument('--output', required=True, 167 | help='output dir of DB') 168 | parser.add_argument('--toker', default='roberta-base', 169 | help='which RoBerTa tokenizer to used') 170 | parser.add_argument('--task', default='tvr', 171 | choices=["tvr", "tvqa", "violin"], 172 | help='which RoBerTa tokenizer to used') 173 | args = parser.parse_args() 174 | main(args) 175 | -------------------------------------------------------------------------------- /scripts/prepro_tvc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | preprocess TVC annotations into LMDB 5 | """ 6 | import argparse 7 | from collections import defaultdict 8 | import json 9 | import os 10 | from os.path import exists 11 | 12 | from cytoolz import curry 13 | from tqdm import tqdm 14 | from transformers import RobertaTokenizer 15 | 16 | # quick hack for import 17 | import sys 18 | sys.path.insert(0, '/src') 19 | from data.data import open_lmdb 20 | 21 | 22 | @curry 23 | def roberta_tokenize(tokenizer, text): 24 | words = tokenizer.tokenize(text) 25 | ids = tokenizer.convert_tokens_to_ids(words) 26 | return ids 27 | 28 | 29 | def _compute_overlapped_subs(ts, subtitles): 30 | st, ed = ts 31 | inds = [] 32 | for i, sub in enumerate(subtitles): 33 | if (st < sub['start'] < ed 34 | or st < sub['end'] < ed 35 | or sub['start'] < st < ed < sub['end']): 36 | inds.append(i) 37 | return inds 38 | 39 | 40 | def process_tvc(cap_jsonl, sub_jsonl, cap_db, clip_db, tokenizer): 41 | # load subtitles 42 | vid2subs = {} 43 | for line in tqdm(sub_jsonl): 44 | sub_info = json.loads(line) 45 | vid2subs[sub_info['vid_name']] = sub_info['sub'] 46 | 47 | id2len = {} 48 | cap2vid = {} 49 | clip2vid = {} 50 | vid2caps = defaultdict(list) 51 | vid2clips = defaultdict(list) 52 | for line in tqdm(cap_jsonl, desc='processing TVC data'): 53 | example = json.loads(line) 54 | vid = example['vid_name'] 55 | ts = example['ts'] 56 | clip_id = str(example['clip_id']) 57 | clip2vid[clip_id] = vid 58 | sub_inds = _compute_overlapped_subs(ts, vid2subs[vid]) 59 | clip = {'vid_name': vid, 'ts': ts, 'sub_indices': sub_inds, 60 | 'duration': example['duration'], 'captions': []} 61 | vid2clips[vid].append(clip_id) 62 | for cap in example['descs']: 63 | cap_id = str(cap['desc_id']) 64 | input_ids = tokenizer(cap['desc']) 65 | cap['input_ids'] = input_ids 66 | cap['vid_name'] = vid 67 | cap['clip_id'] = clip_id 68 | cap['ts'] = ts 69 | cap['sub_indices'] = sub_inds 70 | cap['duration'] = example['duration'] 71 | cap_db[cap_id] = cap 72 | cap2vid[cap_id] = vid 73 | vid2caps[vid].append(cap_id) 74 | 75 | clip['captions'].append({'id': cap['desc_id'], 76 | 'input_ids': input_ids, 77 | 'text': cap['desc']}) 78 | 79 | clip_db[clip_id] = clip 80 | return id2len, cap2vid, clip2vid, vid2caps, vid2clips 81 | 82 | 83 | def main(opts): 84 | if not exists(opts.output): 85 | os.makedirs(opts.output) 86 | else: 87 | print(opts.output) 88 | raise ValueError('Found existing DB. Please explicitly remove ' 89 | 'for re-processing') 90 | meta = vars(opts) 91 | meta['tokenizer'] = opts.toker 92 | toker = RobertaTokenizer.from_pretrained(opts.toker) 93 | tokenizer = roberta_tokenize(toker) 94 | meta['BOS'] = toker.convert_tokens_to_ids([''])[0] 95 | meta['EOS'] = toker.convert_tokens_to_ids([''])[0] 96 | meta['SEP'] = toker.convert_tokens_to_ids([''])[0] 97 | meta['CLS'] = toker.convert_tokens_to_ids([''])[0] 98 | meta['PAD'] = toker.convert_tokens_to_ids([''])[0] 99 | meta['MASK'] = toker.convert_tokens_to_ids([''])[0] 100 | meta['UNK'] = toker.convert_tokens_to_ids([''])[0] 101 | meta['v_range'] = (toker.convert_tokens_to_ids(['.'])[0], 102 | toker.convert_tokens_to_ids(['<|endoftext|>'])[0]+1) 103 | with open(f'{opts.output}/meta.json', 'w') as f: 104 | json.dump(vars(opts), f, indent=4) 105 | 106 | open_cap_db = curry(open_lmdb, f"{opts.output}/cap.db", readonly=False) 107 | open_clip_db = curry(open_lmdb, f"{opts.output}/clip.db", readonly=False) 108 | with open_cap_db() as cap_db, open_clip_db() as clip_db: 109 | with open(opts.annotation) as ann, open(opts.subtitles) as sub: 110 | (id2lens, cap2vid, clip2vid, vid2caps, vid2clips 111 | ) = process_tvc(ann, sub, cap_db, clip_db, tokenizer) 112 | 113 | with open(f'{opts.output}/cap.db/id2len.json', 'w') as f: 114 | json.dump(id2lens, f) 115 | with open(f'{opts.output}/cap.db/cap2vid.json', 'w') as f: 116 | json.dump(cap2vid, f) 117 | with open(f'{opts.output}/clip.db/clip2vid.json', 'w') as f: 118 | json.dump(clip2vid, f) 119 | with open(f'{opts.output}/cap.db/vid2caps.json', 'w') as f: 120 | json.dump(vid2caps, f) 121 | with open(f'{opts.output}/clip.db/vid2clips.json', 'w') as f: 122 | json.dump(vid2clips, f) 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--annotation', required=True, 128 | help='annotation JSON') 129 | parser.add_argument('--subtitles', required=True, 130 | help='subtitle JSON') 131 | parser.add_argument('--output', required=True, 132 | help='output dir of DB') 133 | parser.add_argument('--toker', default='roberta-base', 134 | choices=["roberta-base", "roberta-large"], 135 | help='which RoBerTa tokenizer to used') 136 | args = parser.parse_args() 137 | main(args) 138 | -------------------------------------------------------------------------------- /scripts/prepro_tvc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | 5 | DATA=$1 # txt_db 6 | 7 | for SPLIT in 'val' 'train'; do 8 | CMD="python scripts/prepro_tvc.py \ 9 | --annotation /txt/tvc_${SPLIT}_release.jsonl \ 10 | --subtitles /txt/tvqa_preprocessed_subtitles.jsonl \ 11 | --output /txt/tvc_${SPLIT}_new.db" 12 | 13 | docker run --ipc=host --rm \ 14 | --mount src=$(pwd),dst=/src,type=bind \ 15 | --mount src=$DATA,dst=/txt,type=bind \ 16 | -w /src linjieli222/hero \ 17 | bash -c "$CMD" 18 | done 19 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjieli222/HERO/32c1c523c7a9f547a29f14c8e33dec24ebd14156/utils/__init__.py -------------------------------------------------------------------------------- /utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Basic util functions copied from TVRetrieval implementation 3 | (https://github.com/jayleicn/TVRetrieval) 4 | ''' 5 | import os 6 | import json 7 | import zipfile 8 | import numpy as np 9 | import pickle 10 | import pprint 11 | 12 | 13 | def load_pickle(filename): 14 | with open(filename, "rb") as f: 15 | return pickle.load(f) 16 | 17 | 18 | def save_pickle(data, filename): 19 | with open(filename, "wb") as f: 20 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 21 | 22 | 23 | def load_json(filename): 24 | with open(filename, "r") as f: 25 | return json.load(f) 26 | 27 | 28 | def save_json(data, filename, save_pretty=False, sort_keys=False): 29 | with open(filename, "w") as f: 30 | if save_pretty: 31 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 32 | else: 33 | json.dump(data, f) 34 | 35 | 36 | def load_jsonl(filename): 37 | with open(filename, "r") as f: 38 | return [json.loads(line.strip("\n")) for line in f.readlines()] 39 | 40 | 41 | def save_jsonl(data, filename): 42 | """data is a list""" 43 | with open(filename, "w") as f: 44 | f.write("\n".join([json.dumps(e) for e in data])) 45 | 46 | 47 | def save_lines(list_of_str, filepath): 48 | with open(filepath, "w") as f: 49 | f.write("\n".join(list_of_str)) 50 | 51 | 52 | def read_lines(filepath): 53 | with open(filepath, "r") as f: 54 | return [e.strip("\n") for e in f.readlines()] 55 | 56 | 57 | def mkdirp(p): 58 | if not os.path.exists(p): 59 | os.makedirs(p) 60 | 61 | 62 | def flat_list_of_lists(in_list): 63 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 64 | return [item for sublist in in_list for item in sublist] 65 | 66 | 67 | def convert_to_seconds(hms_time): 68 | """ convert '00:01:12' to 72 seconds. 69 | :hms_time (str): time in comma separated string, e.g. '00:01:12' 70 | :return (int): time in seconds, e.g. 72 71 | """ 72 | times = [float(t) for t in hms_time.split(":")] 73 | return times[0] * 3600 + times[1] * 60 + times[2] 74 | 75 | 76 | def get_video_name_from_url(url): 77 | return url.split("/")[-1][:-4] 78 | 79 | 80 | def merge_dicts(list_dicts): 81 | merged_dict = list_dicts[0].copy() 82 | for i in range(1, len(list_dicts)): 83 | merged_dict.update(list_dicts[i]) 84 | return merged_dict 85 | 86 | 87 | def l2_normalize_np_array(np_array, eps=1e-5): 88 | """np_array: np.ndarray, (*, D), where the last dim will be normalized""" 89 | return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps) 90 | 91 | 92 | def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, 93 | exclude_extensions=None, exclude_dirs_substring=None): 94 | """make a zip file of root_dir, save it to save_path. 95 | exclude_paths will be excluded if it is a subdir of root_dir. 96 | An enclosing_dir is added is specified. 97 | """ 98 | abs_src = os.path.abspath(src_dir) 99 | with zipfile.ZipFile(save_path, "w") as zf: 100 | for dirname, subdirs, files in os.walk(src_dir): 101 | if exclude_dirs is not None: 102 | for e_p in exclude_dirs: 103 | if e_p in subdirs: 104 | subdirs.remove(e_p) 105 | if exclude_dirs_substring is not None: 106 | to_rm = [] 107 | for d in subdirs: 108 | if exclude_dirs_substring in d: 109 | to_rm.append(d) 110 | for e in to_rm: 111 | subdirs.remove(e) 112 | arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:]) 113 | zf.write(dirname, arcname) 114 | for filename in files: 115 | if exclude_extensions is not None: 116 | if os.path.splitext(filename)[1] in exclude_extensions: 117 | continue # do not zip it 118 | absname = os.path.join(dirname, filename) 119 | arcname = os.path.join( 120 | enclosing_dir, absname[len(abs_src) + 1:]) 121 | zf.write(absname, arcname) 122 | 123 | 124 | def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True): 125 | """Dissect an array (N, D) into a list a sub-array, 126 | np_array.shape[0] == sum(lengths), Output is a list of nd arrays 127 | singlton dimention is kept""" 128 | if assert_equal: 129 | assert len(np_array) == sum(lengths) 130 | length_indices = [0, ] 131 | for i in range(len(lengths)): 132 | length_indices.append(length_indices[i] + lengths[i]) 133 | if dim == 0: 134 | array_list = [np_array[length_indices[i]:length_indices[i+1]] 135 | for i in range(len(lengths))] 136 | elif dim == 1: 137 | array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] 138 | for i in range(len(lengths))] 139 | elif dim == 2: 140 | array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] 141 | for i in range(len(lengths))] 142 | else: 143 | raise NotImplementedError 144 | return array_list 145 | 146 | 147 | def get_ratio_from_counter(counter_obj, threshold=200): 148 | keys = counter_obj.keys() 149 | values = counter_obj.values() 150 | filtered_values = [counter_obj[k] for k in keys if k > threshold] 151 | return float(sum(filtered_values)) / sum(values) 152 | 153 | 154 | def get_show_name(vid_name): 155 | """ 156 | get tvshow name from vid_name 157 | :param vid_name: video clip name 158 | :return: tvshow name 159 | """ 160 | show_list = ["friends", "met", "castle", "house", "grey"] 161 | vid_name_prefix = vid_name.split("_")[0] 162 | show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt" 163 | return show_name 164 | 165 | 166 | class FormatPrinter(pprint.PrettyPrinter): 167 | 168 | def __init__(self, formats): 169 | super(FormatPrinter, self).__init__() 170 | self.formats = formats 171 | 172 | def format(self, obj, ctx, maxlvl, lvl): 173 | if type(obj) in self.formats: 174 | return self.formats[type(obj)] % obj, 1, 0 175 | return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl) 176 | -------------------------------------------------------------------------------- /utils/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | constants 5 | """ 6 | VFEAT_DIM = 4352 7 | MAX_FRM_SEQ_LEN = 100 8 | VCMR_IOU_THDS = (0.5, 0.7) 9 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | distributed API using Horovod 6 | Modified from OpenNMT's native pytorch distributed utils 7 | (https://github.com/OpenNMT/OpenNMT-py) 8 | 9 | Copied from UNITER 10 | (https://github.com/ChenRocks/UNITER) 11 | """ 12 | import math 13 | import pickle 14 | 15 | import torch 16 | from horovod import torch as hvd 17 | 18 | 19 | def all_reduce_and_rescale_tensors(tensors, rescale_denom): 20 | """All-reduce and rescale tensors at once (as a flattened tensor) 21 | 22 | Args: 23 | tensors: list of Tensors to all-reduce 24 | rescale_denom: denominator for rescaling summed Tensors 25 | """ 26 | # buffer size in bytes, determine equiv. # of elements based on data type 27 | sz = sum(t.numel() for t in tensors) 28 | buffer_t = tensors[0].new(sz).zero_() 29 | 30 | # copy tensors into buffer_t 31 | offset = 0 32 | for t in tensors: 33 | numel = t.numel() 34 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 35 | offset += numel 36 | 37 | # all-reduce and rescale 38 | hvd.allreduce_(buffer_t[:offset]) 39 | buffer_t.div_(rescale_denom) 40 | 41 | # copy all-reduced buffer back into tensors 42 | offset = 0 43 | for t in tensors: 44 | numel = t.numel() 45 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 46 | offset += numel 47 | 48 | 49 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, 50 | buffer_size=10485760): 51 | """All-reduce and rescale tensors in chunks of the specified size. 52 | 53 | Args: 54 | tensors: list of Tensors to all-reduce 55 | rescale_denom: denominator for rescaling summed Tensors 56 | buffer_size: all-reduce chunk size in bytes 57 | """ 58 | # buffer size in bytes, determine equiv. # of elements based on data type 59 | buffer_t = tensors[0].new( 60 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 61 | buffer = [] 62 | 63 | def all_reduce_buffer(): 64 | # copy tensors into buffer_t 65 | offset = 0 66 | for t in buffer: 67 | numel = t.numel() 68 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 69 | offset += numel 70 | 71 | # all-reduce and rescale 72 | hvd.allreduce_(buffer_t[:offset]) 73 | buffer_t.div_(rescale_denom) 74 | 75 | # copy all-reduced buffer back into tensors 76 | offset = 0 77 | for t in buffer: 78 | numel = t.numel() 79 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 80 | offset += numel 81 | 82 | filled = 0 83 | for t in tensors: 84 | sz = t.numel() * t.element_size() 85 | if sz > buffer_size: 86 | # tensor is bigger than buffer, all-reduce and rescale directly 87 | hvd.allreduce_(t) 88 | t.div_(rescale_denom) 89 | elif filled + sz > buffer_size: 90 | # buffer is full, all-reduce and replace buffer with grad 91 | all_reduce_buffer() 92 | buffer = [t] 93 | filled = sz 94 | else: 95 | # add tensor to buffer 96 | buffer.append(t) 97 | filled += sz 98 | 99 | if len(buffer) > 0: 100 | all_reduce_buffer() 101 | 102 | 103 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760): 104 | """broadcast tensors in chunks of the specified size. 105 | 106 | Args: 107 | tensors: list of Tensors to broadcast 108 | root_rank: rank to broadcast 109 | buffer_size: broadcast chunk size in bytes 110 | """ 111 | # buffer size in bytes, determine equiv. # of elements based on data type 112 | buffer_t = tensors[0].new( 113 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 114 | buffer = [] 115 | 116 | def broadcast_buffer(): 117 | # copy tensors into buffer_t 118 | offset = 0 119 | for t in buffer: 120 | numel = t.numel() 121 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 122 | offset += numel 123 | 124 | # broadcast 125 | hvd.broadcast_(buffer_t[:offset], root_rank) 126 | 127 | # copy all-reduced buffer back into tensors 128 | offset = 0 129 | for t in buffer: 130 | numel = t.numel() 131 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 132 | offset += numel 133 | 134 | filled = 0 135 | for t in tensors: 136 | sz = t.numel() * t.element_size() 137 | if sz > buffer_size: 138 | # tensor is bigger than buffer, broadcast directly 139 | hvd.broadcast_(t, root_rank) 140 | elif filled + sz > buffer_size: 141 | # buffer is full, broadcast and replace buffer with tensor 142 | broadcast_buffer() 143 | buffer = [t] 144 | filled = sz 145 | else: 146 | # add tensor to buffer 147 | buffer.append(t) 148 | filled += sz 149 | 150 | if len(buffer) > 0: 151 | broadcast_buffer() 152 | 153 | 154 | def _encode(enc, max_size, use_max_size=False): 155 | enc_size = len(enc) 156 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 157 | if use_max_size: 158 | # this is used for broadcasting 159 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) 160 | else: 161 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 162 | remainder = enc_size 163 | for i in range(enc_byte): 164 | base = 256 ** (enc_byte-i-1) 165 | buffer_[i] = remainder // base 166 | remainder %= base 167 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 168 | return buffer_, enc_byte 169 | 170 | 171 | def _decode(buffer_, enc_byte): 172 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 173 | for i in range(enc_byte)) 174 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 175 | shift = size + enc_byte 176 | return bytes_list, shift 177 | 178 | 179 | _BUFFER_SIZE = 4096 180 | 181 | 182 | def all_gather_list(data): 183 | """Gathers arbitrary data from all nodes into a list.""" 184 | enc = pickle.dumps(data) 185 | 186 | enc_size = len(enc) 187 | max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() 188 | in_buffer, enc_byte = _encode(enc, max_size) 189 | 190 | out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) 191 | 192 | results = [] 193 | for _ in range(hvd.size()): 194 | bytes_list, shift = _decode(out_buffer, enc_byte) 195 | out_buffer = out_buffer[shift:] 196 | result = pickle.loads(bytes_list) 197 | results.append(result) 198 | return results 199 | 200 | 201 | def any_broadcast(data, root_rank): 202 | """broadcast arbitrary data from root_rank to all nodes.""" 203 | enc = pickle.dumps(data) 204 | 205 | max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() 206 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) 207 | 208 | hvd.broadcast_(buffer_, root_rank) 209 | 210 | bytes_list, _ = _decode(buffer_, enc_byte) 211 | result = pickle.loads(bytes_list) 212 | return result -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | some functions are modified from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | 8 | helper for logging 9 | NOTE: loggers are global objects use with caution 10 | """ 11 | import logging 12 | 13 | import tensorboardX 14 | 15 | 16 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 17 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 18 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 19 | LOGGER = logging.getLogger('__main__') # this is the global logger 20 | 21 | 22 | def add_log_to_file(log_path): 23 | fh = logging.FileHandler(log_path) 24 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 25 | fh.setFormatter(formatter) 26 | LOGGER.addHandler(fh) 27 | 28 | 29 | class TensorboardLogger(object): 30 | def __init__(self): 31 | self._logger = None 32 | self._global_step = 0 33 | 34 | def create(self, path): 35 | self._logger = tensorboardX.SummaryWriter(path) 36 | 37 | def noop(self, *args, **kwargs): 38 | return 39 | 40 | def step(self): 41 | self._global_step += 1 42 | 43 | @property 44 | def global_step(self): 45 | return self._global_step 46 | 47 | @global_step.setter 48 | def global_step(self, step): 49 | self._global_step = step 50 | 51 | def log_scaler_dict(self, log_dict, prefix=''): 52 | """ log a dictionary of scalar values""" 53 | if self._logger is None: 54 | return 55 | if prefix: 56 | prefix = f'{prefix}_' 57 | for name, value in log_dict.items(): 58 | if isinstance(value, dict): 59 | self.log_scaler_dict(value, self._global_step, 60 | prefix=f'{prefix}{name}') 61 | else: 62 | self._logger.add_scalar(f'{prefix}{name}', value, 63 | self._global_step) 64 | 65 | def __getattr__(self, name): 66 | if self._logger is None: 67 | return self.noop 68 | return self._logger.__getattribute__(name) 69 | 70 | 71 | TB_LOGGER = TensorboardLogger() 72 | 73 | 74 | class RunningMeter(object): 75 | """ running meteor of a scalar value 76 | (useful for monitoring training loss) 77 | """ 78 | def __init__(self, name, val=None, smooth=0.99): 79 | self._name = name 80 | self._sm = smooth 81 | self._val = val 82 | 83 | def __call__(self, value): 84 | self._val = (value if self._val is None 85 | else value*(1-self._sm) + self._val*self._sm) 86 | 87 | def __str__(self): 88 | return f'{self._name}: {self._val:.4f}' 89 | 90 | @property 91 | def val(self): 92 | return self._val 93 | 94 | @property 95 | def name(self): 96 | return self._name 97 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Copied from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | 8 | Misc utilities 9 | """ 10 | import random 11 | 12 | import torch 13 | import numpy as np 14 | 15 | from utils.logger import LOGGER 16 | 17 | 18 | class Struct(object): 19 | def __init__(self, dict_): 20 | self.__dict__.update(dict_) 21 | 22 | 23 | class NoOp(object): 24 | """ useful for distributed training No-Ops """ 25 | def __getattr__(self, name): 26 | return self.noop 27 | 28 | def noop(self, *args, **kwargs): 29 | return 30 | 31 | 32 | def set_dropout(model, drop_p): 33 | for name, module in model.named_modules(): 34 | # we might want to tune dropout for smaller dataset 35 | if isinstance(module, torch.nn.Dropout): 36 | if module.p != drop_p: 37 | module.p = drop_p 38 | LOGGER.info(f'{name} set to {drop_p}') 39 | 40 | 41 | def set_random_seed(seed): 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Modified from UNITER 6 | (https://github.com/ChenRocks/UNITER) 7 | 8 | saving utilities 9 | """ 10 | import json 11 | import os 12 | from os.path import abspath, dirname, exists, join, realpath 13 | import subprocess 14 | from apex import amp 15 | import torch 16 | 17 | from utils.logger import LOGGER 18 | from utils.basic_utils import save_json, make_zipfile, load_json 19 | 20 | 21 | def save_training_meta(args): 22 | # Comment out, since rank is not saved to args. Safeguard save_training_meta already in training scripts. 23 | # if args.rank > 0: 24 | # return 25 | 26 | # args is an EasyDict object, treat it the same as a normal dict 27 | os.makedirs(join(args.output_dir, 'log'), exist_ok=True) 28 | os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True) 29 | 30 | # training args 31 | save_args_path = join(args.output_dir, 'log', 'hps.json') 32 | save_json(vars(args), save_args_path, save_pretty=True) 33 | 34 | # model args 35 | model_config = load_json(args.model_config) 36 | save_model_config_path = join(args.output_dir, 'log', 'model_config.json') 37 | save_json(model_config, save_model_config_path, save_pretty=True) 38 | # git info 39 | try: 40 | LOGGER.info("Waiting on git info....") 41 | c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], 42 | timeout=10, stdout=subprocess.PIPE) 43 | git_branch_name = c.stdout.decode().strip() 44 | LOGGER.info("Git branch: %s", git_branch_name) 45 | c = subprocess.run(["git", "rev-parse", "HEAD"], 46 | timeout=10, stdout=subprocess.PIPE) 47 | git_sha = c.stdout.decode().strip() 48 | LOGGER.info("Git SHA: %s", git_sha) 49 | git_dir = abspath(dirname(__file__)) 50 | git_status = subprocess.check_output( 51 | ['git', 'status', '--short'], 52 | cwd=git_dir, universal_newlines=True).strip() 53 | with open(join(args.output_dir, 'log', 'git_info.json'), 54 | 'w') as writer: 55 | json.dump({'branch': git_branch_name, 56 | 'is_dirty': bool(git_status), 57 | 'status': git_status, 58 | 'sha': git_sha}, 59 | writer, indent=4) 60 | except (subprocess.TimeoutExpired, subprocess.CalledProcessError) as e: 61 | LOGGER.exception(e) 62 | LOGGER.warn("Git info not found. Saving code into zip instead...") 63 | # save a copy of the codebase. 64 | # !!!Do not store heavy file in your codebase when using it. 65 | code_dir = dirname(dirname(realpath(__file__))) 66 | code_zip_filename = os.path.join(args.output_dir, "code.zip") 67 | LOGGER.info(f"Saving code from {code_dir} to {code_zip_filename}...") 68 | make_zipfile(code_dir, code_zip_filename, 69 | enclosing_dir="code", 70 | exclude_dirs_substring="results", 71 | exclude_dirs=["results", "debug_results", "__pycache__"], 72 | exclude_extensions=[".pyc", ".ipynb", ".swap"]) 73 | LOGGER.info("Saving code done.") 74 | 75 | 76 | def _to_cuda(state): 77 | """ usually load from cpu checkpoint but need to load to cuda """ 78 | if isinstance(state, torch.Tensor): 79 | ret = state.cuda() # assume propoerly set py torch.cuda.set_device 80 | if 'Half' in state.type(): 81 | ret = ret.float() # apex O2 requires it 82 | return ret 83 | elif isinstance(state, list): 84 | new_state = [_to_cuda(t) for t in state] 85 | elif isinstance(state, tuple): 86 | new_state = tuple(_to_cuda(t) for t in state) 87 | elif isinstance(state, dict): 88 | new_state = {n: _to_cuda(t) for n, t in state.items()} 89 | else: 90 | return state 91 | return new_state 92 | 93 | 94 | def _to_cpu(state): 95 | """ store in cpu to avoid GPU0 device, fp16 to save space """ 96 | if isinstance(state, torch.Tensor): 97 | ret = state.cpu() 98 | if 'Float' in state.type(): 99 | ret = ret.half() 100 | return ret 101 | elif isinstance(state, list): 102 | new_state = [_to_cpu(t) for t in state] 103 | elif isinstance(state, tuple): 104 | new_state = tuple(_to_cpu(t) for t in state) 105 | elif isinstance(state, dict): 106 | new_state = {n: _to_cpu(t) for n, t in state.items()} 107 | else: 108 | return state 109 | return new_state 110 | 111 | 112 | class ModelSaver(object): 113 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 114 | self.output_dir = output_dir 115 | self.prefix = prefix 116 | self.suffix = suffix 117 | 118 | def save(self, model, step, optimizer=None): 119 | output_model_file = join(self.output_dir, 120 | f"{self.prefix}_{step}.{self.suffix}") 121 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 122 | for k, v in model.state_dict().items()} 123 | for k, v in state_dict.items(): 124 | if 'word_embeddings.weight' in k or 'decoder.weight' in k: 125 | assert v.size(0) % 8 == 0 126 | state_dict['vocab_padded'] = True 127 | break 128 | else: 129 | state_dict['vocab_padded'] = False 130 | torch.save(state_dict, output_model_file) 131 | if optimizer is not None: 132 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 133 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 134 | 135 | 136 | class TrainingRestorer(object): 137 | def __init__(self, opts, model, optimizer): 138 | if exists(f'{opts.output_dir}/log/hps.json'): 139 | restore_opts = json.load(open( 140 | f'{opts.output_dir}/log/hps.json', 'r')) 141 | assert vars(opts) == restore_opts 142 | # keep 2 checkpoints in case of corrupted 143 | self.save_path = f'{opts.output_dir}/restore.pt' 144 | self.backup_path = f'{opts.output_dir}/restore_backup.pt' 145 | self.model = model 146 | self.optimizer = optimizer 147 | self.save_steps = opts.save_steps 148 | self.amp = opts.fp16 149 | if exists(self.save_path) or exists(self.backup_path): 150 | LOGGER.info('found previous checkpoint. try to resume...') 151 | self.restore(opts) 152 | else: 153 | self.global_step = 0 154 | 155 | def step(self): 156 | self.global_step += 1 157 | if self.global_step % self.save_steps == 0: 158 | self.save() 159 | 160 | def save(self): 161 | checkpoint = {'global_step': self.global_step, 162 | 'model_state_dict': _to_cpu(self.model.state_dict()), 163 | 'optim_state_dict': _to_cpu(self.optimizer.state_dict())} 164 | if self.amp: 165 | checkpoint['amp_state_dict'] = amp.state_dict() 166 | if exists(self.save_path): 167 | os.rename(self.save_path, self.backup_path) 168 | torch.save(checkpoint, self.save_path) 169 | 170 | def restore(self, opts): 171 | try: 172 | checkpoint = torch.load(self.save_path) 173 | except Exception: 174 | checkpoint = torch.load(self.backup_path) 175 | self.global_step = checkpoint['global_step'] 176 | self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict'])) 177 | self.optimizer.load_state_dict( 178 | _to_cuda(checkpoint['optim_state_dict'])) 179 | if self.amp: 180 | amp.load_state_dict(checkpoint['amp_state_dict']) 181 | LOGGER.info(f'resume training from step {self.global_step}') 182 | --------------------------------------------------------------------------------