├── .gitignore
├── LICENSE
├── README.md
├── configs
├── config.json
├── config_bert.json
├── config_mistral.json
├── instruction_data.py
└── model.py
├── dataset
├── TimeIT
│ ├── dense_video_captioning
│ │ ├── anet
│ │ │ ├── instruct_dvc_10.0k_anet.json
│ │ │ ├── instruct_dvc_10.0k_anet_15asr.json
│ │ │ ├── test.caption_coco_format.json
│ │ │ ├── train.caption_coco_format.json
│ │ │ └── val.caption_coco_format.json
│ │ ├── dense_video_captioning_instructions.json
│ │ ├── vitt
│ │ │ ├── instruct_dvc_5.1k_vitt.json
│ │ │ └── instruct_dvc_5.1k_vitt_15asr.json
│ │ ├── youcook2
│ │ │ ├── instruct_dvc_1.2k_youcook2.json
│ │ │ ├── test.caption_coco_format.json
│ │ │ ├── train.caption_coco_format.json
│ │ │ └── val.caption_coco_format.json
│ │ ├── youcook2_new
│ │ │ └── val.caption_coco_format.json
│ │ └── youcook2_new_224_6
│ │ │ └── val.caption_coco_format.json
│ ├── step_localization
│ │ ├── coin
│ │ │ ├── instruct_action_9.0k_coin.json
│ │ │ └── instruct_action_9.0k_coin_15asr.json
│ │ ├── hirest_step
│ │ │ ├── annotations
│ │ │ │ ├── all_data_test.json
│ │ │ │ ├── all_data_test_negative_samples.json
│ │ │ │ ├── all_data_train.json
│ │ │ │ └── all_data_val.json
│ │ │ └── instruct_action_0.5k_hirest.json
│ │ └── step_localization_instructions.json
│ ├── temporal_video_grounding
│ │ ├── charades
│ │ │ ├── charades_annotation
│ │ │ │ ├── charades_sta_test.txt
│ │ │ │ ├── charades_sta_train.txt
│ │ │ │ ├── get_coco_format.py
│ │ │ │ ├── test.caption_coco_format.json
│ │ │ │ └── train.caption_coco_format.json
│ │ │ └── instruct_tvg_12.4k_charades.json
│ │ ├── didemo
│ │ │ ├── instruct_tvg_33.0k_didemo.json
│ │ │ └── instruct_tvg_33.0k_didemo_15asr.json
│ │ ├── hirest
│ │ │ └── instruct_tvg_0.5k_hirest.json
│ │ ├── queryd
│ │ │ ├── instruct_tvg_14.6k_queryd.json
│ │ │ └── instruct_tvg_14.6k_queryd_15asr.json
│ │ └── temporal_video_grounding_instructions.json
│ ├── time
│ │ ├── instruct_time-sensitive_104k.json
│ │ └── instruct_time-sensitive_104k_asr.json
│ ├── transcribed_speech_generation
│ │ ├── transcribed_speech_generation_instructions.json
│ │ └── yttemporal
│ │ │ ├── instruct_tsg_31.6k_yttemporal.json
│ │ │ └── instruct_tsg_31.6k_yttemporal_15asr.json
│ ├── valley
│ │ ├── Valley_instruct_73k.json
│ │ └── instruct_valley_72k.json
│ ├── video_highlight_detection
│ │ ├── qvhighlights
│ │ │ ├── annotations_raw
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ ├── highlight_test_release.jsonl
│ │ │ │ ├── highlight_train_release.jsonl
│ │ │ │ ├── highlight_val_release.jsonl
│ │ │ │ ├── subs_train.jsonl
│ │ │ │ └── val.caption_coco_format.json
│ │ │ ├── get_coco_format.py
│ │ │ ├── instruct_vhd_6.9k_qvhighlights.json
│ │ │ ├── train.caption_coco_format.json
│ │ │ └── val.caption_coco_format.json
│ │ └── video_highlight_detection_instructions.json
│ └── video_summarization
│ │ ├── summe
│ │ ├── instruct_vhd_25_summe.json
│ │ └── instruct_vhd_25_summe_15asr.json
│ │ ├── tvsum
│ │ ├── instruct_vhd_50_tvsum.json
│ │ └── instruct_vhd_50_tvsum_15asr.json
│ │ └── video_summarization_instructions.json
├── __init__.py
├── base_dataset.py
├── dataloader.py
├── it_dataset.py
├── it_dataset_mistral.py
├── pt_dataset.py
├── sampler.py
├── utils.py
├── video_transforms.py
└── video_utils.py
├── demo
├── demo.ipynb
└── example
│ ├── bear.jpg
│ ├── dog.png
│ ├── jesse_dance.mp4
│ ├── people.jpg
│ └── yoga.mp4
├── download
└── folder_keeper
├── eval
├── Egoschema_trans_csv.py
├── eval_egoschema.sh
├── eval_infer.py
├── eval_qa_tasks.ipynb
├── format_dvc.py
├── format_eval.py
├── format_tvg.py
├── format_vhd.py
├── get_grounding_result.sh
├── test_grounding.sh
└── validate_egoschema.py
├── images
├── 123
├── abstract.png
├── data.png
└── structure.png
├── metrics
├── README.md
├── dvc
│ ├── eval_dvc.py
│ ├── eval_dvc.sh
│ ├── example_gt_file.json
│ ├── example_pred_file.json
│ └── metrics
│ │ ├── README.md
│ │ ├── cider.py
│ │ ├── cider_scorer.py
│ │ ├── meteor.py
│ │ └── ptbtokenizer.py
├── tvg
│ ├── cd
│ ├── eval_tvg.py
│ ├── eval_tvg.sh
│ ├── example_gt_file.json
│ └── example_pred_file.json
└── vhd
│ ├── cd
│ ├── eval_highlights.sh
│ ├── eval_vhd.py
│ ├── example_gt_file.json
│ ├── example_pred_file.json
│ ├── metrics.json
│ └── utils.py
├── models
├── __init__.py
├── bert
│ ├── __init__.py
│ ├── builder.py
│ ├── tokenization_bert.py
│ └── xbert.py
├── blip2
│ ├── Qformer.py
│ ├── __init__.py
│ ├── blip2.py
│ ├── builder.py
│ ├── modeling_llama.py
│ ├── modeling_llama_mem.py
│ ├── utils.py
│ └── vit.py
├── criterions.py
├── utils.py
├── videochat2_qformer.py
└── videochat_mistra
│ ├── __init__.py
│ ├── videochat2_it4_mistral_LinearP.py
│ └── videochat2_it4_mistral_LinearProAda.py
├── prompts
├── alignment_image.txt
├── concise_description.txt
├── concise_image_description.txt
├── dvc_description.txt
├── dvc_description_with_asr.txt
├── dvc_description_zeroshot.txt
├── dvc_post_check.txt
├── tvg_description.txt
├── tvg_description_zeroshot.txt
├── tvg_post_check.txt
├── vhd_description.txt
├── vhd_description_zeroshot.txt
├── vhd_description_zeroshot_new.txt
├── vhd_description_zeroshot_post.txt
└── vhd_post_check.txt
├── requirements.txt
├── scripts
└── videochat_mistral
│ ├── config_LinearP.py
│ ├── config_LinearProAda.py
│ ├── config_LinearProAdaFT.py
│ ├── run_7b_stage4.sh
│ └── run_7b_stage4_ds.sh
├── tasks
├── retrieval_utils.py
├── shared_utils.py
├── shared_utils_ds.py
├── shared_utils_qformer.py
├── train_it.py
├── train_it4.py
├── train_it4_bug.py
├── train_it4_ds.py
├── train_it4_ds_new.py
├── train_it_ds.py
├── train_pt.py
├── train_pt_ds.py
└── train_qformer.py
├── torchrun.sh
└── utils
├── basic_utils.py
├── config.py
├── config_utils.py
├── distributed.py
├── easydict.py
├── logger.py
├── optimizer.py
├── peft.py
├── quant.py
├── quantization.py
└── scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # local #
2 | tmp*/
3 | cache/*
4 | */cache*/
5 | tmp*.py
6 | tmp*
7 | *pickle
8 | data/
9 |
10 | # Zip Files/Packages #
11 | *.7z
12 | *.dmg
13 | *.gz
14 | *.iso
15 | *.jar
16 | *.rar
17 | *.tar
18 | *.zip
19 |
20 | # Logs and databases #
21 | *.log
22 | *.sql
23 | *.sqlite
24 | .ipynb_checkpoints/
25 | *.swp
26 | *.vscode/
27 | *.idea/
28 | *.pyc
29 | __pycache__
30 | slurm*out
31 |
32 | # OS files #
33 | .DS_Store
34 | .DS_Store?
35 | ._*
36 | .Spotlight-V100
37 | .Trashes
38 | ehthumbs.db
39 | Thumbs.db
40 |
41 |
42 | .vim-arsync
43 | scratch.norg
44 | sync_to_red.sh
45 |
46 | anno/
47 | wandb/
48 | logs/
49 | *.pth
50 |
51 | # personal
52 | test.ipynb
53 |
54 | jupyter/
55 |
56 | phoenix-slurm*
57 | batchscript-*
58 |
59 | debug*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 OpenGVLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [Xiangyu Zeng](https://scholar.google.com/citations?user=jS13DXkAAAAJ&hl=zh-CN), [Kunchang Li](https://scholar.google.com/citations?user=D4tLSbsAAAAJ), [Chenting Wang](https://scholar.google.com/citations?user=f81ulHQAAAAJ&hl=zh-CN), [Xinhao Li](https://scholar.google.com/citations?user=evR3uR0AAAAJ&hl=zh-CN), Tianxiang Jiang, [Ziang Yan](https://scholar.google.com/citations?user=78lx13MAAAAJ&hl=zh-CN), [Songze Li](https://scholar.google.com/citations?user=8rBMUD4AAAAJ&hl=zh-CN), Yansong Shi, Zhengrong Yue, [Yi Wang](https://scholar.google.com.hk/citations?hl=zh-CN&user=Xm2M8UwAAAAJ), [Yali Wang](https://scholar.google.com/citations?user=hD948dkAAAAJ), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl), and [Limin Wang](https://scholar.google.com/citations?user=HEuN8PcAAAAJ)
6 |
7 |
8 |
9 | ## :parrot: Introduction
10 |
11 | This paper proposes TimeSuite, a collection of new designs to adapt the existing short-form video MLLMs for long video understanding, including a simple yet efficient framework to process long video sequence, a high-quality video dataset for grounded tuning of MLLMs, and a carefully-designed instruction tuning task to explicitly incorporate the grounding supervision in the traditional QA format.
12 |
13 | **State-of-the-art performance**: VideoChat-T demonstrates high performance for both long-form video question answering and temporal grounding.
14 | 
15 |
16 | **Highly efficient model architecture** with exceptional inference speed, encoding each video frame into just **3 tokens**, leading to the flops of our VideoChat-T are 5.1% of Llava-OneVision
17 | 
18 |
19 | **High-quality data**
20 | - We introduced the comprehensive dataset TimePro, which includes 9 task types with video sources from 15 different datasets.
21 | - We designed a novel Temporal Grounded Caption fine-tuning task to effectively mitigate hallucinations in MLLM.
22 | 
23 |
24 | ## :fire: Updates
25 |
26 | - 2025.02.12 TimeSuite is open-sourced. We welcome everyone to try it out!
27 | - 2025.01.23 TimeSuite has been accepted by ICLR 2025.
28 | - 2024.10.25 The paper of TimeSuite has been uploaded to arXiv.
29 |
30 | ## Preparation
31 |
32 | - Create a new environment and run the command to install the necessary dependencies.
33 |
34 | ```
35 | conda create --name TimeSuite
36 | conda activate TimeSuite
37 | pip install -r requirements.txt
38 | ```
39 |
40 | - Download the model and code of TimeSuite from [https://huggingface.co/Lanxingxuan/TimeSuite](https://huggingface.co/Lanxingxuan/TimeSuite) to the `./download` folder. (Please note that you need to additionally download [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) to `./download/parameters`)
41 |
42 | - Search for all instances of `/path_to_the_timesuite_root_folder` and replace them with the directory of the TimeSuite root folder.
43 |
44 | - Please search for all video dataset paths containing `s3://` and replace them with the corresponding video dataset paths on your server.
45 |
46 | ## Inference & Demo
47 |
48 | - Run `demo/demo.ipynb` to see the demo provided in the paper, or try out the videos and questions of your choice.
49 |
50 | - Run `eval/eval_qa_tasks.ipynb` to test the general QA performance of the model.
51 |
52 | - To test the temporal grounding capability of TimeSuite, please follow these two steps.
53 |
54 | ```
55 | bash eval/test_grounding.sh
56 | bash eval/get_grounding_result.sh
57 | ```
58 |
59 | ## Grounded Tuning
60 |
61 | - Please properly configure the video dataset path in `configs/instruction_data.py`.
62 | - Modify `scripts/videochat_mistral/config_LinearP.py` and `scripts/videochat_mistral/config_LinearProAda.py` to adjust the model training parameter settings.
63 | - Please run `bash scripts/videochat_mistral/run_7b_stage4.sh` to initiate the fine-tuning of the model.
64 | - To reproduce the fine-tuning results presented in the paper, you need to initiate the model training in a two-stage manner. For detailed parameter settings, please refer to Appendix D of the paper.
65 |
66 |
67 | ## TimePro Dataset
68 |
69 | ### Annotations
70 |
71 | - All data used for fine-tuning is now open-sourced. Please visit [https://huggingface.co/Lanxingxuan/TimeSuite/tree/main/datasets/TimePro](https://huggingface.co/Lanxingxuan/TimeSuite/tree/main/datasets/TimePro) to download.
72 |
73 | ### Videos
74 |
75 | **_TimePro_**
76 | - DiDeMo: [https://github.com/LisaAnne/LocalizingMoments?tab=readme-ov-file#dataset](https://github.com/LisaAnne/LocalizingMoments?tab=readme-ov-file#dataset)
77 | - QuerYD: [https://www.robots.ox.ac.uk/~vgg/data/queryd/](https://www.robots.ox.ac.uk/~vgg/data/queryd/)
78 | - HiREST: [https://github.com/j-min/HiREST](https://github.com/j-min/HiREST)
79 | - ActivityNet: [http://activity-net.org/download.html](http://activity-net.org/download.html)
80 | - ViTT: [https://github.com/google-research-datasets/Video-Timeline-Tags-ViTT](https://github.com/google-research-datasets/Video-Timeline-Tags-ViTT)
81 | - YouCook2: [http://youcook2.eecs.umich.edu/download](http://youcook2.eecs.umich.edu/download)
82 | - TVSum: [https://github.com/yalesong/tvsum](https://github.com/yalesong/tvsum)
83 | - SumMe: [http://classif.ai/dataset/ethz-cvl-video-summe/](http://classif.ai/dataset/ethz-cvl-video-summe/)
84 | - COIN: [https://github.com/coin-dataset/annotations](https://github.com/coin-dataset/annotations)
85 | - YT-Temporal: [https://rowanzellers.com/merlot/#data](https://rowanzellers.com/merlot/#data)
86 | - Internvid: [https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/README_CN.md](https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/README_CN.md)
87 | - HowTo100M(CosMo): [https://www.di.ens.fr/willow/research/howto100m/](https://www.di.ens.fr/willow/research/howto100m/)
88 |
89 | **_Normal_**
90 | - VideoChatGPT: [https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data)
91 | - VideoChat: [https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data)
92 | - EgoQA: [https://ego4d-data.org/](https://ego4d-data.org/)
93 | - STAR: [https://bobbywu.com/STAR/](https://bobbywu.com/STAR/)
94 | - MovieChat: [https://huggingface.co/datasets/Enxin/MovieChat-1K_train](https://huggingface.co/datasets/Enxin/MovieChat-1K_train)
95 |
96 | **_FT_**
97 | - Charades-STA: [https://github.com/jiyanggao/TALL#charades-sta-anno-download](https://github.com/jiyanggao/TALL#charades-sta-anno-download)
98 | - QVHighlight: [https://github.com/jayleicn/moment_detr/blob/main/data/README.md](https://github.com/jayleicn/moment_detr/blob/main/data/README.md)
99 |
100 | # :page_facing_up: Citation
101 |
102 | If you find this project useful in your research, please consider cite:
103 | ```BibTeX
104 | @misc{zeng2024timesuite,
105 | title={TimeSuite: Improving MLLMs for Long Video Understanding via Grounded Tuning},
106 | author={Xiangyu Zeng and Kunchang Li and Chenting Wang and Xinhao Li and Tianxiang Jiang and Ziang Yan and Songze Li and Yansong Shi and Zhengrong Yue and Yi Wang and Yali Wang and Yu Qiao and Limin Wang},
107 | year={2024},
108 | eprint={2410.19702},
109 | archivePrefix={arXiv},
110 | primaryClass={cs.CV},
111 | url={https://arxiv.org/abs/2410.19702},
112 | }
113 | ```
114 |
115 | # :dizzy: Acknowledgement
116 |
117 | Thanks to the open source of the following projects:
118 |
119 | - [UMT](https://github.com/OpenGVLab/unmasked_teacher)
120 | - [MVBench](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
121 | - [TimeChat](https://github.com/RenShuhuai-Andy/TimeChat)
122 | - [LITA](https://github.com/NVlabs/LITA)
123 |
--------------------------------------------------------------------------------
/configs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "VideoChat2_it",
4 | "vit_blip_model_path": "your_model_path/umt_l16_qformer.pth",
5 | "llama_model_path": "your_model_path/vicuna-7b-v0",
6 | "videochat2_model_path": "your_model_path/videochat2_7b_stage2.pth",
7 | "freeze_vit": false,
8 | "freeze_qformer": false,
9 | "max_txt_len": 512,
10 | "low_resource": false,
11 | "vision_encoder": {
12 | "name": "vit_l14",
13 | "img_size": 224,
14 | "patch_size": 16,
15 | "d_model": 1024,
16 | "encoder_embed_dim": 1024,
17 | "encoder_depth": 24,
18 | "encoder_num_heads": 16,
19 | "drop_path_rate": 0.0,
20 | "num_frames": 8,
21 | "tubelet_size": 1,
22 | "use_checkpoint": false,
23 | "checkpoint_num": 0,
24 | "pretrained": "",
25 | "return_index": -2,
26 | "vit_add_ln": true,
27 | "ckpt_num_frame": 4
28 | },
29 | "num_query_token": 32,
30 | "qformer_hidden_dropout_prob": 0.1,
31 | "qformer_attention_probs_dropout_prob": 0.1,
32 | "qformer_drop_path_rate": 0.2,
33 | "extra_num_query_token": 64,
34 | "qformer_text_input": true,
35 | "system": "",
36 | "start_token": "",
38 | "img_start_token": "",
39 | "img_end_token": "",
40 | "random_shuffle": true,
41 | "use_lora": false,
42 | "lora_r": 16,
43 | "lora_alpha": 32,
44 | "lora_dropout": 0.1
45 | },
46 | "device": "cuda"
47 | }
48 |
--------------------------------------------------------------------------------
/configs/config_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "fusion_layer": 9,
20 | "encoder_width": 768,
21 | "cross_module": "ca"
22 | }
23 |
--------------------------------------------------------------------------------
/configs/config_mistral.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "VideoChat2_it_mistral",
4 | "vit_blip_model_path": "/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth",
5 | "mistral_model_path": "/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2",
6 | "videochat2_model_path": "/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage2.pth",
7 | "freeze_vit": false,
8 | "freeze_qformer": false,
9 | "max_txt_len": 512,
10 | "low_resource": false,
11 | "vision_encoder": {
12 | "name": "vit_l14",
13 | "img_size": 224,
14 | "patch_size": 16,
15 | "d_model": 1024,
16 | "encoder_embed_dim": 1024,
17 | "encoder_depth": 24,
18 | "encoder_num_heads": 16,
19 | "drop_path_rate": 0.0,
20 | "num_frames": 8,
21 | "tubelet_size": 1,
22 | "use_checkpoint": true,
23 | "checkpoint_num": 18,
24 | "pretrained": "",
25 | "return_index": -2,
26 | "vit_add_ln": true,
27 | "ckpt_num_frame": 4
28 | },
29 | "num_query_token": 32,
30 | "qformer_hidden_dropout_prob": 0.1,
31 | "qformer_attention_probs_dropout_prob": 0.1,
32 | "qformer_drop_path_rate": 0.2,
33 | "extra_num_query_token": 64,
34 | "qformer_text_input": true,
35 | "system": "",
36 | "start_token": "",
38 | "add_second_msg": true,
39 | "img_start_token": "",
40 | "img_end_token": "",
41 | "random_shuffle": true,
42 | "return_question_instruction": false,
43 | "use_flash_attention": true,
44 | "use_lora": false,
45 | "lora_r": 16,
46 | "lora_alpha": 32,
47 | "lora_dropout": 0.1
48 | },
49 | "device": "cuda"
50 | }
51 |
--------------------------------------------------------------------------------
/configs/instruction_data.py:
--------------------------------------------------------------------------------
1 | import os as __os # add "__" if not want to be exported
2 | from copy import deepcopy as __deepcopy
3 |
4 | anno_root_it = "/path_to_the_timesuite_root_folder/download/datasets/TimePro"
5 |
6 |
7 | # ============== pretraining datasets=================
8 | available_corpus = dict(
9 |
10 | caption_youcook2=[
11 | f"{anno_root_it}/caption_youcook2.json",
12 | "pnorm2:s3://youcook2/split_videos",
13 | "video"
14 | ],
15 | conversation_videochat1=[
16 | f"{anno_root_it}/conversation_videochat1.json",
17 | "pnorm2:s3://webvid10m",
18 | "video"
19 | ],
20 | conversation_videochat2=[
21 | f"{anno_root_it}/conversation_videochat2.json",
22 | "pnorm:s3://videointernsegvideos",
23 | "video"
24 | ],
25 | conversation_videochatgpt=[
26 | f"{anno_root_it}/conversation_videochatgpt.json",
27 | "pnorm2:s3://anet/ANet_320p_fps30",
28 | "video"
29 | ],
30 | reasoning_star=[
31 | f"{anno_root_it}/reasoning_star.json",
32 | "pnorm2:s3://star/Charades_v1_480",
33 | "video"
34 | ],
35 | vqa_ego_qa=[
36 | f"{anno_root_it}/vqa_ego_qa.json",
37 | "pnorm2:s3://egoqa/split_videos",
38 | "video"
39 | ],
40 |
41 |
42 |
43 |
44 | # TimeIT
45 | timeit_ANet=[
46 | f"{anno_root_it}/timeit_ANet.json",
47 | "pnorm2:s3://anet",
48 | "video"
49 | ],
50 |
51 | timeit_COIN=[
52 | f"{anno_root_it}/timeit_COIN.json",
53 | "pnorm:s3://COIN_320p",
54 | "video"
55 | ],
56 |
57 | timeit_DiDeMo=[
58 | f"{anno_root_it}/timeit_DiDeMo.json",
59 | "sssd:s3://yjsBucket",
60 | "video"
61 | ],
62 |
63 | timeit_HiREST=[
64 | f"{anno_root_it}/timeit_HiREST.json",
65 | "pnorm2zxy:s3://hirest",
66 | "video"
67 | ],
68 |
69 |
70 | timeit_QuerYD=[
71 | f"{anno_root_it}/timeit_QuerYD.json",
72 | "pnorm2zxy:s3://queryd",
73 | "video"
74 | ],
75 |
76 | timeit_TVSum=[
77 | f"{anno_root_it}/timeit_TVSum.json",
78 | "pnorm2zxy:s3://tvsum",
79 | "video"
80 | ],
81 |
82 | timeit_ViTT=[
83 | f"{anno_root_it}/timeit_ViTT.json",
84 | "sssd:s3://ViTT",
85 | "video"
86 | ],
87 |
88 | timeit_yttemporal180m=[
89 | f"{anno_root_it}/timeit_yttemporal180m.json",
90 | "pnorm:s3://YT-Temporal-180M",
91 | "video"
92 | ],
93 |
94 | grounding_ANetRTL=[
95 | f"{anno_root_it}/grounding_ANetRTL.json",
96 | "pnorm2:s3://anet/ANet_320p_fps30/train",
97 | "video"
98 | ],
99 |
100 | grounding_IntrenvidVTime_100K=[
101 | f"{anno_root_it}/grounding_IntrenvidVTime_100K.json",
102 | "pnorm:s3://youtubeBucket/videos/",
103 | "video"
104 | ],
105 | grounding_ANetHL2=[
106 | f"{anno_root_it}/grounding_ANetHL2.json",
107 | "pnorm2:s3://anet/ANet_320p_fps30/train",
108 | "video"
109 | ],
110 |
111 | grounding_CosmoCap_93K=[
112 | f"{anno_root_it}/grounding_CosmoCap_93K.json",
113 | "pvideo:s3://howto100m/",
114 | "video"
115 | ],
116 | vqa_moviechat = [
117 | f'{anno_root_it}/vqa_moviechat.json',
118 | 'pnorm2:s3://MovieChat/real_video/',
119 | 'video'
120 | ],
121 | caption_moviechat = [
122 | f'{anno_root_it}/caption_moviechat.json',
123 | 'pnorm2:s3://MovieChat/real_video/',
124 | 'video'
125 | ],
126 |
127 |
128 | FT_Charades=[
129 | f"{anno_root_it}/FT_Charades.json",
130 | "s3://zengxiangyu/Charades/",
131 | "video"
132 | ],
133 |
134 | FT_QVHighlights=[
135 | f"{anno_root_it}/FT_QVHighlights.json",
136 | "s3://QVHighlight/videos/",
137 | "video"
138 | ],
139 |
140 | )
141 |
142 |
143 | available_corpus["TimePro_Normal"] = [ #final dataset
144 | #TiIT
145 | available_corpus["timeit_ANet"],
146 | available_corpus["timeit_COIN"],
147 | available_corpus["timeit_DiDeMo"],
148 | available_corpus["timeit_HiREST"],
149 | available_corpus["timeit_QuerYD"],
150 | available_corpus["timeit_TVSum"],
151 | available_corpus["timeit_ViTT"],
152 | available_corpus["timeit_yttemporal180m"],
153 | #Conv
154 | available_corpus["conversation_videochatgpt"],
155 | available_corpus["conversation_videochat2"],
156 | available_corpus["conversation_videochat1"],
157 | #DvcVqa
158 | available_corpus["caption_youcook2"],
159 | available_corpus["vqa_ego_qa"],
160 | #Gro
161 | available_corpus["grounding_ANetRTL"],
162 | available_corpus["grounding_IntrenvidVTime_100K"],
163 | available_corpus["grounding_ANetHL2"],
164 | available_corpus["grounding_CosmoCap_93K"],
165 | available_corpus["vqa_moviechat"],
166 | available_corpus["caption_moviechat"],
167 | available_corpus["reasoning_star"],
168 | ]
169 |
170 |
171 |
172 | available_corpus["FT_Temporal_Grounding_Both"] = [
173 | available_corpus["FT_Charades"],
174 | available_corpus["FT_QVHighlights"],
175 | available_corpus["grounding_ANetHL2"],
176 | available_corpus["caption_youcook2"],
177 | ]
--------------------------------------------------------------------------------
/configs/model.py:
--------------------------------------------------------------------------------
1 | TextEncoders = dict()
2 | TextEncoders["bert"] = dict(
3 | name="bert_base",
4 | pretrained="bert-base-uncased",
5 | config="configs/config_bert.json",
6 | d_model=768,
7 | fusion_layer=9,
8 | )
--------------------------------------------------------------------------------
/dataset/TimeIT/dense_video_captioning/anet/instruct_dvc_10.0k_anet.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/36470518c0a555bbc7e7ae0b30393441ec533e03
--------------------------------------------------------------------------------
/dataset/TimeIT/dense_video_captioning/anet/instruct_dvc_10.0k_anet_15asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/2d6aad3236b910b1877aa8058dd0be19b3f333b7cefebd1f6c852880d13a6dc3
--------------------------------------------------------------------------------
/dataset/TimeIT/dense_video_captioning/dense_video_captioning_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. A specific example is: ' 90 - 102 seconds, spread margarine on two slices of white bread in the video'.",
3 | "1": "Determine the start and end times of various activity events in the video, accompanied by descriptions.",
4 | "2": "Capture and describe the activity events in the given video, specifying their respective time intervals, and outputting the time intervals in the 'start - end seconds format'.",
5 | "3": "Identify, timestamp, and describe various activity events occurring in the video. The timestamp should include the start time and end time in seconds.",
6 | "4": "Detect and report the start and end timestamps of activity events in the video, along with descriptions.",
7 | "5": "Pinpoint the time intervals of activity events in the video, and provide detailed descriptions for each event."
8 | }
--------------------------------------------------------------------------------
/dataset/TimeIT/dense_video_captioning/vitt/instruct_dvc_5.1k_vitt_15asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/1dc9787ee6fa38f8c3223b14eb10da0efdfa1c17ef9f0dea77fafd5425a5c5dc
--------------------------------------------------------------------------------
/dataset/TimeIT/dense_video_captioning/youcook2/train.caption_coco_format.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/d257625e0c84e7a4c6dcbbed5054e0fc3053e6c6
--------------------------------------------------------------------------------
/dataset/TimeIT/step_localization/step_localization_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "Localize a series of action steps in the given video, output a start and end timestamp for each step, and briefly describe the step. ",
3 | "1": "Locate and describe a series of actions or steps in the video, including their start and end timestamps.",
4 | "2": "Identify and mark the video segments corresponding to a series of actions or steps, specifying the timestamps and describing the steps.",
5 | "3": "Find, identify, and determine the temporal boundaries of a series of distinct actions or steps occurring throughout the video. For each action, output the corresponding start and end timestamps, accompanied by a concise description.",
6 | "4": "Identify and localize a series of steps or actions occurring in the video, providing start and end timestamps and related descriptions.",
7 | "5": "Locate and pinpoint a sequential series of specific actions or steps in the video, accurately specifying the start and end timestamps for each action. Additionally, provide a succinct description of each action."
8 | }
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation/get_coco_format.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | from copy import deepcopy
5 | import pdb
6 | import numpy as np
7 | import random
8 | from pathlib import Path
9 |
10 |
11 | # read json files
12 | def read_json(path):
13 | with open(path, "r") as fin:
14 | datas = json.load(fin)
15 | annos = datas["annotations"]
16 | return annos
17 |
18 |
19 | def read_jsonl(path):
20 | anno = []
21 | with open(path, "r") as fin:
22 | datas = fin.readlines()
23 | for data in datas:
24 | anno.append(json.loads(data.strip()))
25 | return anno
26 |
27 |
28 |
29 | def write_json(data, path):
30 | with open(path, "w") as fout:
31 | json.dump(data, fout)
32 | return
33 |
34 |
35 | def read_txt(path):
36 | data = []
37 | with open(path, "r") as fin:
38 | lines = fin.readlines()
39 | for i, line in enumerate(lines):
40 | # e.g. AO8RW 0.0 6.9##a person is putting a book on a shelf.
41 | line = line.strip("\n")
42 | cap = line.split("##")[-1]
43 | if len(cap) < 2:
44 | continue
45 | terms = line.split("##")[0].split(" ")
46 | vid = terms[0] + ".mp4"
47 | start_time = float(terms[1])
48 | end_time = float(terms[2])
49 | data.append({"image_id": vid, "caption": cap, "timestamp": [start_time, end_time], "id": i})
50 | return data
51 |
52 |
53 | def filter_sent(sent):
54 | sent = sent.strip(" ")
55 | if len(sent) < 2:
56 | return False
57 | sent = sent.replace("#", "")
58 | return sent
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--dataset', default='charades') # anet
64 | parser.add_argument('--anno_path', default='/home/yaolinli/dataset/Charades/charades_annotation/')
65 | parser.add_argument('--video_path', default='/home/yaolinli/dataset/Charades/videos') # ActivityNet_asr_denseCap/anet_6fps_224
66 | parser.add_argument('--outpath', default='./')
67 | args = parser.parse_args()
68 | '''output data example:
69 | {
70 | "annotations": [
71 | {
72 | "image_id": "3MSZA.mp4",
73 | "caption": "person turn a light on.",
74 | "timestamp": [24.3, 30.4],
75 | }],
76 | }
77 | '''
78 |
79 | for split in ["train", "test"]: # "val", "test"
80 | if args.dataset == "charades":
81 | filename = f"charades_sta_{split}.txt"
82 | annos = read_txt(os.path.join(args.anno_path, filename))
83 | data = {}
84 | data["annotations"] = annos
85 |
86 | else:
87 | print("Do not support this dataset!")
88 | exit(0)
89 |
90 | print(f"==> {args.dataset} dataset \t# examples num: {len(annos)}")
91 | out_name = "{}.caption_coco_format.json".format(split)
92 | Path(args.outpath).mkdir(parents=True, exist_ok=True)
93 | write_json(data, os.path.join(args.outpath, out_name))
94 |
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/didemo/instruct_tvg_33.0k_didemo.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/2f0961ca863a66444d1149b608439787ba413ddf
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/didemo/instruct_tvg_33.0k_didemo_15asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/a7c74718716319c268c762e5db50be0289d8e679467db97e2f8460b2aacd6677
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/queryd/instruct_tvg_14.6k_queryd.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/77267f35d35c1c44a512e024114b5b6c8b88ac40
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/queryd/instruct_tvg_14.6k_queryd_15asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/c4e98592e6d5e6d604c66ef1923f5385f0e080978fc1945c019ae28cf02bf342
--------------------------------------------------------------------------------
/dataset/TimeIT/temporal_video_grounding/temporal_video_grounding_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "Localize the visual content described by the given textual query in the video, and output the start and end timestamps in seconds.",
3 | "1": "Detect and report the start and end timestamps of the video segment that semantically matches the given textual query .",
4 | "2": "Give you a textual query: When does the described content occur in the video? Please return the timestamp in seconds.",
5 | "3": "Locate and describe the visual content mentioned in the text query within the video, including timestamps.",
6 | "4": "The given natural language query is semantically aligned with a video moment, please give the start time and end time of the video moment.",
7 | "5": "Find the video segment that corresponds to the given textual query and determine its start and end seconds."
8 | }
--------------------------------------------------------------------------------
/dataset/TimeIT/time/instruct_time-sensitive_104k.json:
--------------------------------------------------------------------------------
1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/f343a7a774a15b3b9cf05d3bca346d75397246c50d50367a0f19cd90c55460ec
--------------------------------------------------------------------------------
/dataset/TimeIT/time/instruct_time-sensitive_104k_asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/01c33679dc0b6f8fda5040441b636765c374709d0103f3c9dc285f47f8d1feff
--------------------------------------------------------------------------------
/dataset/TimeIT/transcribed_speech_generation/transcribed_speech_generation_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "After watching the video from the YTTemporal dataset, transcribe the spoken content into text and document the start and end time for each segment. The format should be: 'start time - end time, transcribed speech'.",
3 | "1": "Observe the video thoroughly and transcribe the speech in a maximum of 20 segments. Make sure to include the starting and ending times for each segment in the following format: 'start time - end time, transcribed speech'.",
4 | "2": "Watch the provided video and transcribe the audio content. For each transcribed speech segment, note down its duration in the format: 'start time - end time, transcribed speech'.",
5 | "3": "Review the video from the YTTemporal dataset. Identify segments where speech occurs and transcribe those into text. Record the start and end time for each segment in this format: 'start time - end time, transcribed speech'.",
6 | "4": "Transcribe the spoken words in the video and note down the timestamps for each segment. Your output should look like this: 'start time - end time, transcribed speech'.",
7 | "5": "Watch the video, transcribe the speech, and indicate when each segment starts and ends. Follow this format: 'start time - end time, transcribed speech'."
8 | }
--------------------------------------------------------------------------------
/dataset/TimeIT/transcribed_speech_generation/yttemporal/instruct_tsg_31.6k_yttemporal.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/9156641dad2e4c413db3568ae1b409f65ad2c22392de17396955922536a2184a
--------------------------------------------------------------------------------
/dataset/TimeIT/transcribed_speech_generation/yttemporal/instruct_tsg_31.6k_yttemporal_15asr.json:
--------------------------------------------------------------------------------
1 | ../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/a33f3d1606cde5cc3da373e3737f7ba036697aa196d2b85b185ba02d41507c78
--------------------------------------------------------------------------------
/dataset/TimeIT/valley/Valley_instruct_73k.json:
--------------------------------------------------------------------------------
1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/50425f539707f196b66957997d197da33da969ea5054f4884eae7e8ee17490a0
--------------------------------------------------------------------------------
/dataset/TimeIT/valley/instruct_valley_72k.json:
--------------------------------------------------------------------------------
1 | ../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/4370bbe8cb6fd279898e2d7fd6eb0c76e34bb6bad80a631407edaf042bbf09c1
--------------------------------------------------------------------------------
/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw/README.md:
--------------------------------------------------------------------------------
1 | ## QVHighlights Dataset
2 |
3 | Our annotation files include 3 splits: `train`, `val` and `test`. Each file is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of the annotation:
4 |
5 | ```
6 | {
7 | "qid": 8737,
8 | "query": "A family is playing basketball together on a green court outside.",
9 | "duration": 126,
10 | "vid": "bP5KfdFJzC4_660.0_810.0",
11 | "relevant_windows": [[0, 16]],
12 | "relevant_clip_ids": [0, 1, 2, 3, 4, 5, 6, 7],
13 | "saliency_scores": [[4, 1, 1], [4, 1, 1], [4, 2, 1], [4, 3, 2], [4, 3, 2], [4, 3, 3], [4, 3, 3], [4, 3, 2]]
14 | }
15 | ```
16 | `qid` is a unique identifier of a `query`. This query corresponds to a video identified by its video id `vid`. The `vid` is formatted as `{youtube_id}_{start_time}_{end_time}`. Use this information, one can retrieve the YouTube video from a url `https://www.youtube.com/embed/{youtube_id}?start={start_time}&end={end_time}&version=3`. For example, the video in this example is `https://www.youtube.com/embed/bP5KfdFJzC4?start=660&end=810&version=3`.
17 | `duration` is an integer indicating the duration of this video.
18 | `relevant_windows` is the list of windows that localize the moments, each window has two numbers, one indicates the start time of the moment, another one indicates the end time. `relevant_clip_ids` is the list of ids to the segmented 2-second clips that fall into the moments specified by `relevant_windows`, starting from 0.
19 | `saliency_scores` contains the saliency scores annotations, each sublist corresponds to a clip in `relevant_clip_ids`. There are 3 elements in each sublist, they are the scores from three different annotators. A score of `4` means `Very Good`, while `0` means `Very Bad`.
20 |
21 | Note that the three fields `relevant_clip_ids`, `relevant_windows` and `saliency_scores` for `test` split is not included. Please refer to [../standalone_eval/README.md](../standalone_eval/README.md) for details on evaluating predictions on `test`.
22 |
23 | In addition to the annotation files, we also provided the subtitle file for our weakly supervised ASR pre-training: [subs_train.jsonl](./subs_train.jsonl). This file is formatted similarly as our annotation files, but without the `saliency_scores` entry. This file is not needed if you do not plan to pretrain models using it.
24 |
25 |
--------------------------------------------------------------------------------
/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw/subs_train.jsonl:
--------------------------------------------------------------------------------
1 | ../../../../../../../../.cache/huggingface/hub/datasets--ShuhuaiRen--TimeIT/blobs/33086f20181724a477d7da5b2a063e7935d04d886548917b8ba9912f5a0c7dc5
--------------------------------------------------------------------------------
/dataset/TimeIT/video_highlight_detection/qvhighlights/get_coco_format.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | from copy import deepcopy
5 | import pdb
6 | import numpy as np
7 | import random
8 | from pathlib import Path
9 | from collections import Counter
10 |
11 | # read json files
12 | def read_json(path):
13 | with open(path, "r") as fin:
14 | datas = json.load(fin)
15 | annos = datas["annotations"]
16 | return annos
17 |
18 |
19 | def read_jsonl(path):
20 | anno = []
21 | with open(path, "r") as fin:
22 | datas = fin.readlines()
23 | for data in datas:
24 | anno.append(json.loads(data.strip()))
25 | return anno
26 |
27 |
28 |
29 | def write_json(data, path):
30 | with open(path, "w") as fout:
31 | json.dump(data, fout)
32 | return
33 |
34 |
35 | def read_txt(path):
36 | data = []
37 | with open(path, "r") as fin:
38 | lines = fin.readlines()
39 | for i, line in enumerate(lines):
40 | # e.g. AO8RW 0.0 6.9##a person is putting a book on a shelf.
41 | line = line.strip("\n")
42 | cap = line.split("##")[-1]
43 | if len(cap) < 2:
44 | continue
45 | terms = line.split("##")[0].split(" ")
46 | vid = terms[0] + ".mp4"
47 | start_time = float(terms[1])
48 | end_time = float(terms[2])
49 | data.append({"image_id": vid, "caption": cap, "timestamp": [start_time, end_time], "id": i})
50 | return data
51 |
52 |
53 | def filter_sent(sent):
54 | sent = sent.strip(" ")
55 | if len(sent) < 2:
56 | return False
57 | sent = sent.replace("#", "")
58 | return sent
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--dataset', default='qvhighlights') # anet
64 | parser.add_argument('--anno_path', default='annotations_raw/')
65 | parser.add_argument('--video_path', default='videos/') # ActivityNet_asr_denseCap/anet_6fps_224
66 | parser.add_argument('--outpath', default='./')
67 | args = parser.parse_args()
68 | '''output data example:
69 | {
70 | "annotations": [
71 | {
72 | "image_id": "3MSZA.mp4",
73 | "caption": "person turn a light on.",
74 | "timestamp": [24.3, 30.4],
75 | }],
76 | }
77 | '''
78 | miss_videos = []
79 | num_clips = []
80 | for split in ["train", "val"]: # "val", "test"
81 | if args.dataset == "charades":
82 | filename = f"charades_sta_{split}.txt"
83 | annos = read_txt(os.path.join(args.anno_path, filename))
84 | data = {}
85 | data["annotations"] = annos
86 | elif args.dataset == "qvhighlights":
87 | filename = f"highlight_{split}_release.jsonl"
88 | annos = read_jsonl(os.path.join(args.anno_path, filename))
89 | new_data = []
90 | for jterm in annos:
91 | new_term = {}
92 | new_term["image_id"] = "v_" + jterm["vid"] + ".mp4"
93 | # check the existance of the video
94 | if not os.path.exists(os.path.join(args.video_path, split, new_term["image_id"])):
95 | miss_videos.append(new_term["image_id"])
96 | continue
97 | new_term["id"] = jterm["qid"]
98 | new_term["caption"] = jterm["query"]
99 | new_term["timestamp"] = jterm["relevant_windows"]
100 | new_term["duration"] = jterm["duration"]
101 | new_term["relevant_clip_ids"] = jterm["relevant_clip_ids"]
102 | new_term["saliency_scores"] = jterm["saliency_scores"]
103 | new_data.append(new_term)
104 | num_clips.append(int(jterm["duration"]/2))
105 | data = {}
106 | data["annotations"] = new_data
107 | else:
108 | print("Do not support this dataset!")
109 | exit(0)
110 |
111 | print(f"==> {args.dataset} dataset \t# examples num: {len(new_data)} \t# miss videos num: {len(miss_videos)}\t# raw data num: {len(annos)}")
112 | out_name = "{}.caption_coco_format.json".format(split)
113 | Path(args.outpath).mkdir(parents=True, exist_ok=True)
114 | write_json(data, os.path.join(args.outpath, out_name))
115 |
116 | if len(num_clips) >= 1:
117 | count = Counter(num_clips)
118 | # sort count dict with the clip num
119 | print(count)
120 | print(max(list(count.keys())))
121 |
122 |
--------------------------------------------------------------------------------
/dataset/TimeIT/video_highlight_detection/video_highlight_detection_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "You are given a video from the QVHighlights dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 seconds. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: . Please return the query-based highlight timestamps and salient scores.",
3 | "1": "Watch the provided video and mark out the scenes that stand out based on the description: . Document the timestamps of these highlights and evaluate their saliency scores.",
4 | "2": "Perform a thorough review of the video content, extracting key highlight moments that align with . It is essential to record the times of these moments and assign a distinct saliency value to each.",
5 | "3": "Examine the video and, in accordance with query , highlight the standout moments. You're required to provide the exact timing alongside a saliency rating for each segment.",
6 | "4": "In the video presented, seek moments that are a perfect match with . It's vital to notate their timestamps and to score each based on their level of saliency.",
7 | "5": "Go through the video content, and upon identifying highlight moments that resonate with , list their timestamps. Subsequently, provide a saliency score for each identified highlight."
8 | }
--------------------------------------------------------------------------------
/dataset/TimeIT/video_summarization/video_summarization_instructions.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "From the dataset, generate a summarized version of the video, focusing on extracting key frames that best represent the overall narrative. The output should be a list of timestamps in seconds and their corresponding salient scores",
3 | "1": "You are given a video from the dataset. Please find the highlight contents in the video, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. ",
4 | "2": "Identify and extract the most emotionally impactful moments from the video provided by dataset, rating their intensity on a scale from 1 to 5.",
5 | "3": "Watch the provided video from the dataset and mark out the timestamps with stand-out visual content. Document the timestamps of these highlights and evaluate their saliency scores.",
6 | "4": "In the video presented from dataset, seek moments that could serve as an executive summary for a busy stakeholder. It's vital to notate their timestamps and to score each based on their level of saliency.",
7 | "5": "Go through the video content from dataset, and upon identifying highlight moments, list their timestamps. Subsequently, provide a saliency score for each identified highlight."
8 | }
--------------------------------------------------------------------------------
/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from torch.utils.data import Dataset
5 |
6 | from dataset.utils import load_image_from_path
7 |
8 | try:
9 | from petrel_client.client import Client
10 | has_client = True
11 | except ImportError:
12 | has_client = False
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class ImageVideoBaseDataset(Dataset):
18 | """Base class that implements the image and video loading methods"""
19 |
20 | media_type = "video"
21 |
22 | def __init__(self):
23 | assert self.media_type in ["image", "video", "only_video"]
24 | self.data_root = None
25 | self.anno_list = (
26 | None # list(dict), each dict contains {"image": str, # image or video path}
27 | )
28 | self.transform = None
29 | self.video_reader = None
30 | self.num_tries = None
31 |
32 | self.client = None
33 | if has_client:
34 | self.client = Client('~/petreloss.conf')
35 |
36 | def __getitem__(self, index):
37 | raise NotImplementedError
38 |
39 | def __len__(self):
40 | raise NotImplementedError
41 |
42 | def get_anno(self, index):
43 | """obtain the annotation for one media (video or image)
44 |
45 | Args:
46 | index (int): The media index.
47 |
48 | Returns: dict.
49 | - "image": the filename, video also use "image".
50 | - "caption": The caption for this file.
51 |
52 | """
53 | anno = self.anno_list[index]
54 | if self.data_root is not None:
55 | anno["image"] = os.path.join(self.data_root, anno["image"])
56 | return anno
57 |
58 | def load_and_transform_media_data(self, index, data_path):
59 | if self.media_type == "image":
60 | return self.load_and_transform_media_data_image(index, data_path)
61 | else:
62 | return self.load_and_transform_media_data_video(index, data_path)
63 |
64 | def load_and_transform_media_data_image(self, index, data_path):
65 | image = load_image_from_path(data_path, client=self.client)
66 | image = self.transform(image)
67 | return image, index
68 |
69 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None):
70 | for _ in range(self.num_tries):
71 | try:
72 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
73 | frames, frame_indices, fps = self.video_reader(
74 | data_path, self.num_frames, self.sample_type,
75 | max_num_frames=max_num_frames, client=self.client, clip=clip
76 | )
77 | except Exception as e:
78 | logger.warning(
79 | f"Caught exception {e} when loading video {data_path}, "
80 | f"randomly sample a new video as replacement"
81 | )
82 | index = random.randint(0, len(self) - 1)
83 | ann = self.get_anno(index)
84 | data_path = ann["image"]
85 | continue
86 | # shared aug for video frames
87 | frames = self.transform(frames)
88 | if return_fps:
89 | if clip is None:
90 | sec = [str(round(f / fps, 1)) for f in frame_indices]
91 | else:
92 | sec = [str(round( abs(f / fps - clip[0] + 0.1), 1)) for f in frame_indices]
93 | return frames, index, sec
94 | else:
95 | return frames, index
96 | else:
97 | raise RuntimeError(
98 | f"Failed to fetch video after {self.num_tries} tries. "
99 | f"This might indicate that you have many corrupted videos."
100 | )
101 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
4 | import random
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class MetaLoader(object):
11 | """ wraps multiple data loader """
12 | def __init__(self, name2loader):
13 | """Iterates over multiple dataloaders, it ensures all processes
14 | work on data from the same dataloader. This loader will end when
15 | the shorter dataloader raises StopIteration exception.
16 |
17 | loaders: Dict, {name: dataloader}
18 | """
19 | self.name2loader = name2loader
20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
22 | index2name = {v: k for k, v in name2index.items()}
23 |
24 | iter_order = []
25 | for n, l in name2loader.items():
26 | iter_order.extend([name2index[n]]*len(l))
27 |
28 | random.shuffle(iter_order)
29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
30 |
31 | # sync
32 | if is_dist_avail_and_initialized():
33 | # make sure all processes have the same order so that
34 | # each step they will have data from the same loader
35 | dist.broadcast(iter_order, src=0)
36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
37 |
38 | logger.info(str(self))
39 |
40 | def __str__(self):
41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
42 | for idx, (name, loader) in enumerate(self.name2loader.items()):
43 | output.append(
44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
45 | )
46 | return "\n".join(output)
47 |
48 | def __len__(self):
49 | return len(self.iter_order)
50 |
51 | def __iter__(self):
52 | """ this iterator will run indefinitely """
53 | for name in self.iter_order:
54 | _iter = self.name2iter[name]
55 | batch = next(_iter)
56 | yield name, batch
57 |
58 |
59 | class MetaLoader_rs(object):
60 | """ wraps multiple data loader """
61 | def __init__(self, name2loader, skip_num=0):
62 | """Iterates over multiple dataloaders, it ensures all processes
63 | work on data from the same dataloader. This loader will end when
64 | the shorter dataloader raises StopIteration exception.
65 |
66 | loaders: Dict, {name: dataloader}
67 | """
68 | self.name2loader = name2loader
69 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
70 | index2name = {v: k for k, v in name2index.items()}
71 |
72 | iter_order = []
73 | for n, l in name2loader.items():
74 | iter_order.extend([name2index[n]]*len(l))
75 |
76 | random.shuffle(iter_order)
77 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
78 |
79 | # sync
80 | if is_dist_avail_and_initialized():
81 | # make sure all processes have the same order so that
82 | # each step they will have data from the same loader
83 | dist.broadcast(iter_order, src=0)
84 |
85 | if skip_num > 0:
86 | iter_order_skip = iter_order[:skip_num]
87 | for k, v in index2name.items():
88 | media_step = (iter_order_skip == k).sum().item()
89 | name2loader[v].sampler.set_start_iter(media_step)
90 | logger.info(f"{v} dataloder skip steps: {media_step}")
91 | iter_order = iter_order[skip_num:]
92 | self.name2loader = name2loader
93 | else:
94 | logger.info("Do not skip steps for any dataloader!")
95 | for k, v in index2name.items():
96 | name2loader[v].sampler.set_start_iter(0)
97 |
98 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
99 | self.iter_idx = iter_order
100 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
101 |
102 | logger.info(str(self))
103 |
104 | def __str__(self):
105 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
106 | for idx, (name, loader) in enumerate(self.name2loader.items()):
107 | length = (self.iter_idx == idx).sum()
108 | output.append(
109 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} "
110 | )
111 | return "\n".join(output)
112 |
113 | def __len__(self):
114 | return len(self.iter_order)
115 |
116 | def __iter__(self):
117 | """ this iterator will run indefinitely """
118 | for name in self.iter_order:
119 | _iter = self.name2iter[name]
120 | batch = next(_iter)
121 | yield name, batch
--------------------------------------------------------------------------------
/dataset/it_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import random
5 |
6 | import numpy as np
7 |
8 | from dataset.base_dataset import ImageVideoBaseDataset
9 | from dataset.video_utils import VIDEO_READER_FUNCS
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class ITImgTrainDataset(ImageVideoBaseDataset):
15 | media_type = "image"
16 |
17 | def __init__(
18 | self, ann_file, transform,
19 | system="", role=("Human", "Assistant"),
20 | start_token="", end_token="",
21 | random_shuffle=True, # if True, shuffle the QA list
22 | ):
23 | super().__init__()
24 |
25 | if len(ann_file) == 3 and ann_file[2] == "video":
26 | self.media_type = "video"
27 | else:
28 | self.media_type = "image"
29 | self.label_file, self.data_root = ann_file[:2]
30 |
31 | logger.info('Load json file')
32 | with open(self.label_file, 'r') as f:
33 | self.anno = json.load(f)
34 | self.num_examples = len(self.anno)
35 | self.transform = transform
36 |
37 | # prompt parameters
38 | if system:
39 | assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
40 | # currently not support add start_token and end_token in the system, since the msg should be added properly
41 | self.begin_signal = "###"
42 | self.end_signal = " "
43 | self.start_token = start_token
44 | self.end_token = end_token
45 | self.system = system
46 | self.role = role
47 | self.random_shuffle = random_shuffle
48 | # instruction location and number
49 | logger.info(f"Random shuffle: {self.random_shuffle}")
50 |
51 | def get_anno(self, index):
52 | filename = self.anno[index][self.media_type]
53 | qa = self.anno[index]["QA"]
54 | if "start" in self.anno[index] and "end" in self.anno[index]:
55 | anno = {
56 | "image": os.path.join(self.data_root, filename), "qa": qa,
57 | "start": self.anno[index]["start"], "end": self.anno[index]["end"],
58 | }
59 | else:
60 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
61 | return anno
62 |
63 | def __len__(self):
64 | return self.num_examples
65 |
66 | def process_qa(self, qa, msg=""):
67 | cur_instruction = ""
68 | # randomly shuffle qa for conversation
69 | if self.random_shuffle and len(qa) > 1:
70 | random.shuffle(qa)
71 | if "i" in qa[0].keys() and qa[0]["i"] != "":
72 | cur_instruction = qa[0]["i"] + self.end_signal
73 |
74 | conversation = self.system
75 | # add instruction as system message
76 | if cur_instruction:
77 | conversation += cur_instruction
78 |
79 | # rstrip() for the extra " " in msg
80 | conversation += (
81 | self.begin_signal + self.role[0] + ": " +
82 | self.start_token + self.end_token + msg.rstrip() + self.end_signal
83 | )
84 |
85 | for sentence in qa:
86 | q = sentence["q"]
87 | a = sentence["a"]
88 | if q != "":
89 | conversation += (self.begin_signal + self.role[0] + ": " + q + self.end_signal)
90 | else:
91 | # no question, often in caption dataset
92 | pass
93 | conversation += (self.begin_signal + self.role[1] + ": " + a + self.end_signal)
94 | conversation += self.begin_signal
95 |
96 | if cur_instruction:
97 | cur_instruction += qa[0]["q"]
98 | return conversation, cur_instruction.strip()
99 |
100 | def __getitem__(self, index):
101 | try:
102 | ann = self.get_anno(index)
103 | image, index = self.load_and_transform_media_data_image(index, ann["image"])
104 | conversation, instruction = self.process_qa(ann["qa"])
105 | return image, conversation, instruction, index
106 | except Exception as e:
107 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
108 | index = np.random.randint(0, len(self))
109 | return self.__getitem__(index)
110 |
111 |
112 | class ITVidTrainDataset(ITImgTrainDataset):
113 | media_type = "video"
114 |
115 | def __init__(
116 | self, ann_file, transform,
117 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
118 | system="", role=("Human", "Assistant"),
119 | start_token="",
120 | add_second_msg=True,
121 | random_shuffle=True,
122 | ):
123 | super().__init__(
124 | ann_file, transform,
125 | system=system, role=role,
126 | start_token=start_token, end_token=end_token,
127 | random_shuffle=random_shuffle,
128 | )
129 | self.num_frames = num_frames
130 | self.video_reader_type = video_reader_type
131 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
132 | self.sample_type = sample_type
133 | self.num_tries = num_tries
134 | self.add_second_msg = add_second_msg
135 |
136 | logger.info(f"Use {video_reader_type} for data in {ann_file}")
137 | if add_second_msg:
138 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
139 |
140 | def __getitem__(self, index):
141 | try:
142 | ann = self.get_anno(index)
143 | msg = ""
144 | clip = None
145 | if "start" in ann and "end" in ann:
146 | clip = [ann["start"], ann["end"]]
147 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip)
148 | if self.add_second_msg:
149 | # " " should be added in the start and end
150 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
151 | conversation, instruction = self.process_qa(ann["qa"], msg)
152 | return video, conversation, instruction, index
153 | except Exception as e:
154 | logger.warning(f"Caught exception {e} when loading video {ann['image']}")
155 | index = np.random.randint(0, len(self))
156 | return self.__getitem__(index)
--------------------------------------------------------------------------------
/dataset/it_dataset_mistral.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import random
5 |
6 | import numpy as np
7 |
8 | from dataset.base_dataset import ImageVideoBaseDataset
9 | from dataset.video_utils import VIDEO_READER_FUNCS
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class ITImgTrainDataset_mistral(ImageVideoBaseDataset):
15 | media_type = "image"
16 |
17 | def __init__(
18 | self, ann_file, transform,
19 | system="",
20 | start_token="", end_token="",
21 | random_shuffle=True, # if True, shuffle the QA list
22 | return_question_instruction=False # if True, return instruction with instruciton
23 | ):
24 | super().__init__()
25 |
26 | if len(ann_file) == 3 and ann_file[2] == "video":
27 | self.media_type = "video"
28 | else:
29 | self.media_type = "image"
30 | self.label_file, self.data_root = ann_file[:2]
31 |
32 | logger.info('Load json file')
33 | with open(self.label_file, 'r') as f:
34 | self.anno = json.load(f)
35 | self.num_examples = len(self.anno)
36 | self.transform = transform
37 |
38 | # prompt parameters
39 | if system:
40 | assert system[-1] == " ", "' ' should be add in the end of system."
41 |
42 | self.human_start = "[INST]"
43 | self.human_end = "[/INST]"
44 | self.assist_end = ""
45 | self.start_token = start_token
46 | self.end_token = end_token
47 | self.system = system
48 | self.random_shuffle = random_shuffle
49 | # instruction location and number
50 | self.return_question_instruction = return_question_instruction
51 | logger.info(f"Random shuffle: {self.random_shuffle}")
52 | logger.info(f"Return question with instruction: {self.return_question_instruction}")
53 |
54 | def get_anno(self, index):
55 | filename = self.anno[index][self.media_type]
56 | qa = self.anno[index]["QA"]
57 | if "start" in self.anno[index] and "end" in self.anno[index]:
58 | anno = {
59 | "image": os.path.join(self.data_root, filename), "qa": qa,
60 | "start": self.anno[index]["start"], "end": self.anno[index]["end"],
61 | }
62 | else:
63 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
64 | return anno
65 |
66 | def __len__(self):
67 | return self.num_examples
68 |
69 | def process_qa(self, qa, msg=""):
70 | cur_instruction = ""
71 | # randomly shuffle qa for conversation
72 | if self.random_shuffle and len(qa) > 1:
73 | random.shuffle(qa)
74 | if "i" in qa[0].keys() and qa[0]["i"] != "":
75 | cur_instruction = qa[0]["i"] + " "
76 |
77 | conversation = self.system
78 | # add instruction as system message
79 | if cur_instruction:
80 | conversation += cur_instruction
81 |
82 | # rstrip() for the extra " " in msg
83 | conversation += (
84 | self.human_start + " " + self.start_token + self.end_token + msg.rstrip() + " " + self.human_end
85 | )
86 |
87 | for idx, sentence in enumerate(qa):
88 | q = sentence["q"]
89 | a = sentence["a"]
90 | if q != "":
91 | conversation += (" " + self.human_start + " " + q + " " + self.human_end)
92 | else:
93 | # no question, often in caption dataset
94 | pass
95 | conversation += (" " + a + " " + self.assist_end)
96 |
97 | if self.return_question_instruction and cur_instruction:
98 | cur_instruction += qa[0]["q"]
99 | return conversation.strip(), cur_instruction.strip()
100 |
101 | def __getitem__(self, index):
102 | try:
103 | ann = self.get_anno(index)
104 | image, index = self.load_and_transform_media_data_image(index, ann["image"])
105 | conversation, instruction = self.process_qa(ann["qa"])
106 | return image, conversation, instruction, index
107 | except Exception as e:
108 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
109 | index = np.random.randint(0, len(self))
110 | return self.__getitem__(index)
111 |
112 |
113 | class ITVidTrainDataset_mistral(ITImgTrainDataset_mistral):
114 | media_type = "video"
115 |
116 | def __init__(
117 | self, ann_file, transform,
118 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
119 | system="", start_token="",
120 | add_second_msg=False,
121 | random_shuffle=True,
122 | return_question_instruction=False # if True, return instruction with instruciton
123 | ):
124 | super().__init__(
125 | ann_file, transform,
126 | system=system,
127 | start_token=start_token, end_token=end_token,
128 | random_shuffle=random_shuffle,
129 | return_question_instruction=return_question_instruction
130 | )
131 | self.num_frames = num_frames
132 | self.video_reader_type = video_reader_type
133 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
134 | self.sample_type = sample_type
135 | self.num_tries = num_tries
136 | self.add_second_msg = add_second_msg
137 |
138 | logger.info(f"Use {video_reader_type} for data in {ann_file}")
139 | if add_second_msg:
140 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
141 |
142 | def __getitem__(self, index):
143 | try:
144 | ann = self.get_anno(index)
145 | msg = ""
146 | clip = None
147 | if "start" in ann and "end" in ann:
148 | clip = [ann["start"], ann["end"]]
149 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip)
150 | if self.add_second_msg:
151 | # " " should be added in the start and end
152 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
153 | conversation, instruction = self.process_qa(ann["qa"], msg)
154 | return video, conversation, instruction, index
155 | except Exception as e:
156 | logger.warning(f"Caught exception {e} when loading video {ann['image']}")
157 | index = np.random.randint(0, len(self))
158 | return self.__getitem__(index)
--------------------------------------------------------------------------------
/dataset/pt_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 |
5 | import numpy as np
6 |
7 | from dataset.base_dataset import ImageVideoBaseDataset
8 | from dataset.utils import load_anno, pre_text
9 | from dataset.video_utils import VIDEO_READER_FUNCS
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class PTImgTrainDataset(ImageVideoBaseDataset):
15 | media_type = "image"
16 |
17 | def __init__(self, ann_file, transform, pre_text=True):
18 | super().__init__()
19 |
20 | if len(ann_file) == 3 and ann_file[2] == "video":
21 | self.media_type = "video"
22 | else:
23 | self.media_type = "image"
24 | self.label_file, self.data_root = ann_file[:2]
25 |
26 | logger.info('Load json file')
27 | with open(self.label_file, 'r') as f:
28 | self.anno = json.load(f)
29 | self.num_examples = len(self.anno)
30 |
31 | self.transform = transform
32 | self.pre_text = pre_text
33 | logger.info(f"Pre-process text: {pre_text}")
34 |
35 | def get_anno(self, index):
36 | filename = self.anno[index][self.media_type]
37 | caption = self.anno[index]["caption"]
38 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption}
39 | return anno
40 |
41 | def __len__(self):
42 | return self.num_examples
43 |
44 | def __getitem__(self, index):
45 | try:
46 | ann = self.get_anno(index)
47 | image, index = self.load_and_transform_media_data(index, ann["image"])
48 | caption = pre_text(ann["caption"], pre_text=self.pre_text)
49 | return image, caption, index
50 | except Exception as e:
51 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
52 | index = np.random.randint(0, len(self))
53 | return self.__getitem__(index)
54 |
55 |
56 | class PTVidTrainDataset(PTImgTrainDataset):
57 | media_type = "video"
58 |
59 | def __init__(
60 | self,
61 | ann_file,
62 | transform,
63 | num_frames=4,
64 | video_reader_type="decord",
65 | sample_type="rand",
66 | num_tries=3,
67 | pre_text=True
68 | ):
69 | super().__init__(ann_file, transform, pre_text=pre_text)
70 | self.num_frames = num_frames
71 | self.video_reader_type = video_reader_type
72 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
73 | self.sample_type = sample_type
74 | self.num_tries = num_tries
75 |
76 |
77 | class PTImgEvalDataset(ImageVideoBaseDataset):
78 | media_type = "image"
79 |
80 | def __init__(self, ann_file, transform, has_multi_vision_gt=False):
81 | super(PTImgEvalDataset, self).__init__()
82 | self.raw_anno_list = load_anno(ann_file)
83 | self.transform = transform
84 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth
85 |
86 | self.text = None
87 | self.image = None
88 | self.txt2img = None
89 | self.img2txt = None
90 | self.build_data()
91 |
92 | def build_data(self):
93 | self.text = []
94 | self.image = []
95 | self.txt2img = {}
96 | self.img2txt = {}
97 | if self.has_multi_vision_gt:
98 | self.build_data_multi_img_gt()
99 | else:
100 | self.build_data_multi_txt_gt()
101 | self.anno_list = [dict(image=e) for e in self.image]
102 |
103 | def build_data_multi_img_gt(self):
104 | """each text may have multiple ground_truth image, e.g., ssv2"""
105 | img_id = 0
106 | for txt_id, ann in enumerate(self.raw_anno_list):
107 | self.text.append(pre_text(ann["caption"]))
108 | self.txt2img[txt_id] = []
109 | _images = ann["image"] \
110 | if isinstance(ann["image"], list) else [ann["image"], ]
111 | for i, image in enumerate(_images):
112 | self.image.append(image)
113 | self.txt2img[txt_id].append(img_id)
114 | self.img2txt[img_id] = txt_id
115 | img_id += 1
116 |
117 | def build_data_multi_txt_gt(self):
118 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
119 | txt_id = 0
120 | for img_id, ann in enumerate(self.raw_anno_list):
121 | self.image.append(ann["image"])
122 | self.img2txt[img_id] = []
123 | _captions = ann["caption"] \
124 | if isinstance(ann["caption"], list) else [ann["caption"], ]
125 | for i, caption in enumerate(_captions):
126 | self.text.append(pre_text(caption))
127 | self.img2txt[img_id].append(txt_id)
128 | self.txt2img[txt_id] = img_id
129 | txt_id += 1
130 |
131 | def __len__(self):
132 | return len(self.anno_list)
133 |
134 | def __getitem__(self, index):
135 | ann = self.anno_list[index]
136 | image, index = self.load_and_transform_media_data(index, ann["image"])
137 | return image, index
138 |
139 |
140 | def preprocess_para_retrieval_data(anno_list):
141 | processed_anno_list = []
142 | for d in anno_list:
143 | d["caption"] = " ".join(d.pop("caption"))
144 | processed_anno_list.append(d)
145 | return processed_anno_list
146 |
147 |
148 | class PTVidEvalDataset(PTImgEvalDataset):
149 | media_type = "video"
150 |
151 | def __init__(
152 | self, ann_file, transform, num_frames=4,
153 | video_reader_type="decord", sample_type="rand", num_tries=1,
154 | is_paragraph_retrieval=False, has_multi_vision_gt=False,
155 | ):
156 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt)
157 | self.num_frames = num_frames
158 | self.video_reader_type = video_reader_type
159 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
160 | self.sample_type = sample_type
161 | self.num_tries = num_tries
162 | self.is_paragraph_retrieval = is_paragraph_retrieval
163 |
164 | if is_paragraph_retrieval:
165 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list)
166 | self.build_data()
167 |
--------------------------------------------------------------------------------
/dataset/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import logging
4 | from torch.utils.data.distributed import DistributedSampler
5 |
6 |
7 | # stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93
8 | class StatefulDistributedSampler(DistributedSampler):
9 | """
10 | More fine-grained state DataSampler that uses training iteration and epoch
11 | both for shuffling data. PyTorch DistributedSampler only uses epoch
12 | for the shuffling and starts sampling data from the start. In case of training
13 | on very large data, we train for one epoch only and when we resume training,
14 | we want to resume the data sampler from the training iteration.
15 | """
16 |
17 | def __init__(self, dataset, batch_size=None, seed: int = 0):
18 | """
19 | Initializes the instance of StatefulDistributedSampler. Random seed is set
20 | for the epoch set and data is shuffled. For starting the sampling, use
21 | the start_iter (set to 0 or set by checkpointing resuming) to
22 | sample data from the remaining images.
23 |
24 | Args:
25 | dataset (Dataset): Pytorch dataset that sampler will shuffle
26 | batch_size (int): batch size we want the sampler to sample
27 | seed (int): Seed for the torch generator.
28 | """
29 | super().__init__(dataset, shuffle=False, seed=seed)
30 |
31 | self.start_iter = 0
32 | self.batch_size = batch_size
33 | self.total_size = len(dataset) - (len(dataset) % self.num_replicas)
34 | self.num_samples = self.total_size // self.num_replicas
35 | print(f"rank: {self.rank}: Sampler created...")
36 |
37 | def __iter__(self):
38 | # partition data into num_replicas and optionally shuffle within a rank
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch + self.seed)
41 | shuffling = torch.randperm(self.num_samples, generator=g).tolist()
42 | indices = np.array(
43 | list(
44 | range(
45 | (self.rank * self.num_samples), (self.rank + 1) * self.num_samples
46 | )
47 | )
48 | )[shuffling].tolist()
49 |
50 | # make sure we have correct number of samples per replica
51 | assert len(indices) == self.num_samples
52 | assert self.batch_size > 0, "batch_size not set for the sampler"
53 |
54 | # resume the sampler
55 | start_index = self.start_iter * self.batch_size
56 | indices = indices[start_index:]
57 | return iter(indices)
58 |
59 | def set_start_iter(self, start_iter):
60 | """
61 | Set the iteration number from which the sampling should start. This is
62 | used to find the marker in the data permutation order from where the
63 | sampler should start sampling.
64 | """
65 | self.start_iter = start_iter
66 |
--------------------------------------------------------------------------------
/dataset/video_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
3 | """
4 | import random
5 | import io
6 | import av
7 | import cv2
8 | import decord
9 | import imageio
10 | from decord import VideoReader
11 | import torch
12 | import numpy as np
13 | import math
14 | decord.bridge.set_bridge("torch")
15 |
16 | import logging
17 | logger = logging.getLogger(__name__)
18 |
19 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
20 | """
21 | Converts a present time with the given time base and start_pts offset to seconds.
22 |
23 | Returns:
24 | time_in_seconds (float): The corresponding time in seconds.
25 |
26 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
27 | """
28 | if pts == math.inf:
29 | return math.inf
30 |
31 | return int(pts - start_pts) * time_base
32 |
33 |
34 | def get_pyav_video_duration(video_reader):
35 | video_stream = video_reader.streams.video[0]
36 | video_duration = pts_to_secs(
37 | video_stream.duration,
38 | video_stream.time_base,
39 | video_stream.start_time
40 | )
41 | return float(video_duration)
42 |
43 |
44 | def get_frame_indices_by_fps():
45 | pass
46 |
47 |
48 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
49 | if sample in ["rand", "middle"]: # uniform sampling
50 | acc_samples = min(num_frames, vlen)
51 | # split the video into `acc_samples` intervals, and sample from each interval.
52 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
53 | ranges = []
54 | for idx, interv in enumerate(intervals[:-1]):
55 | ranges.append((interv, intervals[idx + 1] - 1))
56 | if sample == 'rand':
57 | try:
58 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
59 | except:
60 | frame_indices = np.random.permutation(vlen)[:acc_samples]
61 | frame_indices.sort()
62 | frame_indices = list(frame_indices)
63 | elif fix_start is not None:
64 | frame_indices = [x[0] + fix_start for x in ranges]
65 | elif sample == 'middle':
66 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
67 | else:
68 | raise NotImplementedError
69 |
70 | if len(frame_indices) < num_frames: # padded with last frame
71 | padded_frame_indices = [frame_indices[-1]] * num_frames
72 | padded_frame_indices[:len(frame_indices)] = frame_indices
73 | frame_indices = padded_frame_indices
74 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
75 | output_fps = float(sample[3:])
76 | duration = float(vlen) / input_fps
77 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
78 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
79 | frame_indices = np.around(frame_seconds * input_fps).astype(int)
80 | frame_indices = [e for e in frame_indices if e < vlen]
81 | if max_num_frames > 0 and len(frame_indices) > max_num_frames:
82 | frame_indices = frame_indices[:max_num_frames]
83 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
84 | else:
85 | raise ValueError
86 | return frame_indices
87 |
88 |
89 | def read_frames_av(
90 | video_path, num_frames, sample='rand', fix_start=None,
91 | max_num_frames=-1, client=None, clip=None,
92 | ):
93 | reader = av.open(video_path)
94 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
95 | vlen = len(frames)
96 | duration = get_pyav_video_duration(reader)
97 | fps = vlen / float(duration)
98 | frame_indices = get_frame_indices(
99 | num_frames, vlen, sample=sample, fix_start=fix_start,
100 | input_fps=fps, max_num_frames=max_num_frames
101 | )
102 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
103 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
104 | return frames, frame_indices, fps
105 |
106 |
107 | def read_frames_gif(
108 | video_path, num_frames, sample='rand', fix_start=None,
109 | max_num_frames=-1, client=None, clip=None,
110 | ):
111 | if video_path.startswith('s3') or video_path.startswith('p2'):
112 | video_bytes = client.get(video_path)
113 | gif = imageio.get_reader(io.BytesIO(video_bytes))
114 | else:
115 | gif = imageio.get_reader(video_path)
116 | vlen = len(gif)
117 | frame_indices = get_frame_indices(
118 | num_frames, vlen, sample=sample, fix_start=fix_start,
119 | max_num_frames=max_num_frames
120 | )
121 | frames = []
122 | for index, frame in enumerate(gif):
123 | # for index in frame_idxs:
124 | if index in frame_indices:
125 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
126 | frame = torch.from_numpy(frame).byte()
127 | # # (H x W x C) to (C x H x W)
128 | frame = frame.permute(2, 0, 1)
129 | frames.append(frame)
130 | frames = torch.stack(frames) # .float() / 255
131 | return frames, frame_indices, 25. # for tgif
132 |
133 |
134 | def read_frames_decord(
135 | video_path, num_frames, sample='rand', fix_start=None,
136 | max_num_frames=-1, client=None, clip=None
137 | ):
138 | if "s3" in video_path:
139 | video_bytes = client.get(video_path)
140 | video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
141 | else:
142 | video_reader = VideoReader(video_path, num_threads=1)
143 | vlen = len(video_reader)
144 | fps = video_reader.get_avg_fps()
145 | duration = vlen / float(fps)
146 |
147 | if clip:
148 | start, end = clip
149 | duration = end - start
150 | vlen = int(duration * fps)
151 | start_index = int(start * fps)
152 |
153 | frame_indices = get_frame_indices(
154 | num_frames, vlen, sample=sample, fix_start=fix_start,
155 | input_fps=fps, max_num_frames=max_num_frames
156 | )
157 | if clip:
158 | frame_indices = [f + start_index for f in frame_indices]
159 |
160 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
161 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
162 | return frames, frame_indices, float(fps)
163 |
164 |
165 | VIDEO_READER_FUNCS = {
166 | 'av': read_frames_av,
167 | 'decord': read_frames_decord,
168 | 'gif': read_frames_gif,
169 | }
170 |
--------------------------------------------------------------------------------
/demo/example/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/bear.jpg
--------------------------------------------------------------------------------
/demo/example/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/dog.png
--------------------------------------------------------------------------------
/demo/example/jesse_dance.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/jesse_dance.mp4
--------------------------------------------------------------------------------
/demo/example/people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/people.jpg
--------------------------------------------------------------------------------
/demo/example/yoga.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/demo/example/yoga.mp4
--------------------------------------------------------------------------------
/download/folder_keeper:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/download/folder_keeper
--------------------------------------------------------------------------------
/eval/Egoschema_trans_csv.py:
--------------------------------------------------------------------------------
1 | import json
2 | import csv
3 |
4 | file_dir="/path_to_the_timesuite_root_folder/scripts/Ablation/Token_Shufflue/F128_CF8_PoolMax_Ablation_PoolMax/Egoschema_test_ckpt_01"
5 |
6 | # 加载你的json数据
7 | with open(f'{file_dir}/result.json', 'r') as f:
8 | data = json.load(f)
9 |
10 | # 将数据写入CSV文件
11 | with open(f'{file_dir}/result_submit_kaggle.csv', 'w', newline='') as f:
12 | writer = csv.writer(f)
13 | writer.writerow(['q_uid', 'answer'])
14 | for key, value in data.items():
15 | writer.writerow([key, value]) # 写入数据
16 |
--------------------------------------------------------------------------------
/eval/eval_egoschema.sh:
--------------------------------------------------------------------------------
1 | ROOT_DIR="/path_to_the_timesuite_root_folder/download/parameters"
2 |
3 | python3 eval/validate_egoschema.py \
4 | --f ${ROOT_DIR}/Egoschema_test_timesuite/result.json \
5 | 2>&1 | tee "${ROOT_DIR}/Egoschema_test_timesuite/acc.txt"
6 |
--------------------------------------------------------------------------------
/eval/format_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import json
4 |
5 | def read_txt(path):
6 | with open(path, "r") as fin:
7 | data = fin.readline().strip()
8 | return data
9 |
10 |
11 | def load_data(args, anno_path, split=None):
12 | '''
13 | anno data example:
14 | {"annotations":
15 | [
16 | {
17 | "image_id": "xHr8X2Wpmno.mp4"
18 | ...
19 | },
20 | ...
21 | ]
22 | }
23 | '''
24 | file_path = os.path.join(anno_path, f'{split}.caption_coco_format.json')
25 | with open(file_path, 'r') as f:
26 | data = json.load(f)["annotations"]
27 |
28 | if args.debug:
29 | data = data[:10]
30 | return data
31 |
32 |
33 | def merge_seg_caps(results):
34 | """merge mulple generated captions from a same video into paragraph."""
35 | merge_results = {}
36 | for jterm in results:
37 | vname = jterm["vname"]
38 | cap = jterm["generated_cap"]
39 | postfix = vname.split(".mp4")[-1]
40 | start_time, end_time = float(postfix.split("_")[-2]), float(postfix.split("_")[-1])
41 | vid = vname.split(".mp4")[0] + ".mp4"
42 | if vid not in merge_results:
43 | merge_results[vid] = []
44 | merge_results[vid].append({"timestamp": [start_time, end_time], "caption": cap})
45 | return merge_results
46 |
47 |
48 | def save_result(args, output_dir, results, split_name='test', format=False):
49 | Path(output_dir).mkdir(parents=True, exist_ok=True)
50 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result.json'
51 | if args.timestamp:
52 | if args.timestamp_file != '':
53 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result_with_pred_timestamp.json'
54 | else:
55 | file_name = f'{args.dataset}_{split_name}_clipF{args.infer_clip_frames}_result_with_gt_timestamp.json'
56 | if args.debug:
57 | file_name = 'debug_' + file_name
58 | if format:
59 | file_name = 'fmt_' + file_name
60 | with open(os.path.join(output_dir, file_name), 'w') as f:
61 | json.dump(results, f)
62 | return
63 |
64 |
65 | def get_timestamp_from_file(timestamp_file):
66 | timestamp = {}
67 | with open(timestamp_file, 'r') as f:
68 | data = json.load(f)
69 | for vid, vlist in data.items():
70 | timestamp[vid] = []
71 | for vterm in vlist:
72 | timestamp[vid].append(vterm["timestamp"])
73 | return timestamp
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/eval/format_tvg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import re
5 | from copy import deepcopy
6 | import pdb
7 | import numpy as np
8 | from pathlib import Path
9 |
10 | # read json files
11 | def read_json(path):
12 | with open(path, "r") as fin:
13 | datas = json.load(fin)
14 | return datas
15 |
16 |
17 | def write_json(path, data):
18 | with open(path, "w") as fout:
19 | json.dump(data, fout)
20 | print("The format file has been saved at:{}".format(path))
21 | return
22 |
23 |
24 | def extract_time(paragraph):
25 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower()
26 | paragraph = paragraph.lower()
27 | paragraph.replace(prompt, '')
28 | # Split text into sentences based on common delimiters
29 | sentences = re.split(r'[!?\n]', paragraph)
30 |
31 | # Keywords that might indicate the presence of time information
32 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"]
33 | # filter sentences by keywords
34 | candidates = []
35 | for sentence in sentences:
36 | # If sentence contains one of the keywords
37 | if any(keyword in sentence for keyword in keywords):
38 | candidates.append(sentence)
39 |
40 | timestamps = []
41 | # Check for The given query happens in m - n (seconds)
42 | patterns = [
43 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)"
44 | ]
45 |
46 | for time_pattern in patterns:
47 | time_matches = re.findall(time_pattern, paragraph)
48 | if time_matches:
49 | timestamps = [[float(start), float(end)] for start, end in time_matches]
50 |
51 | if len(sentences) == 0:
52 | return []
53 | # check for other formats e.g.:
54 | # 1 .Starting time: 0.8 seconds
55 | # Ending time: 1.1 seconds
56 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds.
57 | if len(timestamps) == 0:
58 | times = []
59 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5)
60 | for sentence in candidates:
61 | time = re.findall(time_regex, sentence)
62 | if time:
63 | time_in_sec = float(time[0])
64 | times.append(time_in_sec)
65 | times = times[:len(times)//2*2]
66 | timestamps = [(times[i], times[i+1]) for i in range(0, len(times), 2)]
67 | # Check for examples like:
68 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23.
69 | if len(timestamps) == 0:
70 | times = []
71 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05)
72 | for sentence in candidates:
73 | time = re.findall(time_regex, sentence)
74 | if time:
75 | t = time[0]
76 | else:
77 | continue
78 | # If time is in HH:MM:SS format, convert to seconds
79 | if t.count(':') == 2:
80 | h, m, s = map(int, t.split(':'))
81 | time_in_sec = h * 3600 + m * 60 + s
82 | elif t.count(':') == 1:
83 | m, s = map(int, t.split(':'))
84 | time_in_sec = m * 60 + s
85 | times.append(time_in_sec)
86 | times = times[:len(times)//2*2]
87 | timestamps = [(times[i], times[i+1]) for i in range(0, len(times), 2)]
88 | results = []
89 | for (start, end) in timestamps:
90 | if end > start:
91 | results.append([start, end])
92 | else:
93 | results.append([end, start])
94 | if len(results) > 1:
95 | results = results[:1]
96 | return results
97 |
98 |
99 | def format_tvg_output(paras):
100 | timestamps = []
101 | # type 1: directly detect timestamps in generated paragraph to process multi-lines cases like:
102 | timestamps = extract_time(paras)
103 |
104 | return timestamps
105 |
106 |
107 | def format_tvg(datas):
108 | fmt_datas = {}
109 | cnt = 0
110 | for i, jterm in enumerate(datas):
111 | vid = jterm["vname"]
112 | query = jterm["query"]
113 | gcap = jterm["generated_cap"]
114 | qid = int(jterm["id"])
115 | timestamps = format_tvg_output(gcap)
116 | if len(timestamps) == 0:
117 | cnt += 1
118 | print(vid, query + "\n", gcap, "\n")
119 | fmt_datas[qid] = {"timestamp": timestamps, "query": query, "vid": vid}
120 | print(f'parse failed number: {cnt}')
121 | return fmt_datas
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument('--inpath', default='/home/yaolinli/code/Ask-Anything/video_chat/output/eval_7b_tvg_charades/charades_test_f8_result.json')
127 | parser.add_argument('--outpath', default='')
128 | args = parser.parse_args()
129 |
130 | datas = read_json(args.inpath)
131 | # example in output file
132 | # {
133 | # "query_idx":
134 | # {
135 | # "timestamp": [47.0, 60.0],
136 | # "query": "a person is shown tying a plant into a bun.",
137 | # "vid": "xHr8X2Wpmno.mp4"
138 | # },
139 | # ...
140 | # }
141 | fmt_datas = {}
142 | cnt = 0
143 | for i, jterm in enumerate(datas):
144 | vid = jterm["vname"]
145 | query = jterm["query"]
146 | gcap = jterm["generated_cap"]
147 | qid = jterm["id"]
148 | timestamps = format_tvg_output(gcap)
149 | if len(timestamps) == 0:
150 | cnt += 1
151 | print(vid, query+"\n", gcap+"\n")
152 | # pdb.set_trace()
153 | else:
154 | # print(gcap)
155 | # print(timestamps)
156 | pass
157 | fmt_datas[qid] = {"timestamp": timestamps, "query": query, "vid": vid}
158 |
159 | print(f'parse failed number: {cnt}')
160 | split = args.inpath.split('/')[-1].split('_')[0]
161 | out_file = args.inpath.split('/')[-2]
162 | out_path = f'{out_file}_{split}.json'
163 | if args.outpath != '':
164 | Path(args.outpath).mkdir(parents=True, exist_ok=True)
165 | out_path = os.path.join(args.outpath, out_path)
166 | write_json(os.path.join(os.getcwd(), out_path), fmt_datas)
167 | else:
168 | infile = args.inpath.split('/')[-1]
169 | outfile = "fmt_" + infile
170 | out_path = args.inpath.replace(infile, outfile)
171 | write_json(out_path, fmt_datas)
--------------------------------------------------------------------------------
/eval/get_grounding_result.sh:
--------------------------------------------------------------------------------
1 | MODEL_DIR="/path_to_the_timesuite_root_folder/download/parameters"
2 |
3 |
4 |
5 | TASK='tvg'
6 | SPLIT='test'
7 | DATASET='charades'
8 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation'
9 | GT_FILE="${ANNO_DIR}/${SPLIT}.caption_coco_format.json"
10 |
11 | sleep 1
12 | MODEL_PTH="timesuite"
13 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}"
14 | PRED_FILE="${RESULT_DIR}/fmt_${DATASET}_${SPLIT}_clipF8_result.json"
15 |
16 | if [ -f "${PRED_FILE}" ]; then
17 | cd metrics/${TASK}
18 | python eval_${TASK}.py \
19 | --gt_file ${GT_FILE} \
20 | --pred_file ${PRED_FILE} \
21 | 2>&1 | tee ${RESULT_DIR}/grounding_result.txt
22 | cd ../..
23 | else
24 | echo "File ${PRED_FILE} not exists. Skipping eval operation."
25 | fi
26 |
27 |
28 | TASK='vhd'
29 | SPLIT='val'
30 | DATASET='qvhighlights'
31 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw'
32 | GT_FILE="${ANNO_DIR}/highlight_${SPLIT}_release.jsonl"
33 |
34 | sleep 1
35 | MODEL_PTH="timesuite"
36 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}"
37 | PRED_FILE="${RESULT_DIR}/fmt_${DATASET}_${SPLIT}_clipF8_result.json"
38 |
39 | if [ -f "${PRED_FILE}" ]; then
40 | cd metrics/${TASK}
41 | python eval_${TASK}.py \
42 | --gt_file ${GT_FILE} \
43 | --pred_file ${PRED_FILE} \
44 | 2>&1 | tee ${RESULT_DIR}/grounding_result.txt
45 | cd ../..
46 | else
47 | echo "File ${PRED_FILE} not exists. Skipping eval operation."
48 | fi
--------------------------------------------------------------------------------
/eval/test_grounding.sh:
--------------------------------------------------------------------------------
1 | MODEL_DIR="/path_to_the_timesuite_root_folder/download/parameters"
2 | MODEL_TYPE="VideoChat2_it4_mistral_LinearProAda"
3 |
4 | TASK='vhd'
5 | SPLIT='val'
6 | DATASET='qvhighlights'
7 | PROMPT_FILE="/path_to_the_timesuite_root_folder/prompts/vhd_description_zeroshot_new.txt"
8 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/video_highlight_detection/qvhighlights/annotations_raw'
9 | GT_FILE="${ANNO_DIR}/highlight_${SPLIT}_release.jsonl"
10 | VIDEO_DIR='pnorm2:s3://qvhighlight/videos'
11 |
12 | sleep 2
13 | MODEL_PTH="timesuite"
14 | PTH_DIR="${MODEL_DIR}/${MODEL_PTH}.pth"
15 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}_new"
16 |
17 | if [ ! -d "${RESULT_DIR}" ] && [ -f "${PTH_DIR}" ]; then
18 | mkdir -p ${RESULT_DIR}
19 | echo "Created directory: ${RESULT_DIR}"
20 | python3 /path_to_the_timesuite_root_folder/eval/eval_infer.py \
21 | --task=${TASK} \
22 | --split=${SPLIT} \
23 | --dataset=${DATASET} \
24 | --prompt_file=${PROMPT_FILE} \
25 | --anno_path=${ANNO_DIR} \
26 | --video_path=${VIDEO_DIR} \
27 | --model_dir=${MODEL_DIR} \
28 | --model_pth=${MODEL_PTH} \
29 | --model_type=${MODEL_TYPE} \
30 | --output_dir=${RESULT_DIR} \
31 | 2>&1 | tee ${RESULT_DIR}/eval_inference.log
32 | else
33 | echo "Directory ${RESULT_DIR} already exists or ${PTH_DIR} not exists. Skipping eval operation."
34 | fi
35 |
36 | TASK='tvg'
37 | SPLIT='test'
38 | DATASET='charades'
39 | PROMPT_FILE="/path_to_the_timesuite_root_folder/prompts/tvg_description_zeroshot.txt"
40 | ANNO_DIR='/path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation'
41 | GT_FILE="${ANNO_DIR}/${SPLIT}.caption_coco_format.json"
42 | VIDEO_DIR="pnorm2zxy:s3://zengxiangyu/Charades/"
43 |
44 |
45 | sleep 2
46 | MODEL_PTH="timesuite"
47 | PTH_DIR="${MODEL_DIR}/${MODEL_PTH}.pth"
48 | RESULT_DIR="${MODEL_DIR}/${TASK}_${SPLIT}_${MODEL_PTH}"
49 |
50 |
51 | if [ ! -d "${RESULT_DIR}" ] && [ -f "${PTH_DIR}" ]; then
52 | mkdir -p ${RESULT_DIR}
53 | echo "Created directory: ${RESULT_DIR}"
54 | python3 /path_to_the_timesuite_root_folder/eval/eval_infer.py \
55 | --task=${TASK} \
56 | --split=${SPLIT} \
57 | --dataset=${DATASET} \
58 | --prompt_file=${PROMPT_FILE} \
59 | --anno_path=${ANNO_DIR} \
60 | --video_path=${VIDEO_DIR} \
61 | --model_dir=${MODEL_DIR} \
62 | --model_pth=${MODEL_PTH} \
63 | --model_type=${MODEL_TYPE} \
64 | --output_dir=${RESULT_DIR} \
65 | 2>&1 | tee ${RESULT_DIR}/eval_inference.log
66 | else
67 | echo "Directory ${RESULT_DIR} already exists or ${PTH_DIR} not exists. Skipping eval operation."
68 | fi
--------------------------------------------------------------------------------
/eval/validate_egoschema.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import requests
3 | import json
4 |
5 | def send_post_request(json_file):
6 | """
7 | Sends a POST request to the specified URL with the given JSON file.
8 |
9 | Parameters:
10 | - json_file (str): Path to the JSON file to be used in the request body.
11 |
12 | Returns:
13 | - Response object containing server's response.
14 | """
15 |
16 | url = "https://validation-server.onrender.com/api/upload/"
17 | headers = {
18 | "Content-Type": "application/json"
19 | }
20 |
21 | with open(json_file, 'r') as f:
22 | data = json.load(f)
23 |
24 | response = requests.post(url, headers=headers, json=data)
25 |
26 | return response
27 |
28 | def main():
29 | """
30 | Main function that parses command-line arguments and sends a POST request.
31 | """
32 |
33 | parser = argparse.ArgumentParser(description="Send a POST request with a JSON file.")
34 | parser.add_argument("--f", required=True, help="Path to the JSON file to be sent with the request.")
35 |
36 | args = parser.parse_args()
37 |
38 | response = send_post_request(args.f)
39 | print(f"Response Status Code: {response.status_code}")
40 | print(f"Response Content:\n{response.text}")
41 |
42 | if __name__ == "__main__":
43 | main()
--------------------------------------------------------------------------------
/images/123:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/abstract.png
--------------------------------------------------------------------------------
/images/data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/data.png
--------------------------------------------------------------------------------
/images/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/images/structure.png
--------------------------------------------------------------------------------
/metrics/README.md:
--------------------------------------------------------------------------------
1 | # Calculate Metrics
2 |
3 | #### Dense video captioning task
4 |
5 | ```
6 | cd dvc/
7 | ```
8 |
9 | Set the `pred_file` and `gt_file` and run:
10 |
11 | ```
12 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file
13 | ```
14 |
15 | Evaluation for paragraph captioning, run:
16 |
17 | ```
18 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file --paragraph
19 | ```
20 |
21 | #### Temporal video grounding task
22 |
23 | ```
24 | cd tvg/
25 | ```
26 |
27 | Set the `pred_file` and `gt_file` and run:
28 |
29 | ```
30 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file
31 | ```
32 |
33 |
34 | #### Video highlight detection task
35 |
36 | ```
37 | cd vhd/
38 | ```
39 |
40 | Set the `pred_file` and `gt_file` and run:
41 |
42 | ```
43 | python eval_highlights.py --pred_file $pred_file --gt_file $gt_file
44 | ```
--------------------------------------------------------------------------------
/metrics/dvc/eval_dvc.sh:
--------------------------------------------------------------------------------
1 | # example in pred file
2 | # {
3 | # "xHr8X2Wpmno.mp4": [
4 | # {
5 | # "timestamp": [47.0, 60.0],
6 | # "caption": "a person is shown tying a plant into a bun and putting the bun into a pink jar."
7 | # },
8 | # ...
9 | # ]
10 | # ...
11 | # }
12 |
13 | pred_file='/home/yaolinli/code/Ask-Anything/video_chat/results/eval_7b_instruct111k_timeit-vally-llava-66k_bz8_f8_epoch3_youcook.json'
14 | gt_file='/home/yaolinli/dataset/YouCook2_asr_denseCap/val.caption_coco_format.json'
15 |
16 | # pred_file='/home/yaolinli/code/Ask-Anything/video_chat/results/eval_7b_instruct11.2k_youcook2-anet_bz8_f8_epoch3_anet.json'
17 | # gt_file='/home/yaolinli/dataset/ActivityNet_asr_denseCap/val.caption_coco_format.json'
18 |
19 | python eval_dvc.py --pred_file $pred_file --gt_file $gt_file --analyze
20 |
21 |
--------------------------------------------------------------------------------
/metrics/dvc/metrics/README.md:
--------------------------------------------------------------------------------
1 | # captioning-metrics
2 |
3 | This is a fork from https://github.com/salaniz/pycocoevalcap, with several functionalities that make it easier to run for dense video captioning, e.g. not closing the METEOR jar at every call but only once per evaluation.
4 |
5 | To use it, you may download https://github.com/salaniz/pycocoevalcap/tree/master/meteor/data and put in under the data/ folder.
6 |
--------------------------------------------------------------------------------
/metrics/dvc/metrics/cider.py:
--------------------------------------------------------------------------------
1 | """Computes the CIDEr (Consensus-Based Image Description Evaluation) Metric."""
2 |
3 | # Filename: cider.py
4 | #
5 | # Description: Describes the class to compute the CIDEr
6 | # (Consensus-Based Image Description Evaluation) Metric
7 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
8 | #
9 | # Creation Date: Sun Feb 8 14:16:54 2015
10 | #
11 | # Authors: Ramakrishna Vedantam
12 | # and Tsung-Yi Lin
13 |
14 | from .cider_scorer import CiderScorer
15 |
16 |
17 | class Cider:
18 | """Main Class to compute the CIDEr metric."""
19 |
20 | def __init__(self, n=4, sigma=6.0):
21 | # set cider to sum over 1 to 4-grams
22 | self._n = n
23 | # set the standard deviation parameter for gaussian penalty
24 | self._sigma = sigma
25 |
26 | def compute_score(self, gts, res):
27 | """Main function to compute CIDEr score.
28 |
29 | Args:
30 | gts: dictionary with key and value
32 | res: dictionary with key and value
33 |
34 | Returns:
35 | Computed CIDEr float score for the corpus.
36 | """
37 |
38 | assert sorted(gts.keys()) == sorted(res.keys())
39 | imgids = list(gts.keys())
40 |
41 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
42 |
43 | # Sort the IDs to be able to have control over the order
44 | # of the individual scores.
45 | for iid in sorted(imgids):
46 | hypo = res[iid]
47 | ref = gts[iid]
48 |
49 | # Sanity check.
50 | assert isinstance(hypo, list)
51 | assert len(hypo) == 1
52 | assert isinstance(ref, list)
53 | assert ref
54 |
55 | cider_scorer += (hypo[0], ref)
56 |
57 | (score, scores) = cider_scorer.compute_score()
58 |
59 | return score, scores
60 |
61 | def method(self):
62 | return "CIDEr"
63 |
--------------------------------------------------------------------------------
/metrics/dvc/metrics/meteor.py:
--------------------------------------------------------------------------------
1 | """Python wrapper for METEOR implementation."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import subprocess
9 | import threading
10 |
11 | import numpy as np
12 | import six
13 |
14 |
15 | class Meteor(object):
16 | """Meteor scorer."""
17 |
18 | def __init__(self,
19 | meteor_jar_path=None,
20 | java_jre_path=None,
21 | jdk_java_options=None):
22 | if java_jre_path:
23 | self.java_bin = java_jre_path
24 | elif 'JRE_BIN_JAVA' in os.environ:
25 | self.java_bin = os.environ['JRE_BIN_JAVA']
26 | else:
27 | self.java_bin = 'java'
28 |
29 | if meteor_jar_path:
30 | meteor_jar = meteor_jar_path
31 | else:
32 | meteor_jar = os.path.join(
33 | './metrics', 'meteor-1.5.jar'
34 | )
35 |
36 | assert os.path.exists(meteor_jar), meteor_jar
37 |
38 | jdk_java_options = jdk_java_options or ['-Xmx2G']
39 | meteor_cmd = [
40 | self.java_bin, '-jar', '-Xmx2G', meteor_jar, '-', '-', '-stdio',
41 | '-l', 'en', '-norm'
42 | ]
43 |
44 | self.meteor_p = subprocess.Popen(
45 | meteor_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
46 | self.lock = threading.Lock()
47 |
48 | def compute_score(self, gts, res):
49 | """Compute METEOR scores."""
50 | with self.lock:
51 | assert sorted(gts.keys()) == sorted(res.keys())
52 | img_ids = sorted(gts.keys())
53 | scores = []
54 |
55 | eval_line = 'EVAL ||| '
56 | stats = self._stat(img_ids, res, gts)
57 | eval_line += ' ||| '.join(stats)
58 | self.meteor_p.stdin.write(six.ensure_binary(eval_line + '\n'))
59 | self.meteor_p.stdin.flush()
60 | scores = [float(six.ensure_str(self.meteor_p.stdout.readline()))
61 | for _ in img_ids]
62 | # get the aggregated value
63 | score = self.meteor_p.stdout.readline()
64 | # do not close the file inside this function to keep it open for full eval
65 | return float(score), np.asarray(scores)
66 |
67 | def method(self):
68 | return 'METEOR'
69 |
70 | def _stat(self, img_ids, hypothesis_str, reference_list): # pylint: disable=missing-function-docstring
71 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
72 | stat_lines = []
73 | for i in img_ids:
74 | assert len(hypothesis_str[i]) == 1
75 | hypo = hypothesis_str[i][0].replace('|||', '').replace(' ', ' ')
76 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list[i]),
77 | hypo))
78 |
79 | self.meteor_p.stdin.write(six.ensure_binary(score_line + '\n'))
80 | self.meteor_p.stdin.flush()
81 | stat_lines.append(six.ensure_str(self.meteor_p.stdout.readline()).strip())
82 | return stat_lines
83 |
--------------------------------------------------------------------------------
/metrics/dvc/metrics/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | """PTBTokenizer."""
2 |
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import subprocess
13 | import tempfile
14 |
15 | # pylint: disable=g-inconsistent-quotes
16 |
17 | # punctuations to be removed from the sentences
18 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-",
19 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
20 |
21 |
22 | class PTBTokenizer:
23 | """Python wrapper of Stanford PTBTokenizer."""
24 |
25 | def __init__(self,
26 | ptbtokenizer_jar_path=None,
27 | java_jre_path=None):
28 | if java_jre_path:
29 | self.java_bin = java_jre_path
30 | elif 'JRE_BIN_JAVA' in os.environ:
31 | self.java_bin = os.environ['JRE_BIN_JAVA']
32 | else:
33 | self.java_bin = 'java'
34 |
35 | if ptbtokenizer_jar_path:
36 | self.ptbtokenizer_jar = ptbtokenizer_jar_path
37 | else:
38 | self.ptbtokenizer_jar = os.path.join(
39 | "./metrics",
40 | "stanford-corenlp-3.4.1.jar",
41 | )
42 |
43 | assert os.path.exists(self.ptbtokenizer_jar), self.ptbtokenizer_jar
44 |
45 | def tokenize(self, captions_for_image):
46 | """Tokenization."""
47 |
48 | cmd = [self.java_bin, '-cp', self.ptbtokenizer_jar,
49 | 'edu.stanford.nlp.process.PTBTokenizer',
50 | '-preserveLines', '-lowerCase']
51 |
52 | # ======================================================
53 | # prepare data for PTB Tokenizer
54 | # ======================================================
55 | final_tokenized_captions_for_image = {}
56 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] # pylint: disable=g-complex-comprehension
57 | sentences = "\n".join(
58 | [ # pylint: disable=g-complex-comprehension
59 | c["caption"].replace("\n", " ")
60 | for k, v in captions_for_image.items()
61 | for c in v
62 | ]
63 | )
64 |
65 | # ======================================================
66 | # save sentences to temporary file
67 | # ======================================================
68 | fd, tmpfname = tempfile.mkstemp()
69 | with os.fdopen(fd, 'w') as f:
70 | f.write(sentences)
71 |
72 | # ======================================================
73 | # tokenize sentence
74 | # ======================================================
75 | cmd.append(tmpfname)
76 | p_tokenizer = subprocess.Popen(cmd, stdout=subprocess.PIPE)
77 | token_lines = p_tokenizer.communicate(input=sentences.rstrip().encode())[0]
78 | token_lines = token_lines.decode()
79 | lines = token_lines.split('\n')
80 | # remove temp file
81 | os.remove(tmpfname)
82 |
83 | # ======================================================
84 | # create dictionary for tokenized captions
85 | # ======================================================
86 | for k, line in zip(image_id, lines):
87 | if k not in final_tokenized_captions_for_image:
88 | final_tokenized_captions_for_image[k] = []
89 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ')
90 | if w not in PUNCTUATIONS])
91 | final_tokenized_captions_for_image[k].append(tokenized_caption)
92 |
93 | return final_tokenized_captions_for_image
94 |
--------------------------------------------------------------------------------
/metrics/tvg/cd:
--------------------------------------------------------------------------------
1 | GT File: /path_to_the_timesuite_root_folder/dataset/TimeIT/temporal_video_grounding/charades/charades_annotation/test.caption_coco_format.json
2 | # pred video timestamps 3720; # gt video timestamps 3720
3 | IOU 0.3: 56.29032258064516
4 | IOU 0.5: 34.56989247311828
5 | IOU 0.7: 15.887096774193548
6 |
--------------------------------------------------------------------------------
/metrics/tvg/eval_tvg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import sys
5 | import argparse
6 | import pdb
7 |
8 | def read_json(path):
9 | with open(path, "r") as fin:
10 | datas = json.load(fin)
11 | return datas
12 |
13 |
14 | def iou(A, B):
15 | max0 = max((A[0]), (B[0]))
16 | min0 = min((A[0]), (B[0]))
17 | max1 = max((A[1]), (B[1]))
18 | min1 = min((A[1]), (B[1]))
19 | return max(min1 - max0, 0) / (max1 - min0)
20 |
21 |
22 | def toSec(timeStr):
23 | t = time.strptime(timeStr, "%H:%M:%S")
24 | return t.tm_hour * 3600 + t.tm_min * 60 + t.tm_sec
25 |
26 | def captiondata_modify(steps):
27 | modify_data = {}
28 | for i, step in enumerate(steps[0]):
29 | for key in step["step"].keys():
30 | name = step["step"][key]["query_idx"]
31 | modify_data[name] = [[step['step'][key]["startime"], step['step'][key]["endtime"]]]
32 |
33 | return modify_data
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--pred_file", type=str, default="/home/yaolinli/code/Ask-Anything/video_chat/output/eval_7b_tvg_charades/fmt_charades_test_f8_result.json")
38 | parser.add_argument('--gt_file', type=str, default='/home/yaolinli/dataset/Charades/charades_annotation/test.caption_coco_format.json')
39 | parser.add_argument('--sample', action='store_true', default=False)
40 | args = parser.parse_args()
41 | '''
42 | {
43 | "query_idx": [start_time, end_time],
44 | ...
45 | }
46 | '''
47 | print("GT File:", args.gt_file)
48 | answer = read_json(args.gt_file)
49 | answer = answer["annotations"]
50 | gt_timestamps = {}
51 | for jterm in answer:
52 | gt_timestamps[jterm["id"]] = jterm["timestamp"]
53 |
54 | submission = read_json(args.pred_file)
55 | pred_timestamps = {}
56 | for qid, jterm in submission.items():
57 | pred_timestamps[int(qid)] = jterm["timestamp"]
58 |
59 | if args.sample:
60 | new = {}
61 | for qid in pred_timestamps.keys():
62 | new[qid] = gt_timestamps[qid]
63 | gt_timestamps = new
64 | num = len(gt_timestamps)
65 | print(f"# pred video timestamps {len(pred_timestamps)}; # gt video timestamps {len(gt_timestamps)}")
66 | assert len(gt_timestamps) == len(pred_timestamps)
67 | Result = {0.3:0, 0.5:0, 0.7:0}
68 | for c_iou in [0.3, 0.5, 0.7]:
69 | for key in gt_timestamps.keys():
70 | if len(pred_timestamps[key]) < 1:
71 | continue
72 | if(iou(gt_timestamps[key], pred_timestamps[key][0]) >= c_iou):
73 | Result[c_iou] = Result[c_iou] + 1
74 | print("IOU 0.3: {0}\nIOU 0.5: {1}\nIOU 0.7: {2}".format(Result[0.3]*100/num, Result[0.5]*100/num, Result[0.7]*100/num))
--------------------------------------------------------------------------------
/metrics/tvg/eval_tvg.sh:
--------------------------------------------------------------------------------
1 | # default: eval all instances according to the gt_file
2 | python eval_tvg.py --pred_file your_pred_file --gt_file your_gt_file
3 |
4 |
5 | # use --sample
6 | # eval sampled instances according to the pred_file
7 | # e.g. # gt examples:500, # pred examples:50 -> # eval examples:50
8 | python eval_tvg.py --pred_file your_pred_file --gt_file your_gt_file --sample
--------------------------------------------------------------------------------
/metrics/vhd/cd:
--------------------------------------------------------------------------------
1 | Calculating highlight scores with min score 2 (Fair)
2 | Time cost 0.35 seconds
3 | Calculating highlight scores with min score 3 (Good)
4 | Time cost 0.32 seconds
5 | Calculating highlight scores with min score 4 (VeryGood)
6 | Time cost 0.26 seconds
7 | HL-min-VeryGood-mAP: 16.0
8 | HL-min-VeryGood-Hit1: 36.09
9 |
--------------------------------------------------------------------------------
/metrics/vhd/eval_highlights.sh:
--------------------------------------------------------------------------------
1 | your_gt_file=example_gt_file.json
2 | your_pred_file=example_pred_file.json
3 | python eval_highlights.py --pred_file $your_pred_file --gt_file $your_gt_file
4 | # --save_path your_result_save_path
5 |
--------------------------------------------------------------------------------
/metrics/vhd/metrics.json:
--------------------------------------------------------------------------------
1 | {
2 | "brief": {
3 | "HL-min-Fair-mAP": 53.36,
4 | "HL-min-Fair-Hit1": 67.67,
5 | "HL-min-Good-mAP": 43.11,
6 | "HL-min-Good-Hit1": 64.66,
7 | "HL-min-VeryGood-mAP": 25.07,
8 | "HL-min-VeryGood-Hit1": 51.26
9 | },
10 | "HL-min-Fair": {
11 | "HL-mAP": 53.36,
12 | "HL-Hit1": 67.67
13 | },
14 | "HL-min-Good": {
15 | "HL-mAP": 43.11,
16 | "HL-Hit1": 64.66
17 | },
18 | "HL-min-VeryGood": {
19 | "HL-mAP": 25.07,
20 | "HL-Hit1": 51.26
21 | }
22 | }
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .videochat2_qformer import VideoChat2_qformer
2 |
3 | from .videochat_mistra.videochat2_it4_mistral_LinearP import VideoChat2_it4_mistral_LinearP
4 | from .videochat_mistra.videochat2_it4_mistral_LinearProAda import VideoChat2_it4_mistral_LinearProAda
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/models/bert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/bert/__init__.py
--------------------------------------------------------------------------------
/models/bert/builder.py:
--------------------------------------------------------------------------------
1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
2 |
3 | import logging
4 | logger = logging.getLogger(__name__)
5 |
6 | def build_bert(model_config, pretrain, checkpoint):
7 | """build text encoder.
8 |
9 | Args:
10 | model_config (dict): model config.
11 | pretrain (bool): Whether to do pretrain or finetuning.
12 | checkpoint (bool): whether to do gradient_checkpointing.
13 |
14 | Returns: TODO
15 |
16 | """
17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
18 | bert_config.encoder_width = model_config.vision_encoder.d_model
19 | bert_config.gradient_checkpointing = checkpoint
20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer
21 |
22 | if not model_config.multimodal.enable:
23 | bert_config.fusion_layer = bert_config.num_hidden_layers
24 |
25 | if pretrain:
26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained(
27 | model_config.text_encoder.pretrained,
28 | config=bert_config,
29 | output_loading_info=True,
30 | )
31 | else:
32 | text_encoder, loading_info = BertModel.from_pretrained(
33 | model_config.text_encoder.pretrained,
34 | config=bert_config,
35 | add_pooling_layer=False,
36 | output_loading_info=True,
37 | )
38 |
39 | return text_encoder
40 |
41 |
42 | def build_bert_decoder(model_config, checkpoint):
43 | """build text decoder the same as the multimodal encoder.
44 |
45 | Args:
46 | model_config (dict): model config.
47 | pretrain (bool): Whether to do pretrain or finetuning.
48 | checkpoint (bool): whether to do gradient_checkpointing.
49 |
50 | Returns: TODO
51 |
52 | """
53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
54 | bert_config.encoder_width = model_config.vision_encoder.d_model
55 | bert_config.gradient_checkpointing = checkpoint
56 |
57 | bert_config.fusion_layer = 0
58 | bert_config.num_hidden_layers = (
59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
60 | )
61 |
62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained(
63 | model_config.text_encoder.pretrained,
64 | config=bert_config,
65 | output_loading_info=True,
66 | )
67 |
68 | return text_decoder
69 |
--------------------------------------------------------------------------------
/models/blip2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/blip2/__init__.py
--------------------------------------------------------------------------------
/models/blip2/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import contextlib
8 | import os
9 | import logging
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | from .Qformer import BertConfig, BertLMHeadModel
15 | from .vit import build_vit
16 | from transformers import BertTokenizer
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Blip2Base(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | @classmethod
26 | def init_tokenizer(cls, truncation_side="right"):
27 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True)
28 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
29 | return tokenizer
30 |
31 | @property
32 | def device(self):
33 | return list(self.parameters())[0].device
34 |
35 | def maybe_autocast(self, dtype=torch.float16):
36 | # if on cpu, don't use autocast
37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38 | enable_autocast = self.device != torch.device("cpu")
39 |
40 | if enable_autocast:
41 | return torch.cuda.amp.autocast(dtype=dtype)
42 | else:
43 | return contextlib.nullcontext()
44 |
45 | @classmethod
46 | def init_Qformer(
47 | cls,
48 | num_query_token, vision_width,
49 | qformer_hidden_dropout_prob=0.1,
50 | qformer_attention_probs_dropout_prob=0.1,
51 | qformer_drop_path_rate=0.,
52 | ):
53 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
54 | encoder_config.encoder_width = vision_width
55 | # insert cross-attention layer every other block
56 | encoder_config.add_cross_attention = True
57 | encoder_config.cross_attention_freq = 2
58 | encoder_config.query_length = num_query_token
59 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
60 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
61 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
62 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
63 | logger.info(encoder_config)
64 | Qformer = BertLMHeadModel(config=encoder_config)
65 | query_tokens = nn.Parameter(
66 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
67 | )
68 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
69 | return Qformer, query_tokens
70 |
71 | @classmethod
72 | def init_vision_encoder_umt(self, config):
73 | """build vision encoder
74 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
75 |
76 | """
77 | vision_encoder = build_vit(config)
78 |
79 | if config.vision_encoder.vit_add_ln:
80 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12)
81 | else:
82 | vision_layernorm = nn.Identity()
83 |
84 | return vision_encoder, vision_layernorm
85 |
86 |
87 | def disabled_train(self, mode=True):
88 | """Overwrite model.train with this function to make sure train/eval mode
89 | does not change anymore."""
90 | return self
91 |
92 |
93 | class LayerNorm(nn.LayerNorm):
94 | """Subclass torch's LayerNorm to handle fp16."""
95 |
96 | def forward(self, x: torch.Tensor):
97 | orig_type = x.dtype
98 | ret = super().forward(x.type(torch.float32))
99 | return ret.type(orig_type)
100 |
--------------------------------------------------------------------------------
/models/blip2/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 |
6 | from .Qformer import BertConfig, BertLMHeadModel
7 | from models.utils import load_temp_embed_with_mismatch
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def build_qformer(num_query_token, vision_width,
13 | qformer_hidden_dropout_prob=0.1,
14 | qformer_attention_probs_dropout_prob=0.1,
15 | drop_path_rate=0.,
16 | ):
17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
18 | encoder_config.encoder_width = vision_width
19 | # insert cross-attention layer every other block
20 | encoder_config.add_cross_attention = True
21 | encoder_config.cross_attention_freq = 2
22 | encoder_config.query_length = num_query_token
23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)]
26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
27 | logger.info(encoder_config)
28 | Qformer = BertLMHeadModel.from_pretrained(
29 | "bert-base-uncased", config=encoder_config, local_files_only=True
30 | )
31 | query_tokens = nn.Parameter(
32 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
33 | )
34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
35 | return Qformer, query_tokens
36 |
37 | def interpolate_pos_embed_blip(state_dict, new_model):
38 | if "vision_temp_embed" in state_dict:
39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"]
40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch(
41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False
42 | )
43 | return state_dict
44 |
--------------------------------------------------------------------------------
/models/videochat_mistra/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/TimeSuite/d68f5feb17eec712d8c8b2fd1b44d24b2e67f00a/models/videochat_mistra/__init__.py
--------------------------------------------------------------------------------
/prompts/alignment_image.txt:
--------------------------------------------------------------------------------
1 | Describe this image in detail.
2 | Take a look at this image and describe what you notice.
3 | Please provide a detailed description of the picture.
4 | Could you describe the contents of this image for me?
--------------------------------------------------------------------------------
/prompts/concise_description.txt:
--------------------------------------------------------------------------------
1 | Describe the following video concisely.
2 | Provide a brief description of the given video clip.
3 | Offer a succinct explanation of the footage presented.
4 | Summarize the visual content of the following video.
5 | Give a short and clear explanation of the subsequent video clip.
6 | Share a concise interpretation of the video provided.
7 | Present a compact description of the clip's key features.
8 | Relay a brief, clear account of the video shown.
9 | Render a clear and concise summary of the video below.
10 | Write a terse but informative summary of the following video clip.
11 | Create a compact narrative representing the video presented.
--------------------------------------------------------------------------------
/prompts/concise_image_description.txt:
--------------------------------------------------------------------------------
1 | Describe the following image concisely.
2 | Provide a brief description of the given image.
3 | Offer a succinct explanation of the picture presented.
4 | Summarize the visual content of the following image.
5 | Give a short and clear explanation of the subsequent image.
6 | Share a concise interpretation of the image provided.
7 | Present a compact description of the photo's key features.
8 | Relay a brief, clear account of the picture shown.
9 | Render a clear and concise summary of the photo below.
10 | Write a terse but informative summary of the following picture.
11 | Create a compact narrative representing the image presented.
--------------------------------------------------------------------------------
/prompts/dvc_description.txt:
--------------------------------------------------------------------------------
1 | Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. A specific example is : ' 90 - 102 seconds, spread margarine on two slices of white bread in the video' .
--------------------------------------------------------------------------------
/prompts/dvc_description_with_asr.txt:
--------------------------------------------------------------------------------
1 | You are able to understand the visual content that the user provides. Localize a series of activity events in the video with the aid of transcribed speech, output the start and end timestamp for each event, and describe each event with sentences. The output format of each predicted event should be like: 'start - end seconds, event description'. An specific example is : ' 90 - 102 seconds, spread margarine on two slices of white bread in the video' .
--------------------------------------------------------------------------------
/prompts/dvc_description_zeroshot.txt:
--------------------------------------------------------------------------------
1 | You are given a cooking video from the YouCook2 dataset. Please watch the video and extract a maximum of 10 significant cooking steps. For each step, determine the starting and ending times and provide a concise description. The format should be: 'start time - end time, brief step description'. For example, ' 90 - 102 seconds, spread margarine on two slices of white bread'.
--------------------------------------------------------------------------------
/prompts/dvc_post_check.txt:
--------------------------------------------------------------------------------
1 | Be careful about your output format, you should provide start time and end time for each step. The format should be: 'start time - end time, brief step description', for example, ' 90 - 102 seconds, spread margarine on two slices of white bread'.
--------------------------------------------------------------------------------
/prompts/tvg_description.txt:
--------------------------------------------------------------------------------
1 | Localize the visual content described by the given textual query {} in the video, and output the start and end timestamps in seconds. The output format of the predicted timestamp should be like: 'start - end seconds'. A specific example is : 20.8 - 30.0 seconds' .
--------------------------------------------------------------------------------
/prompts/tvg_description_zeroshot.txt:
--------------------------------------------------------------------------------
1 | You are given a video from the {} dataset. Please find the visual event described by a sentence in the video, determining its starting and ending times. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds. Now I will give you the textual sentence: {}. Please return its start time and end time.
--------------------------------------------------------------------------------
/prompts/tvg_post_check.txt:
--------------------------------------------------------------------------------
1 | Be careful about your output format, you should provide start time and end time for the visual event. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds.
--------------------------------------------------------------------------------
/prompts/vhd_description.txt:
--------------------------------------------------------------------------------
1 | You are given a video from the {} dataset. Please find the highlight moments in the video described by a sentence query, determining the highlight moment's timestamp and its saliency score. The output format should be like: 'There are 10 highlight moments in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: {}. Please return the query-based highlight moments.
--------------------------------------------------------------------------------
/prompts/vhd_description_zeroshot.txt:
--------------------------------------------------------------------------------
1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 seconds. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores.
--------------------------------------------------------------------------------
/prompts/vhd_description_zeroshot_new.txt:
--------------------------------------------------------------------------------
1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 52, 54, 56, 58 seconds. Their saliency scores are 3.0, 3.0, 3.0, 3.0'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores.
--------------------------------------------------------------------------------
/prompts/vhd_description_zeroshot_post.txt:
--------------------------------------------------------------------------------
1 | You are given a video from the {} dataset. Please find the highlight contents in the video described by a sentence query, determining the highlight timestamps and its saliency score on a scale from 1 to 5. The output format should be like: 'The highlight timestamps are in the 60 - 80 seconds. Their saliency scores are 3.0'. Now I will give you the sentence query: {}. Please return the query-based highlight timestamps and salient scores.
--------------------------------------------------------------------------------
/prompts/vhd_post_check.txt:
--------------------------------------------------------------------------------
1 | Be careful about your output format, you should provide timestamp and saliency score for the highlight moment. The output format should be like: 'There are 10 highlight moments in the 82, 84, 86, 88, 90, 92, 94, 96, 98, 100 second. Their saliency scores are 1.3, 1.7, 1.7, 1.7, 1.7, 1.3, 1.7, 2.3, 2.3, 2.3'.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.28.0
3 | addict==2.4.0
4 | aiofiles==23.2.1
5 | aiohttp==3.9.3
6 | aiosignal==1.3.1
7 | aliyun-python-sdk-core==2.15.0
8 | aliyun-python-sdk-kms==2.16.2
9 | altair==5.2.0
10 | annotated-types==0.6.0
11 | antlr4-python3-runtime==4.9.3
12 | anyio==4.6.0
13 | anykeystore==0.2
14 | apex==0.9.10.dev0
15 | appdirs==1.4.4
16 | argon2-cffi==23.1.0
17 | argon2-cffi-bindings==21.2.0
18 | arrow==1.3.0
19 | asttokens==2.4.1
20 | astunparse==1.6.3
21 | async-lru==2.0.4
22 | async-timeout==4.0.3
23 | attrs==24.2.0
24 | audioread==3.0.1
25 | av==10.0.0
26 | awscli==1.32.87
27 | Babel==2.14.0
28 | beautifulsoup4==4.12.3
29 | bitsandbytes==0.42.0
30 | bleach==6.1.0
31 | blis==0.7.11
32 | boto3==1.34.91
33 | botocore==1.34.91
34 | braceexpand==0.1.7
35 | catalogue==2.0.10
36 | certifi==2024.8.30
37 | cffi==1.17.1
38 | charset-normalizer==3.3.2
39 | click==8.1.7
40 | cloudpathlib==0.16.0
41 | colorama==0.4.4
42 | coloredlogs==15.0.1
43 | comm==0.2.2
44 | conda-pack==0.7.1
45 | confection==0.1.4
46 | contourpy==1.2.0
47 | crcmod==1.7
48 | cryptacular==1.6.2
49 | cryptography==42.0.5
50 | cycler==0.12.1
51 | cymem==2.0.8
52 | datasets==2.19.0
53 | debugpy==1.8.1
54 | decorator==5.1.1
55 | decord==0.6.0
56 | deepspeed==0.14.0
57 | defusedxml==0.7.1
58 | dill==0.3.8
59 | docker-pycreds==0.4.0
60 | docutils==0.16
61 | einops==0.6.1
62 | environs==11.0.0
63 | exceptiongroup==1.2.2
64 | executing==2.1.0
65 | fairscale==0.4.13
66 | fastapi==0.110.0
67 | fastjsonschema==2.19.1
68 | ffmpeg==1.4
69 | ffmpy==0.3.2
70 | filelock==3.13.1
71 | flash-attn==2.5.6
72 | flatbuffers==24.3.7
73 | fonttools==4.50.0
74 | fqdn==1.5.1
75 | frozenlist==1.4.1
76 | fsspec==2024.3.1
77 | ftfy==6.2.0
78 | fvcore==0.1.5.post20221221
79 | gast==0.5.4
80 | gitdb==4.0.11
81 | GitPython==3.1.42
82 | google-pasta==0.2.0
83 | gradio==3.35.0
84 | gradio_client==0.12.0
85 | greenlet==3.0.3
86 | grpcio==1.62.1
87 | h11==0.14.0
88 | h5py==3.10.0
89 | hjson==3.1.0
90 | httpcore==1.0.5
91 | httpx==0.27.2
92 | huggingface-hub==0.21.4
93 | humanfriendly==10.0
94 | humanize==4.9.0
95 | hupper==1.12.1
96 | idna==3.10
97 | imageio==2.27.0
98 | imageio-ffmpeg==0.4.9
99 | importlib_metadata==8.5.0
100 | importlib_resources==6.3.1
101 | iopath==0.1.10
102 | ipykernel==6.29.3
103 | ipython==8.18.1
104 | ipywidgets==8.1.2
105 | isoduration==20.11.0
106 | jedi==0.19.1
107 | Jinja2==3.1.4
108 | jmespath==0.10.0
109 | joblib==1.4.0
110 | json5==0.9.24
111 | jsonpointer==2.4
112 | jsonschema==4.23.0
113 | jsonschema-specifications==2023.12.1
114 | jupyter==1.0.0
115 | jupyter-console==6.6.3
116 | jupyter-events==0.10.0
117 | jupyter-lsp==2.2.4
118 | jupyter_client==8.6.1
119 | jupyter_core==5.7.2
120 | jupyter_server==2.13.0
121 | jupyter_server_terminals==0.5.3
122 | jupyterlab==4.1.5
123 | jupyterlab_pygments==0.3.0
124 | jupyterlab_server==2.25.4
125 | jupyterlab_widgets==3.0.10
126 | keras==3.1.1
127 | kiwisolver==1.4.5
128 | langcodes==3.3.0
129 | lazy_loader==0.4
130 | libclang==18.1.1
131 | librosa==0.10.1
132 | linkify-it-py==2.0.3
133 | llvmlite==0.42.0
134 | Markdown==3.6
135 | markdown-it-py==2.2.0
136 | MarkupSafe==2.1.5
137 | marshmallow==3.21.1
138 | matplotlib==3.8.3
139 | matplotlib-inline==0.1.7
140 | mdit-py-plugins==0.3.3
141 | mdurl==0.1.2
142 | mistune==3.0.2
143 | mkl-fft==1.3.1
144 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work
145 | mkl-service==2.4.0
146 | ml-dtypes==0.3.2
147 | mmcv==2.1.0
148 | mmengine==0.10.3
149 | model-index==0.1.11
150 | moviepy==1.0.3
151 | msgpack==1.0.8
152 | multidict==6.0.5
153 | multiprocess==0.70.16
154 | multiprocessing-logging==0.3.4
155 | murmurhash==1.0.10
156 | namex==0.0.7
157 | nbclient==0.10.0
158 | nbconvert==7.16.2
159 | nbformat==5.10.3
160 | nest-asyncio==1.6.0
161 | ninja==1.11.1.1
162 | notebook==7.1.2
163 | notebook_shim==0.2.4
164 | numba==0.59.1
165 | numpy==1.23.5
166 | oauthlib==3.2.2
167 | omegaconf==2.3.0
168 | opencv-python==4.7.0.72
169 | opendatalab==0.0.10
170 | openmim==0.3.9
171 | openxlab==0.0.36
172 | opt-einsum==3.3.0
173 | optree==0.10.0
174 | ordered-set==4.1.0
175 | orjson==3.9.15
176 | oss2==2.17.0
177 | overrides==7.7.0
178 | packaging==24.1
179 | pandas==1.5.3
180 | pandocfilters==1.5.1
181 | parso==0.8.4
182 | PasteDeploy==3.1.0
183 | pathtools==0.1.2
184 | pbkdf2==1.3
185 | peft==0.3.0
186 | pexpect==4.9.0
187 | Pillow==9.5.0
188 | plaster==1.1.2
189 | plaster-pastedeploy==1.0.1
190 | platformdirs==4.3.6
191 | pooch==1.8.1
192 | portalocker==2.8.2
193 | preshed==3.0.9
194 | proglog==0.1.10
195 | prometheus_client==0.20.0
196 | prompt_toolkit==3.0.47
197 | protobuf==4.25.3
198 | psutil==6.0.0
199 | ptyprocess==0.7.0
200 | pure_eval==0.2.3
201 | py-cpuinfo==9.0.0
202 | pyarrow==16.0.0
203 | pyarrow-hotfix==0.6
204 | pyasn1==0.6.0
205 | pycocoevalcap==1.2
206 | pycocotools==2.0.7
207 | pycparser==2.22
208 | pycryptodome==3.20.0
209 | pydantic==2.6.4
210 | pydantic_core==2.16.3
211 | pydub==0.25.1
212 | Pygments==2.18.0
213 | pynvml==11.5.0
214 | pyparsing==3.1.2
215 | pyramid==2.0.2
216 | pyramid-mailer==0.15.1
217 | pysubs2==1.7.3
218 | python-dateutil==2.9.0.post0
219 | python-dotenv==1.0.1
220 | python-json-logger==2.0.7
221 | python-multipart==0.0.9
222 | python3-openid==3.2.0
223 | pytube==15.0.0
224 | pytz==2023.4
225 | PyYAML==6.0.2
226 | pyzmq==25.1.2
227 | qtconsole==5.5.1
228 | QtPy==2.4.1
229 | redis==5.0.3
230 | referencing==0.35.1
231 | regex==2023.12.25
232 | repoze.sendmail==4.4.1
233 | requests==2.32.3
234 | requests-oauthlib==1.4.0
235 | rfc3339-validator==0.1.4
236 | rfc3986-validator==0.1.1
237 | rich==13.4.2
238 | rpds-py==0.20.0
239 | rsa==4.7.2
240 | s3transfer==0.10.1
241 | safetensors==0.4.2
242 | scikit-learn==1.4.1.post1
243 | scipy==1.10.1
244 | semantic-version==2.10.0
245 | Send2Trash==1.8.2
246 | sentencepiece==0.1.99
247 | sentry-sdk==1.42.0
248 | setproctitle==1.3.3
249 | six==1.16.0
250 | smart-open==6.4.0
251 | smmap==5.0.1
252 | sniffio==1.3.1
253 | soundfile==0.12.1
254 | soupsieve==2.5
255 | soxr==0.3.7
256 | spacy==3.7.4
257 | spacy-legacy==3.0.12
258 | spacy-loggers==1.0.5
259 | SQLAlchemy==2.0.28
260 | srsly==2.4.8
261 | stack-data==0.6.3
262 | starlette==0.36.3
263 | tabulate==0.9.0
264 | tensorboard==2.16.2
265 | tensorboard-data-server==0.7.2
266 | tensorflow==2.16.1
267 | tensorflow-io-gcs-filesystem==0.36.0
268 | termcolor==2.3.0
269 | terminado==0.18.1
270 | thinc==8.2.3
271 | threadpoolctl==3.4.0
272 | timm==0.6.12
273 | tinycss2==1.2.1
274 | tokenizers==0.19.1
275 | tomli==2.0.1
276 | toolz==0.12.1
277 | torch==1.13.1
278 | torchaudio==0.13.1
279 | torchvision==0.14.1
280 | tornado==6.4.1
281 | tqdm==4.65.2
282 | traitlets==5.14.3
283 | transaction==4.0
284 | transformers==4.40.0
285 | translationstring==1.4
286 | typer==0.9.4
287 | types-python-dateutil==2.9.0.20240316
288 | typing_extensions==4.12.2
289 | tzdata==2024.1
290 | uc-micro-py==1.0.3
291 | uri-template==1.3.0
292 | urllib3==2.2.3
293 | uvicorn==0.28.0
294 | velruse==1.1.1
295 | venusian==3.1.0
296 | wandb==0.14.0
297 | wasabi==1.1.2
298 | wcwidth==0.2.13
299 | weasel==0.3.4
300 | webcolors==1.13
301 | webdataset==0.2.86
302 | webencodings==0.5.1
303 | WebOb==1.8.7
304 | websocket-client==1.7.0
305 | websockets==11.0.3
306 | webvtt-py==0.5.1
307 | Werkzeug==3.0.1
308 | widgetsnbextension==4.0.10
309 | wrapt==1.16.0
310 | WTForms==3.1.2
311 | wtforms-recaptcha==0.3.2
312 | xxhash==3.4.1
313 | yacs==0.1.8
314 | yapf==0.40.2
315 | yarl==1.9.4
316 | zipp==3.20.2
317 | zope.deprecation==5.0
318 | zope.interface==6.2
319 | zope.sqlalchemy==3.1
320 |
--------------------------------------------------------------------------------
/scripts/videochat_mistral/config_LinearP.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 |
5 | train_corpus = "TimePro_Normal"
6 |
7 | train_file = "${available_corpus[${train_corpus}]}"
8 | test_file = dict()
9 | test_types = []
10 | num_workers = 6
11 |
12 | stop_key = None
13 |
14 | # ========================= input ==========================
15 | num_frames_test = 16
16 | num_frames = 192
17 | clip_frames = 8
18 | good_init = True
19 |
20 | batch_size = 2
21 | max_txt_l = 1536
22 |
23 | save_iter=1000
24 |
25 | pre_text = False
26 |
27 | inputs = dict(
28 | image_res=224,
29 | video_input=dict(
30 | num_frames="${num_frames}",
31 | sample_type="rand",
32 | num_frames_test="${num_frames_test}",
33 | sample_type_test="middle",
34 | random_aug=False,
35 | ),
36 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
37 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
38 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
39 | )
40 |
41 | # ========================= model ==========================
42 | model = dict(
43 | model_cls="VideoChat2_it4_mistral_LinearP",
44 | vit_blip_model_path="/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth",
45 | mistral_model_path="/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2",
46 | videochat2_model_path="/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage3.pth",
47 | freeze_vit=True,
48 | freeze_qformer=False,
49 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
50 | clip_frames="${clip_frames}",
51 | num_frames="${num_frames}",
52 | token_merge_len = 4,
53 | # vit
54 | low_resource=False,
55 | vision_encoder=dict(
56 | name="vit_l14",
57 | img_size=224,
58 | patch_size=16,
59 | d_model=1024,
60 | encoder_embed_dim=1024,
61 | encoder_depth=24,
62 | encoder_num_heads=16,
63 | drop_path_rate=0.,
64 | num_frames="${clip_frames}",
65 | tubelet_size=1,
66 | use_checkpoint=True,
67 | checkpoint_num=18,
68 | pretrained="",
69 | return_index=-2,
70 | vit_add_ln=True,
71 | ckpt_num_frame=4,
72 | ),
73 | # qformer
74 | num_query_token=32,
75 | qformer_hidden_dropout_prob=0.1,
76 | qformer_attention_probs_dropout_prob=0.1,
77 | qformer_drop_path_rate=0.2,
78 | extra_num_query_token=64,
79 | qformer_text_input=True,
80 | # prompt
81 | system="",
82 | start_token="",
84 | add_second_msg=True,
85 | img_start_token="",
86 | img_end_token="",
87 | random_shuffle=True,
88 | return_question_instruction=False,
89 | use_flash_attention=True,
90 | use_lora=True,
91 | lora_r=16,
92 | lora_alpha=32,
93 | lora_dropout=0.1,
94 | # debug=True,
95 | )
96 |
97 | optimizer = dict(
98 | opt="adamW",
99 | lr=2e-5,
100 | opt_betas=[0.9, 0.999], # default
101 | weight_decay=0.02,
102 | max_grad_norm=-1, # requires a positive float, use -1 to disable
103 | # use a different lr for some modules, e.g., larger lr for new modules
104 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
105 | )
106 |
107 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
108 |
109 | evaluate = False
110 | deep_fusion = False
111 | evaluation = dict(
112 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
113 | eval_x_only=False,
114 | k_test=128,
115 | eval_offload=True, # offload gpu tensors to cpu to save memory.
116 | )
117 |
118 | fp16 = True
119 | gradient_checkpointing = True
120 |
121 | # ========================= wandb ==========================
122 | wandb = dict(
123 | enable=False,
124 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
125 | project="videogpt", # setup in your command line
126 | )
127 | dist_url = "env://"
128 | device = "cuda"
129 | mode = "it_mistral"
130 |
131 | # ========================= others ==========================
132 | output_dir = None # output dir
133 | resume = False # if True, load optimizer and scheduler states as well
134 | debug = False
135 | log_freq = 10
136 | seed = 42
137 |
138 | save_latest = False
139 | auto_resume = True
140 | pretrained_path = "" # path to pretrained model weights, for resume only?
141 |
142 | deepspeed = dict(
143 | enable=False,
144 | stage=1,
145 | )
--------------------------------------------------------------------------------
/scripts/videochat_mistral/config_LinearProAda.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "TimePro_Normal"
5 |
6 | train_file = "${available_corpus[${train_corpus}]}"
7 | test_file = dict()
8 | test_types = []
9 | num_workers = 6
10 |
11 | stop_key = None
12 |
13 | # ========================= input ==========================
14 | num_frames_test = 16
15 | num_frames = 128
16 | clip_frames = 8
17 | good_init = True
18 |
19 | batch_size = 2
20 | max_txt_l = 1024
21 |
22 | save_iter=1000
23 |
24 | pre_text = False
25 |
26 | inputs = dict(
27 | image_res=224,
28 | video_input=dict(
29 | num_frames="${num_frames}",
30 | sample_type="rand",
31 | num_frames_test="${num_frames_test}",
32 | sample_type_test="middle",
33 | random_aug=False,
34 | ),
35 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
36 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
37 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
38 | )
39 |
40 | # ========================= model ==========================
41 | model = dict(
42 | model_cls="VideoChat2_it4_mistral_LinearProAda",
43 | vit_blip_model_path="/path_to_the_timesuite_root_folder/download/parameters/umt_l16_qformer.pth",
44 | mistral_model_path="/path_to_the_timesuite_root_folder/download/parameters/Mistral-7B-Instruct-v0.2",
45 | videochat2_model_path="/path_to_the_timesuite_root_folder/download/parameters/videochat2_mistral_7b_stage3.pth",
46 | pretrained_path="your_LinearP_ckpt_path/ckpt_00.pth",
47 | freeze_vit=True,
48 | freeze_qformer=False,
49 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
50 | clip_frames="${clip_frames}",
51 | num_frames="${num_frames}",
52 | token_merge_len = 4,
53 | # vit
54 | low_resource=False,
55 | vision_encoder=dict(
56 | name="vit_l14",
57 | img_size=224,
58 | patch_size=16,
59 | d_model=1024,
60 | encoder_embed_dim=1024,
61 | encoder_depth=24,
62 | encoder_num_heads=16,
63 | drop_path_rate=0.,
64 | num_frames="${clip_frames}",
65 | tubelet_size=1,
66 | use_checkpoint=True,
67 | checkpoint_num=18,
68 | pretrained="",
69 | return_index=-2,
70 | vit_add_ln=True,
71 | ckpt_num_frame=4,
72 | ),
73 | # qformer
74 | num_query_token=32,
75 | qformer_hidden_dropout_prob=0.1,
76 | qformer_attention_probs_dropout_prob=0.1,
77 | qformer_drop_path_rate=0.2,
78 | extra_num_query_token=64,
79 | qformer_text_input=True,
80 | # prompt
81 | system="",
82 | start_token="",
84 | add_second_msg=True,
85 | img_start_token="",
86 | img_end_token="",
87 | random_shuffle=True,
88 | return_question_instruction=False,
89 | use_flash_attention=True,
90 | use_lora=True,
91 | lora_r=16,
92 | lora_alpha=32,
93 | lora_dropout=0.1,
94 | # debug=True,
95 | )
96 |
97 | optimizer = dict(
98 | opt="adamW",
99 | lr=1.5e-5,
100 | opt_betas=[0.9, 0.999], # default
101 | weight_decay=0.02,
102 | max_grad_norm=-1, # requires a positive float, use -1 to disable
103 | # use a different lr for some modules, e.g., larger lr for new modules
104 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
105 | )
106 |
107 | scheduler = dict(sched="cosine", epochs=2, min_lr_multi=0.2, warmup_epochs=0.05)
108 |
109 | evaluate = False
110 | deep_fusion = False
111 | evaluation = dict(
112 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
113 | eval_x_only=False,
114 | k_test=128,
115 | eval_offload=True, # offload gpu tensors to cpu to save memory.
116 | )
117 |
118 | fp16 = True
119 | gradient_checkpointing = True
120 |
121 | # ========================= wandb ==========================
122 | wandb = dict(
123 | enable=False,
124 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
125 | project="videogpt", # setup in your command line
126 | )
127 | dist_url = "env://"
128 | device = "cuda"
129 | mode = "it_mistral"
130 |
131 | # ========================= others ==========================
132 | output_dir = None # output dir
133 | resume = False # if True, load optimizer and scheduler states as well
134 | debug = False
135 | log_freq = 10
136 | seed = 42
137 |
138 | save_latest = False
139 | auto_resume = True
140 | pretrained_path = "" # path to pretrained model weights, for resume only?
141 |
142 | deepspeed = dict(
143 | enable=False,
144 | stage=1,
145 | )
--------------------------------------------------------------------------------
/scripts/videochat_mistral/config_LinearProAdaFT.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 |
5 | train_corpus = "FT_Temporal_Grounding_Both"
6 |
7 |
8 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
9 | test_file = dict()
10 | test_types = []
11 | num_workers = 6
12 |
13 | stop_key = None
14 |
15 | # ========================= input ==========================
16 | num_frames_test = 16
17 | num_frames = 128
18 | clip_frames = 8
19 | good_init = True
20 |
21 | batch_size = 2
22 | max_txt_l = 1024
23 |
24 | save_iter=1000
25 |
26 | pre_text = False
27 |
28 | inputs = dict(
29 | image_res=224,
30 | video_input=dict(
31 | num_frames="${num_frames}",
32 | sample_type="rand",
33 | num_frames_test="${num_frames_test}",
34 | sample_type_test="middle",
35 | random_aug=False,
36 | ),
37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
38 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
40 | )
41 |
42 | # ========================= model ==========================
43 | model = dict(
44 | model_cls="VideoChat2_it4_mistral_LinearProAda",
45 | vit_blip_model_path="/mnt/petrelfs/share_data/likunchang/model/videochat2/umt_l16_qformer.pth",
46 | mistral_model_path="/mnt/petrelfs/share_data/likunchang/model/llm//Mistral-7B-Instruct-v0.2",
47 | videochat2_model_path="/mnt/petrelfs/share_data/likunchang/model/videochat2/videochat2_mistral_7b_stage3.pth",
48 | pretrained_path="your_LinearProAda_ckpt_path/ckpt_01.pth",
49 | freeze_vit=True,
50 | freeze_qformer=False,
51 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
52 | clip_frames="${clip_frames}",
53 | num_frames="${num_frames}",
54 | token_merge_len = 4,
55 | # vit
56 | low_resource=False,
57 | vision_encoder=dict(
58 | name="vit_l14",
59 | img_size=224,
60 | patch_size=16,
61 | d_model=1024,
62 | encoder_embed_dim=1024,
63 | encoder_depth=24,
64 | encoder_num_heads=16,
65 | drop_path_rate=0.,
66 | num_frames="${clip_frames}",
67 | tubelet_size=1,
68 | use_checkpoint=True,
69 | checkpoint_num=18,
70 | pretrained="",
71 | return_index=-2,
72 | vit_add_ln=True,
73 | ckpt_num_frame=4,
74 | ),
75 | # qformer
76 | num_query_token=32,
77 | qformer_hidden_dropout_prob=0.1,
78 | qformer_attention_probs_dropout_prob=0.1,
79 | qformer_drop_path_rate=0.2,
80 | extra_num_query_token=64,
81 | qformer_text_input=True,
82 | # prompt
83 | system="",
84 | start_token="",
86 | add_second_msg=True,
87 | img_start_token="",
88 | img_end_token="",
89 | random_shuffle=True,
90 | return_question_instruction=False,
91 | use_flash_attention=True,
92 | use_lora=True,
93 | lora_r=16,
94 | lora_alpha=32,
95 | lora_dropout=0.1,
96 | # debug=True,
97 | )
98 |
99 | optimizer = dict(
100 | opt="adamW",
101 | lr=1e-5,
102 | opt_betas=[0.9, 0.999], # default
103 | weight_decay=0.02,
104 | max_grad_norm=-1, # requires a positive float, use -1 to disable
105 | # use a different lr for some modules, e.g., larger lr for new modules
106 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
107 | )
108 |
109 | scheduler = dict(sched="cosine", epochs=5, min_lr_multi=0.1, warmup_epochs=0.1)
110 |
111 | evaluate = False
112 | deep_fusion = False
113 | evaluation = dict(
114 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
115 | eval_x_only=False,
116 | k_test=128,
117 | eval_offload=True, # offload gpu tensors to cpu to save memory.
118 | )
119 |
120 | fp16 = True
121 | gradient_checkpointing = True
122 |
123 | # ========================= wandb ==========================
124 | wandb = dict(
125 | enable=False,
126 | entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
127 | project="videogpt", # setup in your command line
128 | )
129 | dist_url = "env://"
130 | device = "cuda"
131 | mode = "it_mistral"
132 |
133 | # ========================= others ==========================
134 | output_dir = None # output dir
135 | resume = False # if True, load optimizer and scheduler states as well
136 | debug = False
137 | log_freq = 10
138 | seed = 42
139 |
140 | save_latest = False
141 | auto_resume = True
142 | pretrained_path = "" # path to pretrained model weights, for resume only?
143 |
144 | deepspeed = dict(
145 | enable=False,
146 | stage=1,
147 | )
--------------------------------------------------------------------------------
/scripts/videochat_mistral/run_7b_stage4.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 | ROOT_DIR=/path_to_the_timesuite_root_folder
10 | PARTITION='video5'
11 | TRAIN_TYPE=""
12 |
13 |
14 | # <<< Grounded Tuning Epoch 1 >>>
15 | MODEL_TYPE='LinearP'
16 | JOB_NAME="F192_CF8${TRAIN_TYPE}_${MODEL_TYPE}_TimePro_Normal"
17 |
18 |
19 | # <<< Grounded Tuning Epoch 2 & 3 >>>
20 | # MODEL_TYPE='LinearProAda'
21 | # JOB_NAME="F128_CF8${TRAIN_TYPE}_${MODEL_TYPE}_TimePro_Normal"
22 |
23 |
24 | # <<< Supervised FT on Charades-STA & QVHighlight >>>
25 | # MODEL_TYPE='LinearProAdaFT'
26 | # JOB_NAME="F128_CF8${TRAIN_TYPE}_${MODEL_TYPE}_FT_Both"
27 |
28 |
29 | NNODE=2
30 | NUM_GPUS=8
31 | NUM_CPUS=128
32 |
33 |
34 | OUTPUT_DIR="${ROOT_DIR}/$(dirname $0)/${JOB_NAME}"
35 | echo "Model Dir : ${OUTPUT_DIR}"
36 | mkdir ${OUTPUT_DIR}
37 |
38 |
39 |
40 |
41 | # srun -p ${PARTITION} \
42 | # --job-name=${JOB_NAME} \
43 | # -n${NNODE} \
44 | # --gres=gpu:${NUM_GPUS} \
45 | # --ntasks-per-node=1 \
46 | # --cpus-per-task=${NUM_CPUS} \
47 |
48 | bash torchrun.sh \
49 | --nnodes=${NNODE} \
50 | --nproc_per_node=${NUM_GPUS} \
51 | --rdzv_backend=c10d \
52 | tasks/train_it4${TRAIN_TYPE}.py \
53 | $(dirname $0)/config_${MODEL_TYPE}${TRAIN_TYPE}.py \
54 | output_dir ${OUTPUT_DIR} \
55 | 2>&1 | tee ${OUTPUT_DIR}/bash_output.log
--------------------------------------------------------------------------------
/scripts/videochat_mistral/run_7b_stage4_ds.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 | which_python=$(which python)
5 | echo "which python: ${which_python}"
6 | export PYTHONPATH=${PYTHONPATH}:${which_python}
7 | export PYTHONPATH=${PYTHONPATH}:.
8 | echo "PYTHONPATH: ${PYTHONPATH}"
9 |
10 | PARTITION='video5'
11 | JOB_NAME='stage4_ds_TimeIT'
12 | NNODE=1
13 | NUM_GPUS=8
14 | NUM_CPUS=96
15 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
16 |
17 | # srun -p ${PARTITION} \
18 | # --job-name=${JOB_NAME} \
19 | # -n${NNODE} \
20 | # --gres=gpu:${NUM_GPUS} \
21 | # --ntasks-per-node=1 \
22 | # --cpus-per-task=${NUM_CPUS} \
23 |
24 | bash torchrun.sh \
25 | --nnodes=${NNODE} \
26 | --nproc_per_node=${NUM_GPUS} \
27 | --rdzv_backend=c10d \
28 | tasks/train_it4_ds.py \
29 | $(dirname $0)/config_7b_stage4_ds.py \
30 | output_dir ${OUTPUT_DIR} \
31 | 2>&1 | tee ${OUTPUT_DIR}/${JOB_NAME}_bash_output.log
32 |
--------------------------------------------------------------------------------
/tasks/shared_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from utils.optimizer import create_optimizer
11 | from utils.scheduler import create_scheduler
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_media_types(datasources):
17 | """get the media types for for all the dataloaders.
18 |
19 | Args:
20 | datasources (List): List of dataloaders or datasets.
21 |
22 | Returns: List. The media_types.
23 |
24 | """
25 | if isinstance(datasources[0], DataLoader):
26 | datasets = [dataloader.dataset for dataloader in datasources]
27 | else:
28 | datasets = datasources
29 | media_types = [
30 | dataset.datasets[0].media_type
31 | if isinstance(dataset, ConcatDataset)
32 | else dataset.media_type
33 | for dataset in datasets
34 | ]
35 |
36 | return media_types
37 |
38 |
39 | def setup_model(
40 | config, model_cls, find_unused_parameters=False
41 | ):
42 | logger.info("Creating model")
43 | config = copy.deepcopy(config)
44 |
45 | model = model_cls(config=config.model)
46 |
47 | model = model.to(torch.device(config.device))
48 | model_without_ddp = model
49 | if config.distributed:
50 | model = torch.nn.parallel.DistributedDataParallel(
51 | model,
52 | device_ids=[config.gpu],
53 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
54 | )
55 |
56 | optimizer = create_optimizer(config.optimizer, model)
57 | scheduler = create_scheduler(config.scheduler, optimizer)
58 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
59 |
60 | start_epoch = 0
61 | global_step = 0
62 |
63 | # auto resume the latest checkpoint
64 | if config.get("auto_resume", False):
65 | logger.info("Auto resuming")
66 | model_latest = join(config.output_dir, "ckpt_latest.pth")
67 | model_best = join(config.output_dir, "ckpt_best.pth")
68 | if not osp.isfile(model_latest):
69 | large_num = -1
70 | for p in os.listdir(config.output_dir):
71 | if 'ckpt' in p:
72 | num = p.split('_')[1].split('.')[0]
73 | if str.isnumeric(num):
74 | if int(num) > large_num:
75 | large_num = int(num)
76 | if large_num != -1:
77 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
78 | if osp.isfile(model_latest):
79 | config.pretrained_path = model_latest
80 | config.resume = True
81 | elif osp.isfile(model_best):
82 | config.pretrained_path = model_best
83 | config.resume = True
84 | else:
85 | logger.info(f"Not found checkpoint in {config.output_dir}")
86 |
87 | if osp.isfile(config.pretrained_path):
88 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
89 | state_dict = checkpoint["model"]
90 |
91 | if config.resume:
92 | optimizer.load_state_dict(checkpoint["optimizer"])
93 | scheduler.load_state_dict(checkpoint["scheduler"])
94 | scaler.load_state_dict(checkpoint["scaler"])
95 | start_epoch = checkpoint["epoch"] + 1
96 | global_step = checkpoint["global_step"]
97 |
98 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
99 | logger.info(msg)
100 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
101 | else:
102 | logger.warning("No pretrained checkpoint provided, training from scratch")
103 |
104 | return (
105 | model,
106 | model_without_ddp,
107 | optimizer,
108 | scheduler,
109 | scaler,
110 | start_epoch,
111 | global_step,
112 | )
113 |
--------------------------------------------------------------------------------
/tasks/shared_utils_ds.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | import deepspeed
9 | from torch.utils.data import ConcatDataset, DataLoader
10 |
11 | from utils.optimizer import create_optimizer
12 | from utils.scheduler import create_scheduler
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def get_media_types(datasources):
18 | """get the media types for for all the dataloaders.
19 |
20 | Args:
21 | datasources (List): List of dataloaders or datasets.
22 |
23 | Returns: List. The media_types.
24 |
25 | """
26 | if isinstance(datasources[0], DataLoader):
27 | datasets = [dataloader.dataset for dataloader in datasources]
28 | else:
29 | datasets = datasources
30 | media_types = [
31 | dataset.datasets[0].media_type
32 | if isinstance(dataset, ConcatDataset)
33 | else dataset.media_type
34 | for dataset in datasets
35 | ]
36 |
37 | return media_types
38 |
39 |
40 | def setup_model(
41 | config, model_cls, find_unused_parameters=False, num_steps_per_epoch=-1,
42 | ):
43 | logger.info("Creating model")
44 | config = copy.deepcopy(config)
45 |
46 | model = model_cls(config=config.model)
47 |
48 | model = model.to(torch.device(config.device))
49 | if config.fp16:
50 | if config.get('bf16', True):
51 | logger.info("Change to bfloat16 for model")
52 | model = model.to(torch.bfloat16)
53 | else:
54 | logger.info("Change to float16 for model")
55 | model = model.half()
56 | model_without_ddp = model
57 |
58 | if hasattr(config, "deepspeed") and config.deepspeed.enable:
59 | optimizer_params = create_optimizer(config.optimizer, model, return_group=True)
60 | scheduler = None
61 | scaler = None
62 | else:
63 | if config.distributed:
64 | model = torch.nn.parallel.DistributedDataParallel(
65 | model,
66 | device_ids=[config.gpu],
67 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
68 | )
69 |
70 | optimizer = create_optimizer(config.optimizer, model)
71 | scheduler = create_scheduler(config.scheduler, optimizer)
72 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
73 |
74 | start_epoch = 0
75 | global_step = 0
76 |
77 | # auto resume the latest checkpoint
78 | if config.get("auto_resume", False):
79 | logger.info("Auto resuming")
80 | model_latest = join(config.output_dir, "ckpt_latest.pth")
81 | model_best = join(config.output_dir, "ckpt_best.pth")
82 |
83 | large_step_num = -1
84 | large_num = -1
85 | for p in os.listdir(config.output_dir):
86 | if 'ckpt_iter' in p:
87 | num = p.split('_iter')[1].split('.')[0]
88 | if str.isnumeric(num):
89 | if int(num) > large_step_num:
90 | large_step_num = int(num)
91 | elif 'ckpt_' in p:
92 | num = p.split('_')[1].split('.')[0]
93 | if str.isnumeric(num):
94 | if int(num) > large_num:
95 | large_num = int(num)
96 | if large_step_num != -1:
97 | logger.info(f"Load the latest step: {large_step_num}")
98 | model_latest = join(config.output_dir, f"ckpt_iter{large_step_num:02d}.pth")
99 | if large_num != -1 and (large_num + 1) * num_steps_per_epoch > large_step_num:
100 | logger.info(f"Load the latest epoch: {large_num}")
101 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
102 |
103 | if hasattr(config, "deepspeed") and config.deepspeed.enable:
104 | if osp.isdir(model_latest):
105 | config.pretrained_path = model_latest
106 | config.resume = True
107 | elif osp.isdir(model_best):
108 | config.pretrained_path = model_best
109 | config.resume = True
110 | else:
111 | logger.info(f"Not found checkpoint in {config.output_dir}")
112 | else:
113 | if osp.isfile(model_latest):
114 | config.pretrained_path = model_latest
115 | config.resume = True
116 | elif osp.isfile(model_best):
117 | config.pretrained_path = model_best
118 | config.resume = True
119 | else:
120 | logger.info(f"Not found checkpoint in {config.output_dir}")
121 |
122 | # load pretrained model
123 | if hasattr(config, "deepspeed") and config.deepspeed.enable:
124 | logger.info('Use deepspeed to initialize model!!!')
125 | model = model_without_ddp
126 | model, optimizer, _, _ = deepspeed.initialize(
127 | args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed,
128 | lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt)
129 | )
130 | if osp.isdir(config.pretrained_path):
131 | logger.info(f"Load pretrained model from {config.pretrained_path}")
132 | output_dir, tag = os.path.split(config.pretrained_path)
133 | if config.resume:
134 | _, client_state = model.load_checkpoint(output_dir, tag=tag, load_module_strict=False)
135 | global_step = model.global_steps
136 | assert num_steps_per_epoch > 0, "Please provide num_steps_per_epoch"
137 | start_epoch = global_step // num_steps_per_epoch
138 | else:
139 | _, client_state = model.load_checkpoint(
140 | output_dir, tag=tag, load_module_strict=False,
141 | load_optimizer_states=False, load_lr_scheduler_states=False,
142 | load_module_only=True
143 | )
144 | else:
145 | if osp.isfile(config.pretrained_path):
146 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
147 | logger.info(f"Load pretrained model from {config.pretrained_path}")
148 | if 'model' in checkpoint.keys():
149 | state_dict = checkpoint["model"]
150 | elif 'module' in checkpoint.keys():
151 | state_dict = checkpoint["module"]
152 | else:
153 | state_dict = checkpoint
154 | # resume optimizer
155 | if config.resume:
156 | optimizer.load_state_dict(checkpoint["optimizer"])
157 | scheduler.load_state_dict(checkpoint["scheduler"])
158 | scaler.load_state_dict(checkpoint["scaler"])
159 | start_epoch = checkpoint["epoch"] + 1
160 | global_step = checkpoint["global_step"]
161 |
162 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
163 | logger.info(msg)
164 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
165 | else:
166 | logger.warning("No pretrained checkpoint provided, training from scratch")
167 |
168 | logger.info(f"Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M")
169 |
170 | return (
171 | model,
172 | model_without_ddp,
173 | optimizer,
174 | scheduler,
175 | scaler,
176 | start_epoch,
177 | global_step,
178 | )
179 |
--------------------------------------------------------------------------------
/tasks/shared_utils_qformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from models.bert.tokenization_bert import BertTokenizer
11 | from utils.optimizer import create_optimizer
12 | from utils.scheduler import create_scheduler
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def get_media_types(datasources):
18 | """get the media types for for all the dataloaders.
19 |
20 | Args:
21 | datasources (List): List of dataloaders or datasets.
22 |
23 | Returns: List. The media_types.
24 |
25 | """
26 | if isinstance(datasources[0], DataLoader):
27 | datasets = [dataloader.dataset for dataloader in datasources]
28 | else:
29 | datasets = datasources
30 | media_types = [
31 | dataset.datasets[0].media_type
32 | if isinstance(dataset, ConcatDataset)
33 | else dataset.media_type
34 | for dataset in datasets
35 | ]
36 |
37 | return media_types
38 |
39 |
40 | def setup_model(
41 | config, model_cls, find_unused_parameters=False
42 | ):
43 | logger.info("Creating model")
44 | config = copy.deepcopy(config)
45 |
46 | if "bert" in config.model.text_encoder.name:
47 | tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True)
48 | else:
49 | raise ValueError(f"Not supported text encoder.")
50 |
51 | model = model_cls(config=config, tokenizer=tokenizer)
52 |
53 | model = model.to(torch.device(config.device))
54 | model_without_ddp = model
55 | if config.distributed:
56 | model = torch.nn.parallel.DistributedDataParallel(
57 | model,
58 | device_ids=[config.gpu],
59 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
60 | )
61 |
62 | optimizer = create_optimizer(config.optimizer, model)
63 | scheduler = create_scheduler(config.scheduler, optimizer)
64 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
65 |
66 | start_epoch = 0
67 | global_step = 0
68 |
69 | # auto resume the latest checkpoint
70 | if config.get("auto_resume", False):
71 | logger.info("Auto resuming")
72 | model_latest = join(config.output_dir, "ckpt_latest.pth")
73 | model_best = join(config.output_dir, "ckpt_best.pth")
74 | large_num = -1
75 | for p in os.listdir(config.output_dir):
76 | if 'ckpt' in p:
77 | num = p.split('_')[1].split('.')[0]
78 | if str.isnumeric(num):
79 | if int(num) > large_num:
80 | large_num = int(num)
81 | if large_num != -1:
82 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
83 | if osp.isfile(model_latest):
84 | config.pretrained_path = model_latest
85 | config.resume = True
86 | elif osp.isfile(model_best):
87 | config.pretrained_path = model_best
88 | config.resume = True
89 | else:
90 | logger.info(f"Not found checkpoint in {config.output_dir}")
91 |
92 | if osp.isfile(config.pretrained_path):
93 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
94 | state_dict = checkpoint["model"]
95 |
96 | if config.resume:
97 | optimizer.load_state_dict(checkpoint["optimizer"])
98 | scheduler.load_state_dict(checkpoint["scheduler"])
99 | scaler.load_state_dict(checkpoint["scaler"])
100 | start_epoch = checkpoint["epoch"] + 1
101 | global_step = checkpoint["global_step"]
102 |
103 |
104 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
105 | logger.info(msg)
106 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
107 | else:
108 | logger.warning("No pretrained checkpoint provided, training from scratch")
109 |
110 | return (
111 | model,
112 | model_without_ddp,
113 | optimizer,
114 | scheduler,
115 | scaler,
116 | tokenizer,
117 | start_epoch,
118 | global_step,
119 | )
120 |
--------------------------------------------------------------------------------
/tasks/train_it.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 | from os.path import join
5 |
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.distributed as dist
9 | import wandb
10 |
11 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler
12 | from models import *
13 | from tasks.shared_utils import get_media_types, setup_model
14 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed)
15 | from utils.config_utils import setup_main
16 | from utils.distributed import get_rank, get_world_size, is_main_process
17 | from utils.logger import log_dict_to_wandb, setup_wandb
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def train(
23 | model,
24 | train_loaders,
25 | optimizer,
26 | epoch,
27 | global_step,
28 | device,
29 | scheduler,
30 | scaler,
31 | config,
32 | ):
33 | model.train()
34 |
35 | metric_logger = MetricLogger(delimiter=" ")
36 | metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}"))
37 | loss_names = ["loss"]
38 |
39 | media_types = get_media_types(train_loaders)
40 |
41 | for name in loss_names:
42 | for m in media_types:
43 | metric_logger.add_meter(
44 | f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}")
45 | )
46 |
47 | header = f"Train Epoch: [{epoch}]"
48 | log_freq = config.log_freq
49 |
50 | if config.distributed:
51 | for d in train_loaders:
52 | d.sampler.set_epoch(epoch)
53 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
54 |
55 | iterator = metric_logger.log_every(train_loader, log_freq, header)
56 | for i, (media_type, (image, text, instruction, _)) in enumerate(iterator):
57 | image = image.to(device, non_blocking=True)
58 |
59 | with torch.cuda.amp.autocast(enabled=config.fp16):
60 | loss_dict = model(image, text, instruction)
61 | loss = sum(loss_dict.values())
62 |
63 | optimizer.zero_grad()
64 | scaler.scale(loss).backward()
65 | if config.optimizer.max_grad_norm > 0:
66 | scaler.unscale_(optimizer)
67 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
68 | scaler.step(optimizer)
69 | scaler.update()
70 | scheduler.step()
71 |
72 | # logging
73 | for name in loss_names:
74 | value = loss_dict[name]
75 | value = value if isinstance(value, float) else value.item()
76 | metric_logger.update(**{f"{media_type}-{name}": value})
77 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
78 |
79 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
80 | logs = metric_logger.get_global_avg_dict()
81 | log_dict_to_wandb(logs, step=global_step, prefix="train/")
82 |
83 | global_step += 1
84 |
85 | if config.debug and global_step % 20 == 0:
86 | logger.info("debug mode, break training loop")
87 | break
88 |
89 | if config.debug and global_step % (2 * log_freq + 3) == 0:
90 | logger.info("debug mode, break training loop")
91 | break
92 |
93 | # gather the stats from all processes
94 | metric_logger.synchronize_between_processes()
95 | logger.info(f"Averaged stats: {metric_logger.global_avg()}")
96 | return global_step
97 |
98 |
99 | def setup_dataloaders(config, mode="pt"):
100 | # train datasets, create a list of data loaders
101 | logger.info(f"Creating dataset for {mode}")
102 | train_datasets = create_dataset(f"{mode}_train", config)
103 | media_types = get_media_types(train_datasets)
104 |
105 | if config.distributed:
106 | num_tasks = get_world_size()
107 | global_rank = get_rank()
108 | samplers = create_sampler(
109 | train_datasets, [True] * len(media_types), num_tasks, global_rank
110 | )
111 | else:
112 | samplers = [None] * len(media_types)
113 |
114 | train_loaders = create_loader(
115 | train_datasets,
116 | samplers,
117 | batch_size=[config.inputs.batch_size[k] for k in media_types],
118 | num_workers=[config.num_workers] * len(media_types),
119 | is_trains=[True] * len(media_types),
120 | collate_fns=[None] * len(media_types),
121 | ) # [0]
122 |
123 | return train_loaders, media_types
124 |
125 |
126 | def main(config):
127 | if is_main_process() and config.wandb.enable:
128 | run = setup_wandb(config)
129 |
130 | logger.info(f"train_file: {config.train_file}")
131 |
132 | setup_seed(config.seed + get_rank())
133 | device = torch.device(config.device)
134 |
135 | train_loaders, train_media_types = setup_dataloaders(
136 | config, mode=config.mode
137 | )
138 | num_steps_per_epoch = sum(len(d) for d in train_loaders)
139 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
140 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
141 | # set cudnn.benchmark=True only when input size is fixed
142 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
143 | cudnn.benchmark = len(train_media_types) == 1
144 |
145 | model_cls = eval(config.model.get('model_cls', 'VideoChat2_it_vicuna'))
146 | (
147 | model,
148 | model_without_ddp,
149 | optimizer,
150 | scheduler,
151 | scaler,
152 | start_epoch,
153 | global_step,
154 | ) = setup_model(
155 | config,
156 | model_cls=model_cls,
157 | # find_unused_parameters=True,
158 | find_unused_parameters=False,
159 | )
160 | if is_main_process() and config.wandb.enable:
161 | wandb.watch(model)
162 |
163 | logger.info("Start training")
164 | start_time = time.time()
165 | for epoch in range(start_epoch, config.scheduler.epochs):
166 | if not config.evaluate:
167 | global_step = train(
168 | model,
169 | train_loaders,
170 | optimizer,
171 | epoch,
172 | global_step,
173 | device,
174 | scheduler,
175 | scaler,
176 | config,
177 | )
178 |
179 | if is_main_process():
180 | logger.info(f"Epoch {epoch}")
181 | param_grad_dic = {
182 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
183 | }
184 | state_dict = model_without_ddp.state_dict()
185 | for k in list(state_dict.keys()):
186 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
187 | # delete parameters that do not require gradient
188 | del state_dict[k]
189 | save_obj = {
190 | "model": state_dict,
191 | "optimizer": optimizer.state_dict(),
192 | "scheduler": scheduler.state_dict(),
193 | "scaler": scaler.state_dict(),
194 | "config": config,
195 | "epoch": epoch,
196 | "global_step": global_step,
197 | }
198 | if config.get("save_latest", False):
199 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
200 | else:
201 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
202 |
203 | if config.evaluate:
204 | break
205 |
206 | dist.barrier()
207 |
208 | total_time = time.time() - start_time
209 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
210 | logger.info(f"Training time {total_time_str}")
211 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
212 |
213 | if is_main_process() and config.wandb.enable:
214 | run.finish()
215 |
216 |
217 | if __name__ == "__main__":
218 | cfg = setup_main()
219 | main(cfg)
220 |
--------------------------------------------------------------------------------
/tasks/train_pt.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 | from os.path import join
5 |
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.distributed as dist
9 | import wandb
10 |
11 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler
12 | from models import *
13 | from tasks.shared_utils import get_media_types, setup_model
14 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed)
15 | from utils.config_utils import setup_main
16 | from utils.distributed import get_rank, get_world_size, is_main_process
17 | from utils.logger import log_dict_to_wandb, setup_wandb
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def train(
23 | model,
24 | train_loaders,
25 | optimizer,
26 | epoch,
27 | global_step,
28 | device,
29 | scheduler,
30 | scaler,
31 | config,
32 | ):
33 | model.train()
34 |
35 | metric_logger = MetricLogger(delimiter=" ")
36 | metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}"))
37 | loss_names = ["loss"]
38 |
39 | media_types = get_media_types(train_loaders)
40 |
41 | for name in loss_names:
42 | for m in media_types:
43 | metric_logger.add_meter(
44 | f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}")
45 | )
46 |
47 | header = f"Train Epoch: [{epoch}]"
48 | log_freq = config.log_freq
49 |
50 | if config.distributed:
51 | for d in train_loaders:
52 | d.sampler.set_epoch(epoch)
53 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
54 |
55 | iterator = metric_logger.log_every(train_loader, log_freq, header)
56 | for i, (media_type, (image, text, _)) in enumerate(iterator):
57 | image = image.to(device, non_blocking=True)
58 |
59 | with torch.cuda.amp.autocast(enabled=config.fp16):
60 | loss_dict = model(image, text)
61 | loss = sum(loss_dict.values())
62 |
63 | optimizer.zero_grad()
64 | scaler.scale(loss).backward()
65 | if config.optimizer.max_grad_norm > 0:
66 | scaler.unscale_(optimizer)
67 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
68 | scaler.step(optimizer)
69 | scaler.update()
70 | scheduler.step()
71 |
72 | # logging
73 | for name in loss_names:
74 | value = loss_dict[name]
75 | value = value if isinstance(value, float) else value.item()
76 | metric_logger.update(**{f"{media_type}-{name}": value})
77 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
78 |
79 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
80 | logs = metric_logger.get_global_avg_dict()
81 | log_dict_to_wandb(logs, step=global_step, prefix="train/")
82 |
83 | global_step += 1
84 |
85 | if config.debug and global_step % 20 == 0:
86 | logger.info("debug mode, break training loop")
87 | break
88 |
89 | if config.debug and global_step % (2 * log_freq + 3) == 0:
90 | logger.info("debug mode, break training loop")
91 | break
92 |
93 | # gather the stats from all processes
94 | metric_logger.synchronize_between_processes()
95 | logger.info(f"Averaged stats: {metric_logger.global_avg()}")
96 | return global_step
97 |
98 |
99 | def setup_dataloaders(config, mode="pt"):
100 | # train datasets, create a list of data loaders
101 | logger.info(f"Creating dataset for {mode}")
102 | train_datasets = create_dataset(f"{mode}_train", config)
103 | media_types = get_media_types(train_datasets)
104 |
105 | if config.distributed:
106 | num_tasks = get_world_size()
107 | global_rank = get_rank()
108 | samplers = create_sampler(
109 | train_datasets, [True] * len(media_types), num_tasks, global_rank
110 | )
111 | else:
112 | samplers = [None] * len(media_types)
113 |
114 | train_loaders = create_loader(
115 | train_datasets,
116 | samplers,
117 | batch_size=[config.inputs.batch_size[k] for k in media_types],
118 | num_workers=[config.num_workers] * len(media_types),
119 | is_trains=[True] * len(media_types),
120 | collate_fns=[None] * len(media_types),
121 | ) # [0]
122 |
123 | return train_loaders, media_types
124 |
125 |
126 | def main(config):
127 | if is_main_process() and config.wandb.enable:
128 | run = setup_wandb(config)
129 |
130 | logger.info(f"train_file: {config.train_file}")
131 |
132 | setup_seed(config.seed + get_rank())
133 | device = torch.device(config.device)
134 |
135 | train_loaders, train_media_types = setup_dataloaders(
136 | config, mode=config.mode
137 | )
138 | num_steps_per_epoch = sum(len(d) for d in train_loaders)
139 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
140 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
141 | # set cudnn.benchmark=True only when input size is fixed
142 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
143 | cudnn.benchmark = len(train_media_types) == 1
144 |
145 | model_cls = eval(config.model.get('model_cls', 'VideoChat2_pt_vicuna'))
146 | (
147 | model,
148 | model_without_ddp,
149 | optimizer,
150 | scheduler,
151 | scaler,
152 | start_epoch,
153 | global_step,
154 | ) = setup_model(
155 | config,
156 | model_cls=model_cls,
157 | find_unused_parameters=True,
158 | )
159 | if is_main_process() and config.wandb.enable:
160 | wandb.watch(model)
161 |
162 | logger.info("Start training")
163 | start_time = time.time()
164 | for epoch in range(start_epoch, config.scheduler.epochs):
165 | if not config.evaluate:
166 | global_step = train(
167 | model,
168 | train_loaders,
169 | optimizer,
170 | epoch,
171 | global_step,
172 | device,
173 | scheduler,
174 | scaler,
175 | config,
176 | )
177 |
178 | if is_main_process():
179 | logger.info(f"Epoch {epoch}")
180 | param_grad_dic = {
181 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
182 | }
183 | state_dict = model_without_ddp.state_dict()
184 | for k in list(state_dict.keys()):
185 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
186 | # delete parameters that do not require gradient
187 | del state_dict[k]
188 | save_obj = {
189 | "model": state_dict,
190 | "optimizer": optimizer.state_dict(),
191 | "scheduler": scheduler.state_dict(),
192 | "scaler": scaler.state_dict(),
193 | "config": config,
194 | "epoch": epoch,
195 | "global_step": global_step,
196 | }
197 | if config.get("save_latest", False):
198 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
199 | else:
200 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
201 |
202 | if config.evaluate:
203 | break
204 |
205 | dist.barrier()
206 |
207 | total_time = time.time() - start_time
208 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
209 | logger.info(f"Training time {total_time_str}")
210 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
211 |
212 | if is_main_process() and config.wandb.enable:
213 | run.finish()
214 |
215 |
216 | if __name__ == "__main__":
217 | cfg = setup_main()
218 | main(cfg)
219 |
--------------------------------------------------------------------------------
/torchrun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
3 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
4 | MASTER_PORT=$((10000 + $RANDOM % 100))
5 |
6 | echo "All nodes used:"
7 | echo ${ALL_NODES}
8 | echo "Master node:"
9 | echo ${MASTER_NODE}
10 | echo "Master port:"
11 | echo ${MASTER_PORT}
12 | echo "Args:"
13 | echo $@
14 |
15 | torchrun --rdzv_endpoint=${MASTER_NODE}:10072 $@
16 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import torch.distributed as dist
5 | from os.path import dirname, join
6 |
7 | from utils.config import Config
8 | from utils.distributed import init_distributed_mode, is_main_process
9 | from utils.logger import setup_logger
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def setup_config():
15 | """Conbine yaml config and command line config with OmegaConf.
16 | Also converts types, e.g., `'None'` (str) --> `None` (None)
17 | """
18 | config = Config.get_config()
19 | if config.debug:
20 | config.wandb.enable = False
21 | return config
22 |
23 |
24 | def setup_evaluate_config(config):
25 | """setup evaluation default settings, e.g., disable wandb"""
26 | assert config.evaluate
27 | config.wandb.enable = False
28 | if config.output_dir is None:
29 | config.output_dir = join(dirname(config.pretrained_path), "eval")
30 | return config
31 |
32 |
33 | def setup_output_dir(output_dir, excludes=["code"]):
34 | """ensure not overwritting an exisiting/non-empty output dir"""
35 | if not os.path.exists(output_dir):
36 | os.makedirs(output_dir, exist_ok=False)
37 | else:
38 | existing_dirs_files = os.listdir(output_dir) # list
39 | remaining = set(existing_dirs_files) - set(excludes)
40 | remaining = [e for e in remaining if "slurm" not in e]
41 | remaining = [e for e in remaining if ".out" not in e]
42 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
43 | logger.warn(f"remaining dirs or files: {remaining}")
44 |
45 |
46 | def setup_deepspeed_zero_config(stage):
47 | # We currently set ZeRO based on stage:
48 | if stage == 1:
49 | return {"stage": 1, "reduce_bucket_size": 5e8}
50 | if stage == 2:
51 | return {
52 | "stage": 2,
53 | "contiguous_gradients": False,
54 | "overlap_comm": False,
55 | "reduce_scatter": True,
56 | "reduce_bucket_size": 5e8,
57 | "allgather_bucket_size": 5e8,
58 | "offload_optimizer": {
59 | "device": "cpu"
60 | },
61 | }
62 | # return {
63 | # "stage": 2,
64 | # "contiguous_gradients": True,
65 | # "overlap_comm": True,
66 | # "reduce_scatter": True,
67 | # "reduce_bucket_size": 5e8,
68 | # "allgather_bucket_size": 5e8,
69 | # "cpu_offload": False,
70 | # }
71 | if stage == 3:
72 | return {
73 | "stage": 3,
74 | "contiguous_gradients": True,
75 | "stage3_max_live_parameters": 1e9,
76 | "stage3_max_reuse_distance": 1e9,
77 | "stage3_prefetch_bucket_size": 1e7,
78 | "stage3_param_persistence_threshold": 1e5,
79 | "reduce_bucket_size": 1e7,
80 | "sub_group_size": 1e9,
81 | "offload_optimizer": {
82 | "device": "cpu"
83 | },
84 | "offload_param": {
85 | "device": "cpu"
86 | }
87 | }
88 | # return {
89 | # "stage": 3,
90 | # "contiguous_gradients": True,
91 | # "overlap_comm": True,
92 | # "reduce_scatter": True,
93 | # "reduce_bucket_size": 5e4,
94 | # "allgather_bucket_size": 5e4,
95 | # "cpu_offload": False,
96 | # "stage3_max_live_parameters": 1e5,
97 | # "stage3_max_reuse_distance": 1e5,
98 | # }
99 |
100 | raise ValueError("Wrong stage for deepspeed {}".format(stage.stage))
101 |
102 |
103 | def setup_deepspeed_config(config):
104 | config.deepspeed_config = os.path.join(config.output_dir, "deepspeed_config.json")
105 | opts = config.optimizer
106 | logger.info(f'Write deepspeed config to {config.deepspeed_config}')
107 | if not is_main_process():
108 | return config
109 |
110 | os.makedirs(config.output_dir, exist_ok=True)
111 |
112 | with open(config.deepspeed_config, mode="w") as writer:
113 | ds_config = {
114 | "train_batch_size": config.batch_size * dist.get_world_size(),
115 | "train_micro_batch_size_per_gpu": config.batch_size,
116 | "steps_per_print": 100,
117 | "optimizer": {
118 | "type": "Adam",
119 | "adam_w_mode": True,
120 | "params": {
121 | "lr": opts.lr,
122 | "weight_decay": opts.weight_decay,
123 | "bias_correction": True,
124 | "betas": [
125 | opts.opt_betas[0],
126 | opts.opt_betas[1],
127 | ],
128 | "eps": 1e-8
129 | }
130 | }
131 | }
132 | if config.deepspeed.stage != 0:
133 | ds_config["zero_optimization"] = setup_deepspeed_zero_config(config.deepspeed.stage)
134 |
135 | if config.fp16:
136 | if config.get('bf16', True):
137 | ds_config["bf16"] = {
138 | "enabled": True
139 | }
140 | else:
141 | ds_config["fp16"] = {
142 | "enabled": True,
143 | "auto_cast": False,
144 | "loss_scale": 0,
145 | "initial_scale_power": 16,
146 | "loss_scale_window": 1000,
147 | "hysteresis": 2,
148 | "consecutive_hysteresis": False,
149 | "min_loss_scale": 1
150 | }
151 | else:
152 | assert config.deepspeed.stage == 0, "You must use fp16 or bf16 when using ZERO!!!"
153 |
154 | if config.get("max_grad_norm", -1) > 0:
155 | ds_config.update({"gradient_clipping", config.max_grad_norm})
156 |
157 | writer.write(json.dumps(ds_config, indent=2))
158 |
159 | return config
160 |
161 |
162 | def setup_main():
163 | """
164 | Setup config, logger, output_dir, etc.
165 | Shared for pretrain and all downstream tasks.
166 | """
167 | config = setup_config()
168 | if hasattr(config, "evaluate") and config.evaluate:
169 | config = setup_evaluate_config(config)
170 | init_distributed_mode(config)
171 |
172 | if hasattr(config, "deepspeed") and config.deepspeed.enable:
173 | config = setup_deepspeed_config(config)
174 |
175 | if is_main_process():
176 | setup_output_dir(config.output_dir, excludes=["code"])
177 | setup_logger(output=config.output_dir, color=True, name="vindlu")
178 | logger.info(f"config: {Config.pretty_text(config)}")
179 | Config.dump(config, os.path.join(config.output_dir, "config.json"))
180 | return config
181 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import logging
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def setup_for_distributed(is_master):
11 | import warnings
12 |
13 | builtin_warn = warnings.warn
14 |
15 | def warn(*args, **kwargs):
16 | force = kwargs.pop("force", False)
17 | if is_master or force:
18 | builtin_warn(*args, **kwargs)
19 |
20 | # Log warnings only once
21 | warnings.warn = warn
22 | warnings.simplefilter("once", UserWarning)
23 |
24 | if not is_master:
25 | logging.disable()
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 |
36 | def get_world_size():
37 | if not is_dist_avail_and_initialized():
38 | return 1
39 | return dist.get_world_size()
40 |
41 |
42 | def get_rank():
43 | if not is_dist_avail_and_initialized():
44 | return 0
45 | return dist.get_rank()
46 |
47 |
48 | def is_main_process():
49 | return get_rank() == 0
50 |
51 |
52 | def save_on_master(*args, **kwargs):
53 | if is_main_process():
54 | torch.save(*args, **kwargs)
55 |
56 |
57 | def is_port_in_use(port):
58 | import socket
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(('localhost', port)) == 0
61 |
62 |
63 | def init_distributed_mode(args):
64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65 | # job started by torch.distributed.launch
66 | args.rank = int(os.environ["RANK"])
67 | args.world_size = int(os.environ['WORLD_SIZE'])
68 | args.gpu = int(os.environ['LOCAL_RANK'])
69 | elif 'SLURM_PROCID' in os.environ:
70 | # local rank on the current node / global rank
71 | local_rank = int(os.environ['SLURM_LOCALID'])
72 | global_rank = int(os.environ['SLURM_PROCID'])
73 | # number of processes / GPUs per node
74 | world_size = int(os.environ["SLURM_NNODES"]) * \
75 | int(os.environ["SLURM_TASKS_PER_NODE"][0])
76 |
77 | print(world_size)
78 |
79 | args.rank = global_rank
80 | args.gpu = local_rank
81 | args.world_size = world_size
82 | else:
83 | logger.info('Not using distributed mode')
84 | args.distributed = False
85 | return
86 |
87 | args.distributed = True
88 |
89 | torch.cuda.set_device(args.gpu)
90 | args.dist_backend = 'nccl'
91 |
92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
93 | dist_port = int(args.dist_url.split(":")[-1])
94 | while is_port_in_use(dist_port):
95 | dist_port += 10
96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
97 |
98 | logger.info('| distributed init (rank {}): {}'.format(
99 | args.rank, args.dist_url))
100 | if "SLURM_JOB_ID" in os.environ:
101 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
102 | torch.distributed.init_process_group(
103 | backend=args.dist_backend, init_method=args.dist_url,
104 | world_size=args.world_size, rank=args.rank)
105 | torch.distributed.barrier()
106 | setup_for_distributed(args.rank == 0)
107 |
108 |
109 | # Copyright (c) Facebook, Inc. and its affiliates.
110 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
111 | class GatherLayer(torch.autograd.Function):
112 | """
113 | Gather tensors from all workers with support for backward propagation:
114 | This implementation does not cut the gradients as torch.distributed.all_gather does.
115 | """
116 |
117 | @staticmethod
118 | def forward(ctx, x):
119 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
120 | dist.all_gather(output, x)
121 | return tuple(output)
122 |
123 | @staticmethod
124 | def backward(ctx, *grads):
125 | all_gradients = torch.stack(grads)
126 | dist.all_reduce(all_gradients)
127 | return all_gradients[dist.get_rank()]
128 |
129 |
130 | # copied from megavlt
131 | def gather_tensor_along_batch_with_backward(tensor, dim=0):
132 | world_size = get_world_size()
133 |
134 | if world_size < 2:
135 | return tensor
136 |
137 | tensor_list = GatherLayer.apply(tensor)
138 | tensor_list = torch.cat(tensor_list, dim=dim)
139 | return tensor_list
140 |
141 |
142 | @torch.no_grad()
143 | def gather_tensor_along_batch(tensor, dim=0):
144 | """
145 | Performs all_gather operation on the provided tensors.
146 | *** Warning ***: torch.distributed.all_gather has no gradient.
147 | """
148 | world_size = get_world_size()
149 |
150 | if world_size < 2:
151 | return tensor
152 |
153 | with torch.no_grad():
154 | tensor_list = []
155 |
156 | for _ in range(world_size):
157 | tensor_list.append(torch.zeros_like(tensor))
158 |
159 | dist.all_gather(tensor_list, tensor)
160 | tensor_list = torch.cat(tensor_list, dim=dim)
161 | return tensor_list
162 |
--------------------------------------------------------------------------------
/utils/easydict.py:
--------------------------------------------------------------------------------
1 | class EasyDict(dict):
2 | """
3 | Get attributes
4 |
5 | >>> d = EasyDict({'foo':3})
6 | >>> d['foo']
7 | 3
8 | >>> d.foo
9 | 3
10 | >>> d.bar
11 | Traceback (most recent call last):
12 | ...
13 | AttributeError: 'EasyDict' object has no attribute 'bar'
14 |
15 | Works recursively
16 |
17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18 | >>> isinstance(d.bar, dict)
19 | True
20 | >>> d.bar.x
21 | 1
22 |
23 | Bullet-proof
24 |
25 | >>> EasyDict({})
26 | {}
27 | >>> EasyDict(d={})
28 | {}
29 | >>> EasyDict(None)
30 | {}
31 | >>> d = {'a': 1}
32 | >>> EasyDict(**d)
33 | {'a': 1}
34 |
35 | Set attributes
36 |
37 | >>> d = EasyDict()
38 | >>> d.foo = 3
39 | >>> d.foo
40 | 3
41 | >>> d.bar = {'prop': 'value'}
42 | >>> d.bar.prop
43 | 'value'
44 | >>> d
45 | {'foo': 3, 'bar': {'prop': 'value'}}
46 | >>> d.bar.prop = 'newer'
47 | >>> d.bar.prop
48 | 'newer'
49 |
50 |
51 | Values extraction
52 |
53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54 | >>> isinstance(d.bar, list)
55 | True
56 | >>> from operator import attrgetter
57 | >>> map(attrgetter('x'), d.bar)
58 | [1, 3]
59 | >>> map(attrgetter('y'), d.bar)
60 | [2, 4]
61 | >>> d = EasyDict()
62 | >>> d.keys()
63 | []
64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65 | >>> d.foo
66 | 3
67 | >>> d.bar.x
68 | 1
69 |
70 | Still like a dict though
71 |
72 | >>> o = EasyDict({'clean':True})
73 | >>> o.items()
74 | [('clean', True)]
75 |
76 | And like a class
77 |
78 | >>> class Flower(EasyDict):
79 | ... power = 1
80 | ...
81 | >>> f = Flower()
82 | >>> f.power
83 | 1
84 | >>> f = Flower({'height': 12})
85 | >>> f.height
86 | 12
87 | >>> f['power']
88 | 1
89 | >>> sorted(f.keys())
90 | ['height', 'power']
91 |
92 | update and pop items
93 | >>> d = EasyDict(a=1, b='2')
94 | >>> e = EasyDict(c=3.0, a=9.0)
95 | >>> d.update(e)
96 | >>> d.c
97 | 3.0
98 | >>> d['c']
99 | 3.0
100 | >>> d.get('c')
101 | 3.0
102 | >>> d.update(a=4, b=4)
103 | >>> d.b
104 | 4
105 | >>> d.pop('a')
106 | 4
107 | >>> d.a
108 | Traceback (most recent call last):
109 | ...
110 | AttributeError: 'EasyDict' object has no attribute 'a'
111 | """
112 |
113 | def __init__(self, d=None, **kwargs):
114 | if d is None:
115 | d = {}
116 | if kwargs:
117 | d.update(**kwargs)
118 | for k, v in d.items():
119 | setattr(self, k, v)
120 | # Class attributes
121 | for k in self.__class__.__dict__.keys():
122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123 | setattr(self, k, getattr(self, k))
124 |
125 | def __setattr__(self, name, value):
126 | if isinstance(value, (list, tuple)):
127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
129 | value = self.__class__(value)
130 | super(EasyDict, self).__setattr__(name, value)
131 | super(EasyDict, self).__setitem__(name, value)
132 |
133 | __setitem__ = __setattr__
134 |
135 | def update(self, e=None, **f):
136 | d = e or dict()
137 | d.update(f)
138 | for k in d:
139 | setattr(self, k, d[k])
140 |
141 | def pop(self, k, d=None):
142 | if hasattr(self, k):
143 | delattr(self, k)
144 | return super(EasyDict, self).pop(k, d)
145 |
146 |
147 | if __name__ == "__main__":
148 | import doctest
149 |
150 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import re
5 | import torch
6 | from torch import optim as optim
7 | from utils.distributed import is_main_process
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | try:
11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
12 | has_apex = True
13 | except ImportError:
14 | has_apex = False
15 |
16 |
17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
18 | named_param_tuples = []
19 | for name, param in model.named_parameters():
20 | if not param.requires_grad:
21 | continue # frozen weights
22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
23 | named_param_tuples.append([name, param, 0])
24 | elif name in no_decay_list:
25 | named_param_tuples.append([name, param, 0])
26 | else:
27 | named_param_tuples.append([name, param, weight_decay])
28 | return named_param_tuples
29 |
30 |
31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
32 | """use lr=diff_lr for modules named found in diff_lr_names,
33 | otherwise use lr=default_lr
34 |
35 | Args:
36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
37 | diff_lr_names: List(str)
38 | diff_lr: float
39 | default_lr: float
40 | Returns:
41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr])
42 | """
43 | named_param_tuples_with_lr = []
44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
45 | for name, p, wd in named_param_tuples_or_model:
46 | use_diff_lr = False
47 | for diff_name in diff_lr_names:
48 | # if diff_name in name:
49 | if re.search(diff_name, name) is not None:
50 | logger.info(f"param {name} use different_lr: {diff_lr}")
51 | use_diff_lr = True
52 | break
53 |
54 | named_param_tuples_with_lr.append(
55 | [name, p, wd, diff_lr if use_diff_lr else default_lr]
56 | )
57 |
58 | if is_main_process():
59 | for name, _, wd, diff_lr in named_param_tuples_with_lr:
60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
61 |
62 | return named_param_tuples_with_lr
63 |
64 |
65 | def create_optimizer_params_group(named_param_tuples_with_lr):
66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
67 | group = {}
68 | for name, p, wd, lr in named_param_tuples_with_lr:
69 | if wd not in group:
70 | group[wd] = {}
71 | if lr not in group[wd]:
72 | group[wd][lr] = []
73 | group[wd][lr].append(p)
74 |
75 | optimizer_params_group = []
76 | for wd, lr_groups in group.items():
77 | for lr, p in lr_groups.items():
78 | optimizer_params_group.append(dict(
79 | params=p,
80 | weight_decay=wd,
81 | lr=lr
82 | ))
83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
84 | return optimizer_params_group
85 |
86 |
87 | def create_optimizer(args, model, filter_bias_and_bn=True, return_group=False):
88 | opt_lower = args.opt.lower()
89 | weight_decay = args.weight_decay
90 | # check for modules that requires different lr
91 | if hasattr(args, "different_lr") and args.different_lr.enable:
92 | diff_lr_module_names = args.different_lr.module_names
93 | diff_lr = args.different_lr.lr
94 | else:
95 | diff_lr_module_names = []
96 | diff_lr = None
97 |
98 | no_decay = {}
99 | if hasattr(model, 'no_weight_decay'):
100 | no_decay = model.no_weight_decay()
101 | named_param_tuples = add_weight_decay(
102 | model, weight_decay, no_decay, filter_bias_and_bn)
103 | named_param_tuples = add_different_lr(
104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
105 | parameters = create_optimizer_params_group(named_param_tuples)
106 |
107 | if return_group:
108 | return parameters
109 |
110 | if 'fused' in opt_lower:
111 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
112 |
113 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
114 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
115 | opt_args['eps'] = args.opt_eps
116 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
117 | opt_args['betas'] = args.opt_betas
118 | if hasattr(args, 'opt_args') and args.opt_args is not None:
119 | opt_args.update(args.opt_args)
120 |
121 | opt_split = opt_lower.split('_')
122 | opt_lower = opt_split[-1]
123 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
124 | opt_args.pop('eps', None)
125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
126 | elif opt_lower == 'momentum':
127 | opt_args.pop('eps', None)
128 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
129 | elif opt_lower == 'adam':
130 | optimizer = optim.Adam(parameters, **opt_args)
131 | elif opt_lower == 'adamw':
132 | optimizer = optim.AdamW(parameters, **opt_args)
133 | else:
134 | assert False and "Invalid optimizer"
135 | raise ValueError
136 | return optimizer
137 |
--------------------------------------------------------------------------------
/utils/quant.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.peft import LoraColumnParallelLinear, LoraRowParallelLinear
3 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear,RowParallelLinear
4 | # from accessory.model.meta import MetaModel
5 | from transformers.utils.quantization_config import BitsAndBytesConfig
6 | import bitsandbytes as bnb
7 |
8 | from types import MethodType
9 | from tqdm import tqdm
10 |
11 | from fairscale.nn.model_parallel.mappings import (
12 | copy_to_model_parallel_region,
13 | gather_from_model_parallel_region,
14 | reduce_from_model_parallel_region,
15 | scatter_to_model_parallel_region,
16 | )
17 |
18 | def forward_ColumnParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
19 | # Set up backprop all-reduce.
20 | input_parallel = copy_to_model_parallel_region(input_)
21 | # Matrix multiply.
22 | output_parallel = self.quanted_layer(input_parallel)
23 | if self.bias is not None:
24 | output_parallel = output_parallel + self.bias
25 | if self.gather_output:
26 | # All-gather across the partitions.
27 | output = gather_from_model_parallel_region(output_parallel)
28 | else:
29 | output = output_parallel
30 | return output
31 |
32 | def forward_RowParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
33 | # Set up backprop all-reduce.
34 | if self.input_is_parallel:
35 | input_parallel = input_
36 | else:
37 | input_parallel = scatter_to_model_parallel_region(input_)
38 | # Matrix multiply.
39 | output_parallel = self.quanted_layer(input_parallel)
40 | # All-reduce across all the partitions.
41 | output_ = reduce_from_model_parallel_region(output_parallel)
42 | if self.bias is not None:
43 | output = output_ + self.bias
44 | else:
45 | output = output_
46 | return output
47 |
48 | def forward_LoraColumnParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
49 | # Set up backprop all-reduce.
50 | input_parallel = copy_to_model_parallel_region(input_)
51 | # Matrix multiply.
52 | output_parallel = self.quanted_layer(input_parallel)
53 | if self.bias is not None:
54 | output_parallel = output_parallel + self.bias
55 | if self.lora_a is not None:
56 | modification = self.lora_b(self.lora_a(input_))
57 | else:
58 | modification = None
59 |
60 | if self.gather_output:
61 | # All-gather across the partitions.
62 | output = gather_from_model_parallel_region(output_parallel)
63 | else:
64 | output = output_parallel
65 |
66 | if modification is not None:
67 | output = output + modification
68 | return output
69 |
70 | def forward_LoraRowParallelLinear(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
71 | # Set up backprop all-reduce.
72 | if self.input_is_parallel:
73 | input_parallel = input_
74 | else:
75 | input_parallel = scatter_to_model_parallel_region(input_)
76 | # Matrix multiply.
77 | output_parallel = self.quanted_layer(input_parallel)
78 | # All-reduce across all the partitions.
79 | output_ = reduce_from_model_parallel_region(output_parallel)
80 | if self.lora_a is not None:
81 | modification = self.lora_b(self.lora_a(input_parallel))
82 | output_ = output_ + modification
83 | if self.bias is not None:
84 | output = output_ + self.bias
85 | else:
86 | output = output_
87 | return output
88 |
89 | def forward_Linear(self, input: torch.Tensor) -> torch.Tensor:
90 | output = self.quanted_layer(input)
91 | if self.bias != None:
92 | output += self.bias
93 | return output
94 |
95 | def quantize(model, quant_conf : BitsAndBytesConfig):
96 | module_list = [_ for _ in model.named_modules() if isinstance(_[1],
97 | (LoraColumnParallelLinear, LoraRowParallelLinear,
98 | ColumnParallelLinear, RowParallelLinear, torch.nn.Linear))]
99 |
100 | if hasattr(model, "get_quant_blocklist"):
101 | quant_blocklist = [x for x in model.get_quant_blocklist()]
102 | else:
103 | quant_blocklist = []
104 |
105 | for name, module in tqdm(module_list, desc="Qunatization Process",mininterval=10):
106 | if "lora" in name or name in quant_blocklist:
107 | continue
108 | if isinstance(module, (
109 | LoraColumnParallelLinear,
110 | LoraRowParallelLinear,
111 | ColumnParallelLinear,
112 | RowParallelLinear,
113 | torch.nn.Linear
114 | )):
115 | # 1. Initialize quantization operator
116 | if quant_conf.load_in_4bit:
117 | quanted_layer = bnb.nn.Linear4bit(
118 | module.in_features,
119 | module.out_features,
120 | bias=None,
121 | compute_dtype=quant_conf.bnb_4bit_compute_dtype,
122 | compress_statistics=True,
123 | device=None)
124 | if quant_conf.bnb_4bit_compute_dtype != None:
125 | quanted_layer.compute_type_is_set = True
126 |
127 | quanted_layer.weight = bnb.nn.Params4bit(
128 | module.weight.data.clone(),
129 | requires_grad=False,
130 | quant_type=quant_conf.bnb_4bit_quant_type,
131 | )
132 |
133 | elif quant_conf.load_in_8bit:
134 | quanted_layer= bnb.nn.Linear8bitLt(
135 | module.in_features,
136 | module.out_features,
137 | bias=None,
138 | has_fp16_weights=quant_conf.llm_int8_has_fp16_weight,
139 | threshold=quant_conf.llm_int8_threshold,
140 | )
141 | quanted_layer.weight = bnb.nn.Int8Params(
142 | module.weight.data.clone(),
143 | requires_grad=False,
144 | #has_fp16_weights=quant_conf.llm_int8_has_fp16_weight,
145 | )
146 | else:
147 | raise NotImplementedError(f'Please determine the proper quantization type.')
148 |
149 | # 2. Convert FP layer to quantized layer
150 | module.quanted_layer = quanted_layer
151 |
152 | if isinstance(module, LoraColumnParallelLinear):
153 | forward_func = forward_LoraColumnParallelLinear
154 | elif isinstance(module, LoraRowParallelLinear):
155 | forward_func = forward_LoraRowParallelLinear
156 | elif isinstance(module, ColumnParallelLinear):
157 | forward_func = forward_ColumnParallelLinear
158 | elif isinstance(module, RowParallelLinear):
159 | forward_func = forward_RowParallelLinear
160 | elif isinstance(module, torch.nn.Linear):
161 | forward_func = forward_Linear
162 | module.forward = MethodType(forward_func, module)
163 |
164 | del module.weight
165 |
166 |
--------------------------------------------------------------------------------
/utils/quantization.py:
--------------------------------------------------------------------------------
1 | from transformers.utils.bitsandbytes import *
2 | from transformers import BitsAndBytesConfig
3 | import torch
4 | from torch import nn
5 | import bitsandbytes as bnb
6 |
7 | from fairscale.nn.model_parallel.layers import (
8 | ParallelEmbedding,
9 | RowParallelLinear,
10 | ColumnParallelLinear,
11 | )
12 | def _replace_with_bnb_linear(
13 | model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
14 | ):
15 | """
16 | Private method that wraps the recursion for module replacement.
17 |
18 | Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
19 | """
20 | for name, module in model.named_children():
21 | if current_key_name is None:
22 | current_key_name = []
23 | current_key_name.append(name)
24 |
25 | if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear) or isinstance(module, RowParallelLinear) ) and name not in modules_to_not_convert:
26 | # Check if the current key is not in the `modules_to_not_convert`
27 | if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
28 | # with init_empty_weights():
29 | if quantization_config.quantization_method() == "llm_int8":
30 | model._modules[name] = bnb.nn.Linear8bitLt(
31 | module.in_features,
32 | module.out_features,
33 | module.bias is not None,
34 | has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
35 | threshold=quantization_config.llm_int8_threshold,
36 | )
37 | has_been_replaced = True
38 | else:
39 | if (
40 | quantization_config.llm_int8_skip_modules is not None
41 | and name in quantization_config.llm_int8_skip_modules
42 | ):
43 | pass
44 | else:
45 | model._modules[name] = bnb.nn.Linear4bit(
46 | module.in_features,
47 | module.out_features,
48 | module.bias is not None,
49 | quantization_config.bnb_4bit_compute_dtype,
50 | compress_statistics=quantization_config.bnb_4bit_use_double_quant,
51 | quant_type=quantization_config.bnb_4bit_quant_type,
52 | )
53 | has_been_replaced = True
54 | # Force requires grad to False to avoid unexpected errors
55 | model._modules[name].requires_grad_(False)
56 | if len(list(module.children())) > 0:
57 | _, has_been_replaced = _replace_with_bnb_linear(
58 | module,
59 | modules_to_not_convert,
60 | current_key_name,
61 | quantization_config,
62 | has_been_replaced=has_been_replaced,
63 | )
64 | # Remove the last key for recursion
65 | current_key_name.pop(-1)
66 | return model, has_been_replaced
67 |
68 |
69 | def quant_model_bnb(model, quant_bit='4bit', keep_in_fp32_modules=[],
70 | quantization_config=None):
71 | if quantization_config is None:
72 | # set default quantization config
73 | # compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
74 | quantization_config = BitsAndBytesConfig(
75 | load_in_4bit=quant_bit == '4bit',
76 | load_in_8bit=quant_bit == '8bit',
77 | llm_int8_threshold=6.0,
78 | llm_int8_has_fp16_weight=False,
79 | bnb_4bit_compute_dtype=torch.float16,
80 | bnb_4bit_use_double_quant=True,
81 | bnb_4bit_quant_type='nf4',
82 | )
83 |
84 | model,_ = _replace_with_bnb_linear(
85 | model, modules_to_not_convert=keep_in_fp32_modules, quantization_config=quantization_config
86 | )
87 |
88 | return model
89 |
90 |
--------------------------------------------------------------------------------
/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from torch.optim import Optimizer
5 | import math
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def create_scheduler(args, optimizer):
10 | lr_scheduler = None
11 | if args.sched == 'cosine':
12 | lr_scheduler = get_cosine_schedule_with_warmup(
13 | optimizer,
14 | num_warmup_steps=args.num_warmup_steps,
15 | num_training_steps=args.num_training_steps,
16 | num_cycles=0.5,
17 | min_lr_multi=args.min_lr_multi
18 | )
19 | return lr_scheduler
20 |
21 |
22 | def get_cosine_schedule_with_warmup(
23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
25 | ):
26 | """
27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
28 |
29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
31 | initial lr set in the optimizer.
32 | Args:
33 | optimizer ([`~torch.optim.Optimizer`]):
34 | The optimizer for which to schedule the learning rate.
35 | num_warmup_steps (`int`):
36 | The number of steps for the warmup phase.
37 | num_training_steps (`int`):
38 | The total number of training steps.
39 | num_cycles (`float`, *optional*, defaults to 0.5):
40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
41 | following a half-cosine).
42 | min_lr_multi (`float`, *optional*, defaults to 0):
43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
44 | last_epoch (`int`, *optional*, defaults to -1):
45 | The index of the last epoch when resuming training.
46 | Return:
47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
48 | """
49 |
50 | def lr_lambda(current_step):
51 | if current_step < num_warmup_steps:
52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
55 |
56 | return LambdaLR(optimizer, lr_lambda, last_epoch)
57 |
--------------------------------------------------------------------------------