├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config.json ├── config_bert.json ├── config_mistral.json ├── instruction_data.py └── model.py ├── dataset ├── TimeIT │ ├── dense_video_captioning │ │ ├── anet │ │ │ ├── instruct_dvc_10.0k_anet.json │ │ │ ├── instruct_dvc_10.0k_anet_15asr.json │ │ │ ├── test.caption_coco_format.json │ │ │ ├── train.caption_coco_format.json │ │ │ └── val.caption_coco_format.json │ │ ├── dense_video_captioning_instructions.json │ │ ├── vitt │ │ │ ├── instruct_dvc_5.1k_vitt.json │ │ │ └── instruct_dvc_5.1k_vitt_15asr.json │ │ ├── youcook2 │ │ │ ├── instruct_dvc_1.2k_youcook2.json │ │ │ ├── test.caption_coco_format.json │ │ │ ├── train.caption_coco_format.json │ │ │ └── val.caption_coco_format.json │ │ ├── youcook2_new │ │ │ └── val.caption_coco_format.json │ │ └── youcook2_new_224_6 │ │ │ └── val.caption_coco_format.json │ ├── step_localization │ │ ├── coin │ │ │ ├── instruct_action_9.0k_coin.json │ │ │ └── instruct_action_9.0k_coin_15asr.json │ │ ├── hirest_step │ │ │ ├── annotations │ │ │ │ ├── all_data_test.json │ │ │ │ ├── all_data_test_negative_samples.json │ │ │ │ ├── all_data_train.json │ │ │ │ └── all_data_val.json │ │ │ └── instruct_action_0.5k_hirest.json │ │ └── step_localization_instructions.json │ ├── temporal_video_grounding │ │ ├── charades │ │ │ ├── charades_annotation │ │ │ │ ├── charades_sta_test.txt │ │ │ │ ├── charades_sta_train.txt │ │ │ │ ├── get_coco_format.py │ │ │ │ ├── test.caption_coco_format.json │ │ │ │ └── train.caption_coco_format.json │ │ │ └── instruct_tvg_12.4k_charades.json │ │ ├── didemo │ │ │ ├── instruct_tvg_33.0k_didemo.json │ │ │ └── instruct_tvg_33.0k_didemo_15asr.json │ │ ├── hirest │ │ │ └── instruct_tvg_0.5k_hirest.json │ │ ├── queryd │ │ │ ├── instruct_tvg_14.6k_queryd.json │ │ │ └── instruct_tvg_14.6k_queryd_15asr.json │ │ └── temporal_video_grounding_instructions.json │ ├── time │ │ ├── instruct_time-sensitive_104k.json │ │ └── instruct_time-sensitive_104k_asr.json │ ├── transcribed_speech_generation │ │ ├── transcribed_speech_generation_instructions.json │ │ └── yttemporal │ │ │ ├── instruct_tsg_31.6k_yttemporal.json │ │ │ └── instruct_tsg_31.6k_yttemporal_15asr.json │ ├── valley │ │ ├── Valley_instruct_73k.json │ │ └── instruct_valley_72k.json │ ├── video_highlight_detection │ │ ├── qvhighlights │ │ │ ├── annotations_raw │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── highlight_test_release.jsonl │ │ │ │ ├── highlight_train_release.jsonl │ │ │ │ ├── highlight_val_release.jsonl │ │ │ │ ├── subs_train.jsonl │ │ │ │ └── val.caption_coco_format.json │ │ │ ├── get_coco_format.py │ │ │ ├── instruct_vhd_6.9k_qvhighlights.json │ │ │ ├── train.caption_coco_format.json │ │ │ └── val.caption_coco_format.json │ │ └── video_highlight_detection_instructions.json │ └── video_summarization │ │ ├── summe │ │ ├── instruct_vhd_25_summe.json │ │ └── instruct_vhd_25_summe_15asr.json │ │ ├── tvsum │ │ ├── instruct_vhd_50_tvsum.json │ │ └── instruct_vhd_50_tvsum_15asr.json │ │ └── video_summarization_instructions.json ├── __init__.py ├── base_dataset.py ├── dataloader.py ├── it_dataset.py ├── it_dataset_mistral.py ├── pt_dataset.py ├── sampler.py ├── utils.py ├── video_transforms.py └── video_utils.py ├── demo ├── demo.ipynb └── example │ ├── bear.jpg │ ├── dog.png │ ├── jesse_dance.mp4 │ ├── people.jpg │ └── yoga.mp4 ├── download └── folder_keeper ├── eval ├── Egoschema_trans_csv.py ├── eval_egoschema.sh ├── eval_infer.py ├── eval_qa_tasks.ipynb ├── format_dvc.py ├── format_eval.py ├── format_tvg.py ├── format_vhd.py ├── get_grounding_result.sh ├── test_grounding.sh └── validate_egoschema.py ├── images ├── 123 ├── abstract.png ├── data.png └── structure.png ├── metrics ├── README.md ├── dvc │ ├── eval_dvc.py │ ├── eval_dvc.sh │ ├── example_gt_file.json │ ├── example_pred_file.json │ └── metrics │ │ ├── README.md │ │ ├── cider.py │ │ ├── cider_scorer.py │ │ ├── meteor.py │ │ └── ptbtokenizer.py ├── tvg │ ├── cd │ ├── eval_tvg.py │ ├── eval_tvg.sh │ ├── example_gt_file.json │ └── example_pred_file.json └── vhd │ ├── cd │ ├── eval_highlights.sh │ ├── eval_vhd.py │ ├── example_gt_file.json │ ├── example_pred_file.json │ ├── metrics.json │ └── utils.py ├── models ├── __init__.py ├── bert │ ├── __init__.py │ ├── builder.py │ ├── tokenization_bert.py │ └── xbert.py ├── blip2 │ ├── Qformer.py │ ├── __init__.py │ ├── blip2.py │ ├── builder.py │ ├── modeling_llama.py │ ├── modeling_llama_mem.py │ ├── utils.py │ └── vit.py ├── criterions.py ├── utils.py ├── videochat2_qformer.py └── videochat_mistra │ ├── __init__.py │ ├── videochat2_it4_mistral_LinearP.py │ └── videochat2_it4_mistral_LinearProAda.py ├── prompts ├── alignment_image.txt ├── concise_description.txt ├── concise_image_description.txt ├── dvc_description.txt ├── dvc_description_with_asr.txt ├── dvc_description_zeroshot.txt ├── dvc_post_check.txt ├── tvg_description.txt ├── tvg_description_zeroshot.txt ├── tvg_post_check.txt ├── vhd_description.txt ├── vhd_description_zeroshot.txt ├── vhd_description_zeroshot_new.txt ├── vhd_description_zeroshot_post.txt └── vhd_post_check.txt ├── requirements.txt ├── scripts └── videochat_mistral │ ├── config_LinearP.py │ ├── config_LinearProAda.py │ ├── config_LinearProAdaFT.py │ ├── run_7b_stage4.sh │ └── run_7b_stage4_ds.sh ├── tasks ├── retrieval_utils.py ├── shared_utils.py ├── shared_utils_ds.py ├── shared_utils_qformer.py ├── train_it.py ├── train_it4.py ├── train_it4_bug.py ├── train_it4_ds.py ├── train_it4_ds_new.py ├── train_it_ds.py ├── train_pt.py ├── train_pt_ds.py └── train_qformer.py ├── torchrun.sh └── utils ├── basic_utils.py ├── config.py ├── config_utils.py ├── distributed.py ├── easydict.py ├── logger.py ├── optimizer.py ├── peft.py ├── quant.py ├── quantization.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # local # 2 | tmp*/ 3 | cache/* 4 | */cache*/ 5 | tmp*.py 6 | tmp* 7 | *pickle 8 | data/ 9 | 10 | # Zip Files/Packages # 11 | *.7z 12 | *.dmg 13 | *.gz 14 | *.iso 15 | *.jar 16 | *.rar 17 | *.tar 18 | *.zip 19 | 20 | # Logs and databases # 21 | *.log 22 | *.sql 23 | *.sqlite 24 | .ipynb_checkpoints/ 25 | *.swp 26 | *.vscode/ 27 | *.idea/ 28 | *.pyc 29 | __pycache__ 30 | slurm*out 31 | 32 | # OS files # 33 | .DS_Store 34 | .DS_Store? 35 | ._* 36 | .Spotlight-V100 37 | .Trashes 38 | ehthumbs.db 39 | Thumbs.db 40 | 41 | 42 | .vim-arsync 43 | scratch.norg 44 | sync_to_red.sh 45 | 46 | anno/ 47 | wandb/ 48 | logs/ 49 | *.pth 50 | 51 | # personal 52 | test.ipynb 53 | 54 | jupyter/ 55 | 56 | phoenix-slurm* 57 | batchscript-* 58 | 59 | debug* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 OpenGVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

[ICLR 2025] TimeSuite: Improving MLLMs for Long Video Understanding via Grounded Tuning

4 | 5 | [Xiangyu Zeng](https://scholar.google.com/citations?user=jS13DXkAAAAJ&hl=zh-CN), [Kunchang Li](https://scholar.google.com/citations?user=D4tLSbsAAAAJ), [Chenting Wang](https://scholar.google.com/citations?user=f81ulHQAAAAJ&hl=zh-CN), [Xinhao Li](https://scholar.google.com/citations?user=evR3uR0AAAAJ&hl=zh-CN), Tianxiang Jiang, [Ziang Yan](https://scholar.google.com/citations?user=78lx13MAAAAJ&hl=zh-CN), [Songze Li](https://scholar.google.com/citations?user=8rBMUD4AAAAJ&hl=zh-CN), Yansong Shi, Zhengrong Yue, [Yi Wang](https://scholar.google.com.hk/citations?hl=zh-CN&user=Xm2M8UwAAAAJ), [Yali Wang](https://scholar.google.com/citations?user=hD948dkAAAAJ), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl), and [Limin Wang](https://scholar.google.com/citations?user=HEuN8PcAAAAJ) 6 | 7 |
8 | 9 | ## :parrot: Introduction 10 | 11 | This paper proposes TimeSuite, a collection of new designs to adapt the existing short-form video MLLMs for long video understanding, including a simple yet efficient framework to process long video sequence, a high-quality video dataset for grounded tuning of MLLMs, and a carefully-designed instruction tuning task to explicitly incorporate the grounding supervision in the traditional QA format. 12 | 13 | **State-of-the-art performance**: VideoChat-T demonstrates high performance for both long-form video question answering and temporal grounding. 14 | ![alt text](images/abstract.png) 15 | 16 | **Highly efficient model architecture** with exceptional inference speed, encoding each video frame into just **3 tokens**, leading to the flops of our VideoChat-T are 5.1% of Llava-OneVision 17 | ![alt text](images/structure.png) 18 | 19 | **High-quality data** 20 | - We introduced the comprehensive dataset TimePro, which includes 9 task types with video sources from 15 different datasets. 21 | - We designed a novel Temporal Grounded Caption fine-tuning task to effectively mitigate hallucinations in MLLM. 22 | ![alt text](images/data.png) 23 | 24 | ## :fire: Updates 25 | 26 | - 2025.02.12 TimeSuite is open-sourced. We welcome everyone to try it out! 27 | - 2025.01.23 TimeSuite has been accepted by ICLR 2025. 28 | - 2024.10.25 The paper of TimeSuite has been uploaded to arXiv. 29 | 30 | ## Preparation 31 | 32 | - Create a new environment and run the command to install the necessary dependencies. 33 | 34 | ``` 35 | conda create --name TimeSuite 36 | conda activate TimeSuite 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | - Download the model and code of TimeSuite from [https://huggingface.co/Lanxingxuan/TimeSuite](https://huggingface.co/Lanxingxuan/TimeSuite) to the `./download` folder. (Please note that you need to additionally download [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) to `./download/parameters`) 41 | 42 | - Search for all instances of `/path_to_the_timesuite_root_folder` and replace them with the directory of the TimeSuite root folder. 43 | 44 | - Please search for all video dataset paths containing `s3://` and replace them with the corresponding video dataset paths on your server. 45 | 46 | ## Inference & Demo 47 | 48 | - Run `demo/demo.ipynb` to see the demo provided in the paper, or try out the videos and questions of your choice. 49 | 50 | - Run `eval/eval_qa_tasks.ipynb` to test the general QA performance of the model. 51 | 52 | - To test the temporal grounding capability of TimeSuite, please follow these two steps. 53 | 54 | ``` 55 | bash eval/test_grounding.sh 56 | bash eval/get_grounding_result.sh 57 | ``` 58 | 59 | ## Grounded Tuning 60 | 61 | - Please properly configure the video dataset path in `configs/instruction_data.py`. 62 | - Modify `scripts/videochat_mistral/config_LinearP.py` and `scripts/videochat_mistral/config_LinearProAda.py` to adjust the model training parameter settings. 63 | - Please run `bash scripts/videochat_mistral/run_7b_stage4.sh` to initiate the fine-tuning of the model. 64 | - To reproduce the fine-tuning results presented in the paper, you need to initiate the model training in a two-stage manner. For detailed parameter settings, please refer to Appendix D of the paper. 65 | 66 | 67 | ## TimePro Dataset 68 | 69 | ### Annotations 70 | 71 | - All data used for fine-tuning is now open-sourced. Please visit [https://huggingface.co/Lanxingxuan/TimeSuite/tree/main/datasets/TimePro](https://huggingface.co/Lanxingxuan/TimeSuite/tree/main/datasets/TimePro) to download. 72 | 73 | ### Videos 74 | 75 | **_TimePro_** 76 | - DiDeMo: [https://github.com/LisaAnne/LocalizingMoments?tab=readme-ov-file#dataset](https://github.com/LisaAnne/LocalizingMoments?tab=readme-ov-file#dataset) 77 | - QuerYD: [https://www.robots.ox.ac.uk/~vgg/data/queryd/](https://www.robots.ox.ac.uk/~vgg/data/queryd/) 78 | - HiREST: [https://github.com/j-min/HiREST](https://github.com/j-min/HiREST) 79 | - ActivityNet: [http://activity-net.org/download.html](http://activity-net.org/download.html) 80 | - ViTT: [https://github.com/google-research-datasets/Video-Timeline-Tags-ViTT](https://github.com/google-research-datasets/Video-Timeline-Tags-ViTT) 81 | - YouCook2: [http://youcook2.eecs.umich.edu/download](http://youcook2.eecs.umich.edu/download) 82 | - TVSum: [https://github.com/yalesong/tvsum](https://github.com/yalesong/tvsum) 83 | - SumMe: [http://classif.ai/dataset/ethz-cvl-video-summe/](http://classif.ai/dataset/ethz-cvl-video-summe/) 84 | - COIN: [https://github.com/coin-dataset/annotations](https://github.com/coin-dataset/annotations) 85 | - YT-Temporal: [https://rowanzellers.com/merlot/#data](https://rowanzellers.com/merlot/#data) 86 | - Internvid: [https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/README_CN.md](https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/README_CN.md) 87 | - HowTo100M(CosMo): [https://www.di.ens.fr/willow/research/howto100m/](https://www.di.ens.fr/willow/research/howto100m/) 88 | 89 | **_Normal_** 90 | - VideoChatGPT: [https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data) 91 | - VideoChat: [https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data) 92 | - EgoQA: [https://ego4d-data.org/](https://ego4d-data.org/) 93 | - STAR: [https://bobbywu.com/STAR/](https://bobbywu.com/STAR/) 94 | - MovieChat: [https://huggingface.co/datasets/Enxin/MovieChat-1K_train](https://huggingface.co/datasets/Enxin/MovieChat-1K_train) 95 | 96 | **_FT_** 97 | - Charades-STA: [https://github.com/jiyanggao/TALL#charades-sta-anno-download](https://github.com/jiyanggao/TALL#charades-sta-anno-download) 98 | - QVHighlight: [https://github.com/jayleicn/moment_detr/blob/main/data/README.md](https://github.com/jayleicn/moment_detr/blob/main/data/README.md) 99 | 100 | # :page_facing_up: Citation 101 | 102 | If you find this project useful in your research, please consider cite: 103 | ```BibTeX 104 | @misc{zeng2024timesuite, 105 | title={TimeSuite: Improving MLLMs for Long Video Understanding via Grounded Tuning}, 106 | author={Xiangyu Zeng and Kunchang Li and Chenting Wang and Xinhao Li and Tianxiang Jiang and Ziang Yan and Songze Li and Yansong Shi and Zhengrong Yue and Yi Wang and Yali Wang and Yu Qiao and Limin Wang}, 107 | year={2024}, 108 | eprint={2410.19702}, 109 | archivePrefix={arXiv}, 110 | primaryClass={cs.CV}, 111 | url={https://arxiv.org/abs/2410.19702}, 112 | } 113 | ``` 114 | 115 | # :dizzy: Acknowledgement 116 | 117 | Thanks to the open source of the following projects: 118 | 119 | - [UMT](https://github.com/OpenGVLab/unmasked_teacher) 120 | - [MVBench](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) 121 | - [TimeChat](https://github.com/RenShuhuai-Andy/TimeChat) 122 | - [LITA](https://github.com/NVlabs/LITA) 123 | -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "model_cls": "VideoChat2_it", 4 | "vit_blip_model_path": "your_model_path/umt_l16_qformer.pth", 5 | "llama_model_path": "your_model_path/vicuna-7b-v0", 6 | "videochat2_model_path": "your_model_path/videochat2_7b_stage2.pth", 7 | "freeze_vit": false, 8 | "freeze_qformer": false, 9 | "max_txt_len": 512, 10 | "low_resource": false, 11 | "vision_encoder": { 12 | "name": "vit_l14", 13 | "img_size": 224, 14 | "patch_size": 16, 15 | "d_model": 1024, 16 | "encoder_embed_dim": 1024, 17 | "encoder_depth": 24, 18 | "encoder_num_heads": 16, 19 | "drop_path_rate": 0.0, 20 | "num_frames": 8, 21 | "tubelet_size": 1, 22 | "use_checkpoint": false, 23 | "checkpoint_num": 0, 24 | "pretrained": "", 25 | "return_index": -2, 26 | "vit_add_ln": true, 27 | "ckpt_num_frame": 4 28 | }, 29 | "num_query_token": 32, 30 | "qformer_hidden_dropout_prob": 0.1, 31 | "qformer_attention_probs_dropout_prob": 0.1, 32 | "qformer_drop_path_rate": 0.2, 33 | "extra_num_query_token": 64, 34 | "qformer_text_input": true, 35 | "system": "", 36 | "start_token": "", 38 | "img_start_token": "", 39 | "img_end_token": "", 40 | "random_shuffle": true, 41 | "use_lora": false, 42 | "lora_r": 16, 43 | "lora_alpha": 32, 44 | "lora_dropout": 0.1 45 | }, 46 | "device": "cuda" 47 | } 48 | -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 9, 20 | "encoder_width": 768, 21 | "cross_module": "ca" 22 | } 23 | -------------------------------------------------------------------------------- /configs/config_mistral.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "model_cls": "VideoChat2_it_mistral", 4 | "vit_blip_model_path": "/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth", 5 | "mistral_model_path": "/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2", 6 | "videochat2_model_path": "/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage2.pth", 7 | "freeze_vit": false, 8 | "freeze_qformer": false, 9 | "max_txt_len": 512, 10 | "low_resource": false, 11 | "vision_encoder": { 12 | "name": "vit_l14", 13 | "img_size": 224, 14 | "patch_size": 16, 15 | "d_model": 1024, 16 | "encoder_embed_dim": 1024, 17 | "encoder_depth": 24, 18 | "encoder_num_heads": 16, 19 | "drop_path_rate": 0.0, 20 | "num_frames": 8, 21 | "tubelet_size": 1, 22 | "use_checkpoint": true, 23 | "checkpoint_num": 18, 24 | "pretrained": "", 25 | "return_index": -2, 26 | "vit_add_ln": true, 27 | "ckpt_num_frame": 4 28 | }, 29 | "num_query_token": 32, 30 | "qformer_hidden_dropout_prob": 0.1, 31 | "qformer_attention_probs_dropout_prob": 0.1, 32 | "qformer_drop_path_rate": 0.2, 33 | "extra_num_query_token": 64, 34 | "qformer_text_input": true, 35 | "system": "", 36 | "start_token": "", 38 | "add_second_msg": true, 39 | "img_start_token": "", 40 | "img_end_token": "", 41 | "random_shuffle": true, 42 | "return_question_instruction": false, 43 | "use_flash_attention": true, 44 | "use_lora": false, 45 | "lora_r": 16, 46 | "lora_alpha": 32, 47 | "lora_dropout": 0.1 48 | }, 49 | "device": "cuda" 50 | } 51 | -------------------------------------------------------------------------------- /configs/instruction_data.py: -------------------------------------------------------------------------------- 1 | import os as __os # add "__" if not want to be exported 2 | from copy import deepcopy as __deepcopy 3 | 4 | anno_root_it = "/path_to_the_timesuite_root_folder/download/datasets/TimePro" 5 | 6 | 7 | # ============== pretraining datasets================= 8 | available_corpus = dict( 9 | 10 | caption_youcook2=[ 11 | f"{anno_root_it}/caption_youcook2.json", 12 | "pnorm2:s3://youcook2/split_videos", 13 | "video" 14 | ], 15 | conversation_videochat1=[ 16 | f"{anno_root_it}/conversation_videochat1.json", 17 | "pnorm2:s3://webvid10m", 18 | "video" 19 | ], 20 | conversation_videochat2=[ 21 | f"{anno_root_it}/conversation_videochat2.json", 22 | "pnorm:s3://videointernsegvideos", 23 | "video" 24 | ], 25 | conversation_videochatgpt=[ 26 | f"{anno_root_it}/conversation_videochatgpt.json", 27 | "pnorm2:s3://anet/ANet_320p_fps30", 28 | "video" 29 | ], 30 | reasoning_star=[ 31 | f"{anno_root_it}/reasoning_star.json", 32 | "pnorm2:s3://star/Charades_v1_480", 33 | "video" 34 | ], 35 | vqa_ego_qa=[ 36 | f"{anno_root_it}/vqa_ego_qa.json", 37 | "pnorm2:s3://egoqa/split_videos", 38 | "video" 39 | ], 40 | 41 | 42 | 43 | 44 | # TimeIT 45 | timeit_ANet=[ 46 | f"{anno_root_it}/timeit_ANet.json", 47 | "pnorm2:s3://anet", 48 | "video" 49 | ], 50 | 51 | timeit_COIN=[ 52 | f"{anno_root_it}/timeit_COIN.json", 53 | "pnorm:s3://COIN_320p", 54 | "video" 55 | ], 56 | 57 | timeit_DiDeMo=[ 58 | f"{anno_root_it}/timeit_DiDeMo.json", 59 | "sssd:s3://yjsBucket", 60 | "video" 61 | ], 62 | 63 | timeit_HiREST=[ 64 | f"{anno_root_it}/timeit_HiREST.json", 65 | "pnorm2zxy:s3://hirest", 66 | "video" 67 | ], 68 | 69 | 70 | timeit_QuerYD=[ 71 | f"{anno_root_it}/timeit_QuerYD.json", 72 | "pnorm2zxy:s3://queryd", 73 | "video" 74 | ], 75 | 76 | timeit_TVSum=[ 77 | f"{anno_root_it}/timeit_TVSum.json", 78 | "pnorm2zxy:s3://tvsum", 79 | "video" 80 | ], 81 | 82 | timeit_ViTT=[ 83 | f"{anno_root_it}/timeit_ViTT.json", 84 | "sssd:s3://ViTT", 85 | "video" 86 | ], 87 | 88 | timeit_yttemporal180m=[ 89 | f"{anno_root_it}/timeit_yttemporal180m.json", 90 | "pnorm:s3://YT-Temporal-180M", 91 | "video" 92 | ], 93 | 94 | grounding_ANetRTL=[ 95 | f"{anno_root_it}/grounding_ANetRTL.json", 96 | "pnorm2:s3://anet/ANet_320p_fps30/train", 97 | "video" 98 | ], 99 | 100 | grounding_IntrenvidVTime_100K=[ 101 | f"{anno_root_it}/grounding_IntrenvidVTime_100K.json", 102 | "pnorm:s3://youtubeBucket/videos/", 103 | "video" 104 | ], 105 | grounding_ANetHL2=[ 106 | f"{anno_root_it}/grounding_ANetHL2.json", 107 | "pnorm2:s3://anet/ANet_320p_fps30/train", 108 | "video" 109 | ], 110 | 111 | grounding_CosmoCap_93K=[ 112 | f"{anno_root_it}/grounding_CosmoCap_93K.json", 113 | "pvideo:s3://howto100m/", 114 | "video" 115 | ], 116 | vqa_moviechat = [ 117 | f'{anno_root_it}/vqa_moviechat.json', 118 | 'pnorm2:s3://MovieChat/real_video/', 119 | 'video' 120 | ], 121 | caption_moviechat = [ 122 | f'{anno_root_it}/caption_moviechat.json', 123 | 'pnorm2:s3://MovieChat/real_video/', 124 | 'video' 125 | ], 126 | 127 | 128 | FT_Charades=[ 129 | f"{anno_root_it}/FT_Charades.json", 130 | "s3://zengxiangyu/Charades/", 131 | "video" 132 | ], 133 | 134 | FT_QVHighlights=[ 135 | f"{anno_root_it}/FT_QVHighlights.json", 136 | "s3://QVHighlight/videos/", 137 | "video" 138 | ], 139 | 140 | ) 141 | 142 | 143 | available_corpus["TimePro_Normal"] = [ #final dataset 144 | #TiIT 145 | available_corpus["timeit_ANet"], 146 | available_corpus["timeit_COIN"], 147 | available_corpus["timeit_DiDeMo"], 148 | available_corpus["timeit_HiREST"], 149 | available_corpus["timeit_QuerYD"], 150 | available_corpus["timeit_TVSum"], 151 | available_corpus["timeit_ViTT"], 152 | available_corpus["timeit_yttemporal180m"], 153 | #Conv 154 | available_corpus["conversation_videochatgpt"], 155 | available_corpus["conversation_videochat2"], 156 | available_corpus["conversation_videochat1"], 157 | #DvcVqa 158 | available_corpus["caption_youcook2"], 159 | available_corpus["vqa_ego_qa"], 160 | #Gro 161 | available_corpus["grounding_ANetRTL"], 162 | available_corpus["grounding_IntrenvidVTime_100K"], 163 | available_corpus["grounding_ANetHL2"], 164 | available_corpus["grounding_CosmoCap_93K"], 165 | available_corpus["vqa_moviechat"], 166 | available_corpus["caption_moviechat"], 167 | available_corpus["reasoning_star"], 168 | ] 169 | 170 | 171 | 172 | available_corpus["FT_Temporal_Grounding_Both"] = [ 173 | available_corpus["FT_Charades"], 174 | available_corpus["FT_QVHighlights"], 175 | available_corpus["grounding_ANetHL2"], 176 | available_corpus["caption_youcook2"], 177 | ] -------------------------------------------------------------------------------- /configs/model.py: -------------------------------------------------------------------------------- 1 | TextEncoders = dict() 2 | TextEncoders["bert"] = dict( 3 | name="bert_base", 4 | pretrained="bert-base-uncased", 5 | config="configs/config_bert.json", 6 | d_model=768, 7 | fusion_layer=9, 8 | ) -------------------------------------------------------------------------------- /dataset/TimeIT/dense_video_captioning/anet/instruct_dvc_10.0k_anet.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/36470518c0a555bbc7e7ae0b30393441ec533e03 -------------------------------------------------------------------------------- /dataset/TimeIT/dense_video_captioning/anet/instruct_dvc_10.0k_anet_15asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/2d6aad3236b910b1877aa8058dd0be19b3f333b7cefebd1f6c852880d13a6dc3 -------------------------------------------------------------------------------- /dataset/TimeIT/dense_video_captioning/dense_video_captioning_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. A specific example is: ' 90 - 102 seconds, spread margarine on two slices of white bread in the video'.", 3 | "1": "Determine the start and end times of various activity events in the video, accompanied by descriptions.", 4 | "2": "Capture and describe the activity events in the given video, specifying their respective time intervals, and outputting the time intervals in the 'start - end seconds format'.", 5 | "3": "Identify, timestamp, and describe various activity events occurring in the video. The timestamp should include the start time and end time in seconds.", 6 | "4": "Detect and report the start and end timestamps of activity events in the video, along with descriptions.", 7 | "5": "Pinpoint the time intervals of activity events in the video, and provide detailed descriptions for each event." 8 | } -------------------------------------------------------------------------------- /dataset/TimeIT/dense_video_captioning/vitt/instruct_dvc_5.1k_vitt_15asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/1dc9787ee6fa38f8c3223b14eb10da0efdfa1c17ef9f0dea77fafd5425a5c5dc -------------------------------------------------------------------------------- /dataset/TimeIT/dense_video_captioning/youcook2/train.caption_coco_format.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/d257625e0c84e7a4c6dcbbed5054e0fc3053e6c6 -------------------------------------------------------------------------------- /dataset/TimeIT/step_localization/step_localization_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "Localize a series of action steps in the given video, output a start and end timestamp for each step, and briefly describe the step. ", 3 | "1": "Locate and describe a series of actions or steps in the video, including their start and end timestamps.", 4 | "2": "Identify and mark the video segments corresponding to a series of actions or steps, specifying the timestamps and describing the steps.", 5 | "3": "Find, identify, and determine the temporal boundaries of a series of distinct actions or steps occurring throughout the video. For each action, output the corresponding start and end timestamps, accompanied by a concise description.", 6 | "4": "Identify and localize a series of steps or actions occurring in the video, providing start and end timestamps and related descriptions.", 7 | "5": "Locate and pinpoint a sequential series of specific actions or steps in the video, accurately specifying the start and end timestamps for each action. Additionally, provide a succinct description of each action." 8 | } -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/charades/charades_annotation/get_coco_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | from copy import deepcopy 5 | import pdb 6 | import numpy as np 7 | import random 8 | from pathlib import Path 9 | 10 | 11 | # read json files 12 | def read_json(path): 13 | with open(path, "r") as fin: 14 | datas = json.load(fin) 15 | annos = datas["annotations"] 16 | return annos 17 | 18 | 19 | def read_jsonl(path): 20 | anno = [] 21 | with open(path, "r") as fin: 22 | datas = fin.readlines() 23 | for data in datas: 24 | anno.append(json.loads(data.strip())) 25 | return anno 26 | 27 | 28 | 29 | def write_json(data, path): 30 | with open(path, "w") as fout: 31 | json.dump(data, fout) 32 | return 33 | 34 | 35 | def read_txt(path): 36 | data = [] 37 | with open(path, "r") as fin: 38 | lines = fin.readlines() 39 | for i, line in enumerate(lines): 40 | # e.g. AO8RW 0.0 6.9##a person is putting a book on a shelf. 41 | line = line.strip("\n") 42 | cap = line.split("##")[-1] 43 | if len(cap) < 2: 44 | continue 45 | terms = line.split("##")[0].split(" ") 46 | vid = terms[0] + ".mp4" 47 | start_time = float(terms[1]) 48 | end_time = float(terms[2]) 49 | data.append({"image_id": vid, "caption": cap, "timestamp": [start_time, end_time], "id": i}) 50 | return data 51 | 52 | 53 | def filter_sent(sent): 54 | sent = sent.strip(" ") 55 | if len(sent) < 2: 56 | return False 57 | sent = sent.replace("#", "") 58 | return sent 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--dataset', default='charades') # anet 64 | parser.add_argument('--anno_path', default='/home/yaolinli/dataset/Charades/charades_annotation/') 65 | parser.add_argument('--video_path', default='/home/yaolinli/dataset/Charades/videos') # ActivityNet_asr_denseCap/anet_6fps_224 66 | parser.add_argument('--outpath', default='./') 67 | args = parser.parse_args() 68 | '''output data example: 69 | { 70 | "annotations": [ 71 | { 72 | "image_id": "3MSZA.mp4", 73 | "caption": "person turn a light on.", 74 | "timestamp": [24.3, 30.4], 75 | }], 76 | } 77 | ''' 78 | 79 | for split in ["train", "test"]: # "val", "test" 80 | if args.dataset == "charades": 81 | filename = f"charades_sta_{split}.txt" 82 | annos = read_txt(os.path.join(args.anno_path, filename)) 83 | data = {} 84 | data["annotations"] = annos 85 | 86 | else: 87 | print("Do not support this dataset!") 88 | exit(0) 89 | 90 | print(f"==> {args.dataset} dataset \t# examples num: {len(annos)}") 91 | out_name = "{}.caption_coco_format.json".format(split) 92 | Path(args.outpath).mkdir(parents=True, exist_ok=True) 93 | write_json(data, os.path.join(args.outpath, out_name)) 94 | -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/didemo/instruct_tvg_33.0k_didemo.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/2f0961ca863a66444d1149b608439787ba413ddf -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/didemo/instruct_tvg_33.0k_didemo_15asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/a7c74718716319c268c762e5db50be0289d8e679467db97e2f8460b2aacd6677 -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/queryd/instruct_tvg_14.6k_queryd.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/77267f35d35c1c44a512e024114b5b6c8b88ac40 -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/queryd/instruct_tvg_14.6k_queryd_15asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/c4e98592e6d5e6d604c66ef1923f5385f0e080978fc1945c019ae28cf02bf342 -------------------------------------------------------------------------------- /dataset/TimeIT/temporal_video_grounding/temporal_video_grounding_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "Localize the visual content described by the given textual query in the video, and output the start and end timestamps in seconds.", 3 | "1": "Detect and report the start and end timestamps of the video segment that semantically matches the given textual query .", 4 | "2": "Give you a textual query: When does the described content occur in the video? Please return the timestamp in seconds.", 5 | "3": "Locate and describe the visual content mentioned in the text query within the video, including timestamps.", 6 | "4": "The given natural language query is semantically aligned with a video moment, please give the start time and end time of the video moment.", 7 | "5": "Find the video segment that corresponds to the given textual query and determine its start and end seconds." 8 | } -------------------------------------------------------------------------------- /dataset/TimeIT/time/instruct_time-sensitive_104k.json: -------------------------------------------------------------------------------- 1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/f343a7a774a15b3b9cf05d3bca346d75397246c50d50367a0f19cd90c55460ec -------------------------------------------------------------------------------- /dataset/TimeIT/time/instruct_time-sensitive_104k_asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/01c33679dc0b6f8fda5040441b636765c374709d0103f3c9dc285f47f8d1feff -------------------------------------------------------------------------------- /dataset/TimeIT/transcribed_speech_generation/transcribed_speech_generation_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "After watching the video from the YTTemporal dataset, transcribe the spoken content into text and document the start and end time for each segment. The format should be: 'start time - end time, transcribed speech'.", 3 | "1": "Observe the video thoroughly and transcribe the speech in a maximum of 20 segments. Make sure to include the starting and ending times for each segment in the following format: 'start time - end time, transcribed speech'.", 4 | "2": "Watch the provided video and transcribe the audio content. For each transcribed speech segment, note down its duration in the format: 'start time - end time, transcribed speech'.", 5 | "3": "Review the video from the YTTemporal dataset. Identify segments where speech occurs and transcribe those into text. Record the start and end time for each segment in this format: 'start time - end time, transcribed speech'.", 6 | "4": "Transcribe the spoken words in the video and note down the timestamps for each segment. Your output should look like this: 'start time - end time, transcribed speech'.", 7 | "5": "Watch the video, transcribe the speech, and indicate when each segment starts and ends. Follow this format: 'start time - end time, transcribed speech'." 8 | } -------------------------------------------------------------------------------- /dataset/TimeIT/transcribed_speech_generation/yttemporal/instruct_tsg_31.6k_yttemporal.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/9156641dad2e4c413db3568ae1b409f65ad2c22392de17396955922536a2184a -------------------------------------------------------------------------------- /dataset/TimeIT/transcribed_speech_generation/yttemporal/instruct_tsg_31.6k_yttemporal_15asr.json: -------------------------------------------------------------------------------- 1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/a33f3d1606cde5cc3da373e3737f7ba036697aa196d2b85b185ba02d41507c78 -------------------------------------------------------------------------------- /dataset/TimeIT/valley/Valley_instruct_73k.json: -------------------------------------------------------------------------------- 1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/50425f539707f196b66957997d197da33da969ea5054f4884eae7e8ee17490a0 -------------------------------------------------------------------------------- /dataset/TimeIT/valley/instruct_valley_72k.json: -------------------------------------------------------------------------------- 1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/4370bbe8cb6fd279898e2d7fd6eb0c76e34bb6bad80a631407edaf042bbf09c1 -------------------------------------------------------------------------------- /dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw/README.md: -------------------------------------------------------------------------------- 1 | ## QVHighlights Dataset 2 | 3 | Our annotation files include 3 splits: `train`, `val` and `test`. Each file is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of the annotation: 4 | 5 | ``` 6 | { 7 | "qid": 8737, 8 | "query": "A family is playing basketball together on a green court outside.", 9 | "duration": 126, 10 | "vid": "bP5KfdFJzC4_660.0_810.0", 11 | "relevant_windows": [[0, 16]], 12 | "relevant_clip_ids": [0, 1, 2, 3, 4, 5, 6, 7], 13 | "saliency_scores": [[4, 1, 1], [4, 1, 1], [4, 2, 1], [4, 3, 2], [4, 3, 2], [4, 3, 3], [4, 3, 3], [4, 3, 2]] 14 | } 15 | ``` 16 | `qid` is a unique identifier of a `query`. This query corresponds to a video identified by its video id `vid`. The `vid` is formatted as `{youtube_id}_{start_time}_{end_time}`. Use this information, one can retrieve the YouTube video from a url `https://www.youtube.com/embed/{youtube_id}?start={start_time}&end={end_time}&version=3`. For example, the video in this example is `https://www.youtube.com/embed/bP5KfdFJzC4?start=660&end=810&version=3`. 17 | `duration` is an integer indicating the duration of this video. 18 | `relevant_windows` is the list of windows that localize the moments, each window has two numbers, one indicates the start time of the moment, another one indicates the end time. `relevant_clip_ids` is the list of ids to the segmented 2-second clips that fall into the moments specified by `relevant_windows`, starting from 0. 19 | `saliency_scores` contains the saliency scores annotations, each sublist corresponds to a clip in `relevant_clip_ids`. There are 3 elements in each sublist, they are the scores from three different annotators. A score of `4` means `Very Good`, while `0` means `Very Bad`. 20 | 21 | Note that the three fields `relevant_clip_ids`, `relevant_windows` and `saliency_scores` for `test` split is not included. Please refer to [../standalone_eval/README.md](../standalone_eval/README.md) for details on evaluating predictions on `test`. 22 | 23 | In addition to the annotation files, we also provided the subtitle file for our weakly supervised ASR pre-training: [subs_train.jsonl](./subs_train.jsonl). This file is formatted similarly as our annotation files, but without the `saliency_scores` entry. This file is not needed if you do not plan to pretrain models using it. 24 | 25 | -------------------------------------------------------------------------------- /dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw/subs_train.jsonl: -------------------------------------------------------------------------------- 1 | ../../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/33086f20181724a477d7da5b2a063e7935d04d886548917b8ba9912f5a0c7dc5 -------------------------------------------------------------------------------- /dataset/TimeIT/video_highlight_detection/qvhighlights/get_coco_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | from copy import deepcopy 5 | import pdb 6 | import numpy as np 7 | import random 8 | from pathlib import Path 9 | from collections import Counter 10 | 11 | # read json files 12 | def read_json(path): 13 | with open(path, "r") as fin: 14 | datas = json.load(fin) 15 | annos = datas["annotations"] 16 | return annos 17 | 18 | 19 | def read_jsonl(path): 20 | anno = [] 21 | with open(path, "r") as fin: 22 | datas = fin.readlines() 23 | for data in datas: 24 | anno.append(json.loads(data.strip())) 25 | return anno 26 | 27 | 28 | 29 | def write_json(data, path): 30 | with open(path, "w") as fout: 31 | json.dump(data, fout) 32 | return 33 | 34 | 35 | def read_txt(path): 36 | data = [] 37 | with open(path, "r") as fin: 38 | lines = fin.readlines() 39 | for i, line in enumerate(lines): 40 | # e.g. AO8RW 0.0 6.9##a person is putting a book on a shelf. 41 | line = line.strip("\n") 42 | cap = line.split("##")[-1] 43 | if len(cap) < 2: 44 | continue 45 | terms = line.split("##")[0].split(" ") 46 | vid = terms[0] + ".mp4" 47 | start_time = float(terms[1]) 48 | end_time = float(terms[2]) 49 | data.append({"image_id": vid, "caption": cap, "timestamp": [start_time, end_time], "id": i}) 50 | return data 51 | 52 | 53 | def filter_sent(sent): 54 | sent = sent.strip(" ") 55 | if len(sent) < 2: 56 | return False 57 | sent = sent.replace("#", "") 58 | return sent 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--dataset', default='qvhighlights') # anet 64 | parser.add_argument('--anno_path', default='annotations_raw/') 65 | parser.add_argument('--video_path', default='videos/') # ActivityNet_asr_denseCap/anet_6fps_224 66 | parser.add_argument('--outpath', default='./') 67 | args = parser.parse_args() 68 | '''output data example: 69 | { 70 | "annotations": [ 71 | { 72 | "image_id": "3MSZA.mp4", 73 | "caption": "person turn a light on.", 74 | "timestamp": [24.3, 30.4], 75 | }], 76 | } 77 | ''' 78 | miss_videos = [] 79 | num_clips = [] 80 | for split in ["train", "val"]: # "val", "test" 81 | if args.dataset == "charades": 82 | filename = f"charades_sta_{split}.txt" 83 | annos = read_txt(os.path.join(args.anno_path, filename)) 84 | data = {} 85 | data["annotations"] = annos 86 | elif args.dataset == "qvhighlights": 87 | filename = f"highlight_{split}_release.jsonl" 88 | annos = read_jsonl(os.path.join(args.anno_path, filename)) 89 | new_data = [] 90 | for jterm in annos: 91 | new_term = {} 92 | new_term["image_id"] = "v_" + jterm["vid"] + ".mp4" 93 | # check the existance of the video 94 | if not os.path.exists(os.path.join(args.video_path, split, new_term["image_id"])): 95 | miss_videos.append(new_term["image_id"]) 96 | continue 97 | new_term["id"] = jterm["qid"] 98 | new_term["caption"] = jterm["query"] 99 | new_term["timestamp"] = jterm["relevant_windows"] 100 | new_term["duration"] = jterm["duration"] 101 | new_term["relevant_clip_ids"] = jterm["relevant_clip_ids"] 102 | new_term["saliency_scores"] = jterm["saliency_scores"] 103 | new_data.append(new_term) 104 | num_clips.append(int(jterm["duration"]/2)) 105 | data = {} 106 | data["annotations"] = new_data 107 | else: 108 | print("Do not support this dataset!") 109 | exit(0) 110 | 111 | print(f"==> {args.dataset} dataset \t# examples num: {len(new_data)} \t# miss videos num: {len(miss_videos)}\t# raw data num: {len(annos)}") 112 | out_name = "{}.caption_coco_format.json".format(split) 113 | Path(args.outpath).mkdir(parents=True, exist_ok=True) 114 | write_json(data, os.path.join(args.outpath, out_name)) 115 | 116 | if len(num_clips) >= 1: 117 | count = Counter(num_clips) 118 | # sort count dict with the clip num 119 | print(count) 120 | print(max(list(count.keys()))) 121 | 122 | -------------------------------------------------------------------------------- /dataset/TimeIT/video_highlight_detection/video_highlight_detection_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "You are given a video from the QVHighlights dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 seconds. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: . Please return the query-based highlight timestamps and salient scores.", 3 | "1": "Watch the provided video and mark out the scenes that stand out based on the description: . Document the timestamps of these highlights and evaluate their saliency scores.", 4 | "2": "Perform a thorough review of the video content, extracting key highlight moments that align with . It is essential to record the times of these moments and assign a distinct saliency value to each.", 5 | "3": "Examine the video and, in accordance with query , highlight the standout moments. You're required to provide the exact timing alongside a saliency rating for each segment.", 6 | "4": "In the video presented, seek moments that are a perfect match with . It's vital to notate their timestamps and to score each based on their level of saliency.", 7 | "5": "Go through the video content, and upon identifying highlight moments that resonate with , list their timestamps. Subsequently, provide a saliency score for each identified highlight." 8 | } -------------------------------------------------------------------------------- /dataset/TimeIT/video_summarization/video_summarization_instructions.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "From the dataset, generate a summarized version of the video, focusing on extracting key frames that best represent the overall narrative. The output should be a list of timestamps in seconds and their corresponding salient scores", 3 | "1": "You are given a video from the dataset. Please find the highlight contents in the video, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. ", 4 | "2": "Identify and extract the most emotionally impactful moments from the video provided by dataset, rating their intensity on a scale from 1 to 5.", 5 | "3": "Watch the provided video from the dataset and mark out the timestamps with stand-out visual content. Document the timestamps of these highlights and evaluate their saliency scores.", 6 | "4": "In the video presented from dataset, seek moments that could serve as an executive summary for a busy stakeholder. It's vital to notate their timestamps and to score each based on their level of saliency.", 7 | "5": "Go through the video content from dataset, and upon identifying highlight moments, list their timestamps. Subsequently, provide a saliency score for each identified highlight." 8 | } -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | 6 | from dataset.utils import load_image_from_path 7 | 8 | try: 9 | from petrel_client.client import Client 10 | has_client = True 11 | except ImportError: 12 | has_client = False 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ImageVideoBaseDataset(Dataset): 18 | """Base class that implements the image and video loading methods""" 19 | 20 | media_type = "video" 21 | 22 | def __init__(self): 23 | assert self.media_type in ["image", "video", "only_video"] 24 | self.data_root = None 25 | self.anno_list = ( 26 | None # list(dict), each dict contains {"image": str, # image or video path} 27 | ) 28 | self.transform = None 29 | self.video_reader = None 30 | self.num_tries = None 31 | 32 | self.client = None 33 | if has_client: 34 | self.client = Client('~/petreloss.conf') 35 | 36 | def __getitem__(self, index): 37 | raise NotImplementedError 38 | 39 | def __len__(self): 40 | raise NotImplementedError 41 | 42 | def get_anno(self, index): 43 | """obtain the annotation for one media (video or image) 44 | 45 | Args: 46 | index (int): The media index. 47 | 48 | Returns: dict. 49 | - "image": the filename, video also use "image". 50 | - "caption": The caption for this file. 51 | 52 | """ 53 | anno = self.anno_list[index] 54 | if self.data_root is not None: 55 | anno["image"] = os.path.join(self.data_root, anno["image"]) 56 | return anno 57 | 58 | def load_and_transform_media_data(self, index, data_path): 59 | if self.media_type == "image": 60 | return self.load_and_transform_media_data_image(index, data_path) 61 | else: 62 | return self.load_and_transform_media_data_video(index, data_path) 63 | 64 | def load_and_transform_media_data_image(self, index, data_path): 65 | image = load_image_from_path(data_path, client=self.client) 66 | image = self.transform(image) 67 | return image, index 68 | 69 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None): 70 | for _ in range(self.num_tries): 71 | try: 72 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 73 | frames, frame_indices, fps = self.video_reader( 74 | data_path, self.num_frames, self.sample_type, 75 | max_num_frames=max_num_frames, client=self.client, clip=clip 76 | ) 77 | except Exception as e: 78 | logger.warning( 79 | f"Caught exception {e} when loading video {data_path}, " 80 | f"randomly sample a new video as replacement" 81 | ) 82 | index = random.randint(0, len(self) - 1) 83 | ann = self.get_anno(index) 84 | data_path = ann["image"] 85 | continue 86 | # shared aug for video frames 87 | frames = self.transform(frames) 88 | if return_fps: 89 | if clip is None: 90 | sec = [str(round(f / fps, 1)) for f in frame_indices] 91 | else: 92 | sec = [str(round( abs(f / fps - clip[0] + 0.1), 1)) for f in frame_indices] 93 | return frames, index, sec 94 | else: 95 | return frames, index 96 | else: 97 | raise RuntimeError( 98 | f"Failed to fetch video after {self.num_tries} tries. " 99 | f"This might indicate that you have many corrupted videos." 100 | ) 101 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, name2loader): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.name2loader = name2loader 20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 22 | index2name = {v: k for k, v in name2index.items()} 23 | 24 | iter_order = [] 25 | for n, l in name2loader.items(): 26 | iter_order.extend([name2index[n]]*len(l)) 27 | 28 | random.shuffle(iter_order) 29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 30 | 31 | # sync 32 | if is_dist_avail_and_initialized(): 33 | # make sure all processes have the same order so that 34 | # each step they will have data from the same loader 35 | dist.broadcast(iter_order, src=0) 36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 37 | 38 | logger.info(str(self)) 39 | 40 | def __str__(self): 41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 42 | for idx, (name, loader) in enumerate(self.name2loader.items()): 43 | output.append( 44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " 45 | ) 46 | return "\n".join(output) 47 | 48 | def __len__(self): 49 | return len(self.iter_order) 50 | 51 | def __iter__(self): 52 | """ this iterator will run indefinitely """ 53 | for name in self.iter_order: 54 | _iter = self.name2iter[name] 55 | batch = next(_iter) 56 | yield name, batch 57 | 58 | 59 | class MetaLoader_rs(object): 60 | """ wraps multiple data loader """ 61 | def __init__(self, name2loader, skip_num=0): 62 | """Iterates over multiple dataloaders, it ensures all processes 63 | work on data from the same dataloader. This loader will end when 64 | the shorter dataloader raises StopIteration exception. 65 | 66 | loaders: Dict, {name: dataloader} 67 | """ 68 | self.name2loader = name2loader 69 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 70 | index2name = {v: k for k, v in name2index.items()} 71 | 72 | iter_order = [] 73 | for n, l in name2loader.items(): 74 | iter_order.extend([name2index[n]]*len(l)) 75 | 76 | random.shuffle(iter_order) 77 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 78 | 79 | # sync 80 | if is_dist_avail_and_initialized(): 81 | # make sure all processes have the same order so that 82 | # each step they will have data from the same loader 83 | dist.broadcast(iter_order, src=0) 84 | 85 | if skip_num > 0: 86 | iter_order_skip = iter_order[:skip_num] 87 | for k, v in index2name.items(): 88 | media_step = (iter_order_skip == k).sum().item() 89 | name2loader[v].sampler.set_start_iter(media_step) 90 | logger.info(f"{v} dataloder skip steps: {media_step}") 91 | iter_order = iter_order[skip_num:] 92 | self.name2loader = name2loader 93 | else: 94 | logger.info("Do not skip steps for any dataloader!") 95 | for k, v in index2name.items(): 96 | name2loader[v].sampler.set_start_iter(0) 97 | 98 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 99 | self.iter_idx = iter_order 100 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 101 | 102 | logger.info(str(self)) 103 | 104 | def __str__(self): 105 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 106 | for idx, (name, loader) in enumerate(self.name2loader.items()): 107 | length = (self.iter_idx == idx).sum() 108 | output.append( 109 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} " 110 | ) 111 | return "\n".join(output) 112 | 113 | def __len__(self): 114 | return len(self.iter_order) 115 | 116 | def __iter__(self): 117 | """ this iterator will run indefinitely """ 118 | for name in self.iter_order: 119 | _iter = self.name2iter[name] 120 | batch = next(_iter) 121 | yield name, batch -------------------------------------------------------------------------------- /dataset/it_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import random 5 | 6 | import numpy as np 7 | 8 | from dataset.base_dataset import ImageVideoBaseDataset 9 | from dataset.video_utils import VIDEO_READER_FUNCS 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class ITImgTrainDataset(ImageVideoBaseDataset): 15 | media_type = "image" 16 | 17 | def __init__( 18 | self, ann_file, transform, 19 | system="", role=("Human", "Assistant"), 20 | start_token="", end_token="", 21 | random_shuffle=True, # if True, shuffle the QA list 22 | ): 23 | super().__init__() 24 | 25 | if len(ann_file) == 3 and ann_file[2] == "video": 26 | self.media_type = "video" 27 | else: 28 | self.media_type = "image" 29 | self.label_file, self.data_root = ann_file[:2] 30 | 31 | logger.info('Load json file') 32 | with open(self.label_file, 'r') as f: 33 | self.anno = json.load(f) 34 | self.num_examples = len(self.anno) 35 | self.transform = transform 36 | 37 | # prompt parameters 38 | if system: 39 | assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." 40 | # currently not support add start_token and end_token in the system, since the msg should be added properly 41 | self.begin_signal = "###" 42 | self.end_signal = " " 43 | self.start_token = start_token 44 | self.end_token = end_token 45 | self.system = system 46 | self.role = role 47 | self.random_shuffle = random_shuffle 48 | # instruction location and number 49 | logger.info(f"Random shuffle: {self.random_shuffle}") 50 | 51 | def get_anno(self, index): 52 | filename = self.anno[index][self.media_type] 53 | qa = self.anno[index]["QA"] 54 | if "start" in self.anno[index] and "end" in self.anno[index]: 55 | anno = { 56 | "image": os.path.join(self.data_root, filename), "qa": qa, 57 | "start": self.anno[index]["start"], "end": self.anno[index]["end"], 58 | } 59 | else: 60 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa} 61 | return anno 62 | 63 | def __len__(self): 64 | return self.num_examples 65 | 66 | def process_qa(self, qa, msg=""): 67 | cur_instruction = "" 68 | # randomly shuffle qa for conversation 69 | if self.random_shuffle and len(qa) > 1: 70 | random.shuffle(qa) 71 | if "i" in qa[0].keys() and qa[0]["i"] != "": 72 | cur_instruction = qa[0]["i"] + self.end_signal 73 | 74 | conversation = self.system 75 | # add instruction as system message 76 | if cur_instruction: 77 | conversation += cur_instruction 78 | 79 | # rstrip() for the extra " " in msg 80 | conversation += ( 81 | self.begin_signal + self.role[0] + ": " + 82 | self.start_token + self.end_token + msg.rstrip() + self.end_signal 83 | ) 84 | 85 | for sentence in qa: 86 | q = sentence["q"] 87 | a = sentence["a"] 88 | if q != "": 89 | conversation += (self.begin_signal + self.role[0] + ": " + q + self.end_signal) 90 | else: 91 | # no question, often in caption dataset 92 | pass 93 | conversation += (self.begin_signal + self.role[1] + ": " + a + self.end_signal) 94 | conversation += self.begin_signal 95 | 96 | if cur_instruction: 97 | cur_instruction += qa[0]["q"] 98 | return conversation, cur_instruction.strip() 99 | 100 | def __getitem__(self, index): 101 | try: 102 | ann = self.get_anno(index) 103 | image, index = self.load_and_transform_media_data_image(index, ann["image"]) 104 | conversation, instruction = self.process_qa(ann["qa"]) 105 | return image, conversation, instruction, index 106 | except Exception as e: 107 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 108 | index = np.random.randint(0, len(self)) 109 | return self.__getitem__(index) 110 | 111 | 112 | class ITVidTrainDataset(ITImgTrainDataset): 113 | media_type = "video" 114 | 115 | def __init__( 116 | self, ann_file, transform, 117 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, 118 | system="", role=("Human", "Assistant"), 119 | start_token="", 120 | add_second_msg=True, 121 | random_shuffle=True, 122 | ): 123 | super().__init__( 124 | ann_file, transform, 125 | system=system, role=role, 126 | start_token=start_token, end_token=end_token, 127 | random_shuffle=random_shuffle, 128 | ) 129 | self.num_frames = num_frames 130 | self.video_reader_type = video_reader_type 131 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 132 | self.sample_type = sample_type 133 | self.num_tries = num_tries 134 | self.add_second_msg = add_second_msg 135 | 136 | logger.info(f"Use {video_reader_type} for data in {ann_file}") 137 | if add_second_msg: 138 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.") 139 | 140 | def __getitem__(self, index): 141 | try: 142 | ann = self.get_anno(index) 143 | msg = "" 144 | clip = None 145 | if "start" in ann and "end" in ann: 146 | clip = [ann["start"], ann["end"]] 147 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip) 148 | if self.add_second_msg: 149 | # " " should be added in the start and end 150 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " 151 | conversation, instruction = self.process_qa(ann["qa"], msg) 152 | return video, conversation, instruction, index 153 | except Exception as e: 154 | logger.warning(f"Caught exception {e} when loading video {ann['image']}") 155 | index = np.random.randint(0, len(self)) 156 | return self.__getitem__(index) -------------------------------------------------------------------------------- /dataset/it_dataset_mistral.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import random 5 | 6 | import numpy as np 7 | 8 | from dataset.base_dataset import ImageVideoBaseDataset 9 | from dataset.video_utils import VIDEO_READER_FUNCS 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class ITImgTrainDataset_mistral(ImageVideoBaseDataset): 15 | media_type = "image" 16 | 17 | def __init__( 18 | self, ann_file, transform, 19 | system="", 20 | start_token="", end_token="", 21 | random_shuffle=True, # if True, shuffle the QA list 22 | return_question_instruction=False # if True, return instruction with instruciton 23 | ): 24 | super().__init__() 25 | 26 | if len(ann_file) == 3 and ann_file[2] == "video": 27 | self.media_type = "video" 28 | else: 29 | self.media_type = "image" 30 | self.label_file, self.data_root = ann_file[:2] 31 | 32 | logger.info('Load json file') 33 | with open(self.label_file, 'r') as f: 34 | self.anno = json.load(f) 35 | self.num_examples = len(self.anno) 36 | self.transform = transform 37 | 38 | # prompt parameters 39 | if system: 40 | assert system[-1] == " ", "' ' should be add in the end of system." 41 | 42 | self.human_start = "[INST]" 43 | self.human_end = "[/INST]" 44 | self.assist_end = "" 45 | self.start_token = start_token 46 | self.end_token = end_token 47 | self.system = system 48 | self.random_shuffle = random_shuffle 49 | # instruction location and number 50 | self.return_question_instruction = return_question_instruction 51 | logger.info(f"Random shuffle: {self.random_shuffle}") 52 | logger.info(f"Return question with instruction: {self.return_question_instruction}") 53 | 54 | def get_anno(self, index): 55 | filename = self.anno[index][self.media_type] 56 | qa = self.anno[index]["QA"] 57 | if "start" in self.anno[index] and "end" in self.anno[index]: 58 | anno = { 59 | "image": os.path.join(self.data_root, filename), "qa": qa, 60 | "start": self.anno[index]["start"], "end": self.anno[index]["end"], 61 | } 62 | else: 63 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa} 64 | return anno 65 | 66 | def __len__(self): 67 | return self.num_examples 68 | 69 | def process_qa(self, qa, msg=""): 70 | cur_instruction = "" 71 | # randomly shuffle qa for conversation 72 | if self.random_shuffle and len(qa) > 1: 73 | random.shuffle(qa) 74 | if "i" in qa[0].keys() and qa[0]["i"] != "": 75 | cur_instruction = qa[0]["i"] + " " 76 | 77 | conversation = self.system 78 | # add instruction as system message 79 | if cur_instruction: 80 | conversation += cur_instruction 81 | 82 | # rstrip() for the extra " " in msg 83 | conversation += ( 84 | self.human_start + " " + self.start_token + self.end_token + msg.rstrip() + " " + self.human_end 85 | ) 86 | 87 | for idx, sentence in enumerate(qa): 88 | q = sentence["q"] 89 | a = sentence["a"] 90 | if q != "": 91 | conversation += (" " + self.human_start + " " + q + " " + self.human_end) 92 | else: 93 | # no question, often in caption dataset 94 | pass 95 | conversation += (" " + a + " " + self.assist_end) 96 | 97 | if self.return_question_instruction and cur_instruction: 98 | cur_instruction += qa[0]["q"] 99 | return conversation.strip(), cur_instruction.strip() 100 | 101 | def __getitem__(self, index): 102 | try: 103 | ann = self.get_anno(index) 104 | image, index = self.load_and_transform_media_data_image(index, ann["image"]) 105 | conversation, instruction = self.process_qa(ann["qa"]) 106 | return image, conversation, instruction, index 107 | except Exception as e: 108 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 109 | index = np.random.randint(0, len(self)) 110 | return self.__getitem__(index) 111 | 112 | 113 | class ITVidTrainDataset_mistral(ITImgTrainDataset_mistral): 114 | media_type = "video" 115 | 116 | def __init__( 117 | self, ann_file, transform, 118 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, 119 | system="", start_token="", 120 | add_second_msg=False, 121 | random_shuffle=True, 122 | return_question_instruction=False # if True, return instruction with instruciton 123 | ): 124 | super().__init__( 125 | ann_file, transform, 126 | system=system, 127 | start_token=start_token, end_token=end_token, 128 | random_shuffle=random_shuffle, 129 | return_question_instruction=return_question_instruction 130 | ) 131 | self.num_frames = num_frames 132 | self.video_reader_type = video_reader_type 133 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 134 | self.sample_type = sample_type 135 | self.num_tries = num_tries 136 | self.add_second_msg = add_second_msg 137 | 138 | logger.info(f"Use {video_reader_type} for data in {ann_file}") 139 | if add_second_msg: 140 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.") 141 | 142 | def __getitem__(self, index): 143 | try: 144 | ann = self.get_anno(index) 145 | msg = "" 146 | clip = None 147 | if "start" in ann and "end" in ann: 148 | clip = [ann["start"], ann["end"]] 149 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip) 150 | if self.add_second_msg: 151 | # " " should be added in the start and end 152 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " 153 | conversation, instruction = self.process_qa(ann["qa"], msg) 154 | return video, conversation, instruction, index 155 | except Exception as e: 156 | logger.warning(f"Caught exception {e} when loading video {ann['image']}") 157 | index = np.random.randint(0, len(self)) 158 | return self.__getitem__(index) -------------------------------------------------------------------------------- /dataset/pt_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | 7 | from dataset.base_dataset import ImageVideoBaseDataset 8 | from dataset.utils import load_anno, pre_text 9 | from dataset.video_utils import VIDEO_READER_FUNCS 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class PTImgTrainDataset(ImageVideoBaseDataset): 15 | media_type = "image" 16 | 17 | def __init__(self, ann_file, transform, pre_text=True): 18 | super().__init__() 19 | 20 | if len(ann_file) == 3 and ann_file[2] == "video": 21 | self.media_type = "video" 22 | else: 23 | self.media_type = "image" 24 | self.label_file, self.data_root = ann_file[:2] 25 | 26 | logger.info('Load json file') 27 | with open(self.label_file, 'r') as f: 28 | self.anno = json.load(f) 29 | self.num_examples = len(self.anno) 30 | 31 | self.transform = transform 32 | self.pre_text = pre_text 33 | logger.info(f"Pre-process text: {pre_text}") 34 | 35 | def get_anno(self, index): 36 | filename = self.anno[index][self.media_type] 37 | caption = self.anno[index]["caption"] 38 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption} 39 | return anno 40 | 41 | def __len__(self): 42 | return self.num_examples 43 | 44 | def __getitem__(self, index): 45 | try: 46 | ann = self.get_anno(index) 47 | image, index = self.load_and_transform_media_data(index, ann["image"]) 48 | caption = pre_text(ann["caption"], pre_text=self.pre_text) 49 | return image, caption, index 50 | except Exception as e: 51 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 52 | index = np.random.randint(0, len(self)) 53 | return self.__getitem__(index) 54 | 55 | 56 | class PTVidTrainDataset(PTImgTrainDataset): 57 | media_type = "video" 58 | 59 | def __init__( 60 | self, 61 | ann_file, 62 | transform, 63 | num_frames=4, 64 | video_reader_type="decord", 65 | sample_type="rand", 66 | num_tries=3, 67 | pre_text=True 68 | ): 69 | super().__init__(ann_file, transform, pre_text=pre_text) 70 | self.num_frames = num_frames 71 | self.video_reader_type = video_reader_type 72 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 73 | self.sample_type = sample_type 74 | self.num_tries = num_tries 75 | 76 | 77 | class PTImgEvalDataset(ImageVideoBaseDataset): 78 | media_type = "image" 79 | 80 | def __init__(self, ann_file, transform, has_multi_vision_gt=False): 81 | super(PTImgEvalDataset, self).__init__() 82 | self.raw_anno_list = load_anno(ann_file) 83 | self.transform = transform 84 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth 85 | 86 | self.text = None 87 | self.image = None 88 | self.txt2img = None 89 | self.img2txt = None 90 | self.build_data() 91 | 92 | def build_data(self): 93 | self.text = [] 94 | self.image = [] 95 | self.txt2img = {} 96 | self.img2txt = {} 97 | if self.has_multi_vision_gt: 98 | self.build_data_multi_img_gt() 99 | else: 100 | self.build_data_multi_txt_gt() 101 | self.anno_list = [dict(image=e) for e in self.image] 102 | 103 | def build_data_multi_img_gt(self): 104 | """each text may have multiple ground_truth image, e.g., ssv2""" 105 | img_id = 0 106 | for txt_id, ann in enumerate(self.raw_anno_list): 107 | self.text.append(pre_text(ann["caption"])) 108 | self.txt2img[txt_id] = [] 109 | _images = ann["image"] \ 110 | if isinstance(ann["image"], list) else [ann["image"], ] 111 | for i, image in enumerate(_images): 112 | self.image.append(image) 113 | self.txt2img[txt_id].append(img_id) 114 | self.img2txt[img_id] = txt_id 115 | img_id += 1 116 | 117 | def build_data_multi_txt_gt(self): 118 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K""" 119 | txt_id = 0 120 | for img_id, ann in enumerate(self.raw_anno_list): 121 | self.image.append(ann["image"]) 122 | self.img2txt[img_id] = [] 123 | _captions = ann["caption"] \ 124 | if isinstance(ann["caption"], list) else [ann["caption"], ] 125 | for i, caption in enumerate(_captions): 126 | self.text.append(pre_text(caption)) 127 | self.img2txt[img_id].append(txt_id) 128 | self.txt2img[txt_id] = img_id 129 | txt_id += 1 130 | 131 | def __len__(self): 132 | return len(self.anno_list) 133 | 134 | def __getitem__(self, index): 135 | ann = self.anno_list[index] 136 | image, index = self.load_and_transform_media_data(index, ann["image"]) 137 | return image, index 138 | 139 | 140 | def preprocess_para_retrieval_data(anno_list): 141 | processed_anno_list = [] 142 | for d in anno_list: 143 | d["caption"] = " ".join(d.pop("caption")) 144 | processed_anno_list.append(d) 145 | return processed_anno_list 146 | 147 | 148 | class PTVidEvalDataset(PTImgEvalDataset): 149 | media_type = "video" 150 | 151 | def __init__( 152 | self, ann_file, transform, num_frames=4, 153 | video_reader_type="decord", sample_type="rand", num_tries=1, 154 | is_paragraph_retrieval=False, has_multi_vision_gt=False, 155 | ): 156 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt) 157 | self.num_frames = num_frames 158 | self.video_reader_type = video_reader_type 159 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 160 | self.sample_type = sample_type 161 | self.num_tries = num_tries 162 | self.is_paragraph_retrieval = is_paragraph_retrieval 163 | 164 | if is_paragraph_retrieval: 165 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list) 166 | self.build_data() 167 | -------------------------------------------------------------------------------- /dataset/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | 7 | # stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93 8 | class StatefulDistributedSampler(DistributedSampler): 9 | """ 10 | More fine-grained state DataSampler that uses training iteration and epoch 11 | both for shuffling data. PyTorch DistributedSampler only uses epoch 12 | for the shuffling and starts sampling data from the start. In case of training 13 | on very large data, we train for one epoch only and when we resume training, 14 | we want to resume the data sampler from the training iteration. 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=None, seed: int = 0): 18 | """ 19 | Initializes the instance of StatefulDistributedSampler. Random seed is set 20 | for the epoch set and data is shuffled. For starting the sampling, use 21 | the start_iter (set to 0 or set by checkpointing resuming) to 22 | sample data from the remaining images. 23 | 24 | Args: 25 | dataset (Dataset): Pytorch dataset that sampler will shuffle 26 | batch_size (int): batch size we want the sampler to sample 27 | seed (int): Seed for the torch generator. 28 | """ 29 | super().__init__(dataset, shuffle=False, seed=seed) 30 | 31 | self.start_iter = 0 32 | self.batch_size = batch_size 33 | self.total_size = len(dataset) - (len(dataset) % self.num_replicas) 34 | self.num_samples = self.total_size // self.num_replicas 35 | print(f"rank: {self.rank}: Sampler created...") 36 | 37 | def __iter__(self): 38 | # partition data into num_replicas and optionally shuffle within a rank 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch + self.seed) 41 | shuffling = torch.randperm(self.num_samples, generator=g).tolist() 42 | indices = np.array( 43 | list( 44 | range( 45 | (self.rank * self.num_samples), (self.rank + 1) * self.num_samples 46 | ) 47 | ) 48 | )[shuffling].tolist() 49 | 50 | # make sure we have correct number of samples per replica 51 | assert len(indices) == self.num_samples 52 | assert self.batch_size > 0, "batch_size not set for the sampler" 53 | 54 | # resume the sampler 55 | start_index = self.start_iter * self.batch_size 56 | indices = indices[start_index:] 57 | return iter(indices) 58 | 59 | def set_start_iter(self, start_iter): 60 | """ 61 | Set the iteration number from which the sampling should start. This is 62 | used to find the marker in the data permutation order from where the 63 | sampler should start sampling. 64 | """ 65 | self.start_iter = start_iter 66 | -------------------------------------------------------------------------------- /dataset/video_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py 3 | """ 4 | import random 5 | import io 6 | import av 7 | import cv2 8 | import decord 9 | import imageio 10 | from decord import VideoReader 11 | import torch 12 | import numpy as np 13 | import math 14 | decord.bridge.set_bridge("torch") 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: 20 | """ 21 | Converts a present time with the given time base and start_pts offset to seconds. 22 | 23 | Returns: 24 | time_in_seconds (float): The corresponding time in seconds. 25 | 26 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 27 | """ 28 | if pts == math.inf: 29 | return math.inf 30 | 31 | return int(pts - start_pts) * time_base 32 | 33 | 34 | def get_pyav_video_duration(video_reader): 35 | video_stream = video_reader.streams.video[0] 36 | video_duration = pts_to_secs( 37 | video_stream.duration, 38 | video_stream.time_base, 39 | video_stream.start_time 40 | ) 41 | return float(video_duration) 42 | 43 | 44 | def get_frame_indices_by_fps(): 45 | pass 46 | 47 | 48 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 49 | if sample in ["rand", "middle"]: # uniform sampling 50 | acc_samples = min(num_frames, vlen) 51 | # split the video into `acc_samples` intervals, and sample from each interval. 52 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 53 | ranges = [] 54 | for idx, interv in enumerate(intervals[:-1]): 55 | ranges.append((interv, intervals[idx + 1] - 1)) 56 | if sample == 'rand': 57 | try: 58 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] 59 | except: 60 | frame_indices = np.random.permutation(vlen)[:acc_samples] 61 | frame_indices.sort() 62 | frame_indices = list(frame_indices) 63 | elif fix_start is not None: 64 | frame_indices = [x[0] + fix_start for x in ranges] 65 | elif sample == 'middle': 66 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] 67 | else: 68 | raise NotImplementedError 69 | 70 | if len(frame_indices) < num_frames: # padded with last frame 71 | padded_frame_indices = [frame_indices[-1]] * num_frames 72 | padded_frame_indices[:len(frame_indices)] = frame_indices 73 | frame_indices = padded_frame_indices 74 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps 75 | output_fps = float(sample[3:]) 76 | duration = float(vlen) / input_fps 77 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents 78 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) 79 | frame_indices = np.around(frame_seconds * input_fps).astype(int) 80 | frame_indices = [e for e in frame_indices if e < vlen] 81 | if max_num_frames > 0 and len(frame_indices) > max_num_frames: 82 | frame_indices = frame_indices[:max_num_frames] 83 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) 84 | else: 85 | raise ValueError 86 | return frame_indices 87 | 88 | 89 | def read_frames_av( 90 | video_path, num_frames, sample='rand', fix_start=None, 91 | max_num_frames=-1, client=None, clip=None, 92 | ): 93 | reader = av.open(video_path) 94 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] 95 | vlen = len(frames) 96 | duration = get_pyav_video_duration(reader) 97 | fps = vlen / float(duration) 98 | frame_indices = get_frame_indices( 99 | num_frames, vlen, sample=sample, fix_start=fix_start, 100 | input_fps=fps, max_num_frames=max_num_frames 101 | ) 102 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 103 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 104 | return frames, frame_indices, fps 105 | 106 | 107 | def read_frames_gif( 108 | video_path, num_frames, sample='rand', fix_start=None, 109 | max_num_frames=-1, client=None, clip=None, 110 | ): 111 | if video_path.startswith('s3') or video_path.startswith('p2'): 112 | video_bytes = client.get(video_path) 113 | gif = imageio.get_reader(io.BytesIO(video_bytes)) 114 | else: 115 | gif = imageio.get_reader(video_path) 116 | vlen = len(gif) 117 | frame_indices = get_frame_indices( 118 | num_frames, vlen, sample=sample, fix_start=fix_start, 119 | max_num_frames=max_num_frames 120 | ) 121 | frames = [] 122 | for index, frame in enumerate(gif): 123 | # for index in frame_idxs: 124 | if index in frame_indices: 125 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 126 | frame = torch.from_numpy(frame).byte() 127 | # # (H x W x C) to (C x H x W) 128 | frame = frame.permute(2, 0, 1) 129 | frames.append(frame) 130 | frames = torch.stack(frames) # .float() / 255 131 | return frames, frame_indices, 25. # for tgif 132 | 133 | 134 | def read_frames_decord( 135 | video_path, num_frames, sample='rand', fix_start=None, 136 | max_num_frames=-1, client=None, clip=None 137 | ): 138 | if "s3" in video_path: 139 | video_bytes = client.get(video_path) 140 | video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) 141 | else: 142 | video_reader = VideoReader(video_path, num_threads=1) 143 | vlen = len(video_reader) 144 | fps = video_reader.get_avg_fps() 145 | duration = vlen / float(fps) 146 | 147 | if clip: 148 | start, end = clip 149 | duration = end - start 150 | vlen = int(duration * fps) 151 | start_index = int(start * fps) 152 | 153 | frame_indices = get_frame_indices( 154 | num_frames, vlen, sample=sample, fix_start=fix_start, 155 | input_fps=fps, max_num_frames=max_num_frames 156 | ) 157 | if clip: 158 | frame_indices = [f + start_index for f in frame_indices] 159 | 160 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 161 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 162 | return frames, frame_indices, float(fps) 163 | 164 | 165 | VIDEO_READER_FUNCS = { 166 | 'av': read_frames_av, 167 | 'decord': read_frames_decord, 168 | 'gif': read_frames_gif, 169 | } 170 | -------------------------------------------------------------------------------- /demo/example/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/bear.jpg -------------------------------------------------------------------------------- /demo/example/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/dog.png -------------------------------------------------------------------------------- /demo/example/jesse_dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/jesse_dance.mp4 -------------------------------------------------------------------------------- /demo/example/people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/people.jpg -------------------------------------------------------------------------------- /demo/example/yoga.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/yoga.mp4 -------------------------------------------------------------------------------- /download/folder_keeper: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/download/folder_keeper -------------------------------------------------------------------------------- /eval/Egoschema_trans_csv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | 4 | file_dir="/path_to_the_timesuite_root_folder/scripts/Ablation/Token_Shufflue/F128_CF8_PoolMax_Ablation_PoolMax/Egoschema_test_ckpt_01" 5 | 6 | # 加载你的json数据 7 | with open(f'{file_dir}/result.json', 'r') as f: 8 | data = json.load(f) 9 | 10 | # 将数据写入CSV文件 11 | with open(f'{file_dir}/result_submit_kaggle.csv', 'w', newline='') as f: 12 | writer = csv.writer(f) 13 | writer.writerow(['q_uid', 'answer']) 14 | for key, value in data.items(): 15 | writer.writerow([key, value]) # 写入数据 16 | -------------------------------------------------------------------------------- /eval/eval_egoschema.sh: -------------------------------------------------------------------------------- 1 | ROOT_DIR="/path_to_the_timesuite_root_folder/download/parameters" 2 | 3 | python3 eval/validate_egoschema.py \ 4 | --f ${ROOT_DIR}/Egoschema_test_timesuite/result.json \ 5 | 2>&1 | tee "${ROOT_DIR}/Egoschema_test_timesuite/acc.txt" 6 | -------------------------------------------------------------------------------- /eval/format_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import json 4 | 5 | def read_txt(path): 6 | with open(path, "r") as fin: 7 | data = fin.readline().strip() 8 | return data 9 | 10 | 11 | def load_data(args, anno_path, split=None): 12 | ''' 13 | anno data example: 14 | {"annotations": 15 | [ 16 | { 17 | "image_id": "xHr8X2Wpmno.mp4" 18 | ... 19 | }, 20 | ... 21 | ] 22 | } 23 | ''' 24 | file_path = os.path.join(anno_path, f'{split}.caption_coco_format.json') 25 | with open(file_path, 'r') as f: 26 | data = json.load(f)["annotations"] 27 | 28 | if args.debug: 29 | data = data[:10] 30 | return data 31 | 32 | 33 | def merge_seg_caps(results): 34 | """merge mulple generated captions from a same video into paragraph.""" 35 | merge_results = {} 36 | for jterm in results: 37 | vname = jterm["vname"] 38 | cap = jterm["generated_cap"] 39 | postfix = vname.split(".mp4")[-1] 40 | start_time, end_time = float(postfix.split("_")[-2]), float(postfix.split("_")[-1]) 41 | vid = vname.split(".mp4")[0] + ".mp4" 42 | if vid not in merge_results: 43 | merge_results[vid] = [] 44 | merge_results[vid].append({"timestamp": [start_time, end_time], "caption": cap}) 45 | return merge_results 46 | 47 | 48 | def save_result(args, output_dir, results, split_name='test', format=False): 49 | Path(output_dir).mkdir(parents=True, exist_ok=True) 50 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result.json' 51 | if args.timestamp: 52 | if args.timestamp_file != '': 53 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result_with_pred_timestamp.json' 54 | else: 55 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result_with_gt_timestamp.json' 56 | if args.debug: 57 | file_name = 'debug_' + file_name 58 | if format: 59 | file_name = 'fmt_' + file_name 60 | with open(os.path.join(output_dir, file_name), 'w') as f: 61 | json.dump(results, f) 62 | return 63 | 64 | 65 | def get_timestamp_from_file(timestamp_file): 66 | timestamp = {} 67 | with open(timestamp_file, 'r') as f: 68 | data = json.load(f) 69 | for vid, vlist in data.items(): 70 | timestamp[vid] = [] 71 | for vterm in vlist: 72 | timestamp[vid].append(vterm["timestamp"]) 73 | return timestamp 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /eval/format_tvg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import re 5 | from copy import deepcopy 6 | import pdb 7 | import numpy as np 8 | from pathlib import Path 9 | 10 | # read json files 11 | def read_json(path): 12 | with open(path, "r") as fin: 13 | datas = json.load(fin) 14 | return datas 15 | 16 | 17 | def write_json(path, data): 18 | with open(path, "w") as fout: 19 | json.dump(data, fout) 20 | print("The format file has been saved at:{}".format(path)) 21 | return 22 | 23 | 24 | def extract_time(paragraph): 25 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower() 26 | paragraph = paragraph.lower() 27 | paragraph.replace(prompt, '') 28 | # Split text into sentences based on common delimiters 29 | sentences = re.split(r'[!?\n]', paragraph) 30 | 31 | # Keywords that might indicate the presence of time information 32 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"] 33 | # filter sentences by keywords 34 | candidates = [] 35 | for sentence in sentences: 36 | # If sentence contains one of the keywords 37 | if any(keyword in sentence for keyword in keywords): 38 | candidates.append(sentence) 39 | 40 | timestamps = [] 41 | # Check for The given query happens in m - n (seconds) 42 | patterns = [ 43 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)" 44 | ] 45 | 46 | for time_pattern in patterns: 47 | time_matches = re.findall(time_pattern, paragraph) 48 | if time_matches: 49 | timestamps = [[float(start), float(end)] for start, end in time_matches] 50 | 51 | if len(sentences) == 0: 52 | return [] 53 | # check for other formats e.g.: 54 | # 1 .Starting time: 0.8 seconds 55 | # Ending time: 1.1 seconds 56 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds. 57 | if len(timestamps) == 0: 58 | times = [] 59 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5) 60 | for sentence in candidates: 61 | time = re.findall(time_regex, sentence) 62 | if time: 63 | time_in_sec = float(time[0]) 64 | times.append(time_in_sec) 65 | times = times[:len(times)//2*2] 66 | timestamps = [(times[i], times[i+1]) for i in range(0, len(times), 2)] 67 | # Check for examples like: 68 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23. 69 | if len(timestamps) == 0: 70 | times = [] 71 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05) 72 | for sentence in candidates: 73 | time = re.findall(time_regex, sentence) 74 | if time: 75 | t = time[0] 76 | else: 77 | continue 78 | # If time is in HH:MM:SS format, convert to seconds 79 | if t.count(':') == 2: 80 | h, m, s = map(int, t.split(':')) 81 | time_in_sec = h * 3600 + m * 60 + s 82 | elif t.count(':') == 1: 83 | m, s = map(int, t.split(':')) 84 | time_in_sec = m * 60 + s 85 | times.append(time_in_sec) 86 | times = times[:len(times)//2*2] 87 | timestamps = [(times[i], times[i+1]) for i in range(0, len(times), 2)] 88 | results = [] 89 | for (start, end) in timestamps: 90 | if end > start: 91 | results.append([start, end]) 92 | else: 93 | results.append([end, start]) 94 | if len(results) > 1: 95 | results = results[:1] 96 | return results 97 | 98 | 99 | def format_tvg_output(paras): 100 | timestamps = [] 101 | # type 1: directly detect timestamps in generated paragraph to process multi-lines cases like: 102 | timestamps = extract_time(paras) 103 | 104 | return timestamps 105 | 106 | 107 | def format_tvg(datas): 108 | fmt_datas = {} 109 | cnt = 0 110 | for i, jterm in enumerate(datas): 111 | vid = jterm["vname"] 112 | query = jterm["query"] 113 | gcap = jterm["generated_cap"] 114 | qid = int(jterm["id"]) 115 | timestamps = format_tvg_output(gcap) 116 | if len(timestamps) == 0: 117 | cnt += 1 118 | print(vid, query + "\n", gcap, "\n") 119 | fmt_datas[qid] = {"timestamp": timestamps, "query": query, "vid": vid} 120 | print(f'parse failed number: {cnt}') 121 | return fmt_datas 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--inpath', default='/home/yaolinli/code/Ask-Anything/video_chat/output/eval_7b_tvg_charades/charades_test_f8_result.json') 127 | parser.add_argument('--outpath', default='') 128 | args = parser.parse_args() 129 | 130 | datas = read_json(args.inpath) 131 | # example in output file 132 | # { 133 | # "query_idx": 134 | # { 135 | # "timestamp": [47.0, 60.0], 136 | # "query": "a person is shown tying a plant into a bun.", 137 | # "vid": "xHr8X2Wpmno.mp4" 138 | # }, 139 | # ... 140 | # } 141 | fmt_datas = {} 142 | cnt = 0 143 | for i, jterm in enumerate(datas): 144 | vid = jterm["vname"] 145 | query = jterm["query"] 146 | gcap = jterm["generated_cap"] 147 | qid = jterm["id"] 148 | timestamps = format_tvg_output(gcap) 149 | if len(timestamps) == 0: 150 | cnt += 1 151 | print(vid, query+"\n", gcap+"\n") 152 | # pdb.set_trace() 153 | else: 154 | # print(gcap) 155 | # print(timestamps) 156 | pass 157 | fmt_datas[qid] = {"timestamp": timestamps, "query": query, "vid": vid} 158 | 159 | print(f'parse failed number: {cnt}') 160 | split = args.inpath.split('/')[-1].split('_')[0] 161 | out_file = args.inpath.split('/')[-2] 162 | out_path = f'{out_file}_{split}.json' 163 | if args.outpath != '': 164 | Path(args.outpath).mkdir(parents=True, exist_ok=True) 165 | out_path = os.path.join(args.outpath, out_path) 166 | write_json(os.path.join(os.getcwd(), out_path), fmt_datas) 167 | else: 168 | infile = args.inpath.split('/')[-1] 169 | outfile = "fmt_" + infile 170 | out_path = args.inpath.replace(infile, outfile) 171 | write_json(out_path, fmt_datas) -------------------------------------------------------------------------------- /eval/get_grounding_result.sh: -------------------------------------------------------------------------------- 1 | MODEL_DIR="/path_to_the_timesuite_root_folder/download/parameters" 2 | 3 | 4 | 5 | TASK='tvg' 6 | SPLIT='test' 7 | DATASET='charades' 8 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation' 9 | GT_FILE="${ANNO_DIR}/${SPLIT}.caption_coco_format.json" 10 | 11 | sleep 1 12 | MODEL_PTH="timesuite" 13 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}" 14 | PRED_FILE="${RESULT_DIR}/fmt_${DATASET}_${SPLIT}_clipF8_result.json" 15 | 16 | if [ -f "${PRED_FILE}" ]; then 17 | cd metrics/${TASK} 18 | python eval_${TASK}.py \ 19 | --gt_file ${GT_FILE} \ 20 | --pred_file ${PRED_FILE} \ 21 | 2>&1 | tee ${RESULT_DIR}/grounding_result.txt 22 | cd ../.. 23 | else 24 | echo "File ${PRED_FILE} not exists. Skipping eval operation." 25 | fi 26 | 27 | 28 | TASK='vhd' 29 | SPLIT='val' 30 | DATASET='qvhighlights' 31 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw' 32 | GT_FILE="${ANNO_DIR}/highlight_${SPLIT}_release.jsonl" 33 | 34 | sleep 1 35 | MODEL_PTH="timesuite" 36 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}" 37 | PRED_FILE="${RESULT_DIR}/fmt_${DATASET}_${SPLIT}_clipF8_result.json" 38 | 39 | if [ -f "${PRED_FILE}" ]; then 40 | cd metrics/${TASK} 41 | python eval_${TASK}.py \ 42 | --gt_file ${GT_FILE} \ 43 | --pred_file ${PRED_FILE} \ 44 | 2>&1 | tee ${RESULT_DIR}/grounding_result.txt 45 | cd ../.. 46 | else 47 | echo "File ${PRED_FILE} not exists. Skipping eval operation." 48 | fi -------------------------------------------------------------------------------- /eval/test_grounding.sh: -------------------------------------------------------------------------------- 1 | MODEL_DIR="/path_to_the_timesuite_root_folder/download/parameters" 2 | MODEL_TYPE="VideoChat2_it4_mistral_LinearProAda" 3 | 4 | TASK='vhd' 5 | SPLIT='val' 6 | DATASET='qvhighlights' 7 | PROMPT_FILE="/path_to_the_timesuite_root_folder/prompts/vhd_description_zeroshot_new.txt" 8 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw' 9 | GT_FILE="${ANNO_DIR}/highlight_${SPLIT}_release.jsonl" 10 | VIDEO_DIR='pnorm2:s3://qvhighlight/videos' 11 | 12 | sleep 2 13 | MODEL_PTH="timesuite" 14 | PTH_DIR="${MODEL_DIR}/${MODEL_PTH}.pth" 15 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}_new" 16 | 17 | if [ ! -d "${RESULT_DIR}" ] && [ -f "${PTH_DIR}" ]; then 18 | mkdir -p ${RESULT_DIR} 19 | echo "Created directory: ${RESULT_DIR}" 20 | python3 /path_to_the_timesuite_root_folder/eval/eval_infer.py \ 21 | --task=${TASK} \ 22 | --split=${SPLIT} \ 23 | --dataset=${DATASET} \ 24 | --prompt_file=${PROMPT_FILE} \ 25 | --anno_path=${ANNO_DIR} \ 26 | --video_path=${VIDEO_DIR} \ 27 | --model_dir=${MODEL_DIR} \ 28 | --model_pth=${MODEL_PTH} \ 29 | --model_type=${MODEL_TYPE} \ 30 | --output_dir=${RESULT_DIR} \ 31 | 2>&1 | tee ${RESULT_DIR}/eval_inference.log 32 | else 33 | echo "Directory ${RESULT_DIR} already exists or ${PTH_DIR} not exists. Skipping eval operation." 34 | fi 35 | 36 | TASK='tvg' 37 | SPLIT='test' 38 | DATASET='charades' 39 | PROMPT_FILE="/path_to_the_timesuite_root_folder/prompts/tvg_description_zeroshot.txt" 40 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation' 41 | GT_FILE="${ANNO_DIR}/${SPLIT}.caption_coco_format.json" 42 | VIDEO_DIR="pnorm2zxy:s3://zengxiangyu/Charades/" 43 | 44 | 45 | sleep 2 46 | MODEL_PTH="timesuite" 47 | PTH_DIR="${MODEL_DIR}/${MODEL_PTH}.pth" 48 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}" 49 | 50 | 51 | if [ ! -d "${RESULT_DIR}" ] && [ -f "${PTH_DIR}" ]; then 52 | mkdir -p ${RESULT_DIR} 53 | echo "Created directory: ${RESULT_DIR}" 54 | python3 /path_to_the_timesuite_root_folder/eval/eval_infer.py \ 55 | --task=${TASK} \ 56 | --split=${SPLIT} \ 57 | --dataset=${DATASET} \ 58 | --prompt_file=${PROMPT_FILE} \ 59 | --anno_path=${ANNO_DIR} \ 60 | --video_path=${VIDEO_DIR} \ 61 | --model_dir=${MODEL_DIR} \ 62 | --model_pth=${MODEL_PTH} \ 63 | --model_type=${MODEL_TYPE} \ 64 | --output_dir=${RESULT_DIR} \ 65 | 2>&1 | tee ${RESULT_DIR}/eval_inference.log 66 | else 67 | echo "Directory ${RESULT_DIR} already exists or ${PTH_DIR} not exists. Skipping eval operation." 68 | fi -------------------------------------------------------------------------------- /eval/validate_egoschema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import json 4 | 5 | def send_post_request(json_file): 6 | """ 7 | Sends a POST request to the specified URL with the given JSON file. 8 | 9 | Parameters: 10 | - json_file (str): Path to the JSON file to be used in the request body. 11 | 12 | Returns: 13 | - Response object containing server's response. 14 | """ 15 | 16 | url = "https://validation-server.onrender.com/api/upload/" 17 | headers = { 18 | "Content-Type": "application/json" 19 | } 20 | 21 | with open(json_file, 'r') as f: 22 | data = json.load(f) 23 | 24 | response = requests.post(url, headers=headers, json=data) 25 | 26 | return response 27 | 28 | def main(): 29 | """ 30 | Main function that parses command-line arguments and sends a POST request. 31 | """ 32 | 33 | parser = argparse.ArgumentParser(description="Send a POST request with a JSON file.") 34 | parser.add_argument("--f", required=True, help="Path to the JSON file to be sent with the request.") 35 | 36 | args = parser.parse_args() 37 | 38 | response = send_post_request(args.f) 39 | print(f"Response Status Code: {response.status_code}") 40 | print(f"Response Content:\n{response.text}") 41 | 42 | if __name__ == "__main__": 43 | main() -------------------------------------------------------------------------------- /images/123: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/abstract.png -------------------------------------------------------------------------------- /images/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/data.png -------------------------------------------------------------------------------- /images/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/structure.png -------------------------------------------------------------------------------- /metrics/README.md: -------------------------------------------------------------------------------- 1 | # Calculate Metrics 2 | 3 | #### Dense video captioning task 4 | 5 | ``` 6 | cd dvc/ 7 | ``` 8 | 9 | Set the `pred_file` and `gt_file` and run: 10 | 11 | ``` 12 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file 13 | ``` 14 | 15 | Evaluation for paragraph captioning, run: 16 | 17 | ``` 18 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file --paragraph 19 | ``` 20 | 21 | #### Temporal video grounding task 22 | 23 | ``` 24 | cd tvg/ 25 | ``` 26 | 27 | Set the `pred_file` and `gt_file` and run: 28 | 29 | ``` 30 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file 31 | ``` 32 | 33 | 34 | #### Video highlight detection task 35 | 36 | ``` 37 | cd vhd/ 38 | ``` 39 | 40 | Set the `pred_file` and `gt_file` and run: 41 | 42 | ``` 43 | python eval_highlights.py --pred_file $pred_file --gt_file $gt_file 44 | ``` -------------------------------------------------------------------------------- /metrics/dvc/eval_dvc.sh: -------------------------------------------------------------------------------- 1 | # example in pred file 2 | # { 3 | # "xHr8X2Wpmno.mp4": [ 4 | # { 5 | # "timestamp": [47.0, 60.0], 6 | # "caption": "a person is shown tying a plant into a bun and putting the bun into a pink jar." 7 | # }, 8 | # ... 9 | # ] 10 | # ... 11 | # } 12 | 13 | pred_file='/home/yaolinli/code/Ask-Anything/video_chat/results/eval_7b_instruct111k_timeit-vally-llava-66k_bz8_f8_epoch3_youcook.json' 14 | gt_file='/home/yaolinli/dataset/YouCook2_asr_denseCap/val.caption_coco_format.json' 15 | 16 | # pred_file='/home/yaolinli/code/Ask-Anything/video_chat/results/eval_7b_instruct11.2k_youcook2-anet_bz8_f8_epoch3_anet.json' 17 | # gt_file='/home/yaolinli/dataset/ActivityNet_asr_denseCap/val.caption_coco_format.json' 18 | 19 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file --analyze 20 | 21 | -------------------------------------------------------------------------------- /metrics/dvc/metrics/README.md: -------------------------------------------------------------------------------- 1 | # captioning-metrics 2 | 3 | This is a fork from https://github.com/salaniz/pycocoevalcap, with several functionalities that make it easier to run for dense video captioning, e.g. not closing the METEOR jar at every call but only once per evaluation. 4 | 5 | To use it, you may download https://github.com/salaniz/pycocoevalcap/tree/master/meteor/data and put in under the data/ folder. 6 | -------------------------------------------------------------------------------- /metrics/dvc/metrics/cider.py: -------------------------------------------------------------------------------- 1 | """Computes the CIDEr (Consensus-Based Image Description Evaluation) Metric.""" 2 | 3 | # Filename: cider.py 4 | # 5 | # Description: Describes the class to compute the CIDEr 6 | # (Consensus-Based Image Description Evaluation) Metric 7 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 8 | # 9 | # Creation Date: Sun Feb 8 14:16:54 2015 10 | # 11 | # Authors: Ramakrishna Vedantam 12 | # and Tsung-Yi Lin 13 | 14 | from .cider_scorer import CiderScorer 15 | 16 | 17 | class Cider: 18 | """Main Class to compute the CIDEr metric.""" 19 | 20 | def __init__(self, n=4, sigma=6.0): 21 | # set cider to sum over 1 to 4-grams 22 | self._n = n 23 | # set the standard deviation parameter for gaussian penalty 24 | self._sigma = sigma 25 | 26 | def compute_score(self, gts, res): 27 | """Main function to compute CIDEr score. 28 | 29 | Args: 30 | gts: dictionary with key and value 32 | res: dictionary with key and value 33 | 34 | Returns: 35 | Computed CIDEr float score for the corpus. 36 | """ 37 | 38 | assert sorted(gts.keys()) == sorted(res.keys()) 39 | imgids = list(gts.keys()) 40 | 41 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 42 | 43 | # Sort the IDs to be able to have control over the order 44 | # of the individual scores. 45 | for iid in sorted(imgids): 46 | hypo = res[iid] 47 | ref = gts[iid] 48 | 49 | # Sanity check. 50 | assert isinstance(hypo, list) 51 | assert len(hypo) == 1 52 | assert isinstance(ref, list) 53 | assert ref 54 | 55 | cider_scorer += (hypo[0], ref) 56 | 57 | (score, scores) = cider_scorer.compute_score() 58 | 59 | return score, scores 60 | 61 | def method(self): 62 | return "CIDEr" 63 | -------------------------------------------------------------------------------- /metrics/dvc/metrics/meteor.py: -------------------------------------------------------------------------------- 1 | """Python wrapper for METEOR implementation.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import subprocess 9 | import threading 10 | 11 | import numpy as np 12 | import six 13 | 14 | 15 | class Meteor(object): 16 | """Meteor scorer.""" 17 | 18 | def __init__(self, 19 | meteor_jar_path=None, 20 | java_jre_path=None, 21 | jdk_java_options=None): 22 | if java_jre_path: 23 | self.java_bin = java_jre_path 24 | elif 'JRE_BIN_JAVA' in os.environ: 25 | self.java_bin = os.environ['JRE_BIN_JAVA'] 26 | else: 27 | self.java_bin = 'java' 28 | 29 | if meteor_jar_path: 30 | meteor_jar = meteor_jar_path 31 | else: 32 | meteor_jar = os.path.join( 33 | './metrics', 'meteor-1.5.jar' 34 | ) 35 | 36 | assert os.path.exists(meteor_jar), meteor_jar 37 | 38 | jdk_java_options = jdk_java_options or ['-Xmx2G'] 39 | meteor_cmd = [ 40 | self.java_bin, '-jar', '-Xmx2G', meteor_jar, '-', '-', '-stdio', 41 | '-l', 'en', '-norm' 42 | ] 43 | 44 | self.meteor_p = subprocess.Popen( 45 | meteor_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) 46 | self.lock = threading.Lock() 47 | 48 | def compute_score(self, gts, res): 49 | """Compute METEOR scores.""" 50 | with self.lock: 51 | assert sorted(gts.keys()) == sorted(res.keys()) 52 | img_ids = sorted(gts.keys()) 53 | scores = [] 54 | 55 | eval_line = 'EVAL ||| ' 56 | stats = self._stat(img_ids, res, gts) 57 | eval_line += ' ||| '.join(stats) 58 | self.meteor_p.stdin.write(six.ensure_binary(eval_line + '\n')) 59 | self.meteor_p.stdin.flush() 60 | scores = [float(six.ensure_str(self.meteor_p.stdout.readline())) 61 | for _ in img_ids] 62 | # get the aggregated value 63 | score = self.meteor_p.stdout.readline() 64 | # do not close the file inside this function to keep it open for full eval 65 | return float(score), np.asarray(scores) 66 | 67 | def method(self): 68 | return 'METEOR' 69 | 70 | def _stat(self, img_ids, hypothesis_str, reference_list): # pylint: disable=missing-function-docstring 71 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 72 | stat_lines = [] 73 | for i in img_ids: 74 | assert len(hypothesis_str[i]) == 1 75 | hypo = hypothesis_str[i][0].replace('|||', '').replace(' ', ' ') 76 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list[i]), 77 | hypo)) 78 | 79 | self.meteor_p.stdin.write(six.ensure_binary(score_line + '\n')) 80 | self.meteor_p.stdin.flush() 81 | stat_lines.append(six.ensure_str(self.meteor_p.stdout.readline()).strip()) 82 | return stat_lines 83 | -------------------------------------------------------------------------------- /metrics/dvc/metrics/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | """PTBTokenizer.""" 2 | 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | # pylint: disable=g-inconsistent-quotes 16 | 17 | # punctuations to be removed from the sentences 18 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", 19 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 20 | 21 | 22 | class PTBTokenizer: 23 | """Python wrapper of Stanford PTBTokenizer.""" 24 | 25 | def __init__(self, 26 | ptbtokenizer_jar_path=None, 27 | java_jre_path=None): 28 | if java_jre_path: 29 | self.java_bin = java_jre_path 30 | elif 'JRE_BIN_JAVA' in os.environ: 31 | self.java_bin = os.environ['JRE_BIN_JAVA'] 32 | else: 33 | self.java_bin = 'java' 34 | 35 | if ptbtokenizer_jar_path: 36 | self.ptbtokenizer_jar = ptbtokenizer_jar_path 37 | else: 38 | self.ptbtokenizer_jar = os.path.join( 39 | "./metrics", 40 | "stanford-corenlp-3.4.1.jar", 41 | ) 42 | 43 | assert os.path.exists(self.ptbtokenizer_jar), self.ptbtokenizer_jar 44 | 45 | def tokenize(self, captions_for_image): 46 | """Tokenization.""" 47 | 48 | cmd = [self.java_bin, '-cp', self.ptbtokenizer_jar, 49 | 'edu.stanford.nlp.process.PTBTokenizer', 50 | '-preserveLines', '-lowerCase'] 51 | 52 | # ====================================================== 53 | # prepare data for PTB Tokenizer 54 | # ====================================================== 55 | final_tokenized_captions_for_image = {} 56 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] # pylint: disable=g-complex-comprehension 57 | sentences = "\n".join( 58 | [ # pylint: disable=g-complex-comprehension 59 | c["caption"].replace("\n", " ") 60 | for k, v in captions_for_image.items() 61 | for c in v 62 | ] 63 | ) 64 | 65 | # ====================================================== 66 | # save sentences to temporary file 67 | # ====================================================== 68 | fd, tmpfname = tempfile.mkstemp() 69 | with os.fdopen(fd, 'w') as f: 70 | f.write(sentences) 71 | 72 | # ====================================================== 73 | # tokenize sentence 74 | # ====================================================== 75 | cmd.append(tmpfname) 76 | p_tokenizer = subprocess.Popen(cmd, stdout=subprocess.PIPE) 77 | token_lines = p_tokenizer.communicate(input=sentences.rstrip().encode())[0] 78 | token_lines = token_lines.decode() 79 | lines = token_lines.split('\n') 80 | # remove temp file 81 | os.remove(tmpfname) 82 | 83 | # ====================================================== 84 | # create dictionary for tokenized captions 85 | # ====================================================== 86 | for k, line in zip(image_id, lines): 87 | if k not in final_tokenized_captions_for_image: 88 | final_tokenized_captions_for_image[k] = [] 89 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') 90 | if w not in PUNCTUATIONS]) 91 | final_tokenized_captions_for_image[k].append(tokenized_caption) 92 | 93 | return final_tokenized_captions_for_image 94 | -------------------------------------------------------------------------------- /metrics/tvg/cd: -------------------------------------------------------------------------------- 1 | GT File: /path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation/test.caption_coco_format.json 2 | # pred video timestamps 3720; # gt video timestamps 3720 3 | IOU 0.3: 56.29032258064516 4 | IOU 0.5: 34.56989247311828 5 | IOU 0.7: 15.887096774193548 6 | -------------------------------------------------------------------------------- /metrics/tvg/eval_tvg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import sys 5 | import argparse 6 | import pdb 7 | 8 | def read_json(path): 9 | with open(path, "r") as fin: 10 | datas = json.load(fin) 11 | return datas 12 | 13 | 14 | def iou(A, B): 15 | max0 = max((A[0]), (B[0])) 16 | min0 = min((A[0]), (B[0])) 17 | max1 = max((A[1]), (B[1])) 18 | min1 = min((A[1]), (B[1])) 19 | return max(min1 - max0, 0) / (max1 - min0) 20 | 21 | 22 | def toSec(timeStr): 23 | t = time.strptime(timeStr, "%H:%M:%S") 24 | return t.tm_hour * 3600 + t.tm_min * 60 + t.tm_sec 25 | 26 | def captiondata_modify(steps): 27 | modify_data = {} 28 | for i, step in enumerate(steps[0]): 29 | for key in step["step"].keys(): 30 | name = step["step"][key]["query_idx"] 31 | modify_data[name] = [[step['step'][key]["startime"], step['step'][key]["endtime"]]] 32 | 33 | return modify_data 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--pred_file", type=str, default="/home/yaolinli/code/Ask-Anything/video_chat/output/eval_7b_tvg_charades/fmt_charades_test_f8_result.json") 38 | parser.add_argument('--gt_file', type=str, default='/home/yaolinli/dataset/Charades/charades_annotation/test.caption_coco_format.json') 39 | parser.add_argument('--sample', action='store_true', default=False) 40 | args = parser.parse_args() 41 | ''' 42 | { 43 | "query_idx": [start_time, end_time], 44 | ... 45 | } 46 | ''' 47 | print("GT File:", args.gt_file) 48 | answer = read_json(args.gt_file) 49 | answer = answer["annotations"] 50 | gt_timestamps = {} 51 | for jterm in answer: 52 | gt_timestamps[jterm["id"]] = jterm["timestamp"] 53 | 54 | submission = read_json(args.pred_file) 55 | pred_timestamps = {} 56 | for qid, jterm in submission.items(): 57 | pred_timestamps[int(qid)] = jterm["timestamp"] 58 | 59 | if args.sample: 60 | new = {} 61 | for qid in pred_timestamps.keys(): 62 | new[qid] = gt_timestamps[qid] 63 | gt_timestamps = new 64 | num = len(gt_timestamps) 65 | print(f"# pred video timestamps {len(pred_timestamps)}; # gt video timestamps {len(gt_timestamps)}") 66 | assert len(gt_timestamps) == len(pred_timestamps) 67 | Result = {0.3:0, 0.5:0, 0.7:0} 68 | for c_iou in [0.3, 0.5, 0.7]: 69 | for key in gt_timestamps.keys(): 70 | if len(pred_timestamps[key]) < 1: 71 | continue 72 | if(iou(gt_timestamps[key], pred_timestamps[key][0]) >= c_iou): 73 | Result[c_iou] = Result[c_iou] + 1 74 | print("IOU 0.3: {0}\nIOU 0.5: {1}\nIOU 0.7: {2}".format(Result[0.3]*100/num, Result[0.5]*100/num, Result[0.7]*100/num)) -------------------------------------------------------------------------------- /metrics/tvg/eval_tvg.sh: -------------------------------------------------------------------------------- 1 | # default: eval all instances according to the gt_file 2 | python eval_tvg.py --pred_file your_pred_file --gt_file your_gt_file 3 | 4 | 5 | # use --sample 6 | # eval sampled instances according to the pred_file 7 | # e.g. # gt examples:500, # pred examples:50 -> # eval examples:50 8 | python eval_tvg.py --pred_file your_pred_file --gt_file your_gt_file --sample -------------------------------------------------------------------------------- /metrics/vhd/cd: -------------------------------------------------------------------------------- 1 | Calculating highlight scores with min score 2 (Fair) 2 | Time cost 0.35 seconds 3 | Calculating highlight scores with min score 3 (Good) 4 | Time cost 0.32 seconds 5 | Calculating highlight scores with min score 4 (VeryGood) 6 | Time cost 0.26 seconds 7 | HL-min-VeryGood-mAP: 16.0 8 | HL-min-VeryGood-Hit1: 36.09 9 | -------------------------------------------------------------------------------- /metrics/vhd/eval_highlights.sh: -------------------------------------------------------------------------------- 1 | your_gt_file=example_gt_file.json 2 | your_pred_file=example_pred_file.json 3 | python eval_highlights.py --pred_file $your_pred_file --gt_file $your_gt_file 4 | # --save_path your_result_save_path 5 | -------------------------------------------------------------------------------- /metrics/vhd/metrics.json: -------------------------------------------------------------------------------- 1 | { 2 | "brief": { 3 | "HL-min-Fair-mAP": 53.36, 4 | "HL-min-Fair-Hit1": 67.67, 5 | "HL-min-Good-mAP": 43.11, 6 | "HL-min-Good-Hit1": 64.66, 7 | "HL-min-VeryGood-mAP": 25.07, 8 | "HL-min-VeryGood-Hit1": 51.26 9 | }, 10 | "HL-min-Fair": { 11 | "HL-mAP": 53.36, 12 | "HL-Hit1": 67.67 13 | }, 14 | "HL-min-Good": { 15 | "HL-mAP": 43.11, 16 | "HL-Hit1": 64.66 17 | }, 18 | "HL-min-VeryGood": { 19 | "HL-mAP": 25.07, 20 | "HL-Hit1": 51.26 21 | } 22 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .videochat2_qformer import VideoChat2_qformer 2 | 3 | from .videochat_mistra.videochat2_it4_mistral_LinearP import VideoChat2_it4_mistral_LinearP 4 | from .videochat_mistra.videochat2_it4_mistral_LinearProAda import VideoChat2_it4_mistral_LinearProAda 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /models/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/bert/__init__.py -------------------------------------------------------------------------------- /models/bert/builder.py: -------------------------------------------------------------------------------- 1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel 2 | 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | def build_bert(model_config, pretrain, checkpoint): 7 | """build text encoder. 8 | 9 | Args: 10 | model_config (dict): model config. 11 | pretrain (bool): Whether to do pretrain or finetuning. 12 | checkpoint (bool): whether to do gradient_checkpointing. 13 | 14 | Returns: TODO 15 | 16 | """ 17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 18 | bert_config.encoder_width = model_config.vision_encoder.d_model 19 | bert_config.gradient_checkpointing = checkpoint 20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer 21 | 22 | if not model_config.multimodal.enable: 23 | bert_config.fusion_layer = bert_config.num_hidden_layers 24 | 25 | if pretrain: 26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained( 27 | model_config.text_encoder.pretrained, 28 | config=bert_config, 29 | output_loading_info=True, 30 | ) 31 | else: 32 | text_encoder, loading_info = BertModel.from_pretrained( 33 | model_config.text_encoder.pretrained, 34 | config=bert_config, 35 | add_pooling_layer=False, 36 | output_loading_info=True, 37 | ) 38 | 39 | return text_encoder 40 | 41 | 42 | def build_bert_decoder(model_config, checkpoint): 43 | """build text decoder the same as the multimodal encoder. 44 | 45 | Args: 46 | model_config (dict): model config. 47 | pretrain (bool): Whether to do pretrain or finetuning. 48 | checkpoint (bool): whether to do gradient_checkpointing. 49 | 50 | Returns: TODO 51 | 52 | """ 53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 54 | bert_config.encoder_width = model_config.vision_encoder.d_model 55 | bert_config.gradient_checkpointing = checkpoint 56 | 57 | bert_config.fusion_layer = 0 58 | bert_config.num_hidden_layers = ( 59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer 60 | ) 61 | 62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained( 63 | model_config.text_encoder.pretrained, 64 | config=bert_config, 65 | output_loading_info=True, 66 | ) 67 | 68 | return text_decoder 69 | -------------------------------------------------------------------------------- /models/blip2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/blip2/__init__.py -------------------------------------------------------------------------------- /models/blip2/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import os 9 | import logging 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .Qformer import BertConfig, BertLMHeadModel 15 | from .vit import build_vit 16 | from transformers import BertTokenizer 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Blip2Base(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @classmethod 26 | def init_tokenizer(cls, truncation_side="right"): 27 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True) 28 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 29 | return tokenizer 30 | 31 | @property 32 | def device(self): 33 | return list(self.parameters())[0].device 34 | 35 | def maybe_autocast(self, dtype=torch.float16): 36 | # if on cpu, don't use autocast 37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 38 | enable_autocast = self.device != torch.device("cpu") 39 | 40 | if enable_autocast: 41 | return torch.cuda.amp.autocast(dtype=dtype) 42 | else: 43 | return contextlib.nullcontext() 44 | 45 | @classmethod 46 | def init_Qformer( 47 | cls, 48 | num_query_token, vision_width, 49 | qformer_hidden_dropout_prob=0.1, 50 | qformer_attention_probs_dropout_prob=0.1, 51 | qformer_drop_path_rate=0., 52 | ): 53 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True) 54 | encoder_config.encoder_width = vision_width 55 | # insert cross-attention layer every other block 56 | encoder_config.add_cross_attention = True 57 | encoder_config.cross_attention_freq = 2 58 | encoder_config.query_length = num_query_token 59 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 60 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 61 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)] 62 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 63 | logger.info(encoder_config) 64 | Qformer = BertLMHeadModel(config=encoder_config) 65 | query_tokens = nn.Parameter( 66 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 67 | ) 68 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 69 | return Qformer, query_tokens 70 | 71 | @classmethod 72 | def init_vision_encoder_umt(self, config): 73 | """build vision encoder 74 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. 75 | 76 | """ 77 | vision_encoder = build_vit(config) 78 | 79 | if config.vision_encoder.vit_add_ln: 80 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12) 81 | else: 82 | vision_layernorm = nn.Identity() 83 | 84 | return vision_encoder, vision_layernorm 85 | 86 | 87 | def disabled_train(self, mode=True): 88 | """Overwrite model.train with this function to make sure train/eval mode 89 | does not change anymore.""" 90 | return self 91 | 92 | 93 | class LayerNorm(nn.LayerNorm): 94 | """Subclass torch's LayerNorm to handle fp16.""" 95 | 96 | def forward(self, x: torch.Tensor): 97 | orig_type = x.dtype 98 | ret = super().forward(x.type(torch.float32)) 99 | return ret.type(orig_type) 100 | -------------------------------------------------------------------------------- /models/blip2/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | 6 | from .Qformer import BertConfig, BertLMHeadModel 7 | from models.utils import load_temp_embed_with_mismatch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def build_qformer(num_query_token, vision_width, 13 | qformer_hidden_dropout_prob=0.1, 14 | qformer_attention_probs_dropout_prob=0.1, 15 | drop_path_rate=0., 16 | ): 17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True) 18 | encoder_config.encoder_width = vision_width 19 | # insert cross-attention layer every other block 20 | encoder_config.add_cross_attention = True 21 | encoder_config.cross_attention_freq = 2 22 | encoder_config.query_length = num_query_token 23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)] 26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 27 | logger.info(encoder_config) 28 | Qformer = BertLMHeadModel.from_pretrained( 29 | "bert-base-uncased", config=encoder_config, local_files_only=True 30 | ) 31 | query_tokens = nn.Parameter( 32 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 33 | ) 34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 35 | return Qformer, query_tokens 36 | 37 | def interpolate_pos_embed_blip(state_dict, new_model): 38 | if "vision_temp_embed" in state_dict: 39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"] 40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch( 41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False 42 | ) 43 | return state_dict 44 | -------------------------------------------------------------------------------- /models/videochat_mistra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/videochat_mistra/__init__.py -------------------------------------------------------------------------------- /prompts/alignment_image.txt: -------------------------------------------------------------------------------- 1 | Describe this image in detail. 2 | Take a look at this image and describe what you notice. 3 | Please provide a detailed description of the picture. 4 | Could you describe the contents of this image for me? -------------------------------------------------------------------------------- /prompts/concise_description.txt: -------------------------------------------------------------------------------- 1 | Describe the following video concisely. 2 | Provide a brief description of the given video clip. 3 | Offer a succinct explanation of the footage presented. 4 | Summarize the visual content of the following video. 5 | Give a short and clear explanation of the subsequent video clip. 6 | Share a concise interpretation of the video provided. 7 | Present a compact description of the clip's key features. 8 | Relay a brief, clear account of the video shown. 9 | Render a clear and concise summary of the video below. 10 | Write a terse but informative summary of the following video clip. 11 | Create a compact narrative representing the video presented. -------------------------------------------------------------------------------- /prompts/concise_image_description.txt: -------------------------------------------------------------------------------- 1 | Describe the following image concisely. 2 | Provide a brief description of the given image. 3 | Offer a succinct explanation of the picture presented. 4 | Summarize the visual content of the following image. 5 | Give a short and clear explanation of the subsequent image. 6 | Share a concise interpretation of the image provided. 7 | Present a compact description of the photo's key features. 8 | Relay a brief, clear account of the picture shown. 9 | Render a clear and concise summary of the photo below. 10 | Write a terse but informative summary of the following picture. 11 | Create a compact narrative representing the image presented. -------------------------------------------------------------------------------- /prompts/dvc_description.txt: -------------------------------------------------------------------------------- 1 | Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. A specific example is : ' 90 - 102 seconds, spread margarine on two slices of white bread in the video' . -------------------------------------------------------------------------------- /prompts/dvc_description_with_asr.txt: -------------------------------------------------------------------------------- 1 | You are able to understand the visual content that the user provides. Localize a series of activity events in the video with the aid of transcribed speech, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. An specific example is : ' 90 - 102 seconds, spread margarine on two slices of white bread in the video' . -------------------------------------------------------------------------------- /prompts/dvc_description_zeroshot.txt: -------------------------------------------------------------------------------- 1 | You are given a cooking video from the YouCook2 dataset. Please watch the video and extract a maximum of 10 significant cooking steps. For each step, determine the starting and ending times and provide a concise description. The format should be: 'start time - end time, brief step description'. For example, ' 90 - 102 seconds, spread margarine on two slices of white bread'. -------------------------------------------------------------------------------- /prompts/dvc_post_check.txt: -------------------------------------------------------------------------------- 1 | Be careful about your output format, you should provide start time and end time for each step. The format should be: 'start time - end time, brief step description', for example, ' 90 - 102 seconds, spread margarine on two slices of white bread'. -------------------------------------------------------------------------------- /prompts/tvg_description.txt: -------------------------------------------------------------------------------- 1 | Localize the visual content described by the given textual query {} in the video, and output the start and end timestamps in seconds. The output format of the predicted timestamp should be like: 'start - end seconds'. A specific example is : 20.8 - 30.0 seconds' . -------------------------------------------------------------------------------- /prompts/tvg_description_zeroshot.txt: -------------------------------------------------------------------------------- 1 | You are given a video from the {} dataset. Please find the visual event described by a sentence in the video, determining its starting and ending times. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds. Now I will give you the textual sentence: {}. Please return its start time and end time. -------------------------------------------------------------------------------- /prompts/tvg_post_check.txt: -------------------------------------------------------------------------------- 1 | Be careful about your output format, you should provide start time and end time for the visual event. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds. -------------------------------------------------------------------------------- /prompts/vhd_description.txt: -------------------------------------------------------------------------------- 1 | You are given a video from the {} dataset. Please find the highlight moments in the video described by a sentence query, determining the highlight moment's timestamp and its saliency score. The output format should be like: 'There are 10 highlight moments in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: {}. Please return the query-based highlight moments. -------------------------------------------------------------------------------- /prompts/vhd_description_zeroshot.txt: -------------------------------------------------------------------------------- 1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 seconds. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores. -------------------------------------------------------------------------------- /prompts/vhd_description_zeroshot_new.txt: -------------------------------------------------------------------------------- 1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 52, 54, 56, 58 seconds. Their saliency scores are 3.0, 3.0, 3.0, 3.0'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores. -------------------------------------------------------------------------------- /prompts/vhd_description_zeroshot_post.txt: -------------------------------------------------------------------------------- 1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 60 - 80 seconds. Their saliency scores are 3.0'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores. -------------------------------------------------------------------------------- /prompts/vhd_post_check.txt: -------------------------------------------------------------------------------- 1 | Be careful about your output format, you should provide timestamp and saliency score for the highlight moment. The output format should be like: 'There are 10 highlight moments in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.28.0 3 | addict==2.4.0 4 | aiofiles==23.2.1 5 | aiohttp==3.9.3 6 | aiosignal==1.3.1 7 | aliyun-python-sdk-core==2.15.0 8 | aliyun-python-sdk-kms==2.16.2 9 | altair==5.2.0 10 | annotated-types==0.6.0 11 | antlr4-python3-runtime==4.9.3 12 | anyio==4.6.0 13 | anykeystore==0.2 14 | apex==0.9.10.dev0 15 | appdirs==1.4.4 16 | argon2-cffi==23.1.0 17 | argon2-cffi-bindings==21.2.0 18 | arrow==1.3.0 19 | asttokens==2.4.1 20 | astunparse==1.6.3 21 | async-lru==2.0.4 22 | async-timeout==4.0.3 23 | attrs==24.2.0 24 | audioread==3.0.1 25 | av==10.0.0 26 | awscli==1.32.87 27 | Babel==2.14.0 28 | beautifulsoup4==4.12.3 29 | bitsandbytes==0.42.0 30 | bleach==6.1.0 31 | blis==0.7.11 32 | boto3==1.34.91 33 | botocore==1.34.91 34 | braceexpand==0.1.7 35 | catalogue==2.0.10 36 | certifi==2024.8.30 37 | cffi==1.17.1 38 | charset-normalizer==3.3.2 39 | click==8.1.7 40 | cloudpathlib==0.16.0 41 | colorama==0.4.4 42 | coloredlogs==15.0.1 43 | comm==0.2.2 44 | conda-pack==0.7.1 45 | confection==0.1.4 46 | contourpy==1.2.0 47 | crcmod==1.7 48 | cryptacular==1.6.2 49 | cryptography==42.0.5 50 | cycler==0.12.1 51 | cymem==2.0.8 52 | datasets==2.19.0 53 | debugpy==1.8.1 54 | decorator==5.1.1 55 | decord==0.6.0 56 | deepspeed==0.14.0 57 | defusedxml==0.7.1 58 | dill==0.3.8 59 | docker-pycreds==0.4.0 60 | docutils==0.16 61 | einops==0.6.1 62 | environs==11.0.0 63 | exceptiongroup==1.2.2 64 | executing==2.1.0 65 | fairscale==0.4.13 66 | fastapi==0.110.0 67 | fastjsonschema==2.19.1 68 | ffmpeg==1.4 69 | ffmpy==0.3.2 70 | filelock==3.13.1 71 | flash-attn==2.5.6 72 | flatbuffers==24.3.7 73 | fonttools==4.50.0 74 | fqdn==1.5.1 75 | frozenlist==1.4.1 76 | fsspec==2024.3.1 77 | ftfy==6.2.0 78 | fvcore==0.1.5.post20221221 79 | gast==0.5.4 80 | gitdb==4.0.11 81 | GitPython==3.1.42 82 | google-pasta==0.2.0 83 | gradio==3.35.0 84 | gradio_client==0.12.0 85 | greenlet==3.0.3 86 | grpcio==1.62.1 87 | h11==0.14.0 88 | h5py==3.10.0 89 | hjson==3.1.0 90 | httpcore==1.0.5 91 | httpx==0.27.2 92 | huggingface-hub==0.21.4 93 | humanfriendly==10.0 94 | humanize==4.9.0 95 | hupper==1.12.1 96 | idna==3.10 97 | imageio==2.27.0 98 | imageio-ffmpeg==0.4.9 99 | importlib_metadata==8.5.0 100 | importlib_resources==6.3.1 101 | iopath==0.1.10 102 | ipykernel==6.29.3 103 | ipython==8.18.1 104 | ipywidgets==8.1.2 105 | isoduration==20.11.0 106 | jedi==0.19.1 107 | Jinja2==3.1.4 108 | jmespath==0.10.0 109 | joblib==1.4.0 110 | json5==0.9.24 111 | jsonpointer==2.4 112 | jsonschema==4.23.0 113 | jsonschema-specifications==2023.12.1 114 | jupyter==1.0.0 115 | jupyter-console==6.6.3 116 | jupyter-events==0.10.0 117 | jupyter-lsp==2.2.4 118 | jupyter_client==8.6.1 119 | jupyter_core==5.7.2 120 | jupyter_server==2.13.0 121 | jupyter_server_terminals==0.5.3 122 | jupyterlab==4.1.5 123 | jupyterlab_pygments==0.3.0 124 | jupyterlab_server==2.25.4 125 | jupyterlab_widgets==3.0.10 126 | keras==3.1.1 127 | kiwisolver==1.4.5 128 | langcodes==3.3.0 129 | lazy_loader==0.4 130 | libclang==18.1.1 131 | librosa==0.10.1 132 | linkify-it-py==2.0.3 133 | llvmlite==0.42.0 134 | Markdown==3.6 135 | markdown-it-py==2.2.0 136 | MarkupSafe==2.1.5 137 | marshmallow==3.21.1 138 | matplotlib==3.8.3 139 | matplotlib-inline==0.1.7 140 | mdit-py-plugins==0.3.3 141 | mdurl==0.1.2 142 | mistune==3.0.2 143 | mkl-fft==1.3.1 144 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work 145 | mkl-service==2.4.0 146 | ml-dtypes==0.3.2 147 | mmcv==2.1.0 148 | mmengine==0.10.3 149 | model-index==0.1.11 150 | moviepy==1.0.3 151 | msgpack==1.0.8 152 | multidict==6.0.5 153 | multiprocess==0.70.16 154 | multiprocessing-logging==0.3.4 155 | murmurhash==1.0.10 156 | namex==0.0.7 157 | nbclient==0.10.0 158 | nbconvert==7.16.2 159 | nbformat==5.10.3 160 | nest-asyncio==1.6.0 161 | ninja==1.11.1.1 162 | notebook==7.1.2 163 | notebook_shim==0.2.4 164 | numba==0.59.1 165 | numpy==1.23.5 166 | oauthlib==3.2.2 167 | omegaconf==2.3.0 168 | opencv-python==4.7.0.72 169 | opendatalab==0.0.10 170 | openmim==0.3.9 171 | openxlab==0.0.36 172 | opt-einsum==3.3.0 173 | optree==0.10.0 174 | ordered-set==4.1.0 175 | orjson==3.9.15 176 | oss2==2.17.0 177 | overrides==7.7.0 178 | packaging==24.1 179 | pandas==1.5.3 180 | pandocfilters==1.5.1 181 | parso==0.8.4 182 | PasteDeploy==3.1.0 183 | pathtools==0.1.2 184 | pbkdf2==1.3 185 | peft==0.3.0 186 | pexpect==4.9.0 187 | Pillow==9.5.0 188 | plaster==1.1.2 189 | plaster-pastedeploy==1.0.1 190 | platformdirs==4.3.6 191 | pooch==1.8.1 192 | portalocker==2.8.2 193 | preshed==3.0.9 194 | proglog==0.1.10 195 | prometheus_client==0.20.0 196 | prompt_toolkit==3.0.47 197 | protobuf==4.25.3 198 | psutil==6.0.0 199 | ptyprocess==0.7.0 200 | pure_eval==0.2.3 201 | py-cpuinfo==9.0.0 202 | pyarrow==16.0.0 203 | pyarrow-hotfix==0.6 204 | pyasn1==0.6.0 205 | pycocoevalcap==1.2 206 | pycocotools==2.0.7 207 | pycparser==2.22 208 | pycryptodome==3.20.0 209 | pydantic==2.6.4 210 | pydantic_core==2.16.3 211 | pydub==0.25.1 212 | Pygments==2.18.0 213 | pynvml==11.5.0 214 | pyparsing==3.1.2 215 | pyramid==2.0.2 216 | pyramid-mailer==0.15.1 217 | pysubs2==1.7.3 218 | python-dateutil==2.9.0.post0 219 | python-dotenv==1.0.1 220 | python-json-logger==2.0.7 221 | python-multipart==0.0.9 222 | python3-openid==3.2.0 223 | pytube==15.0.0 224 | pytz==2023.4 225 | PyYAML==6.0.2 226 | pyzmq==25.1.2 227 | qtconsole==5.5.1 228 | QtPy==2.4.1 229 | redis==5.0.3 230 | referencing==0.35.1 231 | regex==2023.12.25 232 | repoze.sendmail==4.4.1 233 | requests==2.32.3 234 | requests-oauthlib==1.4.0 235 | rfc3339-validator==0.1.4 236 | rfc3986-validator==0.1.1 237 | rich==13.4.2 238 | rpds-py==0.20.0 239 | rsa==4.7.2 240 | s3transfer==0.10.1 241 | safetensors==0.4.2 242 | scikit-learn==1.4.1.post1 243 | scipy==1.10.1 244 | semantic-version==2.10.0 245 | Send2Trash==1.8.2 246 | sentencepiece==0.1.99 247 | sentry-sdk==1.42.0 248 | setproctitle==1.3.3 249 | six==1.16.0 250 | smart-open==6.4.0 251 | smmap==5.0.1 252 | sniffio==1.3.1 253 | soundfile==0.12.1 254 | soupsieve==2.5 255 | soxr==0.3.7 256 | spacy==3.7.4 257 | spacy-legacy==3.0.12 258 | spacy-loggers==1.0.5 259 | SQLAlchemy==2.0.28 260 | srsly==2.4.8 261 | stack-data==0.6.3 262 | starlette==0.36.3 263 | tabulate==0.9.0 264 | tensorboard==2.16.2 265 | tensorboard-data-server==0.7.2 266 | tensorflow==2.16.1 267 | tensorflow-io-gcs-filesystem==0.36.0 268 | termcolor==2.3.0 269 | terminado==0.18.1 270 | thinc==8.2.3 271 | threadpoolctl==3.4.0 272 | timm==0.6.12 273 | tinycss2==1.2.1 274 | tokenizers==0.19.1 275 | tomli==2.0.1 276 | toolz==0.12.1 277 | torch==1.13.1 278 | torchaudio==0.13.1 279 | torchvision==0.14.1 280 | tornado==6.4.1 281 | tqdm==4.65.2 282 | traitlets==5.14.3 283 | transaction==4.0 284 | transformers==4.40.0 285 | translationstring==1.4 286 | typer==0.9.4 287 | types-python-dateutil==2.9.0.20240316 288 | typing_extensions==4.12.2 289 | tzdata==2024.1 290 | uc-micro-py==1.0.3 291 | uri-template==1.3.0 292 | urllib3==2.2.3 293 | uvicorn==0.28.0 294 | velruse==1.1.1 295 | venusian==3.1.0 296 | wandb==0.14.0 297 | wasabi==1.1.2 298 | wcwidth==0.2.13 299 | weasel==0.3.4 300 | webcolors==1.13 301 | webdataset==0.2.86 302 | webencodings==0.5.1 303 | WebOb==1.8.7 304 | websocket-client==1.7.0 305 | websockets==11.0.3 306 | webvtt-py==0.5.1 307 | Werkzeug==3.0.1 308 | widgetsnbextension==4.0.10 309 | wrapt==1.16.0 310 | WTForms==3.1.2 311 | wtforms-recaptcha==0.3.2 312 | xxhash==3.4.1 313 | yacs==0.1.8 314 | yapf==0.40.2 315 | yarl==1.9.4 316 | zipp==3.20.2 317 | zope.deprecation==5.0 318 | zope.interface==6.2 319 | zope.sqlalchemy==3.1 320 | -------------------------------------------------------------------------------- /scripts/videochat_mistral/config_LinearP.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | 5 | train_corpus = "TimePro_Normal" 6 | 7 | train_file = "${available_corpus[${train_corpus}]}" 8 | test_file = dict() 9 | test_types = [] 10 | num_workers = 6 11 | 12 | stop_key = None 13 | 14 | # ========================= input ========================== 15 | num_frames_test = 16 16 | num_frames = 192 17 | clip_frames = 8 18 | good_init = True 19 | 20 | batch_size = 2 21 | max_txt_l = 1536 22 | 23 | save_iter=1000 24 | 25 | pre_text = False 26 | 27 | inputs = dict( 28 | image_res=224, 29 | video_input=dict( 30 | num_frames="${num_frames}", 31 | sample_type="rand", 32 | num_frames_test="${num_frames_test}", 33 | sample_type_test="middle", 34 | random_aug=False, 35 | ), 36 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 37 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 38 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 39 | ) 40 | 41 | # ========================= model ========================== 42 | model = dict( 43 | model_cls="VideoChat2_it4_mistral_LinearP", 44 | vit_blip_model_path="/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth", 45 | mistral_model_path="/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2", 46 | videochat2_model_path="/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage3.pth", 47 | freeze_vit=True, 48 | freeze_qformer=False, 49 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 50 | clip_frames="${clip_frames}", 51 | num_frames="${num_frames}", 52 | token_merge_len = 4, 53 | # vit 54 | low_resource=False, 55 | vision_encoder=dict( 56 | name="vit_l14", 57 | img_size=224, 58 | patch_size=16, 59 | d_model=1024, 60 | encoder_embed_dim=1024, 61 | encoder_depth=24, 62 | encoder_num_heads=16, 63 | drop_path_rate=0., 64 | num_frames="${clip_frames}", 65 | tubelet_size=1, 66 | use_checkpoint=True, 67 | checkpoint_num=18, 68 | pretrained="", 69 | return_index=-2, 70 | vit_add_ln=True, 71 | ckpt_num_frame=4, 72 | ), 73 | # qformer 74 | num_query_token=32, 75 | qformer_hidden_dropout_prob=0.1, 76 | qformer_attention_probs_dropout_prob=0.1, 77 | qformer_drop_path_rate=0.2, 78 | extra_num_query_token=64, 79 | qformer_text_input=True, 80 | # prompt 81 | system="", 82 | start_token="", 84 | add_second_msg=True, 85 | img_start_token="", 86 | img_end_token="", 87 | random_shuffle=True, 88 | return_question_instruction=False, 89 | use_flash_attention=True, 90 | use_lora=True, 91 | lora_r=16, 92 | lora_alpha=32, 93 | lora_dropout=0.1, 94 | # debug=True, 95 | ) 96 | 97 | optimizer = dict( 98 | opt="adamW", 99 | lr=2e-5, 100 | opt_betas=[0.9, 0.999], # default 101 | weight_decay=0.02, 102 | max_grad_norm=-1, # requires a positive float, use -1 to disable 103 | # use a different lr for some modules, e.g., larger lr for new modules 104 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 105 | ) 106 | 107 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 108 | 109 | evaluate = False 110 | deep_fusion = False 111 | evaluation = dict( 112 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 113 | eval_x_only=False, 114 | k_test=128, 115 | eval_offload=True, # offload gpu tensors to cpu to save memory. 116 | ) 117 | 118 | fp16 = True 119 | gradient_checkpointing = True 120 | 121 | # ========================= wandb ========================== 122 | wandb = dict( 123 | enable=False, 124 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 125 | project="videogpt", # setup in your command line 126 | ) 127 | dist_url = "env://" 128 | device = "cuda" 129 | mode = "it_mistral" 130 | 131 | # ========================= others ========================== 132 | output_dir = None # output dir 133 | resume = False # if True, load optimizer and scheduler states as well 134 | debug = False 135 | log_freq = 10 136 | seed = 42 137 | 138 | save_latest = False 139 | auto_resume = True 140 | pretrained_path = "" # path to pretrained model weights, for resume only? 141 | 142 | deepspeed = dict( 143 | enable=False, 144 | stage=1, 145 | ) -------------------------------------------------------------------------------- /scripts/videochat_mistral/config_LinearProAda.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | train_corpus = "TimePro_Normal" 5 | 6 | train_file = "${available_corpus[${train_corpus}]}" 7 | test_file = dict() 8 | test_types = [] 9 | num_workers = 6 10 | 11 | stop_key = None 12 | 13 | # ========================= input ========================== 14 | num_frames_test = 16 15 | num_frames = 128 16 | clip_frames = 8 17 | good_init = True 18 | 19 | batch_size = 2 20 | max_txt_l = 1024 21 | 22 | save_iter=1000 23 | 24 | pre_text = False 25 | 26 | inputs = dict( 27 | image_res=224, 28 | video_input=dict( 29 | num_frames="${num_frames}", 30 | sample_type="rand", 31 | num_frames_test="${num_frames_test}", 32 | sample_type_test="middle", 33 | random_aug=False, 34 | ), 35 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 36 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 37 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 38 | ) 39 | 40 | # ========================= model ========================== 41 | model = dict( 42 | model_cls="VideoChat2_it4_mistral_LinearProAda", 43 | vit_blip_model_path="/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth", 44 | mistral_model_path="/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2", 45 | videochat2_model_path="/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage3.pth", 46 | pretrained_path="your_LinearP_ckpt_path/ckpt_00.pth", 47 | freeze_vit=True, 48 | freeze_qformer=False, 49 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 50 | clip_frames="${clip_frames}", 51 | num_frames="${num_frames}", 52 | token_merge_len = 4, 53 | # vit 54 | low_resource=False, 55 | vision_encoder=dict( 56 | name="vit_l14", 57 | img_size=224, 58 | patch_size=16, 59 | d_model=1024, 60 | encoder_embed_dim=1024, 61 | encoder_depth=24, 62 | encoder_num_heads=16, 63 | drop_path_rate=0., 64 | num_frames="${clip_frames}", 65 | tubelet_size=1, 66 | use_checkpoint=True, 67 | checkpoint_num=18, 68 | pretrained="", 69 | return_index=-2, 70 | vit_add_ln=True, 71 | ckpt_num_frame=4, 72 | ), 73 | # qformer 74 | num_query_token=32, 75 | qformer_hidden_dropout_prob=0.1, 76 | qformer_attention_probs_dropout_prob=0.1, 77 | qformer_drop_path_rate=0.2, 78 | extra_num_query_token=64, 79 | qformer_text_input=True, 80 | # prompt 81 | system="", 82 | start_token="", 84 | add_second_msg=True, 85 | img_start_token="", 86 | img_end_token="", 87 | random_shuffle=True, 88 | return_question_instruction=False, 89 | use_flash_attention=True, 90 | use_lora=True, 91 | lora_r=16, 92 | lora_alpha=32, 93 | lora_dropout=0.1, 94 | # debug=True, 95 | ) 96 | 97 | optimizer = dict( 98 | opt="adamW", 99 | lr=1.5e-5, 100 | opt_betas=[0.9, 0.999], # default 101 | weight_decay=0.02, 102 | max_grad_norm=-1, # requires a positive float, use -1 to disable 103 | # use a different lr for some modules, e.g., larger lr for new modules 104 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 105 | ) 106 | 107 | scheduler = dict(sched="cosine", epochs=2, min_lr_multi=0.2, warmup_epochs=0.05) 108 | 109 | evaluate = False 110 | deep_fusion = False 111 | evaluation = dict( 112 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 113 | eval_x_only=False, 114 | k_test=128, 115 | eval_offload=True, # offload gpu tensors to cpu to save memory. 116 | ) 117 | 118 | fp16 = True 119 | gradient_checkpointing = True 120 | 121 | # ========================= wandb ========================== 122 | wandb = dict( 123 | enable=False, 124 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 125 | project="videogpt", # setup in your command line 126 | ) 127 | dist_url = "env://" 128 | device = "cuda" 129 | mode = "it_mistral" 130 | 131 | # ========================= others ========================== 132 | output_dir = None # output dir 133 | resume = False # if True, load optimizer and scheduler states as well 134 | debug = False 135 | log_freq = 10 136 | seed = 42 137 | 138 | save_latest = False 139 | auto_resume = True 140 | pretrained_path = "" # path to pretrained model weights, for resume only? 141 | 142 | deepspeed = dict( 143 | enable=False, 144 | stage=1, 145 | ) -------------------------------------------------------------------------------- /scripts/videochat_mistral/config_LinearProAdaFT.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | 5 | train_corpus = "FT_Temporal_Grounding_Both" 6 | 7 | 8 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 9 | test_file = dict() 10 | test_types = [] 11 | num_workers = 6 12 | 13 | stop_key = None 14 | 15 | # ========================= input ========================== 16 | num_frames_test = 16 17 | num_frames = 128 18 | clip_frames = 8 19 | good_init = True 20 | 21 | batch_size = 2 22 | max_txt_l = 1024 23 | 24 | save_iter=1000 25 | 26 | pre_text = False 27 | 28 | inputs = dict( 29 | image_res=224, 30 | video_input=dict( 31 | num_frames="${num_frames}", 32 | sample_type="rand", 33 | num_frames_test="${num_frames_test}", 34 | sample_type_test="middle", 35 | random_aug=False, 36 | ), 37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 38 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 40 | ) 41 | 42 | # ========================= model ========================== 43 | model = dict( 44 | model_cls="VideoChat2_it4_mistral_LinearProAda", 45 | vit_blip_model_path="/mnt/petrelfs/share_data/likunchang/model/videochat2/umt_l16_qformer.pth", 46 | mistral_model_path="/mnt/petrelfs/share_data/likunchang/model/llm//Mistral-7B-Instruct-v0.2", 47 | videochat2_model_path="/mnt/petrelfs/share_data/likunchang/model/videochat2/videochat2_mistral_7b_stage3.pth", 48 | pretrained_path="your_LinearProAda_ckpt_path/ckpt_01.pth", 49 | freeze_vit=True, 50 | freeze_qformer=False, 51 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 52 | clip_frames="${clip_frames}", 53 | num_frames="${num_frames}", 54 | token_merge_len = 4, 55 | # vit 56 | low_resource=False, 57 | vision_encoder=dict( 58 | name="vit_l14", 59 | img_size=224, 60 | patch_size=16, 61 | d_model=1024, 62 | encoder_embed_dim=1024, 63 | encoder_depth=24, 64 | encoder_num_heads=16, 65 | drop_path_rate=0., 66 | num_frames="${clip_frames}", 67 | tubelet_size=1, 68 | use_checkpoint=True, 69 | checkpoint_num=18, 70 | pretrained="", 71 | return_index=-2, 72 | vit_add_ln=True, 73 | ckpt_num_frame=4, 74 | ), 75 | # qformer 76 | num_query_token=32, 77 | qformer_hidden_dropout_prob=0.1, 78 | qformer_attention_probs_dropout_prob=0.1, 79 | qformer_drop_path_rate=0.2, 80 | extra_num_query_token=64, 81 | qformer_text_input=True, 82 | # prompt 83 | system="", 84 | start_token="", 86 | add_second_msg=True, 87 | img_start_token="", 88 | img_end_token="", 89 | random_shuffle=True, 90 | return_question_instruction=False, 91 | use_flash_attention=True, 92 | use_lora=True, 93 | lora_r=16, 94 | lora_alpha=32, 95 | lora_dropout=0.1, 96 | # debug=True, 97 | ) 98 | 99 | optimizer = dict( 100 | opt="adamW", 101 | lr=1e-5, 102 | opt_betas=[0.9, 0.999], # default 103 | weight_decay=0.02, 104 | max_grad_norm=-1, # requires a positive float, use -1 to disable 105 | # use a different lr for some modules, e.g., larger lr for new modules 106 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 107 | ) 108 | 109 | scheduler = dict(sched="cosine", epochs=5, min_lr_multi=0.1, warmup_epochs=0.1) 110 | 111 | evaluate = False 112 | deep_fusion = False 113 | evaluation = dict( 114 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 115 | eval_x_only=False, 116 | k_test=128, 117 | eval_offload=True, # offload gpu tensors to cpu to save memory. 118 | ) 119 | 120 | fp16 = True 121 | gradient_checkpointing = True 122 | 123 | # ========================= wandb ========================== 124 | wandb = dict( 125 | enable=False, 126 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 127 | project="videogpt", # setup in your command line 128 | ) 129 | dist_url = "env://" 130 | device = "cuda" 131 | mode = "it_mistral" 132 | 133 | # ========================= others ========================== 134 | output_dir = None # output dir 135 | resume = False # if True, load optimizer and scheduler states as well 136 | debug = False 137 | log_freq = 10 138 | seed = 42 139 | 140 | save_latest = False 141 | auto_resume = True 142 | pretrained_path = "" # path to pretrained model weights, for resume only? 143 | 144 | deepspeed = dict( 145 | enable=False, 146 | stage=1, 147 | ) -------------------------------------------------------------------------------- /scripts/videochat_mistral/run_7b_stage4.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | ROOT_DIR=/path_to_the_timesuite_root_folder 10 | PARTITION='video5' 11 | TRAIN_TYPE="" 12 | 13 | 14 | # <<< Grounded Tuning Epoch 1 >>> 15 | MODEL_TYPE='LinearP' 16 | JOB_NAME="F192_CF8${TRAIN_TYPE}_${MODEL_TYPE}_TimePro_Normal" 17 | 18 | 19 | # <<< Grounded Tuning Epoch 2 & 3 >>> 20 | # MODEL_TYPE='LinearProAda' 21 | # JOB_NAME="F128_CF8${TRAIN_TYPE}_${MODEL_TYPE}_TimePro_Normal" 22 | 23 | 24 | # <<< Supervised FT on Charades-STA & QVHighlight >>> 25 | # MODEL_TYPE='LinearProAdaFT' 26 | # JOB_NAME="F128_CF8${TRAIN_TYPE}_${MODEL_TYPE}_FT_Both" 27 | 28 | 29 | NNODE=2 30 | NUM_GPUS=8 31 | NUM_CPUS=128 32 | 33 | 34 | OUTPUT_DIR="${ROOT_DIR}/$(dirname $0)/${JOB_NAME}" 35 | echo "Model Dir : ${OUTPUT_DIR}" 36 | mkdir ${OUTPUT_DIR} 37 | 38 | 39 | 40 | 41 | # srun -p ${PARTITION} \ 42 | # --job-name=${JOB_NAME} \ 43 | # -n${NNODE} \ 44 | # --gres=gpu:${NUM_GPUS} \ 45 | # --ntasks-per-node=1 \ 46 | # --cpus-per-task=${NUM_CPUS} \ 47 | 48 | bash torchrun.sh \ 49 | --nnodes=${NNODE} \ 50 | --nproc_per_node=${NUM_GPUS} \ 51 | --rdzv_backend=c10d \ 52 | tasks/train_it4${TRAIN_TYPE}.py \ 53 | $(dirname $0)/config_${MODEL_TYPE}${TRAIN_TYPE}.py \ 54 | output_dir ${OUTPUT_DIR} \ 55 | 2>&1 | tee ${OUTPUT_DIR}/bash_output.log -------------------------------------------------------------------------------- /scripts/videochat_mistral/run_7b_stage4_ds.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | 10 | PARTITION='video5' 11 | JOB_NAME='stage4_ds_TimeIT' 12 | NNODE=1 13 | NUM_GPUS=8 14 | NUM_CPUS=96 15 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME" 16 | 17 | # srun -p ${PARTITION} \ 18 | # --job-name=${JOB_NAME} \ 19 | # -n${NNODE} \ 20 | # --gres=gpu:${NUM_GPUS} \ 21 | # --ntasks-per-node=1 \ 22 | # --cpus-per-task=${NUM_CPUS} \ 23 | 24 | bash torchrun.sh \ 25 | --nnodes=${NNODE} \ 26 | --nproc_per_node=${NUM_GPUS} \ 27 | --rdzv_backend=c10d \ 28 | tasks/train_it4_ds.py \ 29 | $(dirname $0)/config_7b_stage4_ds.py \ 30 | output_dir ${OUTPUT_DIR} \ 31 | 2>&1 | tee ${OUTPUT_DIR}/${JOB_NAME}_bash_output.log 32 | -------------------------------------------------------------------------------- /tasks/shared_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | from torch.utils.data import ConcatDataset, DataLoader 9 | 10 | from utils.optimizer import create_optimizer 11 | from utils.scheduler import create_scheduler 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_media_types(datasources): 17 | """get the media types for for all the dataloaders. 18 | 19 | Args: 20 | datasources (List): List of dataloaders or datasets. 21 | 22 | Returns: List. The media_types. 23 | 24 | """ 25 | if isinstance(datasources[0], DataLoader): 26 | datasets = [dataloader.dataset for dataloader in datasources] 27 | else: 28 | datasets = datasources 29 | media_types = [ 30 | dataset.datasets[0].media_type 31 | if isinstance(dataset, ConcatDataset) 32 | else dataset.media_type 33 | for dataset in datasets 34 | ] 35 | 36 | return media_types 37 | 38 | 39 | def setup_model( 40 | config, model_cls, find_unused_parameters=False 41 | ): 42 | logger.info("Creating model") 43 | config = copy.deepcopy(config) 44 | 45 | model = model_cls(config=config.model) 46 | 47 | model = model.to(torch.device(config.device)) 48 | model_without_ddp = model 49 | if config.distributed: 50 | model = torch.nn.parallel.DistributedDataParallel( 51 | model, 52 | device_ids=[config.gpu], 53 | find_unused_parameters=find_unused_parameters, # `False` for image-only task 54 | ) 55 | 56 | optimizer = create_optimizer(config.optimizer, model) 57 | scheduler = create_scheduler(config.scheduler, optimizer) 58 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) 59 | 60 | start_epoch = 0 61 | global_step = 0 62 | 63 | # auto resume the latest checkpoint 64 | if config.get("auto_resume", False): 65 | logger.info("Auto resuming") 66 | model_latest = join(config.output_dir, "ckpt_latest.pth") 67 | model_best = join(config.output_dir, "ckpt_best.pth") 68 | if not osp.isfile(model_latest): 69 | large_num = -1 70 | for p in os.listdir(config.output_dir): 71 | if 'ckpt' in p: 72 | num = p.split('_')[1].split('.')[0] 73 | if str.isnumeric(num): 74 | if int(num) > large_num: 75 | large_num = int(num) 76 | if large_num != -1: 77 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") 78 | if osp.isfile(model_latest): 79 | config.pretrained_path = model_latest 80 | config.resume = True 81 | elif osp.isfile(model_best): 82 | config.pretrained_path = model_best 83 | config.resume = True 84 | else: 85 | logger.info(f"Not found checkpoint in {config.output_dir}") 86 | 87 | if osp.isfile(config.pretrained_path): 88 | checkpoint = torch.load(config.pretrained_path, map_location="cpu") 89 | state_dict = checkpoint["model"] 90 | 91 | if config.resume: 92 | optimizer.load_state_dict(checkpoint["optimizer"]) 93 | scheduler.load_state_dict(checkpoint["scheduler"]) 94 | scaler.load_state_dict(checkpoint["scaler"]) 95 | start_epoch = checkpoint["epoch"] + 1 96 | global_step = checkpoint["global_step"] 97 | 98 | msg = model_without_ddp.load_state_dict(state_dict, strict=False) 99 | logger.info(msg) 100 | logger.info(f"Loaded checkpoint from {config.pretrained_path}") 101 | else: 102 | logger.warning("No pretrained checkpoint provided, training from scratch") 103 | 104 | return ( 105 | model, 106 | model_without_ddp, 107 | optimizer, 108 | scheduler, 109 | scaler, 110 | start_epoch, 111 | global_step, 112 | ) 113 | -------------------------------------------------------------------------------- /tasks/shared_utils_ds.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | import deepspeed 9 | from torch.utils.data import ConcatDataset, DataLoader 10 | 11 | from utils.optimizer import create_optimizer 12 | from utils.scheduler import create_scheduler 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_media_types(datasources): 18 | """get the media types for for all the dataloaders. 19 | 20 | Args: 21 | datasources (List): List of dataloaders or datasets. 22 | 23 | Returns: List. The media_types. 24 | 25 | """ 26 | if isinstance(datasources[0], DataLoader): 27 | datasets = [dataloader.dataset for dataloader in datasources] 28 | else: 29 | datasets = datasources 30 | media_types = [ 31 | dataset.datasets[0].media_type 32 | if isinstance(dataset, ConcatDataset) 33 | else dataset.media_type 34 | for dataset in datasets 35 | ] 36 | 37 | return media_types 38 | 39 | 40 | def setup_model( 41 | config, model_cls, find_unused_parameters=False, num_steps_per_epoch=-1, 42 | ): 43 | logger.info("Creating model") 44 | config = copy.deepcopy(config) 45 | 46 | model = model_cls(config=config.model) 47 | 48 | model = model.to(torch.device(config.device)) 49 | if config.fp16: 50 | if config.get('bf16', True): 51 | logger.info("Change to bfloat16 for model") 52 | model = model.to(torch.bfloat16) 53 | else: 54 | logger.info("Change to float16 for model") 55 | model = model.half() 56 | model_without_ddp = model 57 | 58 | if hasattr(config, "deepspeed") and config.deepspeed.enable: 59 | optimizer_params = create_optimizer(config.optimizer, model, return_group=True) 60 | scheduler = None 61 | scaler = None 62 | else: 63 | if config.distributed: 64 | model = torch.nn.parallel.DistributedDataParallel( 65 | model, 66 | device_ids=[config.gpu], 67 | find_unused_parameters=find_unused_parameters, # `False` for image-only task 68 | ) 69 | 70 | optimizer = create_optimizer(config.optimizer, model) 71 | scheduler = create_scheduler(config.scheduler, optimizer) 72 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) 73 | 74 | start_epoch = 0 75 | global_step = 0 76 | 77 | # auto resume the latest checkpoint 78 | if config.get("auto_resume", False): 79 | logger.info("Auto resuming") 80 | model_latest = join(config.output_dir, "ckpt_latest.pth") 81 | model_best = join(config.output_dir, "ckpt_best.pth") 82 | 83 | large_step_num = -1 84 | large_num = -1 85 | for p in os.listdir(config.output_dir): 86 | if 'ckpt_iter' in p: 87 | num = p.split('_iter')[1].split('.')[0] 88 | if str.isnumeric(num): 89 | if int(num) > large_step_num: 90 | large_step_num = int(num) 91 | elif 'ckpt_' in p: 92 | num = p.split('_')[1].split('.')[0] 93 | if str.isnumeric(num): 94 | if int(num) > large_num: 95 | large_num = int(num) 96 | if large_step_num != -1: 97 | logger.info(f"Load the latest step: {large_step_num}") 98 | model_latest = join(config.output_dir, f"ckpt_iter{large_step_num:02d}.pth") 99 | if large_num != -1 and (large_num + 1) * num_steps_per_epoch > large_step_num: 100 | logger.info(f"Load the latest epoch: {large_num}") 101 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") 102 | 103 | if hasattr(config, "deepspeed") and config.deepspeed.enable: 104 | if osp.isdir(model_latest): 105 | config.pretrained_path = model_latest 106 | config.resume = True 107 | elif osp.isdir(model_best): 108 | config.pretrained_path = model_best 109 | config.resume = True 110 | else: 111 | logger.info(f"Not found checkpoint in {config.output_dir}") 112 | else: 113 | if osp.isfile(model_latest): 114 | config.pretrained_path = model_latest 115 | config.resume = True 116 | elif osp.isfile(model_best): 117 | config.pretrained_path = model_best 118 | config.resume = True 119 | else: 120 | logger.info(f"Not found checkpoint in {config.output_dir}") 121 | 122 | # load pretrained model 123 | if hasattr(config, "deepspeed") and config.deepspeed.enable: 124 | logger.info('Use deepspeed to initialize model!!!') 125 | model = model_without_ddp 126 | model, optimizer, _, _ = deepspeed.initialize( 127 | args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed, 128 | lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt) 129 | ) 130 | if osp.isdir(config.pretrained_path): 131 | logger.info(f"Load pretrained model from {config.pretrained_path}") 132 | output_dir, tag = os.path.split(config.pretrained_path) 133 | if config.resume: 134 | _, client_state = model.load_checkpoint(output_dir, tag=tag, load_module_strict=False) 135 | global_step = model.global_steps 136 | assert num_steps_per_epoch > 0, "Please provide num_steps_per_epoch" 137 | start_epoch = global_step // num_steps_per_epoch 138 | else: 139 | _, client_state = model.load_checkpoint( 140 | output_dir, tag=tag, load_module_strict=False, 141 | load_optimizer_states=False, load_lr_scheduler_states=False, 142 | load_module_only=True 143 | ) 144 | else: 145 | if osp.isfile(config.pretrained_path): 146 | checkpoint = torch.load(config.pretrained_path, map_location="cpu") 147 | logger.info(f"Load pretrained model from {config.pretrained_path}") 148 | if 'model' in checkpoint.keys(): 149 | state_dict = checkpoint["model"] 150 | elif 'module' in checkpoint.keys(): 151 | state_dict = checkpoint["module"] 152 | else: 153 | state_dict = checkpoint 154 | # resume optimizer 155 | if config.resume: 156 | optimizer.load_state_dict(checkpoint["optimizer"]) 157 | scheduler.load_state_dict(checkpoint["scheduler"]) 158 | scaler.load_state_dict(checkpoint["scaler"]) 159 | start_epoch = checkpoint["epoch"] + 1 160 | global_step = checkpoint["global_step"] 161 | 162 | msg = model_without_ddp.load_state_dict(state_dict, strict=False) 163 | logger.info(msg) 164 | logger.info(f"Loaded checkpoint from {config.pretrained_path}") 165 | else: 166 | logger.warning("No pretrained checkpoint provided, training from scratch") 167 | 168 | logger.info(f"Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M") 169 | 170 | return ( 171 | model, 172 | model_without_ddp, 173 | optimizer, 174 | scheduler, 175 | scaler, 176 | start_epoch, 177 | global_step, 178 | ) 179 | -------------------------------------------------------------------------------- /tasks/shared_utils_qformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | from torch.utils.data import ConcatDataset, DataLoader 9 | 10 | from models.bert.tokenization_bert import BertTokenizer 11 | from utils.optimizer import create_optimizer 12 | from utils.scheduler import create_scheduler 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_media_types(datasources): 18 | """get the media types for for all the dataloaders. 19 | 20 | Args: 21 | datasources (List): List of dataloaders or datasets. 22 | 23 | Returns: List. The media_types. 24 | 25 | """ 26 | if isinstance(datasources[0], DataLoader): 27 | datasets = [dataloader.dataset for dataloader in datasources] 28 | else: 29 | datasets = datasources 30 | media_types = [ 31 | dataset.datasets[0].media_type 32 | if isinstance(dataset, ConcatDataset) 33 | else dataset.media_type 34 | for dataset in datasets 35 | ] 36 | 37 | return media_types 38 | 39 | 40 | def setup_model( 41 | config, model_cls, find_unused_parameters=False 42 | ): 43 | logger.info("Creating model") 44 | config = copy.deepcopy(config) 45 | 46 | if "bert" in config.model.text_encoder.name: 47 | tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True) 48 | else: 49 | raise ValueError(f"Not supported text encoder.") 50 | 51 | model = model_cls(config=config, tokenizer=tokenizer) 52 | 53 | model = model.to(torch.device(config.device)) 54 | model_without_ddp = model 55 | if config.distributed: 56 | model = torch.nn.parallel.DistributedDataParallel( 57 | model, 58 | device_ids=[config.gpu], 59 | find_unused_parameters=find_unused_parameters, # `False` for image-only task 60 | ) 61 | 62 | optimizer = create_optimizer(config.optimizer, model) 63 | scheduler = create_scheduler(config.scheduler, optimizer) 64 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) 65 | 66 | start_epoch = 0 67 | global_step = 0 68 | 69 | # auto resume the latest checkpoint 70 | if config.get("auto_resume", False): 71 | logger.info("Auto resuming") 72 | model_latest = join(config.output_dir, "ckpt_latest.pth") 73 | model_best = join(config.output_dir, "ckpt_best.pth") 74 | large_num = -1 75 | for p in os.listdir(config.output_dir): 76 | if 'ckpt' in p: 77 | num = p.split('_')[1].split('.')[0] 78 | if str.isnumeric(num): 79 | if int(num) > large_num: 80 | large_num = int(num) 81 | if large_num != -1: 82 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") 83 | if osp.isfile(model_latest): 84 | config.pretrained_path = model_latest 85 | config.resume = True 86 | elif osp.isfile(model_best): 87 | config.pretrained_path = model_best 88 | config.resume = True 89 | else: 90 | logger.info(f"Not found checkpoint in {config.output_dir}") 91 | 92 | if osp.isfile(config.pretrained_path): 93 | checkpoint = torch.load(config.pretrained_path, map_location="cpu") 94 | state_dict = checkpoint["model"] 95 | 96 | if config.resume: 97 | optimizer.load_state_dict(checkpoint["optimizer"]) 98 | scheduler.load_state_dict(checkpoint["scheduler"]) 99 | scaler.load_state_dict(checkpoint["scaler"]) 100 | start_epoch = checkpoint["epoch"] + 1 101 | global_step = checkpoint["global_step"] 102 | 103 | 104 | msg = model_without_ddp.load_state_dict(state_dict, strict=False) 105 | logger.info(msg) 106 | logger.info(f"Loaded checkpoint from {config.pretrained_path}") 107 | else: 108 | logger.warning("No pretrained checkpoint provided, training from scratch") 109 | 110 | return ( 111 | model, 112 | model_without_ddp, 113 | optimizer, 114 | scheduler, 115 | scaler, 116 | tokenizer, 117 | start_epoch, 118 | global_step, 119 | ) 120 | -------------------------------------------------------------------------------- /tasks/train_it.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from os.path import join 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.distributed as dist 9 | import wandb 10 | 11 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler 12 | from models import * 13 | from tasks.shared_utils import get_media_types, setup_model 14 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed) 15 | from utils.config_utils import setup_main 16 | from utils.distributed import get_rank, get_world_size, is_main_process 17 | from utils.logger import log_dict_to_wandb, setup_wandb 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train( 23 | model, 24 | train_loaders, 25 | optimizer, 26 | epoch, 27 | global_step, 28 | device, 29 | scheduler, 30 | scaler, 31 | config, 32 | ): 33 | model.train() 34 | 35 | metric_logger = MetricLogger(delimiter=" ") 36 | metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) 37 | loss_names = ["loss"] 38 | 39 | media_types = get_media_types(train_loaders) 40 | 41 | for name in loss_names: 42 | for m in media_types: 43 | metric_logger.add_meter( 44 | f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") 45 | ) 46 | 47 | header = f"Train Epoch: [{epoch}]" 48 | log_freq = config.log_freq 49 | 50 | if config.distributed: 51 | for d in train_loaders: 52 | d.sampler.set_epoch(epoch) 53 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders)))) 54 | 55 | iterator = metric_logger.log_every(train_loader, log_freq, header) 56 | for i, (media_type, (image, text, instruction, _)) in enumerate(iterator): 57 | image = image.to(device, non_blocking=True) 58 | 59 | with torch.cuda.amp.autocast(enabled=config.fp16): 60 | loss_dict = model(image, text, instruction) 61 | loss = sum(loss_dict.values()) 62 | 63 | optimizer.zero_grad() 64 | scaler.scale(loss).backward() 65 | if config.optimizer.max_grad_norm > 0: 66 | scaler.unscale_(optimizer) 67 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) 68 | scaler.step(optimizer) 69 | scaler.update() 70 | scheduler.step() 71 | 72 | # logging 73 | for name in loss_names: 74 | value = loss_dict[name] 75 | value = value if isinstance(value, float) else value.item() 76 | metric_logger.update(**{f"{media_type}-{name}": value}) 77 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 78 | 79 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0: 80 | logs = metric_logger.get_global_avg_dict() 81 | log_dict_to_wandb(logs, step=global_step, prefix="train/") 82 | 83 | global_step += 1 84 | 85 | if config.debug and global_step % 20 == 0: 86 | logger.info("debug mode, break training loop") 87 | break 88 | 89 | if config.debug and global_step % (2 * log_freq + 3) == 0: 90 | logger.info("debug mode, break training loop") 91 | break 92 | 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | logger.info(f"Averaged stats: {metric_logger.global_avg()}") 96 | return global_step 97 | 98 | 99 | def setup_dataloaders(config, mode="pt"): 100 | # train datasets, create a list of data loaders 101 | logger.info(f"Creating dataset for {mode}") 102 | train_datasets = create_dataset(f"{mode}_train", config) 103 | media_types = get_media_types(train_datasets) 104 | 105 | if config.distributed: 106 | num_tasks = get_world_size() 107 | global_rank = get_rank() 108 | samplers = create_sampler( 109 | train_datasets, [True] * len(media_types), num_tasks, global_rank 110 | ) 111 | else: 112 | samplers = [None] * len(media_types) 113 | 114 | train_loaders = create_loader( 115 | train_datasets, 116 | samplers, 117 | batch_size=[config.inputs.batch_size[k] for k in media_types], 118 | num_workers=[config.num_workers] * len(media_types), 119 | is_trains=[True] * len(media_types), 120 | collate_fns=[None] * len(media_types), 121 | ) # [0] 122 | 123 | return train_loaders, media_types 124 | 125 | 126 | def main(config): 127 | if is_main_process() and config.wandb.enable: 128 | run = setup_wandb(config) 129 | 130 | logger.info(f"train_file: {config.train_file}") 131 | 132 | setup_seed(config.seed + get_rank()) 133 | device = torch.device(config.device) 134 | 135 | train_loaders, train_media_types = setup_dataloaders( 136 | config, mode=config.mode 137 | ) 138 | num_steps_per_epoch = sum(len(d) for d in train_loaders) 139 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs 140 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs 141 | # set cudnn.benchmark=True only when input size is fixed 142 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 143 | cudnn.benchmark = len(train_media_types) == 1 144 | 145 | model_cls = eval(config.model.get('model_cls', 'VideoChat2_it_vicuna')) 146 | ( 147 | model, 148 | model_without_ddp, 149 | optimizer, 150 | scheduler, 151 | scaler, 152 | start_epoch, 153 | global_step, 154 | ) = setup_model( 155 | config, 156 | model_cls=model_cls, 157 | # find_unused_parameters=True, 158 | find_unused_parameters=False, 159 | ) 160 | if is_main_process() and config.wandb.enable: 161 | wandb.watch(model) 162 | 163 | logger.info("Start training") 164 | start_time = time.time() 165 | for epoch in range(start_epoch, config.scheduler.epochs): 166 | if not config.evaluate: 167 | global_step = train( 168 | model, 169 | train_loaders, 170 | optimizer, 171 | epoch, 172 | global_step, 173 | device, 174 | scheduler, 175 | scaler, 176 | config, 177 | ) 178 | 179 | if is_main_process(): 180 | logger.info(f"Epoch {epoch}") 181 | param_grad_dic = { 182 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() 183 | } 184 | state_dict = model_without_ddp.state_dict() 185 | for k in list(state_dict.keys()): 186 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 187 | # delete parameters that do not require gradient 188 | del state_dict[k] 189 | save_obj = { 190 | "model": state_dict, 191 | "optimizer": optimizer.state_dict(), 192 | "scheduler": scheduler.state_dict(), 193 | "scaler": scaler.state_dict(), 194 | "config": config, 195 | "epoch": epoch, 196 | "global_step": global_step, 197 | } 198 | if config.get("save_latest", False): 199 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) 200 | else: 201 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) 202 | 203 | if config.evaluate: 204 | break 205 | 206 | dist.barrier() 207 | 208 | total_time = time.time() - start_time 209 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 210 | logger.info(f"Training time {total_time_str}") 211 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}") 212 | 213 | if is_main_process() and config.wandb.enable: 214 | run.finish() 215 | 216 | 217 | if __name__ == "__main__": 218 | cfg = setup_main() 219 | main(cfg) 220 | -------------------------------------------------------------------------------- /tasks/train_pt.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from os.path import join 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.distributed as dist 9 | import wandb 10 | 11 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler 12 | from models import * 13 | from tasks.shared_utils import get_media_types, setup_model 14 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed) 15 | from utils.config_utils import setup_main 16 | from utils.distributed import get_rank, get_world_size, is_main_process 17 | from utils.logger import log_dict_to_wandb, setup_wandb 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train( 23 | model, 24 | train_loaders, 25 | optimizer, 26 | epoch, 27 | global_step, 28 | device, 29 | scheduler, 30 | scaler, 31 | config, 32 | ): 33 | model.train() 34 | 35 | metric_logger = MetricLogger(delimiter=" ") 36 | metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) 37 | loss_names = ["loss"] 38 | 39 | media_types = get_media_types(train_loaders) 40 | 41 | for name in loss_names: 42 | for m in media_types: 43 | metric_logger.add_meter( 44 | f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") 45 | ) 46 | 47 | header = f"Train Epoch: [{epoch}]" 48 | log_freq = config.log_freq 49 | 50 | if config.distributed: 51 | for d in train_loaders: 52 | d.sampler.set_epoch(epoch) 53 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders)))) 54 | 55 | iterator = metric_logger.log_every(train_loader, log_freq, header) 56 | for i, (media_type, (image, text, _)) in enumerate(iterator): 57 | image = image.to(device, non_blocking=True) 58 | 59 | with torch.cuda.amp.autocast(enabled=config.fp16): 60 | loss_dict = model(image, text) 61 | loss = sum(loss_dict.values()) 62 | 63 | optimizer.zero_grad() 64 | scaler.scale(loss).backward() 65 | if config.optimizer.max_grad_norm > 0: 66 | scaler.unscale_(optimizer) 67 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) 68 | scaler.step(optimizer) 69 | scaler.update() 70 | scheduler.step() 71 | 72 | # logging 73 | for name in loss_names: 74 | value = loss_dict[name] 75 | value = value if isinstance(value, float) else value.item() 76 | metric_logger.update(**{f"{media_type}-{name}": value}) 77 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 78 | 79 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0: 80 | logs = metric_logger.get_global_avg_dict() 81 | log_dict_to_wandb(logs, step=global_step, prefix="train/") 82 | 83 | global_step += 1 84 | 85 | if config.debug and global_step % 20 == 0: 86 | logger.info("debug mode, break training loop") 87 | break 88 | 89 | if config.debug and global_step % (2 * log_freq + 3) == 0: 90 | logger.info("debug mode, break training loop") 91 | break 92 | 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | logger.info(f"Averaged stats: {metric_logger.global_avg()}") 96 | return global_step 97 | 98 | 99 | def setup_dataloaders(config, mode="pt"): 100 | # train datasets, create a list of data loaders 101 | logger.info(f"Creating dataset for {mode}") 102 | train_datasets = create_dataset(f"{mode}_train", config) 103 | media_types = get_media_types(train_datasets) 104 | 105 | if config.distributed: 106 | num_tasks = get_world_size() 107 | global_rank = get_rank() 108 | samplers = create_sampler( 109 | train_datasets, [True] * len(media_types), num_tasks, global_rank 110 | ) 111 | else: 112 | samplers = [None] * len(media_types) 113 | 114 | train_loaders = create_loader( 115 | train_datasets, 116 | samplers, 117 | batch_size=[config.inputs.batch_size[k] for k in media_types], 118 | num_workers=[config.num_workers] * len(media_types), 119 | is_trains=[True] * len(media_types), 120 | collate_fns=[None] * len(media_types), 121 | ) # [0] 122 | 123 | return train_loaders, media_types 124 | 125 | 126 | def main(config): 127 | if is_main_process() and config.wandb.enable: 128 | run = setup_wandb(config) 129 | 130 | logger.info(f"train_file: {config.train_file}") 131 | 132 | setup_seed(config.seed + get_rank()) 133 | device = torch.device(config.device) 134 | 135 | train_loaders, train_media_types = setup_dataloaders( 136 | config, mode=config.mode 137 | ) 138 | num_steps_per_epoch = sum(len(d) for d in train_loaders) 139 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs 140 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs 141 | # set cudnn.benchmark=True only when input size is fixed 142 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 143 | cudnn.benchmark = len(train_media_types) == 1 144 | 145 | model_cls = eval(config.model.get('model_cls', 'VideoChat2_pt_vicuna')) 146 | ( 147 | model, 148 | model_without_ddp, 149 | optimizer, 150 | scheduler, 151 | scaler, 152 | start_epoch, 153 | global_step, 154 | ) = setup_model( 155 | config, 156 | model_cls=model_cls, 157 | find_unused_parameters=True, 158 | ) 159 | if is_main_process() and config.wandb.enable: 160 | wandb.watch(model) 161 | 162 | logger.info("Start training") 163 | start_time = time.time() 164 | for epoch in range(start_epoch, config.scheduler.epochs): 165 | if not config.evaluate: 166 | global_step = train( 167 | model, 168 | train_loaders, 169 | optimizer, 170 | epoch, 171 | global_step, 172 | device, 173 | scheduler, 174 | scaler, 175 | config, 176 | ) 177 | 178 | if is_main_process(): 179 | logger.info(f"Epoch {epoch}") 180 | param_grad_dic = { 181 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() 182 | } 183 | state_dict = model_without_ddp.state_dict() 184 | for k in list(state_dict.keys()): 185 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 186 | # delete parameters that do not require gradient 187 | del state_dict[k] 188 | save_obj = { 189 | "model": state_dict, 190 | "optimizer": optimizer.state_dict(), 191 | "scheduler": scheduler.state_dict(), 192 | "scaler": scaler.state_dict(), 193 | "config": config, 194 | "epoch": epoch, 195 | "global_step": global_step, 196 | } 197 | if config.get("save_latest", False): 198 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) 199 | else: 200 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) 201 | 202 | if config.evaluate: 203 | break 204 | 205 | dist.barrier() 206 | 207 | total_time = time.time() - start_time 208 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 209 | logger.info(f"Training time {total_time_str}") 210 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}") 211 | 212 | if is_main_process() and config.wandb.enable: 213 | run.finish() 214 | 215 | 216 | if __name__ == "__main__": 217 | cfg = setup_main() 218 | main(cfg) 219 | -------------------------------------------------------------------------------- /torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 3 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 4 | MASTER_PORT=$((10000 + $RANDOM % 100)) 5 | 6 | echo "All nodes used:" 7 | echo ${ALL_NODES} 8 | echo "Master node:" 9 | echo ${MASTER_NODE} 10 | echo "Master port:" 11 | echo ${MASTER_PORT} 12 | echo "Args:" 13 | echo $@ 14 | 15 | torchrun --rdzv_endpoint=${MASTER_NODE}:10072 $@ 16 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import torch.distributed as dist 5 | from os.path import dirname, join 6 | 7 | from utils.config import Config 8 | from utils.distributed import init_distributed_mode, is_main_process 9 | from utils.logger import setup_logger 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def setup_config(): 15 | """Conbine yaml config and command line config with OmegaConf. 16 | Also converts types, e.g., `'None'` (str) --> `None` (None) 17 | """ 18 | config = Config.get_config() 19 | if config.debug: 20 | config.wandb.enable = False 21 | return config 22 | 23 | 24 | def setup_evaluate_config(config): 25 | """setup evaluation default settings, e.g., disable wandb""" 26 | assert config.evaluate 27 | config.wandb.enable = False 28 | if config.output_dir is None: 29 | config.output_dir = join(dirname(config.pretrained_path), "eval") 30 | return config 31 | 32 | 33 | def setup_output_dir(output_dir, excludes=["code"]): 34 | """ensure not overwritting an exisiting/non-empty output dir""" 35 | if not os.path.exists(output_dir): 36 | os.makedirs(output_dir, exist_ok=False) 37 | else: 38 | existing_dirs_files = os.listdir(output_dir) # list 39 | remaining = set(existing_dirs_files) - set(excludes) 40 | remaining = [e for e in remaining if "slurm" not in e] 41 | remaining = [e for e in remaining if ".out" not in e] 42 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 43 | logger.warn(f"remaining dirs or files: {remaining}") 44 | 45 | 46 | def setup_deepspeed_zero_config(stage): 47 | # We currently set ZeRO based on stage: 48 | if stage == 1: 49 | return {"stage": 1, "reduce_bucket_size": 5e8} 50 | if stage == 2: 51 | return { 52 | "stage": 2, 53 | "contiguous_gradients": False, 54 | "overlap_comm": False, 55 | "reduce_scatter": True, 56 | "reduce_bucket_size": 5e8, 57 | "allgather_bucket_size": 5e8, 58 | "offload_optimizer": { 59 | "device": "cpu" 60 | }, 61 | } 62 | # return { 63 | # "stage": 2, 64 | # "contiguous_gradients": True, 65 | # "overlap_comm": True, 66 | # "reduce_scatter": True, 67 | # "reduce_bucket_size": 5e8, 68 | # "allgather_bucket_size": 5e8, 69 | # "cpu_offload": False, 70 | # } 71 | if stage == 3: 72 | return { 73 | "stage": 3, 74 | "contiguous_gradients": True, 75 | "stage3_max_live_parameters": 1e9, 76 | "stage3_max_reuse_distance": 1e9, 77 | "stage3_prefetch_bucket_size": 1e7, 78 | "stage3_param_persistence_threshold": 1e5, 79 | "reduce_bucket_size": 1e7, 80 | "sub_group_size": 1e9, 81 | "offload_optimizer": { 82 | "device": "cpu" 83 | }, 84 | "offload_param": { 85 | "device": "cpu" 86 | } 87 | } 88 | # return { 89 | # "stage": 3, 90 | # "contiguous_gradients": True, 91 | # "overlap_comm": True, 92 | # "reduce_scatter": True, 93 | # "reduce_bucket_size": 5e4, 94 | # "allgather_bucket_size": 5e4, 95 | # "cpu_offload": False, 96 | # "stage3_max_live_parameters": 1e5, 97 | # "stage3_max_reuse_distance": 1e5, 98 | # } 99 | 100 | raise ValueError("Wrong stage for deepspeed {}".format(stage.stage)) 101 | 102 | 103 | def setup_deepspeed_config(config): 104 | config.deepspeed_config = os.path.join(config.output_dir, "deepspeed_config.json") 105 | opts = config.optimizer 106 | logger.info(f'Write deepspeed config to {config.deepspeed_config}') 107 | if not is_main_process(): 108 | return config 109 | 110 | os.makedirs(config.output_dir, exist_ok=True) 111 | 112 | with open(config.deepspeed_config, mode="w") as writer: 113 | ds_config = { 114 | "train_batch_size": config.batch_size * dist.get_world_size(), 115 | "train_micro_batch_size_per_gpu": config.batch_size, 116 | "steps_per_print": 100, 117 | "optimizer": { 118 | "type": "Adam", 119 | "adam_w_mode": True, 120 | "params": { 121 | "lr": opts.lr, 122 | "weight_decay": opts.weight_decay, 123 | "bias_correction": True, 124 | "betas": [ 125 | opts.opt_betas[0], 126 | opts.opt_betas[1], 127 | ], 128 | "eps": 1e-8 129 | } 130 | } 131 | } 132 | if config.deepspeed.stage != 0: 133 | ds_config["zero_optimization"] = setup_deepspeed_zero_config(config.deepspeed.stage) 134 | 135 | if config.fp16: 136 | if config.get('bf16', True): 137 | ds_config["bf16"] = { 138 | "enabled": True 139 | } 140 | else: 141 | ds_config["fp16"] = { 142 | "enabled": True, 143 | "auto_cast": False, 144 | "loss_scale": 0, 145 | "initial_scale_power": 16, 146 | "loss_scale_window": 1000, 147 | "hysteresis": 2, 148 | "consecutive_hysteresis": False, 149 | "min_loss_scale": 1 150 | } 151 | else: 152 | assert config.deepspeed.stage == 0, "You must use fp16 or bf16 when using ZERO!!!" 153 | 154 | if config.get("max_grad_norm", -1) > 0: 155 | ds_config.update({"gradient_clipping", config.max_grad_norm}) 156 | 157 | writer.write(json.dumps(ds_config, indent=2)) 158 | 159 | return config 160 | 161 | 162 | def setup_main(): 163 | """ 164 | Setup config, logger, output_dir, etc. 165 | Shared for pretrain and all downstream tasks. 166 | """ 167 | config = setup_config() 168 | if hasattr(config, "evaluate") and config.evaluate: 169 | config = setup_evaluate_config(config) 170 | init_distributed_mode(config) 171 | 172 | if hasattr(config, "deepspeed") and config.deepspeed.enable: 173 | config = setup_deepspeed_config(config) 174 | 175 | if is_main_process(): 176 | setup_output_dir(config.output_dir, excludes=["code"]) 177 | setup_logger(output=config.output_dir, color=True, name="vindlu") 178 | logger.info(f"config: {Config.pretty_text(config)}") 179 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 180 | return config 181 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def setup_for_distributed(is_master): 11 | import warnings 12 | 13 | builtin_warn = warnings.warn 14 | 15 | def warn(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_master or force: 18 | builtin_warn(*args, **kwargs) 19 | 20 | # Log warnings only once 21 | warnings.warn = warn 22 | warnings.simplefilter("once", UserWarning) 23 | 24 | if not is_master: 25 | logging.disable() 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def save_on_master(*args, **kwargs): 53 | if is_main_process(): 54 | torch.save(*args, **kwargs) 55 | 56 | 57 | def is_port_in_use(port): 58 | import socket 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(('localhost', port)) == 0 61 | 62 | 63 | def init_distributed_mode(args): 64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 65 | # job started by torch.distributed.launch 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | elif 'SLURM_PROCID' in os.environ: 70 | # local rank on the current node / global rank 71 | local_rank = int(os.environ['SLURM_LOCALID']) 72 | global_rank = int(os.environ['SLURM_PROCID']) 73 | # number of processes / GPUs per node 74 | world_size = int(os.environ["SLURM_NNODES"]) * \ 75 | int(os.environ["SLURM_TASKS_PER_NODE"][0]) 76 | 77 | print(world_size) 78 | 79 | args.rank = global_rank 80 | args.gpu = local_rank 81 | args.world_size = world_size 82 | else: 83 | logger.info('Not using distributed mode') 84 | args.distributed = False 85 | return 86 | 87 | args.distributed = True 88 | 89 | torch.cuda.set_device(args.gpu) 90 | args.dist_backend = 'nccl' 91 | 92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node 93 | dist_port = int(args.dist_url.split(":")[-1]) 94 | while is_port_in_use(dist_port): 95 | dist_port += 10 96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) 97 | 98 | logger.info('| distributed init (rank {}): {}'.format( 99 | args.rank, args.dist_url)) 100 | if "SLURM_JOB_ID" in os.environ: 101 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") 102 | torch.distributed.init_process_group( 103 | backend=args.dist_backend, init_method=args.dist_url, 104 | world_size=args.world_size, rank=args.rank) 105 | torch.distributed.barrier() 106 | setup_for_distributed(args.rank == 0) 107 | 108 | 109 | # Copyright (c) Facebook, Inc. and its affiliates. 110 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py 111 | class GatherLayer(torch.autograd.Function): 112 | """ 113 | Gather tensors from all workers with support for backward propagation: 114 | This implementation does not cut the gradients as torch.distributed.all_gather does. 115 | """ 116 | 117 | @staticmethod 118 | def forward(ctx, x): 119 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 120 | dist.all_gather(output, x) 121 | return tuple(output) 122 | 123 | @staticmethod 124 | def backward(ctx, *grads): 125 | all_gradients = torch.stack(grads) 126 | dist.all_reduce(all_gradients) 127 | return all_gradients[dist.get_rank()] 128 | 129 | 130 | # copied from megavlt 131 | def gather_tensor_along_batch_with_backward(tensor, dim=0): 132 | world_size = get_world_size() 133 | 134 | if world_size < 2: 135 | return tensor 136 | 137 | tensor_list = GatherLayer.apply(tensor) 138 | tensor_list = torch.cat(tensor_list, dim=dim) 139 | return tensor_list 140 | 141 | 142 | @torch.no_grad() 143 | def gather_tensor_along_batch(tensor, dim=0): 144 | """ 145 | Performs all_gather operation on the provided tensors. 146 | *** Warning ***: torch.distributed.all_gather has no gradient. 147 | """ 148 | world_size = get_world_size() 149 | 150 | if world_size < 2: 151 | return tensor 152 | 153 | with torch.no_grad(): 154 | tensor_list = [] 155 | 156 | for _ in range(world_size): 157 | tensor_list.append(torch.zeros_like(tensor)) 158 | 159 | dist.all_gather(tensor_list, tensor) 160 | tensor_list = torch.cat(tensor_list, dim=dim) 161 | return tensor_list 162 | -------------------------------------------------------------------------------- /utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import re 5 | import torch 6 | from torch import optim as optim 7 | from utils.distributed import is_main_process 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | try: 11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 12 | has_apex = True 13 | except ImportError: 14 | has_apex = False 15 | 16 | 17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): 18 | named_param_tuples = [] 19 | for name, param in model.named_parameters(): 20 | if not param.requires_grad: 21 | continue # frozen weights 22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): 23 | named_param_tuples.append([name, param, 0]) 24 | elif name in no_decay_list: 25 | named_param_tuples.append([name, param, 0]) 26 | else: 27 | named_param_tuples.append([name, param, weight_decay]) 28 | return named_param_tuples 29 | 30 | 31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): 32 | """use lr=diff_lr for modules named found in diff_lr_names, 33 | otherwise use lr=default_lr 34 | 35 | Args: 36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module 37 | diff_lr_names: List(str) 38 | diff_lr: float 39 | default_lr: float 40 | Returns: 41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr]) 42 | """ 43 | named_param_tuples_with_lr = [] 44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") 45 | for name, p, wd in named_param_tuples_or_model: 46 | use_diff_lr = False 47 | for diff_name in diff_lr_names: 48 | # if diff_name in name: 49 | if re.search(diff_name, name) is not None: 50 | logger.info(f"param {name} use different_lr: {diff_lr}") 51 | use_diff_lr = True 52 | break 53 | 54 | named_param_tuples_with_lr.append( 55 | [name, p, wd, diff_lr if use_diff_lr else default_lr] 56 | ) 57 | 58 | if is_main_process(): 59 | for name, _, wd, diff_lr in named_param_tuples_with_lr: 60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") 61 | 62 | return named_param_tuples_with_lr 63 | 64 | 65 | def create_optimizer_params_group(named_param_tuples_with_lr): 66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" 67 | group = {} 68 | for name, p, wd, lr in named_param_tuples_with_lr: 69 | if wd not in group: 70 | group[wd] = {} 71 | if lr not in group[wd]: 72 | group[wd][lr] = [] 73 | group[wd][lr].append(p) 74 | 75 | optimizer_params_group = [] 76 | for wd, lr_groups in group.items(): 77 | for lr, p in lr_groups.items(): 78 | optimizer_params_group.append(dict( 79 | params=p, 80 | weight_decay=wd, 81 | lr=lr 82 | )) 83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") 84 | return optimizer_params_group 85 | 86 | 87 | def create_optimizer(args, model, filter_bias_and_bn=True, return_group=False): 88 | opt_lower = args.opt.lower() 89 | weight_decay = args.weight_decay 90 | # check for modules that requires different lr 91 | if hasattr(args, "different_lr") and args.different_lr.enable: 92 | diff_lr_module_names = args.different_lr.module_names 93 | diff_lr = args.different_lr.lr 94 | else: 95 | diff_lr_module_names = [] 96 | diff_lr = None 97 | 98 | no_decay = {} 99 | if hasattr(model, 'no_weight_decay'): 100 | no_decay = model.no_weight_decay() 101 | named_param_tuples = add_weight_decay( 102 | model, weight_decay, no_decay, filter_bias_and_bn) 103 | named_param_tuples = add_different_lr( 104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr) 105 | parameters = create_optimizer_params_group(named_param_tuples) 106 | 107 | if return_group: 108 | return parameters 109 | 110 | if 'fused' in opt_lower: 111 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 112 | 113 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 114 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 115 | opt_args['eps'] = args.opt_eps 116 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 117 | opt_args['betas'] = args.opt_betas 118 | if hasattr(args, 'opt_args') and args.opt_args is not None: 119 | opt_args.update(args.opt_args) 120 | 121 | opt_split = opt_lower.split('_') 122 | opt_lower = opt_split[-1] 123 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 124 | opt_args.pop('eps', None) 125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 126 | elif opt_lower == 'momentum': 127 | opt_args.pop('eps', None) 128 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 129 | elif opt_lower == 'adam': 130 | optimizer = optim.Adam(parameters, **opt_args) 131 | elif opt_lower == 'adamw': 132 | optimizer = optim.AdamW(parameters, **opt_args) 133 | else: 134 | assert False and "Invalid optimizer" 135 | raise ValueError 136 | return optimizer 137 | -------------------------------------------------------------------------------- /utils/quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.peft import LoraColumnParallelLinear, LoraRowParallelLinear 3 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear,RowParallelLinear 4 | # from accessory.model.meta import MetaModel 5 | from transformers.utils.quantization_config import BitsAndBytesConfig 6 | import bitsandbytes as bnb 7 | 8 | from types import MethodType 9 | from tqdm import tqdm 10 | 11 | from fairscale.nn.model_parallel.mappings import ( 12 | copy_to_model_parallel_region, 13 | gather_from_model_parallel_region, 14 | reduce_from_model_parallel_region, 15 | scatter_to_model_parallel_region, 16 | ) 17 | 18 | def forward_ColumnParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore 19 | # Set up backprop all-reduce. 20 | input_parallel = copy_to_model_parallel_region(input_) 21 | # Matrix multiply. 22 | output_parallel = self.quanted_layer(input_parallel) 23 | if self.bias is not None: 24 | output_parallel = output_parallel + self.bias 25 | if self.gather_output: 26 | # All-gather across the partitions. 27 | output = gather_from_model_parallel_region(output_parallel) 28 | else: 29 | output = output_parallel 30 | return output 31 | 32 | def forward_RowParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore 33 | # Set up backprop all-reduce. 34 | if self.input_is_parallel: 35 | input_parallel = input_ 36 | else: 37 | input_parallel = scatter_to_model_parallel_region(input_) 38 | # Matrix multiply. 39 | output_parallel = self.quanted_layer(input_parallel) 40 | # All-reduce across all the partitions. 41 | output_ = reduce_from_model_parallel_region(output_parallel) 42 | if self.bias is not None: 43 | output = output_ + self.bias 44 | else: 45 | output = output_ 46 | return output 47 | 48 | def forward_LoraColumnParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore 49 | # Set up backprop all-reduce. 50 | input_parallel = copy_to_model_parallel_region(input_) 51 | # Matrix multiply. 52 | output_parallel = self.quanted_layer(input_parallel) 53 | if self.bias is not None: 54 | output_parallel = output_parallel + self.bias 55 | if self.lora_a is not None: 56 | modification = self.lora_b(self.lora_a(input_)) 57 | else: 58 | modification = None 59 | 60 | if self.gather_output: 61 | # All-gather across the partitions. 62 | output = gather_from_model_parallel_region(output_parallel) 63 | else: 64 | output = output_parallel 65 | 66 | if modification is not None: 67 | output = output + modification 68 | return output 69 | 70 | def forward_LoraRowParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore 71 | # Set up backprop all-reduce. 72 | if self.input_is_parallel: 73 | input_parallel = input_ 74 | else: 75 | input_parallel = scatter_to_model_parallel_region(input_) 76 | # Matrix multiply. 77 | output_parallel = self.quanted_layer(input_parallel) 78 | # All-reduce across all the partitions. 79 | output_ = reduce_from_model_parallel_region(output_parallel) 80 | if self.lora_a is not None: 81 | modification = self.lora_b(self.lora_a(input_parallel)) 82 | output_ = output_ + modification 83 | if self.bias is not None: 84 | output = output_ + self.bias 85 | else: 86 | output = output_ 87 | return output 88 | 89 | def forward_Linear(self, input: torch.Tensor) -> torch.Tensor: 90 | output = self.quanted_layer(input) 91 | if self.bias != None: 92 | output += self.bias 93 | return output 94 | 95 | def quantize(model, quant_conf : BitsAndBytesConfig): 96 | module_list = [_ for _ in model.named_modules() if isinstance(_[1], 97 | (LoraColumnParallelLinear, LoraRowParallelLinear, 98 | ColumnParallelLinear, RowParallelLinear, torch.nn.Linear))] 99 | 100 | if hasattr(model, "get_quant_blocklist"): 101 | quant_blocklist = [x for x in model.get_quant_blocklist()] 102 | else: 103 | quant_blocklist = [] 104 | 105 | for name, module in tqdm(module_list, desc="Qunatization Process",mininterval=10): 106 | if "lora" in name or name in quant_blocklist: 107 | continue 108 | if isinstance(module, ( 109 | LoraColumnParallelLinear, 110 | LoraRowParallelLinear, 111 | ColumnParallelLinear, 112 | RowParallelLinear, 113 | torch.nn.Linear 114 | )): 115 | # 1. Initialize quantization operator 116 | if quant_conf.load_in_4bit: 117 | quanted_layer = bnb.nn.Linear4bit( 118 | module.in_features, 119 | module.out_features, 120 | bias=None, 121 | compute_dtype=quant_conf.bnb_4bit_compute_dtype, 122 | compress_statistics=True, 123 | device=None) 124 | if quant_conf.bnb_4bit_compute_dtype != None: 125 | quanted_layer.compute_type_is_set = True 126 | 127 | quanted_layer.weight = bnb.nn.Params4bit( 128 | module.weight.data.clone(), 129 | requires_grad=False, 130 | quant_type=quant_conf.bnb_4bit_quant_type, 131 | ) 132 | 133 | elif quant_conf.load_in_8bit: 134 | quanted_layer= bnb.nn.Linear8bitLt( 135 | module.in_features, 136 | module.out_features, 137 | bias=None, 138 | has_fp16_weights=quant_conf.llm_int8_has_fp16_weight, 139 | threshold=quant_conf.llm_int8_threshold, 140 | ) 141 | quanted_layer.weight = bnb.nn.Int8Params( 142 | module.weight.data.clone(), 143 | requires_grad=False, 144 | #has_fp16_weights=quant_conf.llm_int8_has_fp16_weight, 145 | ) 146 | else: 147 | raise NotImplementedError(f'Please determine the proper quantization type.') 148 | 149 | # 2. Convert FP layer to quantized layer 150 | module.quanted_layer = quanted_layer 151 | 152 | if isinstance(module, LoraColumnParallelLinear): 153 | forward_func = forward_LoraColumnParallelLinear 154 | elif isinstance(module, LoraRowParallelLinear): 155 | forward_func = forward_LoraRowParallelLinear 156 | elif isinstance(module, ColumnParallelLinear): 157 | forward_func = forward_ColumnParallelLinear 158 | elif isinstance(module, RowParallelLinear): 159 | forward_func = forward_RowParallelLinear 160 | elif isinstance(module, torch.nn.Linear): 161 | forward_func = forward_Linear 162 | module.forward = MethodType(forward_func, module) 163 | 164 | del module.weight 165 | 166 | -------------------------------------------------------------------------------- /utils/quantization.py: -------------------------------------------------------------------------------- 1 | from transformers.utils.bitsandbytes import * 2 | from transformers import BitsAndBytesConfig 3 | import torch 4 | from torch import nn 5 | import bitsandbytes as bnb 6 | 7 | from fairscale.nn.model_parallel.layers import ( 8 | ParallelEmbedding, 9 | RowParallelLinear, 10 | ColumnParallelLinear, 11 | ) 12 | def _replace_with_bnb_linear( 13 | model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False 14 | ): 15 | """ 16 | Private method that wraps the recursion for module replacement. 17 | 18 | Returns the converted model and a boolean that indicates if the conversion has been successfull or not. 19 | """ 20 | for name, module in model.named_children(): 21 | if current_key_name is None: 22 | current_key_name = [] 23 | current_key_name.append(name) 24 | 25 | if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear) or isinstance(module, RowParallelLinear) ) and name not in modules_to_not_convert: 26 | # Check if the current key is not in the `modules_to_not_convert` 27 | if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): 28 | # with init_empty_weights(): 29 | if quantization_config.quantization_method() == "llm_int8": 30 | model._modules[name] = bnb.nn.Linear8bitLt( 31 | module.in_features, 32 | module.out_features, 33 | module.bias is not None, 34 | has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, 35 | threshold=quantization_config.llm_int8_threshold, 36 | ) 37 | has_been_replaced = True 38 | else: 39 | if ( 40 | quantization_config.llm_int8_skip_modules is not None 41 | and name in quantization_config.llm_int8_skip_modules 42 | ): 43 | pass 44 | else: 45 | model._modules[name] = bnb.nn.Linear4bit( 46 | module.in_features, 47 | module.out_features, 48 | module.bias is not None, 49 | quantization_config.bnb_4bit_compute_dtype, 50 | compress_statistics=quantization_config.bnb_4bit_use_double_quant, 51 | quant_type=quantization_config.bnb_4bit_quant_type, 52 | ) 53 | has_been_replaced = True 54 | # Force requires grad to False to avoid unexpected errors 55 | model._modules[name].requires_grad_(False) 56 | if len(list(module.children())) > 0: 57 | _, has_been_replaced = _replace_with_bnb_linear( 58 | module, 59 | modules_to_not_convert, 60 | current_key_name, 61 | quantization_config, 62 | has_been_replaced=has_been_replaced, 63 | ) 64 | # Remove the last key for recursion 65 | current_key_name.pop(-1) 66 | return model, has_been_replaced 67 | 68 | 69 | def quant_model_bnb(model, quant_bit='4bit', keep_in_fp32_modules=[], 70 | quantization_config=None): 71 | if quantization_config is None: 72 | # set default quantization config 73 | # compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) 74 | quantization_config = BitsAndBytesConfig( 75 | load_in_4bit=quant_bit == '4bit', 76 | load_in_8bit=quant_bit == '8bit', 77 | llm_int8_threshold=6.0, 78 | llm_int8_has_fp16_weight=False, 79 | bnb_4bit_compute_dtype=torch.float16, 80 | bnb_4bit_use_double_quant=True, 81 | bnb_4bit_quant_type='nf4', 82 | ) 83 | 84 | model,_ = _replace_with_bnb_linear( 85 | model, modules_to_not_convert=keep_in_fp32_modules, quantization_config=quantization_config 86 | ) 87 | 88 | return model 89 | 90 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | --------------------------------------------------------------------------------