├── .gitignore ├── LICENSE ├── README.md ├── THIRD_PARTY_LICENSES.md ├── config └── vast │ ├── captioner_cfg │ ├── caption-generation-audio.json │ └── caption-generation-vision.json │ ├── default_model_cfg.json │ ├── default_run_cfg.json │ ├── finetune_cfg │ ├── VQA-activitynet.json │ ├── VQA-msrvtt.json │ ├── VQA-msvd.json │ ├── VQA-music.json │ ├── VQA-tgif.json │ ├── VQA-vqav2.json │ ├── caption-audiocaps.json │ ├── caption-clothov2.json │ ├── caption-mscoco.json │ ├── caption-msrvtt.json │ ├── caption-msvd.json │ ├── caption-tv.json │ ├── caption-valor32k.json │ ├── caption-vatex.json │ ├── caption-youcook.json │ ├── retrieval-activitynet.json │ ├── retrieval-audiocaps.json │ ├── retrieval-clothov2.json │ ├── retrieval-didemo.json │ ├── retrieval-flickr.json │ ├── retrieval-lsmdc.json │ ├── retrieval-mscoco.json │ ├── retrieval-msrvtt.json │ ├── retrieval-valor32k.json │ ├── retrieval-vatex.json │ └── retrieval-youcook.json │ └── pretrain_cfg │ └── pretrain_vast.json ├── data ├── IndexAnno.py ├── IndexSrc.py ├── __init__.py ├── audio_mapper.py ├── loader.py └── vision_mapper.py ├── evaluation ├── __init__.py └── evaluation_mm.py ├── evaluation_tools ├── __init__.py ├── caption_tools │ ├── cocoEvalAllSPICEDemo.ipynb │ ├── cocoEvalCapDemo.ipynb │ ├── get_google_word2vec_model.sh │ ├── get_stanford_models.sh │ ├── pycocoevalcap │ │ ├── __init__.py │ │ ├── bleu │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── eval.py │ │ ├── eval_spice.py │ │ ├── meteor │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ └── paraphrase-en.gz │ │ │ ├── meteor-1.5.jar │ │ │ └── meteor.py │ │ ├── rouge │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ ├── tokenizer │ │ │ ├── __init__.py │ │ │ ├── ptbtokenizer.py │ │ │ ├── stanford-corenlp-3.4.1.jar │ │ │ ├── tmp0h6fcu13 │ │ │ ├── tmp0tjxfbx9 │ │ │ ├── tmp6p13kv9q │ │ │ ├── tmp_j4nl228 │ │ │ ├── tmp_t697skg │ │ │ ├── tmpagmqx2xa │ │ │ ├── tmpcgw5utq8 │ │ │ ├── tmpfpxy5t7t │ │ │ ├── tmpfqd7mk6p │ │ │ ├── tmphtkot7_1 │ │ │ ├── tmpi4w8y1s9 │ │ │ ├── tmpjve_qo10 │ │ │ ├── tmpk_do254y │ │ │ ├── tmpkwuh2_su │ │ │ ├── tmpqf8v96v2 │ │ │ ├── tmpqpusdi9z │ │ │ ├── tmpsohw7poj │ │ │ ├── tmpsuczi0e7 │ │ │ ├── tmpsvgf063u │ │ │ ├── tmpswwdg3wt │ │ │ ├── tmpv72jh5ig │ │ │ ├── tmpvr9bp8m8 │ │ │ ├── tmpw7hxvxkw │ │ │ ├── tmpz5ekywsl │ │ │ └── tmpzkac7gs7 │ │ └── wmd │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ └── stopwords.txt │ │ │ └── wmd.py │ └── pycocotools │ │ ├── __init__.py │ │ └── coco.py └── vqa_tools │ ├── __init__.py │ ├── vqa.py │ └── vqa_eval.py ├── img ├── VAST-model.jpg └── radar_compare_alldata_vast.png ├── model ├── __init__.py ├── audio_encoders │ ├── ast │ │ └── ast.py │ └── beats │ │ └── beats.py ├── general_module.py ├── text_encoders │ └── bert │ │ └── bert.py ├── vast.py └── vision_encoders │ ├── clip │ ├── clip.py │ └── clip_tokenizer.py │ ├── evaclip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── constants.py │ ├── eva_vit_model.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-CLIP-B-16.json │ │ ├── EVA01-CLIP-g-14-plus.json │ │ ├── EVA01-CLIP-g-14.json │ │ ├── EVA02-CLIP-B-16.json │ │ ├── EVA02-CLIP-L-14-336.json │ │ ├── EVA02-CLIP-L-14.json │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ └── EVA02-CLIP-bigE-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── rope.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ └── utils.py │ ├── swin │ ├── swin.py │ └── swin_config.py │ └── videoswin │ └── videoswin.py ├── preinstall.sh ├── run.py ├── scripts └── vast │ ├── audio_captioner.sh │ ├── finetune_cap.sh │ ├── finetune_qa.sh │ ├── finetune_ret.sh │ ├── pretrain_vast.sh │ └── vision_captioner.sh └── utils ├── __init__.py ├── args.py ├── build_dataloader.py ├── build_model.py ├── build_optimizer.py ├── distributed.py ├── initialize.py ├── logger.py ├── offline_process_data.py ├── pipeline.py ├── save.py ├── sched.py └── tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ctags 2 | tags 3 | apex/build 4 | aic_caption/ 5 | cococaption/pycocoevalcap/spice 6 | output/ 7 | ouptut/ 8 | datasets/ 9 | lpips/ 10 | upload/ 11 | __pycache__ 12 | UVG/evaluation 13 | gen_evaluation 14 | evaluation_tools/bleurt_master 15 | pretrained_weights 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | .attach_pid* 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | 121 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sihan Chen, Handong Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /THIRD_PARTY_LICENSES.md: -------------------------------------------------------------------------------- 1 | # Third-Party Licenses 2 | 3 | This project includes code from the following open-source projects, each of which may have its own license: 4 | 5 | ## 1. Microsoft COCO Toolbox 6 | - Authors: Piotr Dollar, Tsung-Yi Lin 7 | - License: Simplified BSD License 8 | - Link: [Microsoft COCO Toolbox](http://mscoco.org/) 9 | 10 | ## 2. salesforce.com, inc. 11 | - License: BSD-3-Clause 12 | - Link to full license text: [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause) 13 | 14 | ## 3. Google AI Language Team Authors, The HuggingFace Inc. team, NVIDIA CORPORATION 15 | - License: Apache License, Version 2.0 16 | - Link to full license text: [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) 17 | 18 | ## 4. BEATs: Audio Pre-Training with Acoustic Tokenizers 19 | - Authors: Microsoft 20 | - License: MIT License 21 | - Link: [BEATs](https://arxiv.org/abs/2212.09058) 22 | 23 | ## 5. CLIP Model 24 | - Authors: OpenAI 25 | - Adapted from [OpenAI/CLIP](https://github.com/openai/CLIP) 26 | - Original License: MIT License 27 | 28 | ## 6. Swin Transformer 29 | - Authors: Microsoft 30 | - License: MIT License 31 | 32 | -------------------------------------------------------------------------------- /config/vast/captioner_cfg/caption-generation-audio.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json", 3 | "mode":"testing"}, 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | 10 | {"train":{}, 11 | 12 | "val": 13 | [ 14 | { 15 | "type":"annoindexed", 16 | "training":false, 17 | "name": "yourdata", 18 | "txt": "datasets/annotations/yourdata/meta.json", 19 | "audio": "datasets/srcdata/yourdata/audios", 20 | "audio_sample_num": 3, 21 | "task" : "cap%ta", 22 | "n_workers": 8, 23 | "batch_size": 64 } 24 | ]}} 25 | 26 | 27 | -------------------------------------------------------------------------------- /config/vast/captioner_cfg/caption-generation-vision.json: -------------------------------------------------------------------------------- 1 | 2 | {"run_cfg": 3 | {"default":"./config/vast/default_run_cfg.json", 4 | "mode":"testing"}, 5 | 6 | "model_cfg": 7 | {"default":"./config/vast/default_model_cfg.json"}, 8 | 9 | "data_cfg": 10 | 11 | {"train":{}, 12 | 13 | "val": 14 | [{ 15 | "type":"annoindexed", 16 | "training":false, 17 | "name": "yourdata", 18 | "txt": "datasets/annotations/yourdata/meta.json", 19 | "vision": "datasets/srcdata/yourdata/videos", 20 | "vision_format": "video_rawvideo", 21 | "vision_sample_num": 8, 22 | "task" : "cap%tv", 23 | "n_workers": 8, 24 | "batch_size": 64 25 | }]}} -------------------------------------------------------------------------------- /config/vast/default_model_cfg.json: -------------------------------------------------------------------------------- 1 | {"model_type": "vast", 2 | "itm_ratio":0.1, 3 | "frozen_vision":false, 4 | "frozen_audio":false, 5 | "checkpointing":false, 6 | "max_caption_len":40, 7 | "max_omni_caption_len":70, 8 | "max_subtitle_len":70, 9 | "contra_dim":512, 10 | "inherit_keys":["vision_encoder_type","audio_encoder_type","audio_melbins","audio_target_length"], 11 | "frame_embedding_type":"adaptive", 12 | "vision_resolution":224, 13 | "vision_encoder_type":"evaclip01_giant", 14 | "audio_encoder_type":"beats", 15 | "audio_melbins":64, 16 | "audio_target_length": 1024, 17 | "beam_size":3, 18 | "captioner_mode":false, 19 | "generate_nums":1, 20 | "ret_bidirection_evaluation":false, 21 | "itm_rerank_num":50, 22 | "evaluation_type":"evaluation_mm"} 23 | 24 | -------------------------------------------------------------------------------- /config/vast/default_run_cfg.json: -------------------------------------------------------------------------------- 1 | {"checkpoint":"", 2 | "output_dir":"none", 3 | "gradient_accumulation_steps":1, 4 | "clip_lr":5e-7, 5 | "optim":"adamw", 6 | "learning_rate":1e-4, 7 | "betas":[0.9, 0.98], 8 | "weight_decay":0.01, 9 | "grad_norm":2.0, 10 | "warmup_ratio":0.1, 11 | "resume":false, 12 | "seed":50, 13 | "fp16":true, 14 | "bf16":false, 15 | "zero_shot":false, 16 | "scheduler":"warmup_linear", 17 | "new_lr":0, 18 | "new_params_name":[], 19 | "valid_freq":10, 20 | "dataset_mix_type":"random", 21 | "remove_before_ckpt":true, 22 | "first_eval":true, 23 | "pretrain_dir":"", 24 | "num_train_steps":0, 25 | "save_best":false, 26 | "pin_mem":true, 27 | "vision_resolution":224, 28 | "use_ddp":true, 29 | "mode":"training", 30 | "log_steps":100 31 | } -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-activitynet.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "activitynet_qa", 13 | "txt": "datasets/annotations/activitynet/descs_qa_train.json", 14 | "vision": "datasets/srcdata/activitynet/videos", 15 | "audio": "datasets/srcdata/activitynet/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 2, 19 | "task" : "qa%tva", 20 | "epoch": 10, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "activitynet_qa", 28 | "txt": "datasets/annotations/activitynet/descs_qa_test.json", 29 | "vision": "datasets/srcdata/activitynet/videos", 30 | "audio": "datasets/srcdata/activitynet/audios", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 16, 33 | "audio_sample_num": 2, 34 | "task" : "qa%tva", 35 | "n_workers": 8, 36 | "batch_size": 8 37 | }]}} 38 | 39 | 40 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-msrvtt.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "msrvtt_qa", 13 | "txt": "datasets/annotations/msrvtt/descs_qa_trainval.json", 14 | "vision": "datasets/srcdata/msrvtt/videos", 15 | "audio": "datasets/srcdata/msrvtt/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_transforms":"crop_flip", 18 | "vision_sample_num": 8, 19 | "audio_sample_num": 1, 20 | "task" : "qa%tvas", 21 | "epoch": 4.5, 22 | "n_workers":8, 23 | "batch_size": 64}], 24 | "val": 25 | [{ 26 | "type":"annoindexed", 27 | "training":false, 28 | "name": "msrvtt_qa", 29 | "txt": "datasets/annotations/msrvtt/descs_qa_test.json", 30 | "vision": "datasets/srcdata/msrvtt/videos", 31 | "audio": "datasets/srcdata/msrvtt/audios", 32 | "vision_transforms":"crop_flip", 33 | "vision_format": "video_rawvideo", 34 | "vision_sample_num": 8, 35 | "audio_sample_num": 1, 36 | "task" : "qa%tvas", 37 | "n_workers": 8, 38 | "batch_size": 8 39 | }]}} 40 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-msvd.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "msvd_qa", 13 | "txt": "datasets/annotations/msvd/descs_qa_trainval.json", 14 | "vision": "datasets/srcdata/msvd/videos", 15 | "vision_format": "video_rawvideo", 16 | "vision_sample_num": 8, 17 | "task" : "qa%tv", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "msvd_qa", 26 | "txt": "datasets/annotations/msvd/descs_qa_test.json", 27 | "vision": "datasets/srcdata/msvd/videos", 28 | "vision_format": "video_rawvideo", 29 | "vision_sample_num": 8, 30 | "task" : "qa%tv", 31 | "n_workers": 8, 32 | "batch_size": 8 33 | }]}} 34 | 35 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-music.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "music_qa", 13 | "txt": "datasets/annotations/music/descs_qa_train.json", 14 | "vision": "datasets/srcdata/music/videos", 15 | "audio": "datasets/srcdata/music/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 2, 19 | "task" : "qa%tva", 20 | "epoch": 4.5, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "music_qa", 28 | "txt": "datasets/annotations/music/descs_qa_test.json", 29 | "vision": "datasets/srcdata/music/videos", 30 | "audio": "datasets/srcdata/music/audios", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 8, 33 | "audio_sample_num": 2, 34 | "task" : "qa%tva", 35 | "n_workers": 8, 36 | "batch_size": 8 37 | }]}} 38 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-tgif.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "tgif_qa", 13 | "txt": "datasets/annotations/tgif/descs_qa_train.json", 14 | "vision": "datasets/srcdata/tgif/videos", 15 | "vision_format": "video_rawvideo", 16 | "vision_sample_num": 4, 17 | "task" : "qa%tv", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "tgif_qa", 26 | "txt": "datasets/annotations/tgif/descs_qa_test.json", 27 | "vision": "datasets/srcdata/tgif/videos", 28 | "vision_format": "video_rawvideo", 29 | "vision_sample_num": 4, 30 | "task" : "qa%tv", 31 | "n_workers": 8, 32 | "batch_size": 8 33 | }]}} 34 | 35 | 36 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/VQA-vqav2.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | 5 | "model_cfg": 6 | {"default":"./config/vast/default_model_cfg.json"}, 7 | 8 | "data_cfg":{"train": 9 | [{ "type":"annoindexed", 10 | "training":true, 11 | "name": "vqav2_trainval", 12 | "txt": "datasets/annotations/mscoco/descs_qa_train.json", 13 | "vision": "datasets/srcdata/mscoco/images", 14 | "max_caption_len":30, 15 | "task" : "qa%tv", 16 | "vision_format":"image_rawimage", 17 | "epoch": 20, 18 | "n_workers":8, 19 | "batch_size": 128} 20 | 21 | ], 22 | "val": 23 | [ 24 | 25 | {"type":"annoindexed", 26 | "training":false, 27 | "name": "vqav2_test", 28 | "txt": "datasets/annotations/mscoco/descs_qa_test.json", 29 | "vision": "datasets/srcdata/mscoco/images", 30 | "vision_format":"image_rawimage", 31 | "max_caption_len":30, 32 | "task" : "qa%tv", 33 | "n_workers": 8, 34 | "batch_size": 64, 35 | "make_submission":true 36 | 37 | } 38 | ]}} 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-audiocaps.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | 10 | {"train": 11 | [{ "type":"annoindexed", 12 | "training":true, 13 | "name": "audiocaps_cap", 14 | "txt": "datasets/annotations/audiocaps/descs_cap_trainval.json", 15 | "audio": "datasets/srcdata/audiocaps/audios", 16 | "audio_sample_num": 1, 17 | "task" : "cap%ta", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "audiocaps_cap", 26 | "txt": "datasets/annotations/audiocaps/descs_cap_test.json", 27 | "audio": "datasets/srcdata/audiocaps/audios", 28 | "annfile": "datasets/annotations/audiocaps/caption_annotation.json", 29 | "audio_sample_num": 1, 30 | "task" : "cap%ta", 31 | "n_workers": 8, 32 | "batch_size": 64 33 | }]}} 34 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-clothov2.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | 10 | {"train": 11 | [{ "type":"annoindexed", 12 | "training":true, 13 | "name": "clothov2_cap", 14 | "txt": "datasets/annotations/clothov2/descs_cap_trainval.json", 15 | "audio": "datasets/srcdata/clothov2/audios", 16 | "audio_sample_num": 3, 17 | "task" : "cap%ta", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "clothov2_cap", 26 | "txt": "datasets/annotations/clothov2/descs_cap_test.json", 27 | "audio": "datasets/srcdata/clothov2/audios", 28 | "annfile": "datasets/annotations/clothov2/caption_annotation.json", 29 | "audio_sample_num": 3, 30 | "task" : "cap%ta", 31 | "n_workers": 8, 32 | "batch_size": 64 33 | }]}} 34 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-mscoco.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "mscoco_cap", 13 | "txt": "datasets/annotations/mscoco/descs_cap_train.json", 14 | "vision": "datasets/srcdata/mscoco/images", 15 | "vision_format": "image_rawimage", 16 | "task" : "cap%tv", 17 | "epoch": 5, 18 | "n_workers": 8, 19 | "batch_size": 64}], 20 | "val": 21 | [{ 22 | "type":"annoindexed", 23 | "training":false, 24 | "name": "mscoco_cap", 25 | "txt": "datasets/annotations/mscoco/descs_cap_test.json", 26 | "vision": "datasets/srcdata/mscoco/images", 27 | "annfile": "datasets/annotations/mscoco/caption_annotation.json", 28 | "vision_format": "image_rawimage", 29 | "task" : "cap%tv", 30 | "n_workers": 8, 31 | "batch_size": 128 32 | }]}} 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-msrvtt.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "msrvtt_cap", 13 | "txt": "datasets/annotations/msrvtt/descs_cap_train.json", 14 | "vision": "datasets/srcdata/msrvtt/videos", 15 | "vision_transforms":"crop_flip", 16 | "audio": "datasets/srcdata/msrvtt/audios", 17 | "vision_format": "video_rawvideo", 18 | "vision_sample_num": 8, 19 | "audio_sample_num": 1, 20 | "task" : "cap%tvas", 21 | "epoch": 5, 22 | "n_workers":8, 23 | "batch_size": 64}], 24 | "val": 25 | [{ 26 | "type":"annoindexed", 27 | "training":false, 28 | "name": "msrvtt_cap", 29 | "txt": "datasets/annotations/msrvtt/descs_cap_test.json", 30 | "vision": "datasets/srcdata/msrvtt/videos", 31 | "audio": "datasets/srcdata/msrvtt/audios", 32 | "vision_transforms":"crop_flip", 33 | "annfile": "datasets/annotations/msrvtt/caption_annotation.json", 34 | "vision_format": "video_rawvideo", 35 | "vision_sample_num": 16, 36 | "audio_sample_num": 1, 37 | "task" : "cap%tvas", 38 | "n_workers": 8, 39 | "batch_size": 64 40 | }]}} 41 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-msvd.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "msvd_cap", 13 | "txt": "datasets/annotations/msvd/descs_cap_train.json", 14 | "vision": "datasets/srcdata/msvd/videos", 15 | "vision_format": "video_rawvideo", 16 | "vision_sample_num": 8, 17 | "task" : "cap%tv", 18 | "epoch": 2.5, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "msvd_cap", 26 | "txt": "datasets/annotations/msvd/descs_cap_test.json", 27 | "vision": "datasets/srcdata/msvd/videos", 28 | "annfile": "datasets/annotations/msvd/caption_annotation.json", 29 | "vision_format": "video_rawvideo", 30 | "vision_sample_num": 8, 31 | "task" : "cap%tv", 32 | "n_workers": 8, 33 | "batch_size": 64 34 | }]}} 35 | 36 | 37 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-tv.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "tv_cap", 13 | "txt": "datasets/annotations/tv/descs_cap_train.json", 14 | "vision": "datasets/srcdata/tv/frames_fps3", 15 | "vision_format": "video_frame", 16 | "vision_sample_num": 8, 17 | "task" : "cap%tv", 18 | "epoch": 20, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "tv_cap", 26 | "txt": "datasets/annotations/tv/descs_cap_test.json", 27 | "vision": "datasets/srcdata/tv/frames_fps3", 28 | "annfile": "datasets/annotations/tv/caption_annotation.json", 29 | "vision_format": "video_frame", 30 | "vision_sample_num": 8, 31 | "task" : "cap%tv", 32 | "n_workers": 8, 33 | "batch_size": 64 34 | }]}} 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-valor32k.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "valor32k_cap", 13 | "txt": "datasets/annotations/valor32k/descs_cap_train.json", 14 | "vision": "datasets/srcdata/valor32k/videos", 15 | "vision_format": "video_rawvideo", 16 | "audio": "datasets/srcdata/valor32k/audios", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "cap%tva", 20 | "epoch": 30, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "valor32k_cap", 28 | "txt": "datasets/annotations/valor32k/descs_cap_test.json", 29 | "vision": "datasets/srcdata/valor32k/videos", 30 | "annfile": "datasets/annotations/valor32k/caption_annotation.json", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 8, 33 | "audio_sample_num": 1, 34 | "audio": "datasets/srcdata/valor32k/audios", 35 | "task" : "cap%tva", 36 | "n_workers": 8, 37 | "batch_size": 64 38 | }]}} 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-vatex.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "vatex_cap", 13 | "txt": "datasets/annotations/vatex/descs_cap_trainval.json", 14 | "vision": "datasets/srcdata/vatex/videos", 15 | "audio": "datasets/srcdata/vatex/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "cap%tvas", 20 | "epoch": 10, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "vatex_cap", 28 | "txt": "datasets/annotations/vatex/descs_cap_test.json", 29 | "vision": "datasets/srcdata/vatex/videos", 30 | "audio": "datasets/srcdata/vatex/audios", 31 | "annfile": "datasets/annotations/vatex/caption_annotation.json", 32 | "vision_format": "video_rawvideo", 33 | "vision_sample_num": 20, 34 | "audio_sample_num": 1, 35 | "task" : "cap%tvas", 36 | "n_workers": 8, 37 | "batch_size": 64 38 | }]}} 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/caption-youcook.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "youcook_cap", 13 | "txt": "datasets/annotations/youcook/descs_cap_train.json", 14 | "vision": "datasets/srcdata/youcook/videos", 15 | "vision_format": "video_rawvideo", 16 | "audio": "datasets/srcdata/youcook/audios", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "cap%tvas", 20 | "epoch": 30, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "youcook_cap", 28 | "txt": "datasets/annotations/youcook/descs_cap_test.json", 29 | "vision": "datasets/srcdata/youcook/videos", 30 | "annfile": "datasets/annotations/youcook/caption_annotation.json", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 16, 33 | "audio_sample_num": 1, 34 | "audio": "datasets/srcdata/youcook/audios", 35 | "task" : "cap%tvas", 36 | "n_workers": 8, 37 | "batch_size": 64 38 | }]}} 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-activitynet.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json", 6 | "max_caption_len":70}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "activitynet_ret", 13 | "txt": "datasets/annotations/activitynet/descs_ret_train.json", 14 | "vision": "datasets/srcdata/activitynet/videos", 15 | "audio": "datasets/srcdata/activitynet/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 2, 19 | "task" : "ret%tva", 20 | "epoch": 20, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "activitynet_ret", 28 | "txt": "datasets/annotations/activitynet/descs_ret_test.json", 29 | "vision": "datasets/srcdata/activitynet/videos", 30 | "audio": "datasets/srcdata/activitynet/audios", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 32, 33 | "audio_sample_num": 2, 34 | "task" : "ret%tva", 35 | "n_workers": 8, 36 | "batch_size": 64 37 | }]}} 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-audiocaps.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | 10 | {"train": 11 | [{ "type":"annoindexed", 12 | "training":true, 13 | "name": "audiocaps_ret", 14 | "txt": "datasets/annotations/audiocaps/descs_ret_trainval.json", 15 | "audio": "datasets/srcdata/audiocaps/audios", 16 | "audio_sample_num": 1, 17 | "task" : "ret%ta", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "audiocaps_ret", 26 | "txt": "datasets/annotations/audiocaps/descs_ret_test.json", 27 | "audio": "datasets/srcdata/audiocaps/audios", 28 | "audio_sample_num": 1, 29 | "task" : "ret%ta", 30 | "n_workers": 8, 31 | "batch_size": 64 32 | }]}} 33 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-clothov2.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | 10 | {"train": 11 | [{ "type":"annoindexed", 12 | "training":true, 13 | "name": "clothov2_ret", 14 | "txt": "datasets/annotations/clothov2/descs_cap_trainval.json", 15 | "audio": "datasets/srcdata/clothov2/audios", 16 | "audio_sample_num": 3, 17 | "task" : "ret%ta", 18 | "epoch": 10, 19 | "n_workers":8, 20 | "batch_size": 64}], 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "clothov2_ret", 26 | "txt": "datasets/annotations/clothov2/descs_cap_test.json", 27 | "audio": "datasets/srcdata/clothov2/audios", 28 | "audio_sample_num": 3, 29 | "task" : "ret%ta", 30 | "n_workers": 8, 31 | "batch_size": 64 32 | }]}} 33 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-didemo.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json", 6 | "max_caption_len":70}, 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "didemo_ret", 13 | "txt": "datasets/annotations/didemo/descs_ret_train.json", 14 | "vision": "datasets/srcdata/didemo/videos", 15 | "audio": "datasets/srcdata/didemo/audios", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 2, 19 | "task" : "ret%tva", 20 | "epoch": 40, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "didemo_ret", 28 | "txt": "datasets/annotations/didemo/descs_ret_test.json", 29 | "vision": "datasets/srcdata/didemo/videos", 30 | "audio": "datasets/srcdata/didemo/audios", 31 | "vision_format": "video_rawvideo", 32 | "vision_sample_num": 32, 33 | "audio_sample_num": 2, 34 | "task" : "ret%tva", 35 | "n_workers": 8, 36 | "batch_size": 64 37 | }]}} 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-flickr.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "flickr_ret", 13 | "txt": "datasets/annotations/flickr/descs_ret_trainval.json", 14 | "vision": "datasets/srcdata/flickr/images", 15 | "vision_format": "image_rawimage", 16 | "task": "ret%tv", 17 | "epoch": 5, 18 | "n_workers": 8, 19 | "batch_size": 256}], 20 | 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "flickr_ret", 26 | "txt": "datasets/annotations/flickr/descs_ret_test.json", 27 | "vision": "datasets/srcdata/flickr/images", 28 | "vision_format": "image_rawimage", 29 | "task": "ret%tv", 30 | "n_workers": 8, 31 | "batch_size": 128 32 | }]}} 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-lsmdc.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | {"train": 9 | [{ "type":"annoindexed", 10 | "training":true, 11 | "name": "lsmdc_ret", 12 | "txt": "datasets/annotations/lsmdc/descs_ret_trainval.json", 13 | "vision": "datasets/srcdata/lsmdc/videos", 14 | "audio": "datasets/srcdata/lsmdc/audios", 15 | "vision_format": "video_rawvideo", 16 | "vision_sample_num": 8, 17 | "audio_sample_num": 1, 18 | "task" : "ret%tva", 19 | "epoch": 5, 20 | "n_workers":8, 21 | "batch_size": 64}], 22 | "val": 23 | [{ 24 | "type":"annoindexed", 25 | "training":false, 26 | "name": "lsmdc_ret", 27 | "txt": "datasets/annotations/lsmdc/descs_ret_test.json", 28 | "vision": "datasets/srcdata/lsmdc/videos", 29 | "audio": "datasets/srcdata/lsmdc/audios", 30 | "vision_format": "video_rawvideo", 31 | "vision_sample_num": 32, 32 | "audio_sample_num": 1, 33 | "task" : "ret%tva", 34 | "n_workers": 8, 35 | "batch_size": 64 36 | }]}} 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-mscoco.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | 8 | "data_cfg": 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "mscoco_ret", 13 | "txt": "datasets/annotations/mscoco/descs_cap_train.json", 14 | "vision": "datasets/srcdata/mscoco/images", 15 | "vision_format": "image_rawimage", 16 | "task": "'ret%tv", 17 | "epoch": 5, 18 | "n_workers": 8, 19 | "batch_size": 256}], 20 | 21 | "val": 22 | [{ 23 | "type":"annoindexed", 24 | "training":false, 25 | "name": "mscoco_ret", 26 | "txt": "datasets/annotations/mscoco/descs_cap_test.json", 27 | "vision": "datasets/srcdata/mscoco/images", 28 | "vision_format": "image_rawimage", 29 | "task": "'ret%tv", 30 | "n_workers": 8, 31 | "batch_size": 128 32 | }]}} 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-msrvtt.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | {"train": 9 | [{ "type":"annoindexed", 10 | "training":true, 11 | "name": "msrvtt_ret", 12 | "txt": "datasets/annotations/msrvtt/descs_ret_train.json", 13 | "vision": "datasets/srcdata/msrvtt/videos", 14 | "audio": "datasets/srcdata/msrvtt/audios", 15 | "vision_transforms":"crop_flip", 16 | "vision_format": "video_rawvideo", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "ret%tvas", 20 | "epoch": 3.6, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "msrvtt_ret", 28 | "txt": "datasets/annotations/msrvtt/descs_ret_test.json", 29 | "vision": "datasets/srcdata/msrvtt/videos", 30 | "vision_transforms":"crop_flip", 31 | "vision_format": "video_rawvideo", 32 | "audio": "datasets/srcdata/msrvtt/audios", 33 | "vision_sample_num": 16, 34 | "audio_sample_num": 1, 35 | "task" : "ret%tvas", 36 | "n_workers": 8, 37 | "batch_size": 64 38 | }]}} 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-valor32k.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "valor32k_cap", 13 | "txt": "datasets/annotations/valor32k/descs_cap_train.json", 14 | "vision": "datasets/srcdata/valor32k/videos", 15 | "vision_format": "video_rawvideo", 16 | "audio": "datasets/srcdata/valor32k/audios", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "ret%tva", 20 | "epoch": 30, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "valor32k_cap", 28 | "txt": "datasets/annotations/valor32k/descs_cap_train.json", 29 | "vision": "datasets/srcdata/valor32k/videos", 30 | "vision_format": "video_rawvideo", 31 | "vision_sample_num": 8, 32 | "audio_sample_num": 1, 33 | "audio": "datasets/srcdata/valor32k/audios", 34 | "task" : "ret%tva", 35 | "n_workers": 8, 36 | "batch_size": 64 37 | }]}} 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-vatex.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | {"train": 9 | [{ "type":"annoindexed", 10 | "training":true, 11 | "name": "vatex_ret", 12 | "txt": "datasets/annotations/vatex/descs_ret_train.json", 13 | "vision": "datasets/srcdata/vatex/videos", 14 | "audio": "datasets/srcdata/vatex/audios", 15 | "vision_format": "video_rawvideo", 16 | "vision_sample_num": 8, 17 | "audio_sample_num": 1, 18 | "task" : "ret%tvas", 19 | "epoch": 2.5, 20 | "n_workers":8, 21 | "batch_size": 64}], 22 | "val": 23 | [{ 24 | "type":"annoindexed", 25 | "training":false, 26 | "name": "vatex_ret", 27 | "txt": "datasets/annotations/vatex/descs_ret_test.json", 28 | "vision": "datasets/srcdata/vatex/videos", 29 | "audio": "datasets/srcdata/vatex/audios", 30 | "vision_format": "video_rawvideo", 31 | "vision_sample_num": 16, 32 | "audio_sample_num": 1, 33 | "task" : "ret%tvas", 34 | "n_workers": 8, 35 | "batch_size": 64 36 | }]}} 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /config/vast/finetune_cfg/retrieval-youcook.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"./config/vast/default_run_cfg.json"}, 3 | 4 | "model_cfg": 5 | {"default":"./config/vast/default_model_cfg.json"}, 6 | 7 | "data_cfg": 8 | 9 | {"train": 10 | [{ "type":"annoindexed", 11 | "training":true, 12 | "name": "youcook_ret", 13 | "txt": "datasets/annotations/youcook/descs_cap_train.json", 14 | "vision": "datasets/srcdata/youcook/videos", 15 | "vision_format": "video_rawvideo", 16 | "audio": "datasets/srcdata/youcook/audios", 17 | "vision_sample_num": 8, 18 | "audio_sample_num": 1, 19 | "task" : "ret%tvas", 20 | "epoch": 30, 21 | "n_workers":8, 22 | "batch_size": 64}], 23 | "val": 24 | [{ 25 | "type":"annoindexed", 26 | "training":false, 27 | "name": "youcook_ret", 28 | "txt": "datasets/annotations/youcook/descs_cap_test.json", 29 | "vision": "datasets/srcdata/youcook/videos", 30 | "vision_format": "video_rawvideo", 31 | "vision_sample_num": 16, 32 | "audio_sample_num": 1, 33 | "audio": "datasets/srcdata/youcook/audios", 34 | "task" : "ret%tvas", 35 | "n_workers": 8, 36 | "batch_size": 64 37 | }]}} 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /config/vast/pretrain_cfg/pretrain_vast.json: -------------------------------------------------------------------------------- 1 | { "run_cfg": 2 | {"default":"./config/default_run_cfg.json", 3 | "learning_rate": 5e-05 4 | }, 5 | 6 | 7 | "model_cfg": 8 | { "default":"./config/newvlp/default_model_cfg.json", 9 | "vision_encoder_type":"evaclip01_giant" 10 | 11 | }, 12 | "data_cfg":{"train": 13 | 14 | [{"type":"annoindexed", 15 | "training":true, 16 | "name": "vast27m", 17 | "txt": "/PATH/TO/vast27m/train_desc.json", 18 | "vision": "/PATH/TO/vast27m/videos", 19 | "audio":"/PATH/TO/vast27m/audios", 20 | "datatype": "video_rawvideo", 21 | "vision_sample_num": 1, 22 | "audio_sample_num": 1, 23 | "task" : "ret%tvas%tvs%tv%ta_cap%tvas%tvs%tv%ta", 24 | "steps": 60000, 25 | "n_workers":8, 26 | "batch_size": 1024}, 27 | 28 | {"type":"annoindexed", 29 | "training":true, 30 | "name": "valor1m", 31 | "txt": "/PATH/TO/valor1m/train_desc.json", 32 | "vision": "/PATH/TO/valor1m/videos", 33 | "audio":"/PATH/TO/valor1m/audios", 34 | "vision_format": "video_rawvideo", 35 | "vision_sample_num": 1, 36 | "audio_sample_num": 1, 37 | "task" : "ret%tva%tv%ta_cap%tva%tv%ta", 38 | "steps": 25000, 39 | "n_workers":4, 40 | "batch_size": 1024}, 41 | 42 | 43 | 44 | {"type":"annoindexed", 45 | "name": "cc4m", 46 | "training":true, 47 | "vision_format": "image_rawimage", 48 | "vision":"/PATH/TO/cc4m/images", 49 | "txt":"/PATH/TO/cc4m/train_desc.json", 50 | "task" : "ret%tv_cap%tv", 51 | "steps": 55000, 52 | "n_workers":2, 53 | "batch_size": 2048}, 54 | 55 | {"type":"annoindexed", 56 | "name": "cc12m", 57 | "training":true, 58 | "vision_format": "image_rawimage", 59 | "vision":"/PATH/TO/cc12m/images", 60 | "txt":"/PATH/TO/cc12m/train_desc.json", 61 | "task" : "ret%tv_cap%tv", 62 | "steps": 20000, 63 | "n_workers":4, 64 | "batch_size": 2048}, 65 | 66 | 67 | {"type":"annoindexed", 68 | "name": "laion2b", 69 | "training":true, 70 | "vision_format": "image_rawimage", 71 | "vision":"/PATH/TO/laion2b/images", 72 | "txt":"/PATH/TO/laion2b/train_desc.json", 73 | "task" : "ret%tv_cap%tv", 74 | "steps": 55000, 75 | "n_workers":4, 76 | "batch_size": 2048}, 77 | 78 | 79 | {"type":"annoindexed", 80 | "training":true, 81 | "name": "audioset-SL", 82 | "txt": "/PATH/TO/AudioSetSL/train_desc.json", 83 | "audio":"/PATH/TO/AudioSetSL/audios", 84 | "audio_sample_num": 1, 85 | "task" : "ret%ta_cap%ta", 86 | "steps": 7500, 87 | "n_workers":4, 88 | "batch_size": 1024}, 89 | 90 | { "type":"annoindexed", 91 | "training":true, 92 | "name": "freesound", 93 | "txt": "/PATH/TO/FreeSound/train_desc.json", 94 | "audio":"/PATH/TO/FreeSound/audios", 95 | "audio_sample_num": 2, 96 | "task" : "ret%ta_cap%ta", 97 | "steps": 7500, 98 | "n_workers":4, 99 | "batch_size": 1024}], 100 | 101 | 102 | 103 | 104 | "val": 105 | []}} 106 | -------------------------------------------------------------------------------- /data/IndexAnno.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import random 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from toolz.sandbox import unzip 9 | from torch.utils.data import Dataset 10 | from utils.logger import LOGGER 11 | from .vision_mapper import VisionMapper 12 | from .audio_mapper import AudioMapper 13 | 14 | from torch.utils.data import ConcatDataset 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | class AnnoIndexedDataset(Dataset): 24 | def __init__(self, d_cfg, args): 25 | self.vision_mapper = VisionMapper(d_cfg, args) if 'vision' in d_cfg else None 26 | self.audio_mapper = AudioMapper(d_cfg, args) if 'audio' in d_cfg else None 27 | self.annos = json.load(open(d_cfg['txt'])) 28 | self.idx = list(range(len(self.annos))) 29 | self.dataset_name = d_cfg['name'] 30 | self.training = d_cfg.training 31 | 32 | self.worker_init_fn = None 33 | self.use_sampler = True 34 | self.collate_fn = annoindexedcollate 35 | 36 | self.annfile = getattr(d_cfg,'annfile',None) 37 | self.make_submission = getattr(d_cfg,'make_submission',False) 38 | self.multi_evaluation = getattr(d_cfg,'multi_evaluation',False) 39 | self.vqa_anno_file = getattr(d_cfg,'vqa_anno_file',None) 40 | self.vqa_question_file = getattr(d_cfg,'vqa_question_file',None) 41 | 42 | 43 | def __len__(self): 44 | return len(self.annos) 45 | 46 | def __getitem__(self, i): 47 | anno = self.annos[i] 48 | 49 | for key in ['video_id','image_id','image','id']: 50 | if key in anno: 51 | id_ = anno[key] 52 | break 53 | 54 | raw_captions = None 55 | raw_subtitles = None 56 | question_id = None 57 | question = None 58 | answer = None 59 | id_txt = None 60 | vision_pixels = None 61 | audio_spectrograms = None 62 | 63 | 64 | 65 | 66 | 67 | 68 | raw_captions = anno['desc'] if 'desc' in anno else anno['caption'] 69 | num_samples = len(raw_captions) if isinstance(raw_captions, list) else 1 70 | id_txt = [id_] * num_samples 71 | 72 | 73 | if 'subtitle' in anno: 74 | raw_subtitles = anno['subtitle'] 75 | 76 | if 'question' in anno: 77 | 78 | if self.training: 79 | question = anno['question'] 80 | if isinstance(anno['answer'],list): #### vqav2 81 | answer = random.choice(anno['answer']) 82 | else: 83 | answer = anno['answer'] 84 | 85 | else: 86 | question = anno['question'] 87 | answer = anno['answer'] 88 | if 'question_id' in anno: 89 | question_id = anno['question_id'] 90 | 91 | 92 | if self.vision_mapper: 93 | if self.vision_mapper.vision_format == 'video_feats': 94 | vision_feats = self.vision_mapper.read(id_) 95 | 96 | else: 97 | vision_pixels = self.vision_mapper.read(id_) 98 | if vision_pixels is None: ###wrong img/video, resample when training and raise error when testing 99 | if self.training: 100 | resample_idx = random.choice(self.idx) 101 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video, use {resample_idx} instead.') 102 | return self.__getitem__(resample_idx) 103 | else: 104 | resample_idx = random.choice(self.idx) 105 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video,!!!!!!!!!!!!!!!!!!!!!!!! use {resample_idx} instead.') 106 | return self.__getitem__(resample_idx) 107 | # raise ValueError 108 | 109 | if self.audio_mapper: 110 | audio_spectrograms = self.audio_mapper.read(id_) 111 | if audio_spectrograms is None: ### wrong audio, resample when training and raise error when testing 112 | if self.training: 113 | resample_idx = random.choice(self.idx) 114 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong audio, use {resample_idx} instead.') 115 | return self.__getitem__(resample_idx) 116 | else: 117 | raise ValueError 118 | 119 | return id_, raw_captions, vision_pixels, id_txt, question, answer, question_id, \ 120 | audio_spectrograms, raw_subtitles 121 | 122 | 123 | 124 | def annoindexedcollate(inputs): 125 | 126 | batch = {} 127 | all_data = map(list, unzip(inputs)) 128 | keys = ['ids', 129 | 'raw_captions', 130 | 'vision_pixels', 131 | 'ids_txt', 132 | 'raw_questions', 133 | 'raw_answers', 134 | 'question_ids', 135 | 'audio_spectrograms', 136 | 'raw_subtitles'] 137 | 138 | for key, data in zip(keys, all_data): 139 | 140 | if data[0] is None: 141 | continue 142 | elif isinstance(data[0], torch.Tensor): 143 | batch[key] = torch.stack(data, dim=0).float() 144 | 145 | else: 146 | batch[key] = data 147 | 148 | 149 | 150 | return batch 151 | 152 | 153 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .IndexAnno import AnnoIndexedDataset 3 | from .IndexSrc import SrcIndexedDataset 4 | 5 | data_registry={ 6 | 'annoindexed':AnnoIndexedDataset, 7 | 'srcindexed':SrcIndexedDataset, 8 | 9 | } 10 | -------------------------------------------------------------------------------- /data/audio_mapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchaudio 5 | from utils.logger import LOGGER 6 | from utils.tool import split 7 | 8 | 9 | class AudioMapper(object): 10 | # def __init__(self, audio_dir, opts, sample_num, check_exists=True): 11 | def __init__(self, d_cfg, args): 12 | self.audio_dir = d_cfg.audio 13 | self.melbins = args.model_cfg.audio_melbins 14 | self.target_length = args.model_cfg.audio_target_length 15 | self.training = d_cfg.training 16 | self.frame_shift = 10 17 | self.sample_num = d_cfg.audio_sample_num 18 | self.audio_encoder_type = args.model_cfg.audio_encoder_type 19 | if self.audio_encoder_type == 'ast': 20 | self.mean = -4.2677393 21 | self.std = 4.5689974 22 | elif self.audio_encoder_type == 'beats': 23 | self.mean = 15.41663 24 | self.std = 6.55582 25 | else: 26 | raise NotImplementedError 27 | 28 | 29 | 30 | def read(self, id_): 31 | 32 | wav_file = os.path.join(self.audio_dir, id_) 33 | 34 | if not os.path.exists(wav_file): 35 | wav_file = os.path.join(self.audio_dir, id_+'.wav') 36 | if not os.path.exists(wav_file): 37 | wav_file = wav_file.replace('wav','mp3') 38 | if not os.path.exists(wav_file): 39 | wav_file = wav_file.replace('mp3','mkv') 40 | if not os.path.exists(wav_file): 41 | print('not have audios', id_) 42 | return torch.zeros(self.sample_num, self.target_length, self.melbins) 43 | try: 44 | if self.audio_encoder_type == 'ast': 45 | 46 | waveform, sr = torchaudio.load(wav_file) 47 | 48 | waveform = waveform - waveform.mean() 49 | fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, 50 | window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=self.frame_shift) 51 | 52 | 53 | 54 | elif self.audio_encoder_type == 'beats': 55 | 56 | waveform, sr = torchaudio.load(wav_file) 57 | if sr != 16000: 58 | trans = torchaudio.transforms.Resample(sr, 16000) 59 | waveform = trans(waveform) 60 | 61 | waveform = waveform * 2 ** 15 62 | fbank = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=self.melbins, sample_frequency=16000, frame_length=25, frame_shift=10) 63 | 64 | else: 65 | raise NotImplementedError 66 | 67 | # ### normalization 68 | fbank = (fbank - self.mean) / (self.std * 2) 69 | src_length = fbank.shape[0] 70 | # #### sample 71 | output_slices = [] 72 | pad_len = max(self.target_length * self.sample_num -src_length, self.target_length - src_length%self.target_length) 73 | fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) 74 | total_slice_num = fbank.shape[0] // self.target_length 75 | total_slice_num = list(range(total_slice_num)) 76 | total_slice_num = split(total_slice_num, self.sample_num) 77 | 78 | if self.training: 79 | sample_idx = [random.choice(i) for i in total_slice_num] 80 | else: 81 | sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num] 82 | 83 | 84 | for i in sample_idx: 85 | cur_bank = fbank[i*self.target_length : (i+1)*self.target_length] 86 | output_slices.append(cur_bank) 87 | 88 | fbank = torch.stack(output_slices,dim=0) ### n, 1024, 128 89 | return fbank 90 | 91 | except Exception as e: 92 | print(e) 93 | return 94 | 95 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from utils.distributed import any_broadcast 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | class MetaLoader(object): 9 | """ wraps multiple data loaders """ 10 | def __init__(self, loaders, accum_steps=1, distributed=False): 11 | assert isinstance(loaders, dict) 12 | self.name2loader = {} 13 | self.name2iter = {} 14 | self.sampling_pools = [] 15 | self.name2labelname={} 16 | for idx, (n, l) in enumerate(loaders.items()): 17 | if isinstance(l, tuple): 18 | l, r = l 19 | elif isinstance(l, DataLoader): 20 | r = 1 21 | else: 22 | raise ValueError() 23 | self.name2loader[n] = l 24 | self.name2iter[n] = iter(l) 25 | self.sampling_pools.extend([n]*r) 26 | # import ipdb 27 | # ipdb.set_trace() 28 | 29 | 30 | self.accum_steps = accum_steps 31 | self.distributed = distributed 32 | self.step = 0 33 | self.epoch = 0 34 | 35 | 36 | def __iter__(self): 37 | """ this iterator will run indefinitely """ 38 | task = self.sampling_pools[0] 39 | while True: 40 | if self.step % self.accum_steps == 0: 41 | task = random.choice(self.sampling_pools) 42 | if self.distributed: 43 | # make sure all process is training same task 44 | task = any_broadcast(task, 0) 45 | self.step += 1 46 | iter_ = self.name2iter[task] 47 | try: 48 | batch = next(iter_) 49 | except StopIteration: 50 | self.epoch = self.epoch + 1 51 | if isinstance(self.name2loader[task].sampler, DistributedSampler): 52 | self.name2loader[task].sampler.set_epoch(self.epoch) 53 | else: 54 | pass 55 | iter_ = iter(self.name2loader[task]) 56 | batch = next(iter_) 57 | self.name2iter[task] = iter_ 58 | 59 | 60 | yield task, batch 61 | 62 | 63 | def move_to_cuda(batch): 64 | if isinstance(batch, torch.Tensor): 65 | return batch.cuda(non_blocking=True) 66 | elif isinstance(batch, list): 67 | new_batch = [move_to_cuda(t) for t in batch] 68 | elif isinstance(batch, tuple): 69 | new_batch = tuple(move_to_cuda(t) for t in batch) 70 | elif isinstance(batch, dict): 71 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 72 | else: 73 | return batch 74 | return new_batch 75 | 76 | 77 | def record_cuda_stream(batch): 78 | if isinstance(batch, torch.Tensor): 79 | batch.record_stream(torch.cuda.current_stream()) 80 | elif isinstance(batch, list) or isinstance(batch, tuple): 81 | for t in batch: 82 | record_cuda_stream(t) 83 | elif isinstance(batch, dict): 84 | for t in batch.values(): 85 | record_cuda_stream(t) 86 | else: 87 | pass 88 | 89 | 90 | class PrefetchLoader(object): 91 | """ 92 | overlap compute and cuda data transfer 93 | (copied and then modified from nvidia apex) 94 | """ 95 | def __init__(self, loader): 96 | self.loader = loader 97 | self.stream = torch.cuda.Stream() 98 | 99 | def __iter__(self): 100 | loader_it = iter(self.loader) 101 | self.preload(loader_it) 102 | batch = self.next(loader_it) 103 | while batch is not None: 104 | yield batch 105 | batch = self.next(loader_it) 106 | 107 | def __len__(self): 108 | return len(self.loader) 109 | 110 | def preload(self, it): 111 | try: 112 | self.batch = next(it) 113 | except StopIteration: 114 | self.batch = None 115 | return 116 | # if record_stream() doesn't work, another option is to make sure 117 | # device inputs are created on the main stream. 118 | # self.next_input_gpu = torch.empty_like(self.next_input, 119 | # device='cuda') 120 | # self.next_target_gpu = torch.empty_like(self.next_target, 121 | # device='cuda') 122 | # Need to make sure the memory allocated for next_* is not still in use 123 | # by the main stream at the time we start copying to next_*: 124 | # self.stream.wait_stream(torch.cuda.current_stream()) 125 | with torch.cuda.stream(self.stream): 126 | self.batch = move_to_cuda(self.batch) 127 | # more code for the alternative if record_stream() doesn't work: 128 | # copy_ will record the use of the pinned source tensor in this 129 | # side stream. 130 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 131 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 132 | # self.next_input = self.next_input_gpu 133 | # self.next_target = self.next_target_gpu 134 | 135 | def next(self, it): 136 | 137 | torch.cuda.current_stream().wait_stream(self.stream) 138 | batch = self.batch 139 | if batch is not None: 140 | record_cuda_stream(batch) 141 | self.preload(it) 142 | return batch 143 | 144 | 145 | 146 | def __getattr__(self, name): 147 | method = self.loader.__getattribute__(name) 148 | return method 149 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .evaluation_mm import evaluate_mm 3 | 4 | evaluation_registry={ 5 | 'evaluation_mm':evaluate_mm 6 | } 7 | -------------------------------------------------------------------------------- /evaluation_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/__init__.py -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/cocoEvalAllSPICEDemo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "from pycocotools.coco import COCO\n", 11 | "from pycocoevalcap.eval import COCOEvalCap\n", 12 | "from pycocoevalcap.eval_spice import COCOEvalCapSpice\n", 13 | "\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import skimage.io as io\n", 16 | "import pylab\n", 17 | "pylab.rcParams['figure.figsize'] = (10.0, 8.0)\n", 18 | "\n", 19 | "import json\n", 20 | "from json import encoder\n", 21 | "encoder.FLOAT_REPR = lambda o: format(o, '.3f')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": "Found Stanford CoreNLP.\nDownloading...\nsed: illegal option -- r\nusage: sed script [-Ealn] [-i extension] [file ...]\n sed [-Ealn] [-i extension] [-e script] ... [-f script_file] ... [file ...]\n--2020-02-23 15:20:46-- https://docs.google.com/uc?export=download&id=0B7XkCwpI5KDYNlNUTTlSS21pQmM\nResolving docs.google.com... 172.217.0.14\nConnecting to docs.google.com|172.217.0.14|:443...connected.\nHTTP request sent, awaiting response...200 OK\nLength: unspecified [text/html]\nSaving to: 'STDOUT'\n\n- [ <=> ] 0 --.-KB/s in 0s \n\n\nCannot write to '-' (Success).\nCode:\n--2020-02-23 15:20:48-- https://docs.google.com/uc?export=download&confirm=&id=0B7XkCwpI5KDYNlNUTTlSS21pQmM\nResolving docs.google.com... 172.217.0.14\nConnecting to docs.google.com|172.217.0.14|:443...connected.\nHTTP request sent, awaiting response...200 OK\nLength: unspecified [text/html]\nSaving to: 'pycocoevalcap/wmd/data/GoogleNews-vectors-negative300.bin.gz'\n\npycocoevalcap/wmd/d [ <=> ] 3.19K --.-KB/s in 0.003s \n\n2020-02-23 15:20:49 (1.18 MB/s) - 'pycocoevalcap/wmd/data/GoogleNews-vectors-negative300.bin.gz' saved [3268]\n\nUnzipping...\ngzip: pycocoevalcap/wmd/data/GoogleNews-vectors-negative300.bin.gz: not in gzip format\ngzip: pycocoevalcap/wmd/data/ is a directory\nDone.\n" 33 | } 34 | ], 35 | "source": [ 36 | "# set up file names and pathes\n", 37 | "dataDir='.'\n", 38 | "dataType='val2014'\n", 39 | "algName = 'fakecap'\n", 40 | "annFile='%s/annotations/captions_%s.json'%(dataDir,dataType)\n", 41 | "subtypes=['results', 'evalImgs', 'eval']\n", 42 | "[resFile, evalImgsFile, evalFile]= \\\n", 43 | "['%s/results/captions_%s_%s_%s.json'%(dataDir,dataType,algName,subtype) for subtype in subtypes]\n", 44 | "\n", 45 | "# download Stanford models\n", 46 | "! bash get_stanford_models.sh\n", 47 | "\n", 48 | "# download Google word2vec model\n", 49 | "! bash get_google_word2vec_model.sh" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": "[{'image_id': 404464,\n 'caption': 'black and white photo of a man standing in front of a building'},\n {'image_id': 404464,\n 'caption': 'group of people are on the side of a snowy field'},\n {'image_id': 565778, 'caption': 'train traveling down a train station'},\n {'image_id': 565778,\n 'caption': 'red fire hydrant sitting on a park bench in front of a road'},\n {'image_id': 322226,\n 'caption': 'black and white cat is sitting on top of a wooden bench'},\n {'image_id': 322226, 'caption': 'baseball player swinging a bat at a game'},\n {'image_id': 351053, 'caption': 'laptop computer sitting on top of a table'},\n {'image_id': 351053,\n 'caption': 'zebra standing on top of a lush green field'},\n {'image_id': 40102,\n 'caption': 'group of giraffes standing next to each other in a grassy field'},\n {'image_id': 40102,\n 'caption': 'close up of a pile of oranges sitting on a table'}]" 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "import tempfile\n", 68 | "preds = json.load(open(resFile, 'r'))\n", 69 | "# Create fake predictions\n", 70 | "for i in range(1, len(preds), 2):\n", 71 | " preds[i]['image_id'] = preds[i-1]['image_id']\n", 72 | "tmp_resFile = tempfile.NamedTemporaryFile('w+')\n", 73 | "tmp_resFile.write(json.dumps(preds))\n", 74 | "preds[:10]\n" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": "loading annotations into memory...\n0:00:00.366050\ncreating index...\nindex created!\nLoading and preparing results... \nDONE (t=0.01s)\ncreating index...\nindex created!\ntokenization...\nsetting up scorers...\ncomputing SPICE score...\nSPICE: 0.121\n" 86 | } 87 | ], 88 | "source": [ 89 | "# Eval AllSPICE\n", 90 | "coco = COCO(annFile)\n", 91 | "cocoRes_n = coco.loadRes(tmp_resFile.name)\n", 92 | "cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)\n", 93 | "cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()\n", 94 | "cocoEvalAllSPICE.evaluate()\n", 95 | "tmp_resFile.close()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": "AllSPICE: 0.121\n" 107 | } 108 | ], 109 | "source": [ 110 | "# print output evaluation scores\n", 111 | "for metric, score in cocoEvalAllSPICE.eval.items():\n", 112 | " print('%s: %.3f'%('All'+metric, score))" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.7.4-final" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 1 137 | } -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/get_google_word2vec_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Google News word2vec negative 300 model. 3 | # Script code was taken from https://gist.github.com/yanaiela/cfef50380de8a5bfc8c272bb0c91d6e1 4 | 5 | WMDDATA=pycocoevalcap/wmd/data 6 | MODEL=GoogleNews-vectors-negative300 7 | 8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 9 | cd $DIR 10 | 11 | if [ -f $WMDDATA/$MODEL.bin ]; then 12 | echo "Found Google news word2vec model." 13 | else 14 | echo "Downloading..." 15 | OUTPUT=$( wget --save-cookies $WMDDATA/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=0B7XkCwpI5KDYNlNUTTlSS21pQmM' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/Code: \1\n/p' ) 16 | CODE=${OUTPUT##*Code: } 17 | echo Code: $CODE 18 | wget --load-cookies $WMDDATA/cookies.txt 'https://docs.google.com/uc?export=download&confirm='$CODE'&id=0B7XkCwpI5KDYNlNUTTlSS21pQmM' -O $WMDDATA/$MODEL.bin.gz 19 | rm $WMDDATA/cookies.txt 20 | echo "Unzipping..." 21 | gzip -d $WMDDATA/$MODEL.bin.gz $WMDDATA/ 22 | echo "Done." 23 | fi -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/get_stanford_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Stanford CoreNLP models. 3 | 4 | CORENLP=stanford-corenlp-full-2015-12-09 5 | SPICELIB=pycocoevalcap/spice/lib 6 | JAR=stanford-corenlp-3.6.0 7 | 8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 9 | cd $DIR 10 | 11 | if [ -f $SPICELIB/$JAR.jar ]; then 12 | echo "Found Stanford CoreNLP." 13 | else 14 | echo "Downloading..." 15 | wget http://nlp.stanford.edu/software/$CORENLP.zip 16 | echo "Unzipping..." 17 | unzip $CORENLP.zip -d $SPICELIB/ 18 | mv $SPICELIB/$CORENLP/$JAR.jar $SPICELIB/ 19 | mv $SPICELIB/$CORENLP/$JAR-models.jar $SPICELIB/ 20 | rm -f $CORENLP.zip 21 | rm -rf $SPICELIB/$CORENLP/ 22 | echo "Done." 23 | fi 24 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | from .bleu_scorer import BleuScorer 15 | 16 | 17 | class Bleu: 18 | def __init__(self, n=4): 19 | # default compute Blue score up to 4 20 | self._n = n 21 | self._hypo_for_image = {} 22 | self.ref_for_image = {} 23 | 24 | def compute_score(self, gts, res): 25 | 26 | assert(list(gts.keys()) == list(res.keys())) 27 | imgIds = list(gts.keys()) 28 | 29 | bleu_scorer = BleuScorer(n=self._n) 30 | for id in imgIds: 31 | hypo = res[id] 32 | ref = gts[id] 33 | 34 | # Sanity check. 35 | assert(type(hypo) is list) 36 | assert(len(hypo) == 1) 37 | assert(type(ref) is list) 38 | assert(len(ref) >= 1) 39 | 40 | bleu_scorer += (hypo[0], ref) 41 | 42 | #score, scores = bleu_scorer.compute_score(option='shortest') 43 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 44 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 45 | 46 | # return (bleu, bleu_info) 47 | return score, scores 48 | 49 | def method(self): 50 | return "Bleu" 51 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from .cider_scorer import CiderScorer 14 | import pdb 15 | 16 | class Cider: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 22 | # set cider to sum over 1 to 4-grams 23 | self._n = n 24 | # set the standard deviation parameter for gaussian penalty 25 | self._sigma = sigma 26 | 27 | def compute_score(self, gts, res): 28 | """ 29 | Main function to compute CIDEr score 30 | :param hypo_for_image (dict) : dictionary with key and value 31 | ref_for_image (dict) : dictionary with key and value 32 | :return: cider (float) : computed CIDEr score for the corpus 33 | """ 34 | 35 | assert(list(gts.keys()) == list(res.keys())) 36 | imgIds = list(gts.keys()) 37 | 38 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 39 | 40 | for id in imgIds: 41 | hypo = res[id] 42 | ref = gts[id] 43 | 44 | # Sanity check. 45 | assert(type(hypo) is list) 46 | assert(len(hypo) == 1) 47 | assert(type(ref) is list) 48 | assert(len(ref) > 0) 49 | 50 | cider_scorer += (hypo[0], ref) 51 | 52 | (score, scores) = cider_scorer.compute_score() 53 | 54 | return score, scores 55 | 56 | def method(self): 57 | return "CIDEr" -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | __author__ = 'tylin' 5 | from .tokenizer.ptbtokenizer import PTBTokenizer 6 | from .bleu.bleu import Bleu 7 | from .meteor.meteor import Meteor 8 | from .rouge.rouge import Rouge 9 | from .cider.cider import Cider 10 | # from .spice.spice import Spice 11 | from .wmd.wmd import WMD 12 | import ipdb 13 | class COCOEvalCap: 14 | def __init__(self, coco, cocoRes,process=True): 15 | self.evalVideos = [] 16 | self.eval = {} 17 | self.videoToEval = {} 18 | self.coco = coco 19 | self.cocoRes = cocoRes 20 | #self.params = {'video_id': coco.getVideoIds()} 21 | self.params = {'video_id': cocoRes.getVideoIds()} 22 | self.process = process 23 | 24 | # self.Spice = Spice() 25 | 26 | def evaluate(self): 27 | videoIds = self.params['video_id'] 28 | # videoIds = self.coco.getvideoIds() 29 | gts = {} 30 | res = {} 31 | print('total test num:{}'.format(len(videoIds))) 32 | for videoId in videoIds: 33 | gts[videoId] = self.coco.videoToAnns[videoId] 34 | res[videoId] = self.cocoRes.videoToAnns[videoId] 35 | 36 | #ipdb.set_trace() 37 | # ================================================= 38 | # Set up scorers 39 | # ================================================= 40 | print('tokenization...') 41 | tokenizer = PTBTokenizer() 42 | # import ipdb 43 | # ipdb.set_trace() 44 | 45 | 46 | if self.process: 47 | gts = tokenizer.tokenize(gts) 48 | res = tokenizer.tokenize(res) 49 | else: 50 | gts = {i:[gts[i][0]['caption']] for i in gts} 51 | res = {i:[res[i][0]['caption']] for i in res} 52 | # ================================================= 53 | # Set up scorers 54 | # ================================================= 55 | print('setting up scorers...') 56 | scorers = [ 57 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 58 | (Meteor(),"METEOR"), 59 | (Rouge(), "ROUGE_L"), 60 | (Cider(), "CIDEr"), 61 | # (Spice(), "SPICE") 62 | #(self.Spice, "SPICE"), 63 | #(WMD(), "WMD"), 64 | ] 65 | 66 | 67 | # ================================================= 68 | # Compute scores 69 | # ================================================= 70 | for scorer, method in scorers: 71 | print('computing %s score...'%(scorer.method())) 72 | score, scores = scorer.compute_score(gts, res) 73 | if type(method) == list: 74 | for sc, scs, m in zip(score, scores, method): 75 | self.setEval(sc, m) 76 | self.setvideoToEvalvideos(scs, list(gts.keys()), m) 77 | #print("%s: %0.1f"%(m, sc*100)) 78 | else: 79 | self.setEval(score, method) 80 | self.setvideoToEvalvideos(scores, list(gts.keys()), method) 81 | #print("%s: %0.1f"%(method, score*100)) 82 | self.setEvalvideos() 83 | 84 | def setEval(self, score, method): 85 | self.eval[method] = score 86 | 87 | def setvideoToEvalvideos(self, scores, videoIds, method): 88 | for videoId, score in zip(videoIds, scores): 89 | if not videoId in self.videoToEval: 90 | self.videoToEval[videoId] = {} 91 | self.videoToEval[videoId]["video_id"] = videoId 92 | self.videoToEval[videoId][method] = score 93 | 94 | def setEvalvideos(self): 95 | self.evalvideos = [eval for videoId, eval in list(self.videoToEval.items())] 96 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/eval_spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | __author__ = 'tylin' 5 | from .tokenizer.ptbtokenizer import PTBTokenizer 6 | from .spice.spice import Spice 7 | 8 | class SpiceEval(): 9 | def __init__(self): 10 | self.evalImgs = [] 11 | self.eval = {} 12 | self.imgToEval = {} 13 | self.spice = Spice() 14 | self.tokenizer = PTBTokenizer() 15 | 16 | """ 17 | The input have structure 18 | {'123': [{'image_id':123, 'caption': 'xxxxx'}, {'image_id':123, 'caption': 'yyy'}], ...} 19 | """ 20 | def evaluate(self, gts, res): 21 | assert set(gts.keys()) == set(res.keys()) 22 | imgIds = list(gts.keys()) 23 | gts = self.tokenizer.tokenize(gts) 24 | res = self.tokenizer.tokenize(res) 25 | 26 | # ================================================= 27 | # Set up scorers 28 | # ================================================= 29 | 30 | # ================================================= 31 | # Compute scores 32 | # ================================================= 33 | print('computing %s score...'%(self.spice.method())) 34 | score, scores = self.spice.compute_score(gts, res) 35 | print("%s: %0.3f"%("spice", score)) 36 | self.eval['spice'] = score 37 | print(scores) 38 | for imgId, score in zip(sorted(imgIds), scores): 39 | if not imgId in self.imgToEval: 40 | self.imgToEval[imgId] = {} 41 | self.imgToEval[imgId]["image_id"] = imgId 42 | self.imgToEval[imgId]["spice"] = score 43 | return self.eval['spice'], self.imgToEval 44 | # self.evalImgs = [self.imgToEval[imgId] for imgId in sorted(self.imgToEval.keys())] 45 | 46 | 47 | class COCOEvalCapSpice: 48 | def __init__(self, coco, cocoRes): 49 | self.evalImgs = [] 50 | self.eval = {} 51 | self.imgToEval = {} 52 | self.coco = coco 53 | self.cocoRes = cocoRes 54 | self.params = {'image_id': coco.getImgIds()} 55 | 56 | self.Spice = Spice() 57 | 58 | def evaluate(self): 59 | imgIds = self.params['image_id'] 60 | # imgIds = self.coco.getImgIds() 61 | gts = {} 62 | res = {} 63 | for imgId in imgIds: 64 | gts[imgId] = self.coco.imgToAnns[imgId] 65 | res[imgId] = self.cocoRes.imgToAnns[imgId] 66 | 67 | # ================================================= 68 | # Set up scorers 69 | # ================================================= 70 | print('tokenization...') 71 | tokenizer = PTBTokenizer() 72 | gts = tokenizer.tokenize(gts) 73 | res = tokenizer.tokenize(res) 74 | 75 | # ================================================= 76 | # Set up scorers 77 | # ================================================= 78 | print('setting up scorers...') 79 | scorers = [ 80 | (self.Spice, "SPICE") 81 | ] 82 | 83 | # ================================================= 84 | # Compute scores 85 | # ================================================= 86 | for scorer, method in scorers: 87 | print('computing %s score...'%(scorer.method())) 88 | score, scores = scorer.compute_score(gts, res) 89 | if type(method) == list: 90 | for sc, scs, m in zip(score, scores, method): 91 | self.setEval(sc, m) 92 | self.setImgToEvalImgs(scs, list(gts.keys()), m) 93 | print("%s: %0.3f"%(m, sc)) 94 | else: 95 | self.setEval(score, method) 96 | self.setImgToEvalImgs(scores, list(gts.keys()), method) 97 | print("%s: %0.3f"%(method, score)) 98 | self.setEvalImgs() 99 | 100 | def setEval(self, score, method): 101 | self.eval[method] = score 102 | 103 | def setImgToEvalImgs(self, scores, imgIds, method): 104 | for imgId, score in zip(sorted(imgIds), scores): 105 | if not imgId in self.imgToEval: 106 | self.imgToEval[imgId] = {} 107 | self.imgToEval[imgId]["image_id"] = imgId 108 | self.imgToEval[imgId][method] = score 109 | 110 | def setEvalImgs(self): 111 | self.evalImgs = [self.imgToEval[imgId] for imgId in sorted(self.imgToEval.keys())] 112 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import sys 11 | import subprocess 12 | import threading 13 | 14 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 15 | METEOR_JAR = 'meteor-1.5.jar' 16 | # print METEOR_JAR 17 | 18 | class Meteor: 19 | 20 | def __init__(self): 21 | self.env = os.environ 22 | self.env['LC_ALL'] = 'en_US.UTF_8' 23 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 24 | '-', '-', '-stdio', '-l', 'en', '-norm'] 25 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 26 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 27 | stdin=subprocess.PIPE, \ 28 | stdout=subprocess.PIPE, \ 29 | stderr=subprocess.PIPE, 30 | env=self.env, universal_newlines=True, bufsize=1) 31 | # Used to guarantee thread safety 32 | self.lock = threading.Lock() 33 | 34 | def compute_score(self, gts, res): 35 | 36 | assert(gts.keys() == res.keys()) 37 | imgIds = sorted(list(gts.keys())) 38 | scores = [] 39 | 40 | eval_line = 'EVAL' 41 | self.lock.acquire() 42 | for i in imgIds: 43 | assert(len(res[i]) == 1) 44 | stat = self._stat(res[i][0], gts[i]) 45 | eval_line += ' ||| {}'.format(stat) 46 | 47 | # Send to METEOR 48 | self.meteor_p.stdin.write(eval_line + '\n') 49 | 50 | # Collect segment scores 51 | for i in range(len(imgIds)): 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | scores.append(score) 54 | 55 | # Final score 56 | final_score = float(self.meteor_p.stdout.readline().strip()) 57 | self.lock.release() 58 | 59 | return final_score, scores 60 | 61 | def method(self): 62 | return "METEOR" 63 | 64 | def _stat(self, hypothesis_str, reference_list): 65 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 66 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 67 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 68 | self.meteor_p.stdin.write(score_line+'\n') 69 | return self.meteor_p.stdout.readline().strip() 70 | 71 | def __del__(self): 72 | self.lock.acquire() 73 | self.meteor_p.stdin.close() 74 | self.meteor_p.kill() 75 | self.meteor_p.wait() 76 | self.lock.release() 77 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import numpy as np 14 | 15 | def my_lcs(string, sub): 16 | """ 17 | Calculates longest common subsequence for a pair of tokenized strings 18 | :param string : list of str : tokens from a string split using whitespace 19 | :param sub : list of str : shorter string, also split using whitespace 20 | :returns: length (list of int): length of the longest common subsequence between the two strings 21 | 22 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 23 | """ 24 | if(len(string)< len(sub)): 25 | sub, string = string, sub 26 | 27 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 28 | 29 | for j in range(1,len(sub)+1): 30 | for i in range(1,len(string)+1): 31 | if(string[i-1] == sub[j-1]): 32 | lengths[i][j] = lengths[i-1][j-1] + 1 33 | else: 34 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 35 | 36 | return lengths[len(string)][len(sub)] 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | def __init__(self): 44 | # vrama91: updated the value below based on discussion with Hovey 45 | self.beta = 1.2 46 | 47 | def calc_score(self, candidate, refs): 48 | """ 49 | Compute ROUGE-L score given one candidate and references for an image 50 | :param candidate: str : candidate sentence to be evaluated 51 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 52 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 53 | """ 54 | assert(len(candidate)==1) 55 | assert(len(refs)>0) 56 | prec = [] 57 | rec = [] 58 | 59 | # split into tokens 60 | token_c = candidate[0].split(" ") 61 | 62 | for reference in refs: 63 | # split into tokens 64 | token_r = reference.split(" ") 65 | # compute the longest common subsequence 66 | lcs = my_lcs(token_r, token_c) 67 | prec.append(lcs/float(len(token_c))) 68 | rec.append(lcs/float(len(token_r))) 69 | 70 | prec_max = max(prec) 71 | rec_max = max(rec) 72 | 73 | if(prec_max!=0 and rec_max !=0): 74 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 75 | else: 76 | score = 0.0 77 | return score 78 | 79 | def compute_score(self, gts, res): 80 | """ 81 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 82 | Invoked by evaluate_captions.py 83 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 84 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 85 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 86 | """ 87 | assert(list(gts.keys()) == list(res.keys())) 88 | imgIds = list(gts.keys()) 89 | 90 | score = [] 91 | for id in imgIds: 92 | hypo = res[id] 93 | ref = gts[id] 94 | 95 | score.append(self.calc_score(hypo, ref)) 96 | 97 | # Sanity check. 98 | assert(type(hypo) is list) 99 | assert(len(hypo) == 1) 100 | assert(type(ref) is list) 101 | assert(len(ref) > 0) 102 | 103 | average_score = np.mean(np.array(score)) 104 | return average_score, np.array(score) 105 | 106 | def method(self): 107 | return "Rouge" 108 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | import sys 16 | import subprocess 17 | import tempfile 18 | import itertools 19 | 20 | # path to the stanford corenlp jar 21 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 22 | 23 | # punctuations to be removed from the sentences 24 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 25 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 26 | 27 | class PTBTokenizer: 28 | """Python wrapper of Stanford PTBTokenizer""" 29 | 30 | def tokenize(self, captions_for_image): 31 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 32 | 'edu.stanford.nlp.process.PTBTokenizer', \ 33 | '-preserveLines', '-lowerCase'] 34 | 35 | # ====================================================== 36 | # prepare data for PTB Tokenizer 37 | # ====================================================== 38 | final_tokenized_captions_for_image = {} 39 | image_id = [k for k, v in list(captions_for_image.items()) for _ in range(len(v))] 40 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in list(captions_for_image.items()) for c in v]) 41 | 42 | # ====================================================== 43 | # save sentences to temporary file 44 | # ====================================================== 45 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 46 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 47 | tmp_file.write(sentences.encode('utf-8')) 48 | tmp_file.close() 49 | 50 | # ====================================================== 51 | # tokenize sentence 52 | # ====================================================== 53 | cmd.append(os.path.basename(tmp_file.name)) 54 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 55 | stdout=subprocess.PIPE) 56 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 57 | lines = token_lines.decode("utf-8").split('\n') 58 | # remove temp file 59 | os.remove(tmp_file.name) 60 | 61 | # ====================================================== 62 | # create dictionary for tokenized captions 63 | # ====================================================== 64 | for k, line in zip(image_id, lines): 65 | if not k in final_tokenized_captions_for_image: 66 | final_tokenized_captions_for_image[k] = [] 67 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 68 | if w not in PUNCTUATIONS]) 69 | final_tokenized_captions_for_image[k].append(tokenized_caption) 70 | 71 | return final_tokenized_captions_for_image 72 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp0h6fcu13: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp0h6fcu13 -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp0tjxfbx9: -------------------------------------------------------------------------------- 1 | a collage of photos of plants and flowers 2 | a person sitting in a field with a skateboard 3 | a basket full of apples sitting on top of a table 4 | a man in a yellow jacket is skiing down a hill 5 | a couple of people standing next to each other 6 | a nintendo wii controller sitting on a table 7 | a man riding a snow mobile pulling a man on a snowboard 8 | a little girl walking down a street with an umbrella -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp6p13kv9q: -------------------------------------------------------------------------------- 1 | the woman in the red dress looks at her man at a table 2 | a skier with a black jacket and blue pants 3 | a reflection of people walking with luggage in a wet sidewalk 4 | a woman that is in a fridge with many types of drinks 5 | the bathroom is clean and ready for the guests to use 6 | some cars and traffic lights a building and a street light 7 | a bowl of dip with sliced carrots cucumbers and tomatoes 8 | a man sitting down with a slice of pizza -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp_j4nl228: -------------------------------------------------------------------------------- 1 | two sheep standing next to each other on a lush green hillside 2 | a top sign pole with plants grown on it 3 | there is a bed with a metal frame and no mattress 4 | roses in full bloom sit in a vase 5 | a clock in a statue showing the time 6 | two zebras and a giraffe grazing on grass and trees 7 | an apple has a knife stuck in it white it sits on a shiny surface 8 | a large black refrigerator is in a kitchen -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp_t697skg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmp_t697skg -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpagmqx2xa: -------------------------------------------------------------------------------- 1 | ambulance in route to hospital on a busy london street 2 | a girl holding a doughnut close to her face 3 | a man holds up his middle finger to a parking meter 4 | a woman looks on as an older woman reads a magazine 5 | people stand near some canadian flags at the base of a mountain 6 | two young children eat on a porch bench 7 | a boy stands in front of a skateboard in a yard covered in snow 8 | two sheep standing in brambles on a grassy hillside -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpcgw5utq8: -------------------------------------------------------------------------------- 1 | a man is standing next to a bicycle on a street 2 | two men in yellow jackets are eating at a table 3 | a large air force plane sitting on top of an airport tarmac 4 | a woman is playing tennis on a tennis court 5 | a room with a lot of bookshelves and a chair 6 | two young girls standing next to a fence with a giraffe behind them 7 | a baseball player is sliding into home plate 8 | a clock on a brick building with windows -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpfpxy5t7t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpfpxy5t7t -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpfqd7mk6p: -------------------------------------------------------------------------------- 1 | an airplane that has a large black propeller 2 | a wall that has a bunch of bunches of bananas on it 3 | a man and a women eating and drinking a drink 4 | a person is lighting a candle on a decorated cake 5 | a skateboarder is jumping over a pile of dirt 6 | a living room filled with furniture and a fireplace 7 | some cute stuffed animals sitting on a bed 8 | a batter in a baseball game swinging a bat -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpi4w8y1s9: -------------------------------------------------------------------------------- 1 | a large tray holding an assortment of decorated doughnuts 2 | there are fur apples that are in the bowl on the table 3 | an electric train traveling next to its station 4 | a zebra stands and eats in a grassy area 5 | baby in crib looking to large bear stuffed animal 6 | an old picture shows a bride and groom 7 | zebras graze in a field but one remains alert for predators 8 | a subway car filled with people and plants -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpjve_qo10: -------------------------------------------------------------------------------- 1 | a white toilet bowl with water spray from the faucets 2 | a man and a woman that are using wii game controllers 3 | two people laying in bed with food in their hands 4 | three urinals are on one wall in a restroom 5 | a large bouquet of colorful flowers sits in a plastic container 6 | the plane is flying over the water on a cloudy day 7 | a boat and boat next to a building with a pier in it 8 | a horse drawn carriage parked in front of a white building -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpk_do254y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpk_do254y -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpkwuh2_su: -------------------------------------------------------------------------------- 1 | a toilet that is covered with a shower curtain 2 | the cake is cut up and ready for a slice 3 | a man talking on a cell phone in an outside setting 4 | two buses parked next to each other with two buses nearby 5 | the elephant is standing alone in the zoo 6 | two young girl making a green beans dish with utensils in a kitchen 7 | a close up of a branch of a plant with a bug on the leaves 8 | a player is getting ready to hit the ball on the court -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpqf8v96v2: -------------------------------------------------------------------------------- 1 | kitchen counter cluttered with cake and ice cream buckets 2 | a bathroom with a counter and sink inside of it 3 | a very fake attempt at showing a woman committing suicide in a bath 4 | a hand holding three oranges in front of a white wall 5 | a mustachioed cowboy riding a horse in a rodeo 6 | a large church with a towering clock tower above it 7 | a man being led by two white cows 8 | two men surfboarding along a large ocean wave -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpqpusdi9z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpqpusdi9z -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpsohw7poj: -------------------------------------------------------------------------------- 1 | a bride and groom are cutting a birthday cake 2 | a large white bus inside of a building 3 | a close up of a young child sitting on a luggage bag 4 | a snow skier is on the snow with red pants 5 | a side view of a kitchen with french doors 6 | a tractor trailer with safeway written on the side 7 | the man on the motorcycle is traveling down the road 8 | pit bull playing with soccer ball in the grass -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpsuczi0e7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpsuczi0e7 -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpswwdg3wt: -------------------------------------------------------------------------------- 1 | a man is sitting down next to a pizza on a pan 2 | a street in a city with tall buildings 3 | a number of zebras stand near one another in a field 4 | the large airplane is flying high over a mountain range 5 | a white toilet and sink that are in a bathroom 6 | a man is posing for a picture in the room holding a drink in his hand 7 | two people in a canoe with vegetables on top of a river 8 | a man holding two cell phones in his hand -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpv72jh5ig: -------------------------------------------------------------------------------- 1 | a person in a green jacket doing a snowboarding trick on a rail 2 | two women are riding in a horse drawn carriage 3 | a pizza on some rack on the kitchen counter 4 | a woman and a dog with a dog sticking its tongue 5 | a photo of a water stream in the sky 6 | a very big red bus on the street near trees 7 | a white bench is near a large rock wall 8 | a group of people riding skis down a snow covered slope -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpvr9bp8m8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpvr9bp8m8 -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpw7hxvxkw: -------------------------------------------------------------------------------- 1 | a meal is served with a variety of vegetables 2 | a person riding a surf board on a wave 3 | the woman is standing on the grass talking on a cell phone 4 | a white and blue truck driving down a highway 5 | a sail boat that is on a body of water 6 | clothes and a ladder are on display in a room 7 | a young woman is sitting at a restaurant table in front of a large pizza 8 | a vintage picture with a man holding a beer bottle and hat -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpz5ekywsl: -------------------------------------------------------------------------------- 1 | a rack with large number of different colored ties hanging from it 2 | an passenger train travels down the track under power lines 3 | a person that is in a room with a table 4 | a baseball game is in progress with the stands filled 5 | a cellphone is sitting open on top of a workbook 6 | a man pitching a ball towards the batter 7 | a dog wearing a black bandanna stands near a white counter and a window with plants in it 8 | a big brown teddy bear sitting in a toy store -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpzkac7gs7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/tokenizer/tmpzkac7gs7 -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/wmd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/evaluation_tools/caption_tools/pycocoevalcap/wmd/__init__.py -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/wmd/data/stopwords.txt: -------------------------------------------------------------------------------- 1 | i me my myself we our ours ourselves you your yours yourself yourselves he him his himself she her hers herself it its itself they them their theirs themselves what which who whom this that these those am is are was were be been being have has had having do does did doing a an the and but if or because as until while of at by for with about against between into through during before after above below to from up down in out on off over under again further then once here there when where why how all any both each few more most other some such no nor not only own same so than too very s t can will just don should now d ll m o re ve y ain aren couldn didn doesn hadn hasn haven isn ma mightn mustn needn shan shouldn wasn weren won wouldn -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocoevalcap/wmd/wmd.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Special thanks to Mert Kilickaya, first author of 'Re-evaluating Automatic Metrics for Image Captioning' [http://aclweb.org/anthology/E17-1019] for giving exact instructions on how to implement the Word Mover's Distance metric here. 3 | ''' 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | # import gensim 10 | import os 11 | 12 | class WMD: 13 | 14 | def __init__(self): 15 | with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data', 'stopwords.txt'), 'rb') as f: 16 | self.stop_words = set(f.read().decode('utf-8').strip().split(' ')) #Stop words were taken from NLTK nltk.stopwords.words('english') 17 | self.model = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data', 'GoogleNews-vectors-negative300.bin'), binary=True) 18 | self.sigma = 1.0 19 | 20 | def calc_score(self, candidate, refs): 21 | scores = list() 22 | c_tokens = [ token for token in candidate[0].split(' ') if token not in self.stop_words ] 23 | for ref in refs: 24 | r_tokens = [ token for token in ref.split(' ') if token not in self.stop_words ] 25 | dist = self.model.wmdistance(c_tokens, r_tokens) 26 | score = np.exp(-dist/self.sigma) 27 | scores.append(score) 28 | return max(scores) 29 | 30 | def compute_score(self, gts, res): 31 | assert(sorted(gts.keys()) == sorted(res.keys())) 32 | imgIds = sorted(gts.keys()) 33 | 34 | score = [] 35 | for id in imgIds: 36 | hypo = res[id] 37 | ref = gts[id] 38 | 39 | score.append(self.calc_score(hypo, ref)) 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) >= 1) 46 | 47 | average_score = np.mean(np.array(score)) 48 | return average_score, np.array(score) 49 | 50 | def method(self): 51 | return "WMD" 52 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /evaluation_tools/caption_tools/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | __author__ = 'tylin' 5 | __version__ = '1.0.1' 6 | # Interface for accessing the Microsoft COCO dataset. 7 | 8 | # Microsoft COCO is a large image dataset designed for object detection, 9 | # segmentation, and caption generation. pycocotools is a Python API that 10 | # assists in loading, parsing and visualizing the annotations in COCO. 11 | # Please visit http://mscoco.org/ for more information on COCO, including 12 | # for the data, paper, and tutorials. The exact format of the annotations 13 | # is also described on the COCO website. For example usage of the pycocotools 14 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 15 | # the COCO images and annotations in order to run the demo. 16 | 17 | # An alternative to using the API is to load the annotations directly 18 | # into Python dictionary 19 | # Using the API provides additional utility functions. Note that this API 20 | # supports both *instance* and *caption* annotations. In the case of 21 | # captions not all functions are defined (e.g. categories are undefined). 22 | 23 | # The following API functions are defined: 24 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 25 | # decodeMask - Decode binary mask M encoded via run-length encoding. 26 | # encodeMask - Encode binary mask M using run-length encoding. 27 | # getAnnIds - Get ann ids that satisfy given filter conditions. 28 | # getCatIds - Get cat ids that satisfy given filter conditions. 29 | # getImgIds - Get img ids that satisfy given filter conditions. 30 | # loadAnns - Load anns with the specified ids. 31 | # loadCats - Load cats with the specified ids. 32 | # loadImgs - Load imgs with the specified ids. 33 | # segToMask - Convert polygon segmentation to binary mask. 34 | # showAnns - Display the specified annotations. 35 | # loadRes - Load result file and create result api object. 36 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 37 | # Help on each functions can be accessed by: "help COCO>function". 38 | 39 | # See also COCO>decodeMask, 40 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 41 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 42 | # COCO>loadImgs, COCO>segToMask, COCO>showAnns 43 | 44 | # Microsoft COCO Toolbox. Version 1.0 45 | # Data, paper, and tutorials available at: http://mscoco.org/ 46 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 47 | # Licensed under the Simplified BSD License [see bsd.txt] 48 | 49 | import json 50 | import datetime 51 | # import matplotlib.pyplot as plt 52 | # from matplotlib.collections import PatchCollection 53 | # from matplotlib.patches import Polygon 54 | import numpy as np 55 | # from skimage.draw import polygon 56 | import copy 57 | import ipdb 58 | 59 | class COCO: 60 | def __init__(self, annotation_file=None): 61 | """ 62 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 63 | :param annotation_file (str): location of annotation file 64 | :param image_folder (str): location to the folder that hosts images. 65 | :return: 66 | """ 67 | # load dataset 68 | self.dataset = {} 69 | self.anns = [] 70 | self.videoToAnns = {} 71 | self.videos = [] 72 | if not annotation_file == None: 73 | print('loading annotations into memory...') 74 | time_t = datetime.datetime.utcnow() 75 | dataset = json.load(open(annotation_file, 'r')) 76 | if 'type' not in dataset: 77 | dataset['type']='caption' 78 | print(datetime.datetime.utcnow() - time_t) 79 | self.dataset = dataset 80 | self.createIndex() 81 | 82 | def createIndex(self): 83 | # create index 84 | print('creating index...') 85 | videoToAnns = {ann['video_id']: [] for ann in self.dataset['annotations']} 86 | #ipdb.set_trace() 87 | #anns = {ann['sen_id']: [] for ann in self.dataset['annotations']} 88 | for ann in self.dataset['annotations']: 89 | videoToAnns[ann['video_id']] += [ann] 90 | #anns[ann['sen_id']] = ann 91 | 92 | # videos = {vi['video_id']: {} for vi in self.dataset['videos']} 93 | # for vi in self.dataset['videos']: 94 | # videos[vi['video_id']] = vi 95 | 96 | videos = {vi['video_id']: {} for vi in self.dataset['annotations']} 97 | 98 | print('index created!') 99 | 100 | # create class members 101 | #self.anns = anns 102 | self.videoToAnns = videoToAnns 103 | self.videos = videos 104 | 105 | 106 | def info(self): 107 | """ 108 | Print information about the annotation file. 109 | :return: 110 | """ 111 | for key, value in list(self.datset['info'].items()): 112 | print('%s: %s'%(key, value)) 113 | 114 | 115 | def getVideoIds(self): 116 | return list(self.videos.keys()) 117 | 118 | 119 | def loadRes(self, resFile): 120 | """ 121 | Load result file and return a result api object. 122 | :param resFile (str) : file name of result file 123 | :return: res (obj) : result api object 124 | """ 125 | res = COCO() 126 | #res.dataset['videos'] = [vi for vi in self.dataset['videos']] 127 | #res.dataset['info'] = copy.deepcopy(self.dataset['info']) 128 | #res.dataset['type'] = copy.deepcopy(self.dataset['type']) 129 | 130 | 131 | print('Loading and preparing results... ') 132 | time_t = datetime.datetime.utcnow() 133 | #anns = json.load(open(resFile)) 134 | anns = resFile 135 | assert type(anns) == list, 'results in not an array of objects' 136 | annsVideoIds = [ann['video_id'] for ann in anns] 137 | # print(len(set(annsVideoIds))) 138 | # print(set(annsVideoIds)) 139 | # print(set(self.getVideoIds())) 140 | # print(len(set(self.getVideoIds()))) 141 | # import ipdb 142 | # ipdb.set_trace() 143 | assert set(annsVideoIds) == (set(annsVideoIds) & set(self.getVideoIds())), \ 144 | 'Results do not correspond to current coco set' 145 | 146 | #videoIds = set([vi['video_id'] for vi in res.dataset['videos']]) & set([ann['video_id'] for ann in anns]) 147 | #res.dataset['videos'] = [vi for vi in res.dataset['videos'] if vi['video_id'] in videoIds] 148 | # for id, ann in enumerate(anns): 149 | # ann['sen_id'] = id 150 | 151 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 152 | 153 | res.dataset['annotations'] = anns 154 | res.createIndex() 155 | return res 156 | -------------------------------------------------------------------------------- /evaluation_tools/vqa_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = "aagrawal" 9 | -------------------------------------------------------------------------------- /img/VAST-model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/img/VAST-model.jpg -------------------------------------------------------------------------------- /img/radar_compare_alldata_vast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/img/radar_compare_alldata_vast.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .vast import VAST 2 | model_registry = { 3 | 'vast':VAST 4 | } -------------------------------------------------------------------------------- /model/vision_encoders/clip/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | import ftfy 8 | import regex as re 9 | import torch 10 | 11 | 12 | @lru_cache() 13 | def default_bpe(): 14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class SimpleTokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | 137 | 138 | _tokenizer = SimpleTokenizer() 139 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 140 | """ 141 | Returns the tokenized representation of given input string(s) 142 | 143 | Parameters 144 | ---------- 145 | texts : Union[str, List[str]] 146 | An input string or a list of input strings to tokenize 147 | 148 | context_length : int 149 | The context length to use; all CLIP models use 77 as the context length 150 | 151 | truncate: bool 152 | Whether to truncate the text in case its encoding is longer than the context length 153 | 154 | Returns 155 | ------- 156 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 157 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 158 | """ 159 | if isinstance(texts, str): 160 | texts = [texts] 161 | 162 | sot_token = _tokenizer.encoder["<|startoftext|>"] 163 | eot_token = _tokenizer.encoder["<|endoftext|>"] 164 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 165 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 166 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 167 | else: 168 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 169 | 170 | for i, tokens in enumerate(all_tokens): 171 | if len(tokens) > context_length: 172 | if truncate: 173 | tokens = tokens[:context_length] 174 | tokens[-1] = eot_token 175 | else: 176 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 177 | result[i, :len(tokens)] = torch.tensor(tokens) 178 | 179 | return result 180 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | # from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/loss.py: -------------------------------------------------------------------------------- 1 | # import math 2 | # import torch 3 | # import torch.nn as nn 4 | # from torch.nn import functional as F 5 | 6 | # try: 7 | # import torch.distributed.nn 8 | # from torch import distributed as dist 9 | # has_distributed = True 10 | # except ImportError: 11 | # has_distributed = False 12 | 13 | # try: 14 | # import horovod.torch as hvd 15 | # except ImportError: 16 | # hvd = None 17 | 18 | # from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | # def gather_features( 22 | # image_features, 23 | # text_features, 24 | # local_loss=False, 25 | # gather_with_grad=False, 26 | # rank=0, 27 | # world_size=1, 28 | # use_horovod=False 29 | # ): 30 | # assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | # if use_horovod: 32 | # assert hvd is not None, 'Please install horovod' 33 | # if gather_with_grad: 34 | # all_image_features = hvd.allgather(image_features) 35 | # all_text_features = hvd.allgather(text_features) 36 | # else: 37 | # with torch.no_grad(): 38 | # all_image_features = hvd.allgather(image_features) 39 | # all_text_features = hvd.allgather(text_features) 40 | # if not local_loss: 41 | # # ensure grads for local rank when all_* features don't have a gradient 42 | # gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | # gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | # gathered_image_features[rank] = image_features 45 | # gathered_text_features[rank] = text_features 46 | # all_image_features = torch.cat(gathered_image_features, dim=0) 47 | # all_text_features = torch.cat(gathered_text_features, dim=0) 48 | # else: 49 | # # We gather tensors from all gpus 50 | # if gather_with_grad: 51 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | # # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 54 | # # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 55 | # else: 56 | # gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 57 | # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 58 | # dist.all_gather(gathered_image_features, image_features) 59 | # dist.all_gather(gathered_text_features, text_features) 60 | # if not local_loss: 61 | # # ensure grads for local rank when all_* features don't have a gradient 62 | # gathered_image_features[rank] = image_features 63 | # gathered_text_features[rank] = text_features 64 | # all_image_features = torch.cat(gathered_image_features, dim=0) 65 | # all_text_features = torch.cat(gathered_text_features, dim=0) 66 | 67 | # return all_image_features, all_text_features 68 | 69 | 70 | # class ClipLoss(nn.Module): 71 | 72 | # def __init__( 73 | # self, 74 | # local_loss=False, 75 | # gather_with_grad=False, 76 | # cache_labels=False, 77 | # rank=0, 78 | # world_size=1, 79 | # use_horovod=False, 80 | # smoothing=0., 81 | # ): 82 | # super().__init__() 83 | # self.local_loss = local_loss 84 | # self.gather_with_grad = gather_with_grad 85 | # self.cache_labels = cache_labels 86 | # self.rank = rank 87 | # self.world_size = world_size 88 | # self.use_horovod = use_horovod 89 | # self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 90 | 91 | # # cache state 92 | # self.prev_num_logits = 0 93 | # self.labels = {} 94 | 95 | # def forward(self, image_features, text_features, logit_scale=1.): 96 | # device = image_features.device 97 | # if self.world_size > 1: 98 | # all_image_features, all_text_features = gather_features( 99 | # image_features, text_features, 100 | # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 101 | 102 | # if self.local_loss: 103 | # logits_per_image = logit_scale * image_features @ all_text_features.T 104 | # logits_per_text = logit_scale * text_features @ all_image_features.T 105 | # else: 106 | # logits_per_image = logit_scale * all_image_features @ all_text_features.T 107 | # logits_per_text = logits_per_image.T 108 | # else: 109 | # logits_per_image = logit_scale * image_features @ text_features.T 110 | # logits_per_text = logit_scale * text_features @ image_features.T 111 | # # calculated ground-truth and cache if enabled 112 | # num_logits = logits_per_image.shape[0] 113 | # if self.prev_num_logits != num_logits or device not in self.labels: 114 | # labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | # if self.world_size > 1 and self.local_loss: 116 | # labels = labels + num_logits * self.rank 117 | # if self.cache_labels: 118 | # self.labels[device] = labels 119 | # self.prev_num_logits = num_logits 120 | # else: 121 | # labels = self.labels[device] 122 | 123 | # if self.label_smoothing_cross_entropy: 124 | # total_loss = ( 125 | # self.label_smoothing_cross_entropy(logits_per_image, labels) + 126 | # self.label_smoothing_cross_entropy(logits_per_text, labels) 127 | # ) / 2 128 | # else: 129 | # total_loss = ( 130 | # F.cross_entropy(logits_per_image, labels) + 131 | # F.cross_entropy(logits_per_text, labels) 132 | # ) / 2 133 | 134 | # acc = None 135 | # i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 136 | # t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 137 | # acc = {"i2t": i2t_acc, "t2i": t2i_acc} 138 | # return total_loss, acc -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | def broadcat(tensors, dim = -1): 8 | num_tensors = len(tensors) 9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 11 | shape_len = list(shape_lens)[0] 12 | dim = (dim + shape_len) if dim < 0 else dim 13 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 18 | expanded_dims.insert(dim, (dim, dims[dim])) 19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 21 | return torch.cat(tensors, dim = dim) 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 25 | x1, x2 = x.unbind(dim = -1) 26 | x = torch.stack((-x2, x1), dim = -1) 27 | return rearrange(x, '... d r -> ... (d r)') 28 | 29 | 30 | class VisionRotaryEmbedding(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | pt_seq_len, 35 | ft_seq_len=None, 36 | custom_freqs = None, 37 | freqs_for = 'lang', 38 | theta = 10000, 39 | max_freq = 10, 40 | num_freqs = 1, 41 | ): 42 | super().__init__() 43 | if custom_freqs: 44 | freqs = custom_freqs 45 | elif freqs_for == 'lang': 46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 47 | elif freqs_for == 'pixel': 48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 49 | elif freqs_for == 'constant': 50 | freqs = torch.ones(num_freqs).float() 51 | else: 52 | raise ValueError(f'unknown modality {freqs_for}') 53 | 54 | if ft_seq_len is None: ft_seq_len = pt_seq_len 55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 56 | 57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 59 | 60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 62 | 63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 64 | 65 | self.register_buffer("freqs_cos", freqs.cos()) 66 | self.register_buffer("freqs_sin", freqs.sin()) 67 | 68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 69 | 70 | def forward(self, t, start_index = 0): 71 | rot_dim = self.freqs_cos.shape[-1] 72 | end_index = start_index + rot_dim 73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 76 | 77 | return torch.cat((t_left, t, t_right), dim = -1) 78 | 79 | class VisionRotaryEmbeddingFast(nn.Module): 80 | def __init__( 81 | self, 82 | dim, 83 | pt_seq_len, 84 | ft_seq_len=None, 85 | custom_freqs = None, 86 | freqs_for = 'lang', 87 | theta = 10000, 88 | max_freq = 10, 89 | num_freqs = 1, 90 | patch_dropout = 0. 91 | ): 92 | super().__init__() 93 | if custom_freqs: 94 | freqs = custom_freqs 95 | elif freqs_for == 'lang': 96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | elif freqs_for == 'pixel': 98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 99 | elif freqs_for == 'constant': 100 | freqs = torch.ones(num_freqs).float() 101 | else: 102 | raise ValueError(f'unknown modality {freqs_for}') 103 | 104 | if ft_seq_len is None: ft_seq_len = pt_seq_len 105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 106 | 107 | freqs = torch.einsum('..., f -> ... f', t, freqs) 108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 110 | 111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 113 | 114 | self.patch_dropout = patch_dropout 115 | 116 | self.register_buffer("freqs_cos", freqs_cos) 117 | self.register_buffer("freqs_sin", freqs_sin) 118 | 119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 120 | 121 | def forward(self, t, patch_indices_keep=None): 122 | if patch_indices_keep is not None: 123 | batch = t.size()[0] 124 | batch_indices = torch.arange(batch) 125 | batch_indices = batch_indices[..., None] 126 | 127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 129 | 130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') 132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') 134 | 135 | return t * freqs_cos + rotate_half(t) * freqs_sin 136 | 137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /model/vision_encoders/evaclip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /preinstall.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 2 | pip install torchvision==0.15.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 3 | pip install torchaudio==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 4 | pip install decord -i https://pypi.tuna.tsinghua.edu.cn/simple 5 | pip install h5py -i https://pypi.tuna.tsinghua.edu.cn/simple 6 | pip install ffmpeg-python -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple 9 | pip install ipdb -i https://pypi.tuna.tsinghua.edu.cn/simple 10 | pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple 11 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | pip install transformers==4.31.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 13 | pip install webdataset -i https://pypi.tuna.tsinghua.edu.cn/simple 14 | pip install SentencePiece -i https://pypi.tuna.tsinghua.edu.cn/simple 15 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torch.distributed as dist 5 | from utils.args import get_args,logging_cfgs 6 | from utils.initialize import initialize 7 | from utils.build_model import build_model 8 | from utils.build_optimizer import build_optimizer 9 | from utils.build_dataloader import create_train_dataloaders, create_val_dataloaders 10 | from utils.pipeline import train, test 11 | 12 | 13 | def main(): 14 | 15 | ### init 16 | 17 | args = get_args() 18 | initialize(args) 19 | 20 | ### logging cfgs 21 | logging_cfgs(args) 22 | 23 | 24 | if args.run_cfg.mode == 'training': 25 | 26 | ### create datasets and dataloader 27 | train_loader = create_train_dataloaders(args) 28 | val_loaders = create_val_dataloaders(args) 29 | 30 | ### build model and optimizer 31 | 32 | model, optimizer_ckpt, start_step = build_model(args) 33 | 34 | optimizer = build_optimizer(model, args, optimizer_ckpt) 35 | 36 | 37 | ### start evaluation 38 | if args.run_cfg.first_eval or args.run_cfg.zero_shot: 39 | test(model, val_loaders, args.run_cfg) 40 | if args.run_cfg.zero_shot: 41 | return 42 | 43 | ### start training 44 | 45 | 46 | train(model, optimizer, train_loader, val_loaders, args.run_cfg, start_step = start_step, verbose_time=False) 47 | 48 | elif args.run_cfg.mode == 'testing': 49 | ### build model 50 | model,_,_ = build_model(args) 51 | 52 | ### create datasets and dataloader 53 | val_loaders = create_val_dataloaders(args) 54 | 55 | ### start evaluation 56 | test(model, val_loaders, args.run_cfg) 57 | 58 | else: 59 | raise NotImplementedError 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /scripts/vast/audio_captioner.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch \ 2 | --nnodes 1 \ 3 | --node_rank 0 \ 4 | --nproc_per_node 8 \ 5 | --master_port 9814 \ 6 | ./run.py \ 7 | --config ./config/vast/captioner_cfg/caption-generation.json \ 8 | --pretrain_dir './output/vast/audio_captioner' \ 9 | --output_dir './output/vast/audio_captioner/generation' \ 10 | --test_batch_size 128 \ 11 | --generate_nums 3 \ 12 | --captioner_mode true \ -------------------------------------------------------------------------------- /scripts/vast/finetune_cap.sh: -------------------------------------------------------------------------------- 1 | config_name='pretrain_vast' 2 | output_dir=./output/vast/$config_name 3 | 4 | 5 | 6 | 7 | ##### VIDEO-CAP 8 | 9 | # caption-msrvtt 10 | python3 -m torch.distributed.launch \ 11 | --nnodes 1 \ 12 | --node_rank 0 \ 13 | --nproc_per_node 8 \ 14 | --master_port 9634 \ 15 | ./run.py \ 16 | --learning_rate 2e-5 \ 17 | --train_batch_size 128 \ 18 | --train_epoch 10 \ 19 | --checkpointing true \ 20 | --save_best true \ 21 | --config ./config/vast/finetune_cfg/caption-msrvtt.json \ 22 | --pretrain_dir $output_dir \ 23 | --beam_size 3 \ 24 | --first_eval false \ 25 | --output_dir $output_dir/downstream/caption-msrvtt-tvas \ 26 | 27 | # # caption-msvd 28 | # python3 -m torch.distributed.launch \ 29 | # --nnodes 1 \ 30 | # --node_rank 0 \ 31 | # --nproc_per_node 8 \ 32 | # --master_port 9834 \ 33 | # ./run.py \ 34 | # --learning_rate 1e-5 \ 35 | # --checkpointing true \ 36 | # --config ./config/vast/finetune_cfg/caption-msvd.json \ 37 | # --pretrain_dir $output_dir \ 38 | # --save_best true \ 39 | # --output_dir $output_dir/downstream/caption-msvd \ 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | # # caption-youcook 48 | # python3 -m torch.distributed.launch \ 49 | # --nnodes 1 \ 50 | # --node_rank 0 \ 51 | # --nproc_per_node 8 \ 52 | # --master_port 9834 \ 53 | # ./run.py \ 54 | # --learning_rate 3e-5 \ 55 | # --checkpointing true \ 56 | # --config ./config/vast/finetune_cfg/caption-youcook.json \ 57 | # --pretrain_dir $output_dir \ 58 | # --save_best true \ 59 | # --output_dir $output_dir/downstream/caption-youcook \ 60 | 61 | 62 | 63 | # # caption-valor32k 64 | # python3 -m torch.distributed.launch \ 65 | # --nnodes 1 \ 66 | # --node_rank 0 \ 67 | # --nproc_per_node 8 \ 68 | # --master_port 9834 \ 69 | # ./run.py \ 70 | # --learning_rate 2e-5 \ 71 | # --pretrain_dir $output_dir \ 72 | # --checkpointing true \ 73 | # --config ./config/vast/finetune_cfg/caption-valor32k.json \ 74 | # --output_dir $output_dir/downstream/caption-valor32k \ 75 | 76 | 77 | 78 | 79 | 80 | # # caption-tv 81 | # python3 -m torch.distributed.launch \ 82 | # --nnodes 1 \ 83 | # --node_rank 0 \ 84 | # --nproc_per_node 8 \ 85 | # --master_port 9834 \ 86 | # ./run.py \ 87 | # --learning_rate 3e-5 \ 88 | # --checkpointing true \ 89 | # --train_epoch 40 \ 90 | # --save_best true \ 91 | # --config ./config/vast/finetune_cfg/caption-tv.json \ 92 | # --pretrain_dir $output_dir \ 93 | # --output_dir $output_dir/downstream/caption-tv \ 94 | 95 | 96 | # # caption-vatex 97 | # python3 -m torch.distributed.launch \ 98 | # --nnodes 1 \ 99 | # --node_rank 0 \ 100 | # --nproc_per_node 8 \ 101 | # --master_port 9834 \ 102 | # ./run.py \ 103 | # --learning_rate 2e-5 \ 104 | # --checkpointing true \ 105 | # --config ./config/vast/finetune_cfg/caption-vatex.json \ 106 | # --pretrain_dir $output_dir \ 107 | # --output_dir $output_dir/downstream/caption-vatex \ 108 | # --save_best true \ 109 | 110 | 111 | # #### AUDIO-CAP 112 | 113 | 114 | # # caption-clothov2 115 | # python3 -m torch.distributed.launch \ 116 | # --nnodes 1 \ 117 | # --node_rank 0 \ 118 | # --nproc_per_node 8 \ 119 | # --master_port 9834 \ 120 | # ./run.py \ 121 | # --learning_rate 2e-5 \ 122 | # --checkpointing true \ 123 | # --config ./config/vast/finetune_cfg/caption-clothov2.json \ 124 | # --pretrain_dir $output_dir \ 125 | # --save_best true \ 126 | # --output_dir $output_dir/downstream/caption-clothov2 127 | 128 | 129 | # # caption-audiocaps 130 | # python3 -m torch.distributed.launch \ 131 | # --nnodes 1 \ 132 | # --node_rank 0 \ 133 | # --nproc_per_node 8 \ 134 | # --master_port 9834 \ 135 | # ./run.py \ 136 | # --learning_rate 2e-5 \ 137 | # --checkpointing true \ 138 | # --config ./config/vast/finetune_cfg/caption-audiocaps.json \ 139 | # --pretrain_dir $output_dir \ 140 | # --save_best true \ 141 | # --output_dir $output_dir/downstream/caption-audiocaps 142 | 143 | 144 | # #### IMAGE-CAP 145 | 146 | # # caption-mscoco 147 | # python3 -m torch.distributed.launch \ 148 | # --nnodes 1 \ 149 | # --node_rank 0 \ 150 | # --nproc_per_node 8 \ 151 | # --master_port 9834 \ 152 | # ./run.py \ 153 | # --learning_rate 1e-5 \ 154 | # --config ./config/vast/finetune_cfg/caption-mscoco.json \ 155 | # --pretrain_dir $output_dir \ 156 | # --vision_resolution 480 \ 157 | # --output_dir $output_dir/downstream/caption-mscoco \ 158 | # --checkpointing false \ 159 | # --save_best true \ 160 | 161 | -------------------------------------------------------------------------------- /scripts/vast/finetune_qa.sh: -------------------------------------------------------------------------------- 1 | config_name='pretrain_vast' 2 | output_dir=./output/vast/$config_name 3 | 4 | 5 | ### VIDEO-QA 6 | 7 | 8 | # vqa-msrvtt 9 | python3 -m torch.distributed.launch \ 10 | --nnodes 1 \ 11 | --node_rank 0 \ 12 | --nproc_per_node $ARNOLD_WORKER_GPU \ 13 | --master_port 9834 \ 14 | ./run.py \ 15 | --learning_rate 2e-5 \ 16 | --checkpointing true \ 17 | --beam_size 1 \ 18 | --config ./config/vast/finetune_cfg/VQA-msrvtt.json \ 19 | --first_eval false \ 20 | --save_best true \ 21 | --valid_freq 1 \ 22 | --pretrain_dir $output_dir \ 23 | --output_dir $output_dir/downstream/VQA-msrvtt \ 24 | 25 | 26 | 27 | # vqa-msvd 28 | python3 -m torch.distributed.launch \ 29 | --nnodes 1 \ 30 | --node_rank 0 \ 31 | --nproc_per_node 8 \ 32 | --master_port 9834 \ 33 | ./run.py \ 34 | --learning_rate 1e-5 \ 35 | --checkpointing true \ 36 | --first_eval false \ 37 | --config ./config/vast/finetune_cfg/VQA-msvd.json \ 38 | --pretrain_dir $output_dir \ 39 | --save_best true \ 40 | --output_dir $output_dir/downstream/VQA-msvd \ 41 | 42 | 43 | 44 | 45 | 46 | 47 | # vqa-tgif-frame 48 | python3 -m torch.distributed.launch \ 49 | --nnodes 1 \ 50 | --node_rank 0 \ 51 | --nproc_per_node 8 \ 52 | --master_port 9834 \ 53 | ./run.py \ 54 | --learning_rate 2e-5 \ 55 | --checkpointing true \ 56 | --first_eval false \ 57 | --save_best true \ 58 | --pretrain_dir $output_dir \ 59 | --config ./config/vast/finetune_cfg/VQA-tgif-frame.json \ 60 | --output_dir $output_dir/downstream/VQA-tgif \ 61 | 62 | 63 | 64 | 65 | 66 | 67 | # vqa-anet 68 | python3 -m torch.distributed.launch \ 69 | --nnodes 1 \ 70 | --node_rank 0 \ 71 | --nproc_per_node 8 \ 72 | --master_port 9834 \ 73 | ./run.py \ 74 | --learning_rate 2e-5 \ 75 | --checkpointing true \ 76 | --save_best true \ 77 | --config ./config/vast/finetune_cfg/VQA-activitynet.json \ 78 | --first_eval false \ 79 | --pretrain_dir $output_dir \ 80 | --output_dir $output_dir/downstream/VQA-activitynet \ 81 | 82 | 83 | 84 | # vqa-music-avqa 85 | python3 -m torch.distributed.launch \ 86 | --nnodes 1 \ 87 | --node_rank 0 \ 88 | --nproc_per_node 8 \ 89 | --master_port 9834 \ 90 | ./run.py \ 91 | --learning_rate 2e-5 \ 92 | --pretrain_dir $output_dir \ 93 | --first_eval false \ 94 | --checkpointing true \ 95 | --config ./config/vast/finetune_cfg/VQA-music.json \ 96 | --output_dir $output_dir/downstream/VQA-music 97 | 98 | 99 | 100 | 101 | ### IMAGE-QA 102 | 103 | 104 | # vqa-vqav2 105 | python3 -m torch.distributed.launch \ 106 | --nnodes 1 \ 107 | --node_rank 0 \ 108 | --nproc_per_node 8 \ 109 | --master_port 9824 \ 110 | ./run.py \ 111 | --learning_rate 2e-5 \ 112 | --config ./config/vast/finetune_cfg/VQA-vqav2.json \ 113 | --pretrain_dir $output_dir \ 114 | --first_eval false \ 115 | --vision_resolution 384 \ 116 | --valid_freq 1 \ 117 | --output_dir $output_dir/downstream/VQA-vqav2 \ 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /scripts/vast/finetune_ret.sh: -------------------------------------------------------------------------------- 1 | config_name='pretrain_vast' 2 | output_dir=./output/vast/$config_name 3 | 4 | 5 | 6 | ### VIDEO-RET 7 | 8 | #retrieval-msrvtt 9 | python3 -m torch.distributed.launch \ 10 | --nnodes 1 \ 11 | --node_rank 0 \ 12 | --nproc_per_node 8 \ 13 | --master_port 9834 \ 14 | ./run.py \ 15 | --learning_rate 2e-5 \ 16 | --checkpointing true \ 17 | --first_eval true \ 18 | --save_best true \ 19 | --config ./config/vast/finetune_cfg/retrieval-msrvtt.json \ 20 | --pretrain_dir $output_dir \ 21 | --output_dir $output_dir/downstream/retrieval-msrvtt \ 22 | 23 | 24 | 25 | # #retrieval-vatex 26 | # python3 -m torch.distributed.launch \ 27 | # --nnodes 1 \ 28 | # --node_rank 0 \ 29 | # --nproc_per_node 8 \ 30 | # --master_port 9834 \ 31 | # ./run.py \ 32 | # --learning_rate 2e-5 \ 33 | # --checkpointing true \ 34 | # --config ./config/vast/finetune_cfg/retrieval-vatex.json \ 35 | # --pretrain_dir $output_dir \ 36 | # --save_best true \ 37 | # --output_dir $output_dir/downstream/retrieval-vatex \ 38 | 39 | 40 | 41 | # #retrieval-valor32k 42 | # python3 -m torch.distributed.launch \ 43 | # --nnodes 1 \ 44 | # --node_rank 0 \ 45 | # --nproc_per_node 8 \ 46 | # --master_port 9834 \ 47 | # ./run.py \ 48 | # --learning_rate 2e-5 \ 49 | # --pretrain_dir $output_dir \ 50 | # --checkpointing true \ 51 | # --config ./config/vast/finetune_cfg/retrieval-valor32k.json \ 52 | # --output_dir $output_dir/downstream/retrieval-valor32k \ 53 | 54 | 55 | 56 | 57 | # #retrieval-lsmdc 58 | # python3 -m torch.distributed.launch \ 59 | # --nnodes 1 \ 60 | # --node_rank 0 \ 61 | # --nproc_per_node 8 \ 62 | # --master_port 9834 \ 63 | # ./run.py \ 64 | # --learning_rate 2e-5 \ 65 | # --checkpointing true \ 66 | # --first_eval false \ 67 | # --config ./config/vast/finetune_cfg/retrieval-lsmdc.json \ 68 | # --pretrain_dir $output_dir \ 69 | # --save_best true \ 70 | # --output_dir $output_dir/downstream/retrieval-lsmdc \ 71 | 72 | 73 | # #retrieval-youcook 74 | # python3 -m torch.distributed.launch \ 75 | # --nnodes 1 \ 76 | # --node_rank 0 \ 77 | # --nproc_per_node 8 \ 78 | # --master_port 9834 \ 79 | # ./run.py \ 80 | # --learning_rate 3e-5 \ 81 | # --checkpointing true \ 82 | # --config ./config/vast/finetune_cfg/retrieval-youcook.json \ 83 | # --pretrain_dir $output_dir \ 84 | # --save_best true \ 85 | # --output_dir $output_dir/downstream/retrieval-youcook \ 86 | 87 | 88 | 89 | # #retrieval-didemo 90 | # python3 -m torch.distributed.launch \ 91 | # --nnodes 1 \ 92 | # --node_rank 0 \ 93 | # --nproc_per_node 8 \ 94 | # --master_port 9834 \ 95 | # ./run.py \ 96 | # --learning_rate 2e-5 \ 97 | # --checkpointing true \ 98 | # --config ./config/vast/finetune_cfg/retrieval-didemo.json \ 99 | # --pretrain_dir $output_dir \ 100 | # --save_best true \ 101 | # --output_dir $output_dir/downstream/retrieval-didemo \ 102 | 103 | 104 | # #retrieval-activitynet 105 | # python3 -m torch.distributed.launch \ 106 | # --nnodes 1 \ 107 | # --node_rank 0 \ 108 | # --nproc_per_node 8 \ 109 | # --master_port 9834 \ 110 | # ./run.py \ 111 | # --learning_rate 2e-5 \ 112 | # --checkpointing true \ 113 | # --config ./config/vast/finetune_cfg/retrieval-activitynet.json \ 114 | # --pretrain_dir $output_dir \ 115 | # --output_dir $output_dir/downstream/retrieval-activitynet \ 116 | # --save_best true \ 117 | 118 | 119 | 120 | # ### AUDIO-RET 121 | 122 | # #retrieval-clothov2 123 | # python3 -m torch.distributed.launch \ 124 | # --nnodes 1 \ 125 | # --node_rank 0 \ 126 | # --nproc_per_node 8 \ 127 | # --master_port 9834 \ 128 | # ./run.py \ 129 | # --learning_rate 2e-5 \ 130 | # --checkpointing true \ 131 | # --config ./config/vast/finetune_cfg/retrieval-clothov2.json \ 132 | # --pretrain_dir $output_dir \ 133 | # --output_dir $output_dir/downstream/retrieval-clothov2 \ 134 | 135 | # #retrieval-audiocaps 136 | # python3 -m torch.distributed.launch \ 137 | # --nnodes 1 \ 138 | # --node_rank 0 \ 139 | # --nproc_per_node 8 \ 140 | # --master_port 9834 \ 141 | # ./run.py \ 142 | # --learning_rate 2e-5 \ 143 | # --checkpointing true \ 144 | # --config ./config/vast/finetune_cfg/retrieval-audiocaps.json \ 145 | # --pretrain_dir $output_dir \ 146 | # --output_dir $output_dir/downstream/retrieval-audiocaps 147 | 148 | 149 | 150 | 151 | # ### IMAGE_RET 152 | 153 | # #retrieval-mscoco 154 | # python3 -m torch.distributed.launch \ 155 | # --nproc_per_node 8 \ 156 | # --master_port 9134 \ 157 | # ./run.py \ 158 | # --learning_rate 1e-5 \ 159 | # --checkpointing true \ 160 | # --config ./config/vast/finetune_cfg/retrieval-mscoco.json \ 161 | # --first_eval true \ 162 | # --save_best true \ 163 | # --pretrain_dir $output_dir \ 164 | # --output_dir $output_dir/downstream/retrieval-mscoco \ 165 | # --vision_resolution 384 \ 166 | 167 | 168 | # #retrieval-flickr 169 | # python3 -m torch.distributed.launch \ 170 | # --nproc_per_node 8 \ 171 | # --master_port 9134 \ 172 | # ./run.py \ 173 | # --learning_rate 1e-5 \ 174 | # --checkpointing true \ 175 | # --config ./config/vast/finetune_cfg/retrieval-flickr.json \ 176 | # --first_eval true \ 177 | # --save_best true \ 178 | # --pretrain_dir $output_dir \ 179 | # --output_dir $output_dir/downstream/retrieval-flickr \ 180 | # --vision_resolution 384 \ 181 | 182 | -------------------------------------------------------------------------------- /scripts/vast/pretrain_vast.sh: -------------------------------------------------------------------------------- 1 | config_name='pretrain_vast' 2 | output_dir=./output/vast/$config_name 3 | 4 | python3 -m torch.distributed.launch \ 5 | --nnodes 1 \ 6 | --node_rank 0 \ 7 | --nproc_per_node 8 \ 8 | --master_port 9834 \ 9 | ./run.py \ 10 | --config ./config/vast/pretrain/$config_name.json \ 11 | --output_dir $output_dir \ 12 | --checkpointing true 13 | 14 | -------------------------------------------------------------------------------- /scripts/vast/vision_captioner.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch \ 2 | --nnodes 1 \ 3 | --node_rank 0 \ 4 | --nproc_per_node 8 \ 5 | --master_port 9814 \ 6 | ./run.py \ 7 | --config ./config/vast/captioner_cfg/caption-generation-vision.json \ 8 | --pretrain_dir './output/vast/vision_captioner' \ 9 | --output_dir './output/vast/vision_captioner/generation' \ 10 | --test_batch_size 64 \ 11 | --test_vision_sample_num 8 \ 12 | --generate_nums 3 \ 13 | --captioner_mode true \ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TXH-mercury/VAST/410ca47acf40d4ab098e345b76159df66bc42239/utils/__init__.py -------------------------------------------------------------------------------- /utils/build_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.utils.data.distributed import DistributedSampler 3 | from torch.utils.data import DataLoader 4 | from data.loader import MetaLoader, PrefetchLoader 5 | from data import data_registry 6 | from utils.distributed import DistributedSampler_wopadding 7 | from .logger import LOGGER 8 | 9 | 10 | 11 | def create_train_dataloaders(args): 12 | data_cfg = args.data_cfg.train 13 | dataloaders = [] 14 | dataloaders_dict={} 15 | train_steps = [] 16 | loader_names = [] 17 | 18 | if len(data_cfg) == 0: 19 | return 20 | 21 | for d_cfg in data_cfg: 22 | 23 | dataset_ls = [] 24 | 25 | name = d_cfg['name'] 26 | dataset = data_registry[d_cfg.type](d_cfg, args) 27 | 28 | # print(dataset[0]) 29 | 30 | collate_fn = dataset.collate_fn 31 | worker_init_fn = dataset.worker_init_fn 32 | use_sampler = dataset.use_sampler 33 | 34 | 35 | LOGGER.info("Create Dataset {} Success".format(name)) 36 | task = d_cfg['task'] 37 | batch_size = d_cfg['batch_size'] 38 | n_workers = d_cfg['n_workers'] 39 | 40 | if 'steps' in d_cfg: 41 | train_steps.append(d_cfg['steps']) 42 | elif 'epoch' in d_cfg: 43 | epoch = d_cfg['epoch'] 44 | train_steps.append(int((len(dataset) // batch_size) * epoch)) 45 | 46 | loader = build_dataloader(dataset, collate_fn, True, batch_size // args.run_cfg.gradient_accumulation_steps , n_workers, worker_init_fn, use_sampler) 47 | 48 | dataloaders.append(loader) 49 | loader_names.append(f'{task}--{name}') 50 | 51 | 52 | for i in range(len(dataloaders)): 53 | ratio = train_steps[i] 54 | dataloaders_dict[loader_names[i]] = (dataloaders[i], ratio) 55 | 56 | n_gpu = dist.get_world_size() 57 | for name, (loader, ratio) in dataloaders_dict.items(): 58 | # epoch = (ratio * loader.batch_size * n_gpu ) // len(loader.dataset) 59 | LOGGER.info(f" loader {name} , ratio {ratio} , bs_pergpu {loader.batch_size}, n_workers {loader.num_workers}" ) 60 | 61 | 62 | meta_loader = MetaLoader(dataloaders_dict, 63 | accum_steps=args.run_cfg.gradient_accumulation_steps, 64 | distributed=n_gpu > 1) 65 | 66 | if args.run_cfg.num_train_steps == 0: 67 | total_train_steps = sum(train_steps) 68 | args.run_cfg.num_train_steps = total_train_steps 69 | 70 | 71 | 72 | meta_loader = PrefetchLoader(meta_loader) 73 | meta_loader.ndata = len(dataloaders_dict) 74 | args.run_cfg.valid_steps = args.run_cfg.num_train_steps // args.run_cfg.valid_freq -1 75 | 76 | 77 | 78 | return meta_loader 79 | 80 | 81 | def create_val_dataloaders(args): 82 | data_cfg = args.data_cfg.val 83 | dataloaders = {} 84 | for d_cfg in data_cfg: 85 | name = d_cfg['name'] 86 | dataset = data_registry[d_cfg.type](d_cfg, args) 87 | collate_fn = dataset.collate_fn 88 | worker_init_fn = dataset.worker_init_fn 89 | use_sampler = dataset.use_sampler 90 | # task = d_cfg['task'].split('_') 91 | # if 'qa' in task: 92 | # dataset.make_submission = d_cfg.get('make_submission', False) 93 | 94 | # if 'cap' in task: 95 | # dataset.annfile = d_cfg['annfile'] 96 | 97 | # dataset.data_type = data_type 98 | dataset.name = name 99 | LOGGER.info("Create Dataset {} Success".format(name)) 100 | task = d_cfg['task'] 101 | batch_size = d_cfg['batch_size'] 102 | n_workers = d_cfg['n_workers'] 103 | loader = build_dataloader(dataset, collate_fn, False, batch_size, n_workers, worker_init_fn, use_sampler) 104 | task_name = f'{task}--{name}' 105 | dataloaders[task_name] = PrefetchLoader(loader) 106 | return dataloaders 107 | 108 | 109 | def build_dataloader(dataset, collate_fn, is_train, batch_size, n_workers=None, worker_init_fn=None, use_sampler=True): 110 | batch_size = batch_size // dist.get_world_size() 111 | if use_sampler: 112 | if is_train: 113 | sampler = DistributedSampler(dataset) 114 | else: 115 | sampler = DistributedSampler_wopadding(dataset) 116 | loader = DataLoader(dataset, sampler = sampler, batch_size = batch_size, 117 | num_workers=n_workers, pin_memory=True, 118 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn) 119 | else: 120 | 121 | loader = DataLoader(dataset, batch_size = batch_size, 122 | num_workers=n_workers, pin_memory=True, 123 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn) 124 | 125 | return loader 126 | 127 | -------------------------------------------------------------------------------- /utils/build_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from model import model_registry 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from .logger import LOGGER 8 | from .build_optimizer import build_optimizer 9 | 10 | 11 | class DDP_modify(DDP): 12 | def __getattr__(self, name): 13 | try: 14 | return super().__getattr__(name) 15 | except: 16 | return getattr(self.module,name) 17 | 18 | 19 | def build_model(args): 20 | 21 | model = model_registry[args.model_cfg.model_type](args.model_cfg) 22 | checkpoint = {} 23 | 24 | ### load ckpt from a pretrained_dir 25 | if args.run_cfg.pretrain_dir: 26 | checkpoint = load_from_pretrained_dir(args) 27 | LOGGER.info("Load from pretrained dir {}".format(args.run_cfg.pretrain_dir)) 28 | 29 | ### load ckpt from specific path 30 | if args.run_cfg.checkpoint: 31 | checkpoint = torch.load(args.run_cfg.checkpoint, map_location = 'cpu') 32 | 33 | ### resume training 34 | if args.run_cfg.resume: 35 | checkpoint, checkpoint_optim, start_step = load_from_resume(args.run_cfg) 36 | else: 37 | checkpoint_optim, start_step = None , 0 38 | 39 | 40 | checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 41 | 42 | if checkpoint != {}: 43 | 44 | checkpoint = model.modify_checkpoint(checkpoint) 45 | if "model" in checkpoint.keys(): 46 | checkpoint = checkpoint["model"] 47 | 48 | missing_keys,unexpected_keys = model.load_state_dict(checkpoint,strict=False) 49 | LOGGER.info(f"Unexpected keys {unexpected_keys}") 50 | LOGGER.info(f"missing_keys {missing_keys}") 51 | 52 | 53 | local_rank = args.local_rank 54 | device = torch.device("cuda", local_rank) 55 | model.to(device) 56 | if args.run_cfg.use_ddp: 57 | model = DDP_modify(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 58 | else: 59 | pass 60 | 61 | return model, checkpoint_optim, start_step 62 | 63 | 64 | 65 | def load_from_pretrained_dir(args): 66 | 67 | 68 | try: ### huggingface trainer 69 | checkpoint_dir = args.run_cfg.pretrain_dir 70 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('checkpoint')] 71 | checkpoint_ls = [int(i.split('-')[1]) for i in checkpoint_ls] 72 | checkpoint_ls.sort() 73 | step = checkpoint_ls[-1] 74 | 75 | try: 76 | checkpoint_name = f'checkpoint-{step}/pytorch_model.bin' 77 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name) 78 | checkpoint = torch.load(ckpt_file, map_location = 'cpu') 79 | except: 80 | checkpoint_name1 = f'checkpoint-{step}/pytorch_model-00001-of-00002.bin' 81 | ckpt_file1 = torch.load(os.path.join(checkpoint_dir, checkpoint_name1), map_location = 'cpu') 82 | checkpoint_name2 = f'checkpoint-{step}/pytorch_model-00002-of-00002.bin' 83 | ckpt_file2 = torch.load(os.path.join(checkpoint_dir, checkpoint_name2), map_location = 'cpu') 84 | ckpt_file1.update(ckpt_file2) 85 | checkpoint = ckpt_file1 86 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 87 | LOGGER.info(f'load_from_pretrained: {ckpt_file}') 88 | 89 | except: 90 | checkpoint_dir = os.path.join(args.run_cfg.pretrain_dir,'ckpt') 91 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('model_step')] 92 | checkpoint_ls = [int(i.split('_')[2].split('.')[0]) for i in checkpoint_ls] 93 | checkpoint_ls.sort() 94 | step = checkpoint_ls[-1] 95 | 96 | checkpoint_name = 'model_step_'+str(step)+'.pt' 97 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name) 98 | checkpoint = torch.load(ckpt_file, map_location = 'cpu') 99 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 100 | LOGGER.info(f'load_from_pretrained: {ckpt_file}') 101 | 102 | 103 | return checkpoint 104 | 105 | 106 | def load_from_resume(run_cfg): 107 | ckpt_dir = os.path.join(run_cfg.output_dir,'ckpt') 108 | previous_optimizer_state = [i for i in os.listdir(ckpt_dir) if i.startswith('optimizer')] 109 | steps = [i.split('.pt')[0].split('_')[-1] for i in previous_optimizer_state] 110 | steps = [ int(i) for i in steps] 111 | steps.sort() 112 | previous_step = steps[-1] 113 | previous_optimizer_state = f'optimizer_step_{previous_step}.pt' 114 | previous_model_state = f'model_step_{previous_step}.pt' 115 | previous_step = int(previous_model_state.split('.')[0].split('_')[-1]) 116 | previous_optimizer_state = os.path.join(ckpt_dir, previous_optimizer_state) 117 | previous_model_state = os.path.join(ckpt_dir, previous_model_state) 118 | 119 | assert os.path.exists(previous_optimizer_state) and os.path.exists(previous_model_state) 120 | LOGGER.info("choose previous model: {}".format(previous_model_state)) 121 | LOGGER.info("choose previous optimizer: {}".format(previous_optimizer_state)) 122 | previous_model_state = torch.load(previous_model_state,map_location='cpu') 123 | previous_optimizer_state = torch.load(previous_optimizer_state,map_location='cpu') 124 | return previous_model_state, previous_optimizer_state, previous_step 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import torch 4 | import torch.distributed as dist 5 | from torch.autograd import Function 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | 9 | 10 | 11 | 12 | class GatherLayer(torch.autograd.Function): 13 | """ 14 | Gather tensors from all workers with support for backward propagation: 15 | This implementation does not cut the gradients as torch.distributed.all_gather does. 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, x): 20 | output = [ 21 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 22 | ] 23 | torch.distributed.all_gather(output, x) 24 | return tuple(output) 25 | 26 | @staticmethod 27 | def backward(ctx, *grads): 28 | all_gradients = torch.stack(grads) 29 | torch.distributed.all_reduce(all_gradients) 30 | return all_gradients[torch.distributed.get_rank()] 31 | 32 | 33 | def all_gather_with_grad(tensors): 34 | """ 35 | Performs all_gather operation on the provided tensors. 36 | Graph remains connected for backward grad computation. 37 | """ 38 | # Queue the gathered tensors 39 | world_size = torch.distributed.get_world_size() 40 | # There is no need for reduction in the single-proc case 41 | if world_size == 1: 42 | return tensors 43 | 44 | # tensor_all = GatherLayer.apply(tensors) 45 | tensor_all = GatherLayer.apply(tensors) 46 | 47 | return torch.cat(tensor_all, dim=0) 48 | 49 | 50 | @torch.no_grad() 51 | def concat_all_gather(tensor): 52 | """ 53 | Performs all_gather operation on the provided tensors. 54 | *** Warning ***: torch.distributed.all_gather has no gradient. 55 | """ 56 | # if use distributed training 57 | # if not is_dist_avail_and_initialized(): 58 | # return tensor 59 | 60 | tensors_gather = [ 61 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 62 | ] 63 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 64 | 65 | output = torch.cat(tensors_gather, dim=0) 66 | return output 67 | 68 | 69 | 70 | def _encode(enc, max_size, use_max_size=False): 71 | enc_size = len(enc) 72 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 73 | if use_max_size: 74 | # this is used for broadcasting 75 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) 76 | else: 77 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 78 | remainder = enc_size 79 | for i in range(enc_byte): 80 | base = 256 ** (enc_byte-i-1) 81 | buffer_[i] = remainder // base 82 | remainder %= base 83 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 84 | return buffer_, enc_byte 85 | 86 | 87 | def _decode(buffer_, enc_byte): 88 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 89 | for i in range(enc_byte)) 90 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 91 | shift = size + enc_byte 92 | return bytes_list, shift 93 | 94 | 95 | 96 | 97 | 98 | def all_gather_list(data): 99 | """Gathers arbitrary data from all nodes into a list.""" 100 | enc = pickle.dumps(data) 101 | 102 | enc_size = len(enc) 103 | max_size = ddp_allgather(torch.tensor([enc_size]).cuda()).max().item() 104 | in_buffer, enc_byte = _encode(enc, max_size) 105 | 106 | out_buffer = ddp_allgather(in_buffer[:enc_byte+enc_size]) 107 | 108 | results = [] 109 | for _ in range(dist.get_world_size()): 110 | bytes_list, shift = _decode(out_buffer, enc_byte) 111 | out_buffer = out_buffer[shift:] 112 | result = pickle.loads(bytes_list) 113 | results.append(result) 114 | return results 115 | 116 | 117 | def any_broadcast(data, root_rank): 118 | """broadcast arbitrary data from root_rank to all nodes.""" 119 | enc = pickle.dumps(data) 120 | 121 | max_size = ddp_allgather(torch.tensor([len(enc)]).cuda()).max().item() 122 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) 123 | 124 | dist.broadcast(buffer_, root_rank) 125 | 126 | bytes_list, _ = _decode(buffer_, enc_byte) 127 | result = pickle.loads(bytes_list) 128 | return result 129 | 130 | 131 | 132 | ###### with different batch_size ~ 133 | def ddp_allgather(input): 134 | tmp_input = input.cuda() 135 | size = torch.tensor(tmp_input.shape[0]).cuda() 136 | size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())] 137 | dist.all_gather(size_list, size) 138 | max_size = max(size_list).item() 139 | padding_size = max_size - size 140 | if padding_size > 0 : 141 | padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input) 142 | tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0) 143 | tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())] 144 | dist.all_gather(tmp_list, tmp_input) 145 | output = [] 146 | for t, s in zip(tmp_list, size_list): 147 | output.append(t[:s]) 148 | output = torch.cat(output,dim=0) 149 | return output 150 | 151 | 152 | 153 | class DistributedSampler_wopadding(DistributedSampler): 154 | 155 | def __iter__(self): 156 | if self.shuffle: 157 | # deterministically shuffle based on epoch and seed 158 | g = torch.Generator() 159 | g.manual_seed(self.seed + self.epoch) 160 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 161 | else: 162 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 163 | 164 | # if not self.drop_last: 165 | # # add extra samples to make it evenly divisible 166 | # padding_size = self.total_size - len(indices) 167 | # if padding_size <= len(indices): 168 | # indices += indices[:padding_size] 169 | # else: 170 | # indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 171 | # else: 172 | # remove tail of data to make it evenly divisible. 173 | if self.drop_last: 174 | indices = indices[:self.total_size] 175 | #assert len(indices) == self.total_size 176 | 177 | # subsample 178 | indices = indices[self.rank:len(indices):self.num_replicas] 179 | # assert len(indices) == self.num_samples 180 | 181 | return iter(indices) 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /utils/initialize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import torch.distributed as dist 6 | from .logger import LOGGER,add_log_to_file 7 | 8 | def initialize(opts): 9 | 10 | # if not os.path.exists(opts.run_cfg.output_dir): 11 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'log'), exist_ok=True) 12 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'ckpt'), exist_ok=True) 13 | 14 | local_rank = opts.local_rank 15 | torch.cuda.set_device(local_rank) 16 | dist.init_process_group(backend='nccl') 17 | if opts.run_cfg.gradient_accumulation_steps < 1: 18 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " 19 | "should be >= 1".format( 20 | opts.run_cfg.gradient_accumulation_steps)) 21 | set_random_seed(opts.run_cfg.seed) 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.enabled = True 24 | if dist.get_rank() == 0: 25 | # TB_LOGGER.create(os.path.join(opts.output_dir, 'log')) 26 | add_log_to_file(os.path.join(opts.run_cfg.output_dir, 'log', 'log.txt')) 27 | else: 28 | LOGGER.disabled = True 29 | 30 | 31 | def set_random_seed(seed): 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | 37 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | 5 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 6 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 7 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 8 | LOGGER = logging.getLogger('__main__') # this is the global logger 9 | 10 | 11 | def add_log_to_file(log_path): 12 | fh = logging.FileHandler(log_path) 13 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 14 | fh.setFormatter(formatter) 15 | LOGGER.addHandler(fh) 16 | 17 | 18 | class RunningMeter(object): 19 | """ running meteor of a scalar value 20 | (useful for monitoring training loss) 21 | """ 22 | def __init__(self, name=None, val=None, smooth=0.99): 23 | self._name = name 24 | self._sm = smooth 25 | self._val = val 26 | 27 | def __call__(self, value): 28 | val = (value if self._val is None 29 | else value*(1-self._sm) + self._val*self._sm) 30 | if not math.isnan(val): 31 | self._val = val 32 | 33 | def __str__(self): 34 | return f'{self._name}: {self._val:.4f}' 35 | 36 | @property 37 | def val(self): 38 | if self._val is None: 39 | return 0 40 | return self._val 41 | 42 | @property 43 | def name(self): 44 | return self._name 45 | 46 | -------------------------------------------------------------------------------- /utils/offline_process_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import ffmpeg 4 | import subprocess 5 | import multiprocessing 6 | import numpy as np 7 | 8 | from multiprocessing import Pool 9 | 10 | 11 | input_path = '/public/chensihan/datasets/tgif/gifs_used' 12 | output_path = '/public/chensihan/datasets/tgif/' 13 | 14 | data_list = os.listdir(input_path) 15 | 16 | def execCmd(cmd): 17 | r = os.popen(cmd) 18 | text = r.read() 19 | r.close() 20 | return text 21 | 22 | def pipline(video_path, video_probe, output_dir, fps, sr, duration_target): 23 | video_name = os.path.basename(video_path) 24 | 25 | video_name = video_name.replace(".mp4", "") 26 | 27 | 28 | # extract video frames fps 29 | fps_frame_dir = os.path.join(output_dir, f"frames_fps{fps}", video_name) 30 | os.makedirs(fps_frame_dir, exist_ok=True) 31 | cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vf fps=fps={:.02f} -qscale:v 2 {}/frame_%04d.jpg".format( 32 | video_path, fps, fps_frame_dir) 33 | 34 | ## extract fixed number frames 35 | # fps_frame_dir = os.path.join(output_dir, f"frames_32", video_name) 36 | # os.makedirs(fps_frame_dir, exist_ok=True) 37 | # cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vframes 32 -qscale:v 2 {}/frame_%04d.jpg".format( 38 | # video_path, fps_frame_dir) 39 | 40 | 41 | # ## extract audios 42 | # sr_audio_dir = os.path.join(output_dir,f"audios") 43 | # os.makedirs(sr_audio_dir, exist_ok=True) 44 | # # print(sr_audio_dir) 45 | # audio_name = video_name+'.wav' 46 | # audio_file_path = os.path.join(sr_audio_dir, audio_name) 47 | 48 | 49 | cmd = "ffmpeg -i {} -loglevel error -f wav -vn -ac 1 -ab 16k -ar {} -y {}".format( 50 | video_path, sr, audio_file_path) 51 | 52 | 53 | subprocess.call(cmd, shell=True) 54 | 55 | 56 | 57 | def extract_thread(video_id): 58 | 59 | video_name = os.path.join(input_path, video_id) 60 | 61 | if not os.path.exists(video_name): 62 | 63 | return 64 | try: 65 | # print(1) 66 | probe = ffmpeg.probe(video_name) 67 | # print(1) 68 | pipline(video_name, probe, output_path, fps=1, sr=22050, duration_target=10) 69 | except Exception as e: 70 | print(e) 71 | return 72 | 73 | 74 | def extract_all(video_ids, thread_num, start): 75 | length = len(video_ids) 76 | print(length) 77 | with Pool(thread_num) as p: 78 | list(tqdm.tqdm(p.imap(extract_thread, video_ids), total=length)) 79 | 80 | if __name__=='__main__': 81 | thread_num = 20 82 | start = 0 83 | 84 | print(len(data_list)) 85 | extract_all(data_list, thread_num, start) 86 | 87 | -------------------------------------------------------------------------------- /utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | # from apex import amp 7 | from collections import defaultdict 8 | from torch.nn.utils import clip_grad_norm_ 9 | from evaluation import evaluation_registry 10 | from .save import ModelSaver 11 | from .tool import NoOp 12 | from .logger import LOGGER, RunningMeter 13 | from .sched import get_lr_sched 14 | from torch.cuda.amp import autocast, GradScaler 15 | 16 | 17 | def train(model, optimizer, train_loader, val_loaders, run_cfg, start_step=0, verbose_time=False): 18 | 19 | if dist.get_rank() == 0: 20 | pbar = tqdm(total=run_cfg.num_train_steps, initial=start_step) 21 | model_saver = ModelSaver(os.path.join(run_cfg.output_dir, 'ckpt'),remove_before_ckpt=run_cfg.remove_before_ckpt) 22 | else: 23 | pbar = NoOp() 24 | model_saver = NoOp() 25 | 26 | loss_moving_averagetors ={} 27 | metric_logger_dict = defaultdict(dict) 28 | global_step = start_step 29 | 30 | scaler = GradScaler() 31 | 32 | best_indicator = {} 33 | evaluate_fn = evaluation_registry[model.config.evaluation_type] 34 | 35 | for step, (name, batch) in enumerate(train_loader): 36 | 37 | ndata = train_loader.ndata 38 | task = name.split('--')[0] 39 | 40 | 41 | 42 | if run_cfg.fp16: 43 | with autocast(): 44 | loss_dict = model(batch, task=task, compute_loss=True) 45 | loss = sum(list(loss_dict.values())) 46 | loss_dict['total_loss'] = loss 47 | loss_dict = {k:v.item() for k,v in loss_dict.items()} 48 | 49 | else: 50 | loss_dict = model(batch, task=task, compute_loss=True) 51 | loss = sum(list(loss_dict.values())) 52 | loss_dict['total_loss'] = loss 53 | loss_dict = {k:v.item() for k,v in loss_dict.items()} 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | if not name in loss_moving_averagetors: 64 | ### first time initialize 65 | for k in loss_dict.keys(): 66 | loss_moving_averagetors[f'loss_{name}/{k}'] = RunningMeter() 67 | ####accumulate loss 68 | 69 | for k,v in loss_dict.items(): 70 | loss_moving_averagetors[f'loss_{name}/{k}'](v) 71 | 72 | 73 | global_step += 1 74 | # learning rate scheduling 75 | lr_ratio = get_lr_sched(global_step, run_cfg) 76 | 77 | for param_group in optimizer.param_groups: 78 | param_group['lr'] = param_group['init_lr'] * lr_ratio 79 | 80 | if global_step % 50 == 0: 81 | LOGGER.info({name : averagetor.val for name, averagetor in loss_moving_averagetors.items()}) 82 | 83 | # update model params 84 | 85 | 86 | if run_cfg.fp16: 87 | optimizer.zero_grad() 88 | scaler.scale(loss).backward() 89 | else: 90 | loss.backward() 91 | 92 | if not run_cfg.use_ddp: 93 | works = [] 94 | for p in model.parameters(): 95 | # to speed it up, you can also organize grads to larger buckets to make allreduce more efficient 96 | if p.grad is not None: 97 | works.append(dist.all_reduce(p.grad, async_op=True)) 98 | for work in works: 99 | work.wait() 100 | 101 | 102 | # if run_cfg.grad_norm != -1: 103 | # grad_norm = clip_grad_norm_(model.parameters(), run_cfg.grad_norm) 104 | 105 | if run_cfg.fp16: 106 | scaler.step(optimizer) 107 | scaler.update() 108 | else: 109 | optimizer.step() 110 | optimizer.zero_grad() 111 | pbar.update(1) 112 | 113 | 114 | 115 | if (global_step+1) % run_cfg.valid_steps == 0: 116 | eval_log = evaluate_fn(model, val_loaders, run_cfg, global_step) 117 | 118 | if dist.get_rank() == 0: 119 | for task_name, val_log in eval_log.items(): 120 | for eval_name, metric in val_log.items(): 121 | eval_name = task_name +'_' +eval_name 122 | metric_logger_dict[eval_name][str(global_step)] = metric 123 | LOGGER.info(f"====-evaluation--{eval_name}=====step {global_step}--===========\n") 124 | LOGGER.info(metric) 125 | best_name = get_best_name(eval_name, metric) 126 | if best_name is not None: 127 | if ('best_step' not in metric_logger_dict[eval_name]) or \ 128 | (metric[best_name] >= metric_logger_dict[eval_name]['best_value']): 129 | metric_logger_dict[eval_name]['best_step'] = global_step 130 | metric_logger_dict[eval_name]['best_value'] = metric[best_name] 131 | best_indicator[eval_name] = True 132 | else: 133 | best_indicator[eval_name] = False 134 | best_step = metric_logger_dict[eval_name]['best_step'] 135 | LOGGER.info(f"======evaluation--{eval_name}====history best step: {best_step}=======\n") 136 | LOGGER.info(metric_logger_dict[eval_name][str(best_step)]) 137 | 138 | model_saver.save(model, global_step, optimizer,best_indicator, run_cfg.save_best) 139 | 140 | 141 | if global_step >= run_cfg.num_train_steps: 142 | break 143 | pbar.close() 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | def test(model, test_loader, run_cfg): 153 | 154 | evaluate_fn = evaluation_registry[model.config.evaluation_type] 155 | eval_log = evaluate_fn(model, test_loader, run_cfg, global_step=0) 156 | if dist.get_rank()==0: 157 | for task_name, val_log in eval_log.items(): 158 | for eval_name, metric in val_log.items(): 159 | eval_name = task_name +'_' +eval_name 160 | # TB_LOGGER.log_scaler_dict({f"eval/{eval_name}/test_{k}": v 161 | # for k, v in metric.items() if not isinstance(v,str)}) 162 | LOGGER.info(f"==== evaluation--{eval_name}========\n") 163 | LOGGER.info(metric) 164 | 165 | 166 | 167 | 168 | def get_best_name(eval_name, metric): 169 | if eval_name.startswith('cap'): 170 | return 'CIDEr' 171 | elif eval_name.startswith('qa'): 172 | return 'accuracy' 173 | elif eval_name.startswith('ret'): 174 | if 'video_r1' in metric: 175 | return 'video_r1' 176 | elif eval_name.startswith('pt'): 177 | return None 178 | else: 179 | raise NotImplementedError 180 | 181 | -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from os.path import join 4 | from utils.logger import LOGGER 5 | 6 | 7 | 8 | 9 | class ModelSaver(object): 10 | def __init__(self, output_dir, prefix='model_step', suffix='pt',remove_before_ckpt=True): 11 | self.output_dir = output_dir 12 | self.prefix = prefix 13 | self.suffix = suffix 14 | self.remove_before_ckpt = remove_before_ckpt 15 | def save(self, model, step, optimizer=None, best_indicator=None, save_best=False): 16 | ###remove previous model 17 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('model')] 18 | # if not self.pretraining: 19 | if self.remove_before_ckpt: 20 | for p in previous_state: 21 | os.remove(os.path.join(self.output_dir,p)) 22 | output_model_file = join(self.output_dir, 23 | f"{self.prefix}_{step}.{self.suffix}") 24 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 25 | for k, v in model.state_dict().items()} 26 | torch.save(state_dict, output_model_file) 27 | 28 | if save_best: 29 | for k in best_indicator: 30 | if best_indicator[k]: 31 | torch.save(state_dict, join(self.output_dir, 32 | f"best_{k}.{self.suffix}")) 33 | 34 | if optimizer is not None: 35 | if hasattr(optimizer, '_amp_stash'): 36 | pass # TODO fp16 optimizer 37 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('optimizer')] 38 | if self.remove_before_ckpt: 39 | for p in previous_state: 40 | os.remove(os.path.join(self.output_dir,p)) 41 | torch.save(optimizer.state_dict(), f'{self.output_dir}/optimizer_step_{step}.pt') 42 | -------------------------------------------------------------------------------- /utils/sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def warmup_cosine(x, warmup_ratio): 4 | if x < warmup_ratio: 5 | return x/warmup_ratio 6 | return 0.5 * (1.0 + math.cos(math.pi * x)) 7 | 8 | def warmup_constant(x, warmup_ratio): 9 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 10 | Learning rate is 1. afterwards. """ 11 | if x < warmup_ratio: 12 | return x/warmup_ratio 13 | return 1.0 14 | 15 | def warmup_linear(x, warmup_ratio): 16 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 17 | After `t_total`-th training step, learning rate is zero. """ 18 | if x < warmup_ratio: 19 | return x/warmup_ratio 20 | return max((x-1.)/(warmup_ratio-1.), 0) 21 | 22 | scheduler_dict = {'warmup_linear' : warmup_linear, 23 | 'warmup_cosine' : warmup_cosine} 24 | 25 | def get_lr_sched(global_step, opts): 26 | warmup_ratio = opts.warmup_ratio 27 | current_ratio = global_step / opts.num_train_steps 28 | lr_ratio = scheduler_dict[opts.scheduler](current_ratio, warmup_ratio) 29 | return lr_ratio 30 | 31 | 32 | -------------------------------------------------------------------------------- /utils/tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class NoOp(object): 4 | """ useful for distributed training No-Ops """ 5 | def __getattr__(self, name): 6 | return self.noop 7 | 8 | def noop(self, *args, **kwargs): 9 | return 10 | 11 | 12 | 13 | 14 | def split(frame_name_lists, sample_num): 15 | if len(frame_name_lists) < sample_num: ###padding with the last frame 16 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists)) 17 | k, m = divmod(len(frame_name_lists), sample_num) 18 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))] 19 | 20 | 21 | 22 | --------------------------------------------------------------------------------