├── models
├── __init__.py
├── bert
│ ├── __init__.py
│ └── builder.py
├── blip2
│ ├── __init__.py
│ ├── builder.py
│ ├── blip2.py
│ └── utils.py
└── utils.py
├── assets
├── hawk.png
├── dataset.jpg
├── example1.jpg
├── example4.jpg
└── performance.jpg
├── .gitignore
├── internvid_g
├── scripts
│ ├── run_1121
│ │ ├── check_videos.sh
│ │ ├── blip2.sh
│ │ ├── text2tag.sh
│ │ ├── llama2.sh
│ │ ├── ground_data_construction.sh
│ │ ├── config_7b.json
│ │ ├── videochat.sh
│ │ └── clip_filter.sh
│ ├── extract_of.sh
│ ├── run_llama2.sh
│ └── run.sh
├── code
│ ├── download_videos.py
│ └── ground_data_construction.py
└── README.md
├── scripts
├── train
│ ├── run_7b_stage3.sh
│ ├── anetc.sh
│ ├── charades_sta.sh
│ ├── config_7b_stage3.py
│ ├── charades_sta.py
│ └── anetc.py
└── test
│ ├── videoqa.sh
│ └── recursive_grounding.sh
├── requirements.txt
├── configs
├── config.json
├── dataset_utils.py
└── instruction_data.py
├── data_preparing
├── nextgqa.py
├── internvid.py
├── charades.py
├── check_grounding_results.ipynb
├── anetc.py
└── videoqa.py
├── dataset
├── dataloader.py
├── base_dataset.py
├── pt_dataset.py
├── video_utils.py
├── utils.py
└── __init__.py
├── utils
├── config_utils.py
├── scheduler.py
├── easydict.py
├── distributed.py
├── optimizer.py
├── config.py
├── logger.py
└── basic_utils.py
├── tasks
├── shared_utils.py
└── train_it.py
└── README.md
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/blip2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/hawk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/hawk.png
--------------------------------------------------------------------------------
/assets/dataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/dataset.jpg
--------------------------------------------------------------------------------
/assets/example1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/example1.jpg
--------------------------------------------------------------------------------
/assets/example4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/example4.jpg
--------------------------------------------------------------------------------
/assets/performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/performance.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | model/*
2 | data/*
3 | outputs/*
4 |
5 | .ipynb_checkpoints/
6 | __pycache__
7 |
8 | *.pyc
9 | *.pth
10 | *.log
11 | *.out
12 | test.json
13 | debug.py
14 | debug.ipynb
15 |
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/check_videos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv0b
4 | #SBATCH -p HGX
5 | #SBATCH --time=24:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:0
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | python code/check_videos.py
--------------------------------------------------------------------------------
/scripts/train/run_7b_stage3.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | NNODE=1
3 | NUM_GPUS=8
4 |
5 | OUTPUT_DIR=$1
6 | mkdir -vp $OUTPUT_DIR
7 |
8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \
9 | tasks/train_it.py scripts/train/config_7b_stage3.py \
10 | output_dir ${OUTPUT_DIR} \
11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \
12 | > ${OUTPUT_DIR}/train.log 2>&1 &
13 |
14 | wait
15 |
--------------------------------------------------------------------------------
/scripts/train/anetc.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | NNODE=1
3 | NUM_GPUS=2
4 |
5 | OUTPUT_DIR=$1
6 | PRETRAINED_PATH=$2
7 | mkdir -vp ${OUTPUT_DIR}
8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \
9 | tasks/train_it.py scripts/finetune/anetc.py \
10 | output_dir ${OUTPUT_DIR} \
11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \
12 | pretrained_path ${PRETRAINED_PATH} \
13 | > ${OUTPUT_DIR}/train.log 2>&1 &
14 | wait
--------------------------------------------------------------------------------
/scripts/train/charades_sta.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | NNODE=1
3 | NUM_GPUS=2
4 |
5 | OUTPUT_DIR=$1
6 | PRETRAINED_PATH=$2
7 | mkdir -vp ${OUTPUT_DIR}
8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \
9 | tasks/train_it.py scripts/train/charades_sta.py \
10 | output_dir ${OUTPUT_DIR} \
11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \
12 | pretrained_path ${PRETRAINED_PATH} \
13 | > ${OUTPUT_DIR}/train.log 2>&1 &
14 | wait
15 |
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/blip2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv0b
4 | #SBATCH -p HGX
5 | #SBATCH --time=24:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | for i in 0 1 2 3
12 | do
13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func blip2 \
14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
15 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \
16 | > code/nohup/blip2.log.${i} 2>&1 &
17 | done
18 | wait
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/text2tag.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv0b
4 | #SBATCH -p HGX
5 | #SBATCH --time=24:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | for i in 0 1 2 3
12 | do
13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func tag2text \
14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
15 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \
16 | > code/nohup/tag2text.log.${i} 2>&1 &
17 | done
18 | wait
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | apex==0.1
3 | arrow==1.3.0
4 | av==10.0.0
5 | datasets==2.14.6
6 | decorator==5.1.1
7 | decord==0.6.0
8 | einops==0.6.1
9 | einops-exts==0.0.4
10 | flash-attn==2.3.6
11 | gradio==3.23.0
12 | gradio_client==0.2.6
13 | h5py==3.8.0
14 | huggingface-hub==0.19.4
15 | imageio==2.27.0
16 | jupyter==1.0.0
17 | matplotlib==3.7.1
18 | notebook==7.0.6
19 | opencv-python==4.7.0.72
20 | pandas==1.5.3
21 | PyYAML==6.0
22 | scipy==1.10.1
23 | spacy==3.7.2
24 | tensorboard==2.12.3
25 | transformers==4.31.0
26 | torch==1.13.1
27 | torchvision==0.14.1+cu116
28 | tqdm==4.64.1
29 | -e git+https://github.com/facebookresearch/xformers.git@e153e4b4f5d0d821d707696029d84faed11a92bf#egg=xformers
--------------------------------------------------------------------------------
/internvid_g/scripts/extract_of.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=of
3 | #SBATCH --qos=lv4
4 | #SBATCH --nodes=1
5 | #SBATCH --time=10:00:00
6 | #SBATCH --cpus-per-gpu=4
7 | #SBATCH --gres=gpu:1
8 | #SBATCH --output=./code/nohup/webdvd_of.out
9 | #SBATCH --error=./code/nohup/webvid_of.out
10 |
11 | # extract optical flow
12 |
13 | cd /home/wangyuxuan1/codes/video_features
14 |
15 | python main.py \
16 | feature_type=raft \
17 | device="cuda:0" \
18 | on_extraction=save_numpy \
19 | batch_size=4 \
20 | side_size=224 \
21 | file_with_video_paths=/scratch2/nlp/wangyueqian/InternVid/code/scripts/webvid_of_test_path.txt \
22 | output_path=/scratch2/nlp/wangyueqian/InternVid/code/webvid_of \
23 | # extraction_fps=2
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/llama2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv0b
4 | #SBATCH -p HGX
5 | #SBATCH --time=24:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | for i in 0 1 2 3
12 | do
13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func llama2 \
14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
15 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \
16 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \
17 | --llama2-fname temp/1121/scene_captions_llama2.jsonl.${i} \
18 | > code/nohup/llama2.log.${i} 2>&1 &
19 | done
20 | wait
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/ground_data_construction.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv4
4 | #SBATCH -p HGX
5 | #SBATCH --time=10:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:0
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | for i in 0 1 2 3
12 | do
13 | python code/ground_data_construction.py \
14 | --video-base-folder videos-lowres \
15 | --caption-fname temp/1121/scene_captions_tag2text_high_sim.jsonl.$i \
16 | --scene-fname temp/1121/scenes_merged.jsonl.0$i \
17 | --scene-sim-fname temp/1121/scenes_merged_similarity.jsonl \
18 | --caption-with-neg-interval-fname temp/1121/scene_captions_tag2text_high_sim-with_neg.jsonl.$i &
19 | done
20 | wait
--------------------------------------------------------------------------------
/scripts/test/videoqa.sh:
--------------------------------------------------------------------------------
1 | # test on mvbench
2 | python tasks/test.py --tasks "MVBench" --config configs/config.json \
3 | --ckpt model/hawkeye.pth --data-dir data \
4 | --save-path outputs/MVBench.jsonl \
5 | > outputs/MVBench.log 2>&1 &
6 | wait
7 |
8 | # test on NExT-QA
9 | python tasks/test.py --tasks "NExTQA" --config configs/config.json \
10 | --ckpt model/hawkeye.pth --data-dir data \
11 | --save-path outputs/NExTQA.jsonl \
12 | > outputs/NExTQA.log 2>&1 &
13 | wait
14 |
15 |
16 | # test on TVQA
17 | python tasks/test.py --tasks "TVQA" --config configs/config.json \
18 | --ckpt model/hawkeye.pth --data-dir data \
19 | --save-path outputs/TVQA.jsonl \
20 | > outputs/TVQA.log 2>&1 &
21 | wait
22 |
23 |
24 | # test on STAR
25 | python tasks/test.py --tasks "STAR" --config configs/config.json \
26 | --ckpt model/hawkeye.pth --data-dir data \
27 | --save-path outputs/STAR.jsonl \
28 | > outputs/STAR.log 2>&1 &
29 | wait
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/config_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "vit_model": "eva_clip_g",
4 | "vit_model_path": "code/video_chat/model/eva_vit_g.pth",
5 | "q_former_model_path": "code/video_chat/model/blip2_pretrained_flant5xxl.pth",
6 | "llama_model_path": "code/video_chat/model/vicuna-7b-v0",
7 | "videochat_model_path": "code/video_chat/model/VideoChat/videochat_7b_stage1.pth",
8 | "img_size": 224,
9 | "num_query_token": 32,
10 | "drop_path_rate": 0.0,
11 | "use_grad_checkpoint": false,
12 | "vit_precision": "fp32",
13 | "freeze_vit": true,
14 | "freeze_mhra": false,
15 | "freeze_qformer": true,
16 | "low_resource": false,
17 | "max_txt_len": 320,
18 | "temporal_downsample": false,
19 | "no_lmhra": true,
20 | "double_lmhra": false,
21 | "lmhra_reduction": 2.0,
22 | "gmhra_layers": 8,
23 | "gmhra_drop_path_rate": 0.0,
24 | "gmhra_dropout": 0.5,
25 | "extra_num_query_token": 64
26 | },
27 | "device": "cuda"
28 | }
29 |
--------------------------------------------------------------------------------
/scripts/test/recursive_grounding.sh:
--------------------------------------------------------------------------------
1 | max_turns=4
2 |
3 | # test on charades-sta
4 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \
5 | --ckpt model/hawkeye.pth \
6 | --video-path data/videos/charades \
7 | --data-path data/test-anno/charades_sta-recursive_grounding.json \
8 | --save-path outputs/charades_sta-recursive_grounding-${max_turns}_turns.jsonl \
9 | > outputs/charades_sta-recursive_grounding-${max_turns}_turns.log 2>&1 &
10 | wait
11 |
12 | # test on anet-captions
13 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \
14 | --ckpt model/hawkeye.pth \
15 | --video-path data/videos/activitynet \
16 | --data-path data/test-anno/anetc-recursive_grounding.json \
17 | --save-path outputs/anetc-recursive_grounding-${max_turns}_turns.jsonl \
18 | > outputs/anetc-recursive_grounding-${max_turns}_turns.log 2>&1 &
19 | wait
20 |
21 | # test on nextgqa
22 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \
23 | --ckpt model/hawkeye.pth \
24 | --video-path data/videos/nextqa \
25 | --data-path data/test-anno/nextgqa-recursive_grounding.json \
26 | --save-path outputs/nextgqa-recursive_grounding-${max_turns}_turns.jsonl \
27 | > outputs/nextgqa-recursive_grounding-${max_turns}_turns.log 2>&1 &
28 | wait
29 |
--------------------------------------------------------------------------------
/configs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "model_cls": "HawkEye_it",
4 | "bert_path": "bert-base-uncased",
5 | "llama_model_path": "model/vicuna-7b-v0",
6 | "freeze_vit": false,
7 | "freeze_qformer": false,
8 | "max_txt_len": 512,
9 | "low_resource": false,
10 | "vision_encoder": {
11 | "name": "vit_l14",
12 | "img_size": 224,
13 | "patch_size": 16,
14 | "d_model": 1024,
15 | "encoder_embed_dim": 1024,
16 | "encoder_depth": 24,
17 | "encoder_num_heads": 16,
18 | "drop_path_rate": 0.0,
19 | "num_frames": 32,
20 | "tubelet_size": 1,
21 | "use_checkpoint": false,
22 | "checkpoint_num": 0,
23 | "pretrained": "",
24 | "return_index": -2,
25 | "vit_add_ln": true,
26 | "ckpt_num_frame": 4
27 | },
28 | "num_query_token": 32,
29 | "qformer_hidden_dropout_prob": 0.1,
30 | "qformer_attention_probs_dropout_prob": 0.1,
31 | "qformer_drop_path_rate": 0.2,
32 | "extra_num_query_token": 64,
33 | "qformer_text_input": true,
34 | "system": "",
35 | "start_token": "",
37 | "img_start_token": "",
38 | "img_end_token": "",
39 | "random_shuffle": true,
40 | "use_lora": true,
41 | "lora_r": 16,
42 | "lora_alpha": 32,
43 | "lora_dropout": 0.1
44 | },
45 | "device": "cuda"
46 | }
47 |
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/videochat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=i3-blip2-text2tag
3 | #SBATCH --qos=lv4
4 | #SBATCH -p HGX
5 | #SBATCH --time=10:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:2
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | # for i in 0 # 1 2 3
12 | # do
13 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func videochat \
14 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
15 | # --videochat-fname temp/1121/scene_captions_videochat.jsonl.${i} \
16 | # --videochat-config-fname code/scripts/run_1121/config_7b.json \
17 | # # > code/nohup/videochat.log.${i} 2>&1 &
18 | # done
19 | # wait
20 | # 上面这个程序标注了一些内容,感觉经过finetune的videochat的特点就是会生成冗长但经常有幻觉的caption.
21 | # 接下来还尝试了只经过pre-train的videochat,发现它生成的东西有很明显的webvid noise,即总是加上时间地点等,基本也不能用.
22 |
23 | i=3
24 | CUDA_VISIBLE_DEVICES=0 python code/caption_clips.py --func blip2 \
25 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
26 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \
27 | >> code/nohup/blip2.log.${i} 2>&1 &
28 |
29 | CUDA_VISIBLE_DEVICES=1 python code/caption_clips.py --func tag2text \
30 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
31 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \
32 | >> code/nohup/tag2text.log.${i} 2>&1 &
33 | wait
--------------------------------------------------------------------------------
/internvid_g/scripts/run_1121/clip_filter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=filter-internvid
3 | #SBATCH --qos=lv0b
4 | #SBATCH -p HGX
5 | #SBATCH --time=24:00:00
6 | #SBATCH --cpus-per-gpu=6
7 | #SBATCH --gres=gpu:4
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 |
12 | # 对llama2 summary之后的结果计算相似度
13 | # for i in 0 1 2 3
14 | # do
15 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func filter \
16 | # --batch-size 32 \
17 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
18 | # --llama2-fname temp/1121/scene_captions_llama2.jsonl.${i} \
19 | # --filtered-fname temp/1121/scene_captions_filtered.jsonl.${i} \
20 | # > code/nohup/clip_filter.log.${i} 2>&1 &
21 | # done
22 | # wait
23 |
24 | # 对tag2text标注出的每个caption计算相似度
25 | # for i in 0 1 2 3
26 | # do
27 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func filter \
28 | # --batch-size 32 \
29 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \
30 | # --filter-input-fname temp/1121/scene_captions_tag2text.jsonl.${i} --merge-method max \
31 | # --filtered-fname temp/1121/scene_captions_tag2text_clip_sim.jsonl.${i} \
32 | # > code/nohup/tag2text_clip_sim.log.${i} 2>&1 &
33 | # done
34 | # wait
35 |
36 | # select tag2text captions with clip_sim score above median
37 | for i in 0 1 2 3
38 | do
39 | python code/caption_clips.py --func select_filtered_tag2text_captions \
40 | --filter-input-fname temp/1121/scene_captions_tag2text_clip_sim.jsonl.$i \
41 | --filtered-fname temp/1121/scene_captions_tag2text_high_sim.jsonl.$i
42 | done
43 |
--------------------------------------------------------------------------------
/configs/dataset_utils.py:
--------------------------------------------------------------------------------
1 | # some logit for loading data
2 |
3 | import os
4 | import logging
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | class WebvidPathMapping:
9 | def __init__(self, video_ids_fname) -> None:
10 | self.video_ids = set()
11 | self.video_ids_fname = video_ids_fname
12 |
13 | def __call__(self, input_fname):
14 | if not self.video_ids:
15 | self.video_ids = set([line.strip() for line in open(self.video_ids_fname)])
16 | logger.info("In WebvidPathMapping, there are %d available videos in folder %s" % (len(self.video_ids), self.video_ids_fname))
17 | fname = input_fname.split("/")[-1].split('.')[0]
18 | if fname not in self.video_ids:
19 | return None
20 | return os.path.join(fname[:3], fname)
21 |
22 |
23 | def anet_path_mapping(input_fname):
24 | fname = input_fname.split("/")[-1].split('.')[0]
25 | return fname
26 |
27 |
28 | def clevrer_path_mapping(input_fname):
29 | '''
30 | video_02238.mp4 to video_02000-03000/video_02238.mp4
31 | '''
32 | interval = int(input_fname.split('.')[0].split('_')[-1]) // 1000
33 | folder = 'video_%05d-%05d' % (interval * 1000, interval * 1000 + 1000)
34 | return os.path.join(folder, input_fname)
35 |
36 |
37 | webvid_path_mapping = WebvidPathMapping('data/WebVid/video_ids.txt')
38 |
39 | VIDEO_PATH_MAPPING = {
40 | 'caption_webvid': webvid_path_mapping,
41 | 'caption_videochat': webvid_path_mapping,
42 | 'conversation_videochat1': webvid_path_mapping,
43 | 'vqa_webvid_qa': webvid_path_mapping,
44 | 'conversation_videochatgpt': anet_path_mapping,
45 | 'reasoning_clevrer_qa': clevrer_path_mapping,
46 | 'reasoning_clevrer_mc': clevrer_path_mapping,
47 | }
--------------------------------------------------------------------------------
/data_preparing/nextgqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import pandas as pd
4 |
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--func', type=str, default='grounding_qa')
9 | parser.add_argument('--annotation-fname', type=str, default='data/NExTQA/nextgqa/test.csv')
10 | parser.add_argument('--grounding-fname', type=str, default='data/NExTQA/nextgqa/gsub_test.json')
11 | parser.add_argument('--pred-span-fname', type=str)
12 | parser.add_argument('--video-mapping-fname', type=str, default='data/NExTQA/nextgqa/map_vid_vidorID.json')
13 | parser.add_argument('--output-fname', type=str, default='data/MVBench/json/nextgqa.json')
14 | args = parser.parse_args()
15 | print(args)
16 |
17 | video_mapping_dict = json.load(open(args.video_mapping_fname))
18 | video_mapping_dict = {key: val + '.mp4' for key, val in video_mapping_dict.items()}
19 | df = pd.read_csv(args.annotation_fname)
20 |
21 | grounding_dict, video_lengths = dict(), dict()
22 | for video_key, data in json.load(open(args.grounding_fname)).items():
23 | video_fname = video_mapping_dict[video_key]
24 | grounding_dict[video_fname] = dict()
25 | video_lengths[video_fname] = data['duration']
26 | for question_key, spans in data['location'].items():
27 | grounding_dict[video_fname][question_key] = spans
28 |
29 | if args.func in ["test_grounding"]:
30 | res_list = list()
31 | for line_i, row in df.iterrows():
32 | video_fname = video_mapping_dict[str(row['video_id'])]
33 | res_dict = {'video': video_fname, 'question': row['question'] + '?', 'duration': video_lengths[video_fname], 'qid': row['qid'], 'answer': grounding_dict[video_fname][str(row['qid'])]}
34 | res_list.append(res_dict)
35 | json.dump(res_list, open(args.output_fname, 'w'))
--------------------------------------------------------------------------------
/models/blip2/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 |
6 | from .Qformer import BertConfig, BertLMHeadModel
7 | from models.utils import load_temp_embed_with_mismatch
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def build_qformer(num_query_token, vision_width,
13 | qformer_hidden_dropout_prob=0.1,
14 | qformer_attention_probs_dropout_prob=0.1,
15 | drop_path_rate=0.,
16 | ):
17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
18 | encoder_config.encoder_width = vision_width
19 | # insert cross-attention layer every other block
20 | encoder_config.add_cross_attention = True
21 | encoder_config.cross_attention_freq = 2
22 | encoder_config.query_length = num_query_token
23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)]
26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
27 | logger.info(encoder_config)
28 | Qformer = BertLMHeadModel.from_pretrained(
29 | "bert-base-uncased", config=encoder_config, local_files_only=True
30 | )
31 | query_tokens = nn.Parameter(
32 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
33 | )
34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
35 | return Qformer, query_tokens
36 |
37 | def interpolate_pos_embed_blip(state_dict, new_model):
38 | if "vision_temp_embed" in state_dict:
39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"]
40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch(
41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False
42 | )
43 | return state_dict
44 |
--------------------------------------------------------------------------------
/internvid_g/scripts/run_llama2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv4
4 | #SBATCH -p HGX
5 | #SBATCH --time=10:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:1
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | # for i in 0 1 2 3 4
12 | # do
13 | # python -u code/clip_sim.py --func split_scene \
14 | # --video-sample-fname temp/1121/video_ids.txt \
15 | # --scene-fname temp/1121/scenes.jsonl.$i \
16 | # --start-idx $((i*1800)) --end-idx $((i*1800+1800)) \
17 | # > ./code/nohup/split_scene.log.$i 2>&1 &
18 | # done
19 | # wait
20 |
21 | # cat temp/1121/scenes.jsonl.* > temp/1121/scenes.jsonl
22 | # GPUS=(0 0 1 1 2 2)
23 | # for i in 0 1 2 3 4 5
24 | # do
25 | # CUDA_VISIBLE_DEVICES=${GPUS[i]} \
26 | # python code/clip_sim.py --func scene_sim \
27 | # --scene-fname temp/1121/scenes.jsonl \
28 | # --scene-sim-fname temp/1121/scenes_similarity.jsonl.$i \
29 | # --start-idx $((i*1500)) --end-idx $((i*1500+1500)) \
30 | # > ./code/nohup/scene_sim.log.$i 2>&1 &
31 | # done
32 | # wait
33 |
34 | # cat scenes_similarity.jsonl.* > scenes_similarity.jsonl
35 | # python code/ground_data_construction.py \
36 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \
37 | # --caption-with-neg-interval-fname temp/1121/caption_with_neg_interval.jsonl
38 |
39 |
40 | # merge scene
41 | # python code/clip_sim.py --func merge_scene \
42 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \
43 | # --scene-merged-fname temp/1121/scenes_merged.jsonl --scene-merged-sim-fname temp/1121/scenes_merged_similarity.jsonl \
44 |
45 | # test scene captioning with blip2
46 | CUDA_VISIBLE_DEVICES=0 \
47 | python code/caption_clips.py --func llama2 \
48 | --blip2-fname temp/scene_captions_blip2.jsonl \
49 | --tag2text-fname temp/scene_captions_tag2text.jsonl \
50 | --llama2-fname temp/scene_captions_llama2.jsonl
51 |
--------------------------------------------------------------------------------
/internvid_g/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=internvid
3 | #SBATCH --qos=lv4
4 | #SBATCH -p HGX
5 | #SBATCH --time=10:00:00
6 | #SBATCH --cpus-per-task=6
7 | #SBATCH --gres=gpu:1
8 | #SBATCH --output=./code/nohup/internvid.log
9 | #SBATCH --error=./code/nohup/internvid.log
10 |
11 | # for i in 0 1 2 3 4
12 | # do
13 | # python -u code/clip_sim.py --func split_scene \
14 | # --video-sample-fname temp/1121/video_ids.txt \
15 | # --scene-fname temp/1121/scenes.jsonl.$i \
16 | # --start-idx $((i*1800)) --end-idx $((i*1800+1800)) \
17 | # > ./code/nohup/split_scene.log.$i 2>&1 &
18 | # done
19 | # wait
20 |
21 | # cat temp/1121/scenes.jsonl.* > temp/1121/scenes.jsonl
22 | # GPUS=(0 0 1 1 2 2)
23 | # for i in 0 1 2 3 4 5
24 | # do
25 | # CUDA_VISIBLE_DEVICES=${GPUS[i]} \
26 | # python code/clip_sim.py --func scene_sim \
27 | # --scene-fname temp/1121/scenes.jsonl \
28 | # --scene-sim-fname temp/1121/scenes_similarity.jsonl.$i \
29 | # --start-idx $((i*1500)) --end-idx $((i*1500+1500)) \
30 | # > ./code/nohup/scene_sim.log.$i 2>&1 &
31 | # done
32 | # wait
33 |
34 | # cat scenes_similarity.jsonl.* > scenes_similarity.jsonl
35 | # python code/ground_data_construction.py \
36 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \
37 | # --caption-with-neg-interval-fname temp/1121/caption_with_neg_interval.jsonl
38 |
39 |
40 | # merge scene
41 | # python code/clip_sim.py --func merge_scene \
42 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \
43 | # --scene-merged-fname temp/1121/scenes_merged.jsonl --scene-merged-sim-fname temp/1121/scenes_merged_similarity.jsonl \
44 |
45 | # scene captioning with blip2
46 | # CUDA_VISIBLE_DEVICES=0 \
47 | # python code/caption_clips.py --func blip2 \
48 | # --blip2-fname temp/scene_captions_blip2.jsonl
49 |
50 | # filter similarity, though the summary now does not include blip2
51 | # CUDA_VISIBLE_DEVICES=0 \
52 | # python code/caption_clips.py --func filter \
53 | # --llama2-fname temp/scene_captions_llama2.jsonl \
54 | # --filtered-fname temp/scene_captions_blip_filtered.jsonl \
55 | # > code/nohup/scene_captions_filter.log 2>&1 &
56 |
57 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
4 | import random
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class MetaLoader(object):
11 | """ wraps multiple data loader """
12 | def __init__(self, name2loader):
13 | """Iterates over multiple dataloaders, it ensures all processes
14 | work on data from the same dataloader. This loader will end when
15 | the shorter dataloader raises StopIteration exception.
16 |
17 | loaders: Dict, {name: dataloader}
18 | """
19 | self.name2loader = name2loader
20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
22 | index2name = {v: k for k, v in name2index.items()}
23 |
24 | iter_order = []
25 | for n, l in name2loader.items():
26 | iter_order.extend([name2index[n]]*len(l))
27 |
28 | random.shuffle(iter_order)
29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
30 |
31 | # sync
32 | if is_dist_avail_and_initialized():
33 | # make sure all processes have the same order so that
34 | # each step they will have data from the same loader
35 | dist.broadcast(iter_order, src=0)
36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
37 |
38 | logger.info(str(self))
39 |
40 | def __str__(self):
41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
42 | for idx, (name, loader) in enumerate(self.name2loader.items()):
43 | output.append(
44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
45 | )
46 | return "\n".join(output)
47 |
48 | def __len__(self):
49 | return len(self.iter_order)
50 |
51 | def __iter__(self):
52 | """ this iterator will run indefinitely """
53 | for name in self.iter_order:
54 | _iter = self.name2iter[name]
55 | batch = next(_iter)
56 | yield name, batch
57 |
--------------------------------------------------------------------------------
/models/bert/builder.py:
--------------------------------------------------------------------------------
1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
2 |
3 | import logging
4 | logger = logging.getLogger(__name__)
5 |
6 | def build_bert(model_config, pretrain, checkpoint):
7 | """build text encoder.
8 |
9 | Args:
10 | model_config (dict): model config.
11 | pretrain (bool): Whether to do pretrain or finetuning.
12 | checkpoint (bool): whether to do gradient_checkpointing.
13 |
14 | Returns: TODO
15 |
16 | """
17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
18 | bert_config.encoder_width = model_config.vision_encoder.d_model
19 | bert_config.gradient_checkpointing = checkpoint
20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer
21 |
22 | if not model_config.multimodal.enable:
23 | bert_config.fusion_layer = bert_config.num_hidden_layers
24 |
25 | if pretrain:
26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained(
27 | model_config.text_encoder.pretrained,
28 | config=bert_config,
29 | output_loading_info=True,
30 | )
31 | else:
32 | text_encoder, loading_info = BertModel.from_pretrained(
33 | model_config.text_encoder.pretrained,
34 | config=bert_config,
35 | add_pooling_layer=False,
36 | output_loading_info=True,
37 | )
38 |
39 | return text_encoder
40 |
41 |
42 | def build_bert_decoder(model_config, checkpoint):
43 | """build text decoder the same as the multimodal encoder.
44 |
45 | Args:
46 | model_config (dict): model config.
47 | pretrain (bool): Whether to do pretrain or finetuning.
48 | checkpoint (bool): whether to do gradient_checkpointing.
49 |
50 | Returns: TODO
51 |
52 | """
53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
54 | bert_config.encoder_width = model_config.vision_encoder.d_model
55 | bert_config.gradient_checkpointing = checkpoint
56 |
57 | bert_config.fusion_layer = 0
58 | bert_config.num_hidden_layers = (
59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
60 | )
61 |
62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained(
63 | model_config.text_encoder.pretrained,
64 | config=bert_config,
65 | output_loading_info=True,
66 | )
67 |
68 | return text_decoder
69 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from os.path import dirname, join
5 |
6 | from utils.config import Config
7 | from utils.distributed import init_distributed_mode, is_main_process
8 | from utils.logger import setup_logger
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_config():
14 | """Conbine yaml config and command line config with OmegaConf.
15 | Also converts types, e.g., `'None'` (str) --> `None` (None)
16 | """
17 | config = Config.get_config()
18 | if config.debug:
19 | config.wandb.enable = False
20 | return config
21 |
22 |
23 | def setup_evaluate_config(config):
24 | """setup evaluation default settings, e.g., disable wandb"""
25 | assert config.evaluate
26 | config.wandb.enable = False
27 | if config.output_dir is None:
28 | config.output_dir = join(dirname(config.pretrained_path), "eval")
29 | return config
30 |
31 |
32 | def setup_output_dir(output_dir, excludes=["code"]):
33 | """ensure not overwritting an exisiting/non-empty output dir"""
34 | if not os.path.exists(output_dir):
35 | os.makedirs(output_dir, exist_ok=False)
36 | else:
37 | existing_dirs_files = os.listdir(output_dir) # list
38 | remaining = set(existing_dirs_files) - set(excludes)
39 | remaining = [e for e in remaining if "slurm" not in e]
40 | remaining = [e for e in remaining if ".out" not in e]
41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
42 | logger.warn(f"remaining dirs or files: {remaining}")
43 |
44 |
45 | def setup_main():
46 | """
47 | Setup config, logger, output_dir, etc.
48 | Shared for pretrain and all downstream tasks.
49 | """
50 | config = setup_config()
51 |
52 | config.model.num_frame_tokens = 0
53 | for dataset_cfg in config.train_file:
54 | if dataset_cfg.get("grounding_method", None) == "frame_token":
55 | logger.info("dataset %s is using frame_token as grounding method, add %d special frame tokens to the model" % (dataset_cfg['dataset_name'], config.num_frames))
56 | config.model.num_frame_tokens = config.num_frames
57 |
58 | if hasattr(config, "evaluate") and config.evaluate:
59 | config = setup_evaluate_config(config)
60 | init_distributed_mode(config)
61 |
62 | if is_main_process():
63 | setup_output_dir(config.output_dir, excludes=["code"])
64 | setup_logger(output=config.output_dir, color=True, name="vindlu")
65 | logger.info(f"config: {Config.pretty_text(config)}")
66 | Config.dump(config, os.path.join(config.output_dir, "config.json"))
67 | return config
68 |
--------------------------------------------------------------------------------
/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from torch.optim import Optimizer
5 | import math
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def create_scheduler(args, optimizer):
10 | lr_scheduler = None
11 | if args.sched == 'cosine':
12 | lr_scheduler = get_cosine_schedule_with_warmup(
13 | optimizer,
14 | num_warmup_steps=args.num_warmup_steps,
15 | num_training_steps=args.num_training_steps,
16 | num_cycles=0.5,
17 | min_lr_multi=args.min_lr_multi
18 | )
19 | return lr_scheduler
20 |
21 |
22 | def get_cosine_schedule_with_warmup(
23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
25 | ):
26 | """
27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
28 |
29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
31 | initial lr set in the optimizer.
32 | Args:
33 | optimizer ([`~torch.optim.Optimizer`]):
34 | The optimizer for which to schedule the learning rate.
35 | num_warmup_steps (`int`):
36 | The number of steps for the warmup phase.
37 | num_training_steps (`int`):
38 | The total number of training steps.
39 | num_cycles (`float`, *optional*, defaults to 0.5):
40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
41 | following a half-cosine).
42 | min_lr_multi (`float`, *optional*, defaults to 0):
43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
44 | last_epoch (`int`, *optional*, defaults to -1):
45 | The index of the last epoch when resuming training.
46 | Return:
47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
48 | """
49 |
50 | def lr_lambda(current_step):
51 | if current_step < num_warmup_steps:
52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
55 |
56 | return LambdaLR(optimizer, lr_lambda, last_epoch)
57 |
--------------------------------------------------------------------------------
/data_preparing/internvid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import random
5 | import argparse
6 |
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--func', type=str, default='grounding', choices=['grounding', 'caption', 'choice', 'choice_caption'])
11 | parser.add_argument('--annotation-fname', type=str, default='data/InternVid-G/train.jsonl')
12 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/instructions.json')
13 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/questions.json')
14 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/train.json')
15 | parser.add_argument('--time-span-sent', type=str, default='From second %.1f to second %.1f.') # used for frame-level / second-level rep.
16 | args = parser.parse_args()
17 | print(args)
18 |
19 | res = list()
20 |
21 | instructions = json.load(open(args.instruction_fname))
22 | input_fnames = glob.glob(args.annotation_fname)
23 |
24 | if os.path.exists(args.question):
25 | args.question = json.load(open(args.question))
26 | else:
27 | args.question = [args.question]
28 |
29 | for input_fname in input_fnames:
30 | print('loading data from', input_fname)
31 | for line in open(input_fname).readlines():
32 | example = json.loads(line)
33 |
34 | example['caption'] = example['caption'][0].upper() + example['caption'][1:]
35 | if example['caption'].endswith('.'):
36 | example['caption'] = example['caption'][:-1]
37 |
38 | new_example = {k: example[k] for k in ['start_sec', 'end_sec', 'neg_start_sec', 'neg_end_sec']}
39 | new_example['i'] = random.choice(instructions)
40 | if args.func == 'grounding':
41 | new_example['q'] = random.choice(args.question) % example['caption']
42 | new_example['a'] = args.time_span_sent
43 | elif args.func == 'caption':
44 | new_example['a'] = example['caption']
45 | new_example['q'] = args.time_span_sent
46 | elif args.func == 'choice':
47 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."]
48 | random.shuffle(options)
49 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)]
50 | new_example['q'] = "Question: " + random.choice(args.question) % example['caption'] + "\nOptions:" + "".join(options)
51 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options]
52 | elif args.func == 'choice_caption':
53 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."]
54 | new_example['q'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options]
55 | new_example['a'] = example['caption']
56 | res.append({'video': example['video'], 'QA': [new_example]})
57 |
58 | with open(args.output_fname, 'w') as f_out:
59 | json.dump(res, f_out)
60 |
--------------------------------------------------------------------------------
/internvid_g/code/download_videos.py:
--------------------------------------------------------------------------------
1 | # download videos from youtube
2 |
3 | import time
4 | import os
5 | import json
6 | from multiprocessing import Pool
7 | import argparse
8 |
9 |
10 | def download_video(command):
11 | # time.sleep(2)
12 | video_id, target_dir = command
13 | # download videos with resolution nearest to 720p
14 | to_execute = f'yt-dlp "http://youtu.be/{video_id}" -o "{target_dir}/{video_id}.%(ext)s" -S "+res:720,fps" -q --no-check-certificate'
15 | res_code = os.system(to_execute)
16 | if res_code:
17 | error_video_fout.write(video_id + '\n') # though this is not multi-threaded safe, but this is less important
18 | error_video_fout.flush()
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--old_dirs", type=str, nargs="+", default=["videos"], help="folders with videos already downloaded, so we dont need to download them again")
24 | parser.add_argument("--target_dir", type=str, default='videos', help="folder to save downloaded videos")
25 | parser.add_argument("--src_file", type=str, default='temp/video_ids.txt', help="folder that contains video ids to download")
26 | parser.add_argument("--num_workers", type=int, default=5)
27 | parser.add_argument("--error_file", type=str, default='temp/error_video_ids.txt')
28 | args = parser.parse_args()
29 |
30 | video_ids = set()
31 | if args.src_file.endswith('jsonl'): # download from InternVid annotations or InternVid-G annotations
32 | for line in open(args.src_file):
33 | example = json.loads(line)
34 | if 'YoutubeID' in example:
35 | video_ids.add(example['YoutubeID'])
36 | if 'video_fname' in example:
37 | video_ids.add(example['video_fname'].split('.')[0])
38 |
39 | elif args.src_file.endswith('json'): # download from InternVid annotations or InternVid-G annotations
40 | data = json.load(open(args.src_file))
41 | for example in data:
42 | video_ids.add(example['video'][:11]) # the first 11 letters is the youtube id
43 |
44 | elif args.src_file.endswith('txt'): # download from InternVid annotations or InternVid-G annotations
45 | for line in open(args.src_file):
46 | video_ids.add(line.strip())
47 |
48 | else: # change annotation loading code as you wish
49 | raise NotImplementedError
50 |
51 | error_video_ids = set([line.strip() for line in open(args.error_file)]) if os.path.exists(args.error_file) else set()
52 | videos_to_download = video_ids - error_video_ids
53 |
54 | for dir in args.old_dirs + [args.target_dir]:
55 | if not os.path.exists(dir):
56 | print('path does not exist, skipping stat video in this folder', dir)
57 | continue
58 | videos_downloaded = set(['.'.join(fname.split('.')[:-1]) for fname in os.listdir(dir)])
59 | videos_to_download = videos_to_download - videos_downloaded
60 |
61 | videos_to_download = list(videos_to_download)
62 | videos_to_download.sort()
63 |
64 | print(f'{len(video_ids)} videos in total')
65 | print(f'{len(error_video_ids)} videos error')
66 | print(f'{len(videos_downloaded)} videos already downloaded')
67 | print(f'{len(videos_to_download)} videos to download')
68 |
69 | error_video_fout = open(args.error_file, 'a')
70 | commands = [(video_id, args.target_dir) for video_id in videos_downloaded]
71 | with Pool(args.num_workers) as p:
72 | p.map(download_video, commands)
73 |
--------------------------------------------------------------------------------
/data_preparing/charades.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import argparse
5 |
6 |
7 | def load_data(fname):
8 | data_list = list()
9 | for line in open(fname):
10 | info, sent = line.split('##')
11 | vid, start, end = info.split(' ')
12 | start, end = float(start), float(end)
13 | data_list.append({'video': vid + '.mp4', 'start': start, 'end': end, 'caption': sent.strip()})
14 | return data_list
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--func', type=str, default='test_grounding')
20 | parser.add_argument('--video-lengths-fname', type=str, default='data/Charades-STA/video_lengths.json')
21 | parser.add_argument('--annotation-fname', type=str, default='data/Charades-STA/charades_sta_train.txt')
22 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/temporal/charades_sta_grounding-choice/instructions.json')
23 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/temporal/charades_sta_grounding-choice/questions.json')
24 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/charades_sta_grounding-choice/train.json')
25 | parser.add_argument('--time-span-sent', type=str, default='From frame %d to frame %d.')
26 | args = parser.parse_args()
27 | print(args)
28 |
29 | anno = load_data(args.annotation_fname)
30 | if os.path.exists(args.question):
31 | args.question = json.load(open(args.question))
32 | else:
33 | args.question = [args.question]
34 |
35 | if args.func in ['grounding', 'caption', 'choice']:
36 | video_lengths = json.load(open(args.video_lengths_fname))
37 | instructions = json.load(open(args.instruction_fname))
38 | res = list()
39 | for example in anno:
40 | sent = example['caption'].replace('.', '')
41 | new_example = {'i': random.choice(instructions), 'start_sec': example['start'], 'end_sec': example['end'], 'neg_start_sec': 0, 'neg_end_sec': video_lengths[example['video'][:5]]}
42 | if args.func == 'grounding':
43 | new_example['q'] = random.choice(args.question) % sent
44 | new_example['a'] = args.time_span_sent
45 | elif args.func == 'caption':
46 | new_example['a'] = sent
47 | new_example['q'] = args.time_span_sent
48 | elif args.func == 'choice':
49 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."]
50 | random.shuffle(options)
51 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)]
52 | new_example['q'] = "Question: " + random.choice(args.question) % sent + "\nOptions:" + "".join(options)
53 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options]
54 | res.append({'video': example['video'], 'QA': [new_example]})
55 |
56 | if args.func in ["test_grounding"]:
57 | res = list()
58 | for example in anno:
59 | sent = example['caption'].replace('.', '')
60 | new_example = {'video': example['video'], 'question': random.choice(args.question) % sent, 'answer': "%.1f-%.1f" % (example['start'], example['end'])}
61 | new_example = {'video': example['video'], 'question': sent, 'answer': [example['start'], example['end']]}
62 | res.append(new_example)
63 |
64 | with open(args.output_fname, 'w') as f_out:
65 | json.dump(res, f_out)
66 |
--------------------------------------------------------------------------------
/utils/easydict.py:
--------------------------------------------------------------------------------
1 | class EasyDict(dict):
2 | """
3 | Get attributes
4 |
5 | >>> d = EasyDict({'foo':3})
6 | >>> d['foo']
7 | 3
8 | >>> d.foo
9 | 3
10 | >>> d.bar
11 | Traceback (most recent call last):
12 | ...
13 | AttributeError: 'EasyDict' object has no attribute 'bar'
14 |
15 | Works recursively
16 |
17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18 | >>> isinstance(d.bar, dict)
19 | True
20 | >>> d.bar.x
21 | 1
22 |
23 | Bullet-proof
24 |
25 | >>> EasyDict({})
26 | {}
27 | >>> EasyDict(d={})
28 | {}
29 | >>> EasyDict(None)
30 | {}
31 | >>> d = {'a': 1}
32 | >>> EasyDict(**d)
33 | {'a': 1}
34 |
35 | Set attributes
36 |
37 | >>> d = EasyDict()
38 | >>> d.foo = 3
39 | >>> d.foo
40 | 3
41 | >>> d.bar = {'prop': 'value'}
42 | >>> d.bar.prop
43 | 'value'
44 | >>> d
45 | {'foo': 3, 'bar': {'prop': 'value'}}
46 | >>> d.bar.prop = 'newer'
47 | >>> d.bar.prop
48 | 'newer'
49 |
50 |
51 | Values extraction
52 |
53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54 | >>> isinstance(d.bar, list)
55 | True
56 | >>> from operator import attrgetter
57 | >>> map(attrgetter('x'), d.bar)
58 | [1, 3]
59 | >>> map(attrgetter('y'), d.bar)
60 | [2, 4]
61 | >>> d = EasyDict()
62 | >>> d.keys()
63 | []
64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65 | >>> d.foo
66 | 3
67 | >>> d.bar.x
68 | 1
69 |
70 | Still like a dict though
71 |
72 | >>> o = EasyDict({'clean':True})
73 | >>> o.items()
74 | [('clean', True)]
75 |
76 | And like a class
77 |
78 | >>> class Flower(EasyDict):
79 | ... power = 1
80 | ...
81 | >>> f = Flower()
82 | >>> f.power
83 | 1
84 | >>> f = Flower({'height': 12})
85 | >>> f.height
86 | 12
87 | >>> f['power']
88 | 1
89 | >>> sorted(f.keys())
90 | ['height', 'power']
91 |
92 | update and pop items
93 | >>> d = EasyDict(a=1, b='2')
94 | >>> e = EasyDict(c=3.0, a=9.0)
95 | >>> d.update(e)
96 | >>> d.c
97 | 3.0
98 | >>> d['c']
99 | 3.0
100 | >>> d.get('c')
101 | 3.0
102 | >>> d.update(a=4, b=4)
103 | >>> d.b
104 | 4
105 | >>> d.pop('a')
106 | 4
107 | >>> d.a
108 | Traceback (most recent call last):
109 | ...
110 | AttributeError: 'EasyDict' object has no attribute 'a'
111 | """
112 |
113 | def __init__(self, d=None, **kwargs):
114 | if d is None:
115 | d = {}
116 | if kwargs:
117 | d.update(**kwargs)
118 | for k, v in d.items():
119 | setattr(self, k, v)
120 | # Class attributes
121 | for k in self.__class__.__dict__.keys():
122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123 | setattr(self, k, getattr(self, k))
124 |
125 | def __setattr__(self, name, value):
126 | if isinstance(value, (list, tuple)):
127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
129 | value = self.__class__(value)
130 | super(EasyDict, self).__setattr__(name, value)
131 | super(EasyDict, self).__setitem__(name, value)
132 |
133 | __setitem__ = __setattr__
134 |
135 | def update(self, e=None, **f):
136 | d = e or dict()
137 | d.update(f)
138 | for k in d:
139 | setattr(self, k, d[k])
140 |
141 | def pop(self, k, d=None):
142 | if hasattr(self, k):
143 | delattr(self, k)
144 | return super(EasyDict, self).pop(k, d)
145 |
146 |
147 | if __name__ == "__main__":
148 | import doctest
149 |
150 |
--------------------------------------------------------------------------------
/data_preparing/check_grounding_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import json\n",
11 | "\n",
12 | "\n",
13 | "def calculate_iou(pred_span, gold_span):\n",
14 | " gold_start, gold_end, pred_start, pred_end = gold_span[0], gold_span[1], pred_span[0], pred_span[1]\n",
15 | " intersection = max(0, min(gold_end, pred_end) - max(gold_start, pred_start))\n",
16 | " union = max(0, max(gold_end, pred_end) - min(gold_start, pred_start))\n",
17 | " if union <= 0 or intersection <= 0:\n",
18 | " return 0\n",
19 | " return intersection / union\n",
20 | "\n",
21 | "\n",
22 | "def check_ans(pred_span, gold_spans):\n",
23 | " if not isinstance(gold_spans[0], (list, tuple)):\n",
24 | " gold_spans = [gold_spans]\n",
25 | " return max([calculate_iou(pred_span, gold_span) for gold_span in gold_spans])\n",
26 | "\n",
27 | "\n",
28 | "def get_iou_at_different_turns(example, max_turns=4):\n",
29 | " pred_answer_list = example['pred_answer_list'] + ['throughout'] * (max_turns - len(example['pred_answer_list']))\n",
30 | " start, end = 0, example['duration']\n",
31 | " \n",
32 | " res_list = list()\n",
33 | " for pred in pred_answer_list:\n",
34 | " interval = (end - start) / 4\n",
35 | " if pred == 'beginning':\n",
36 | " end = end - 2*interval\n",
37 | " elif pred == 'middle':\n",
38 | " start, end = start + interval, end - interval\n",
39 | " elif pred == 'end':\n",
40 | " start = start + 2 * interval\n",
41 | " res_list.append(check_ans((start, end), example['gt_span']))\n",
42 | " return res_list\n"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "# change output_fname at your need\n",
52 | "output_fname = '../outputs/charades_sta-recursive_grounding-4_turns.jsonl'\n",
53 | "gold_fname = '../data/test-anno/charades_sta-recursive_grounding.json'\n",
54 | "gold_data = json.load(open(gold_fname))\n",
55 | "max_turns = 4\n",
56 | "\n",
57 | "res_list_all_turns = list()\n",
58 | "\n",
59 | "for line, gold in zip(open(output_fname), gold_data):\n",
60 | " example = json.loads(line)\n",
61 | " if example['duration'] is None: continue\n",
62 | " example['gt_span'] = gold['answer']\n",
63 | " res_list = get_iou_at_different_turns(example, max_turns)\n",
64 | " res_list_all_turns.append(res_list)\n",
65 | "\n",
66 | "print(output_fname, 'num examples:', len(res_list_all_turns))\n",
67 | "for turns in range(max_turns):\n",
68 | " print('turns: %d' % (turns + 1))\n",
69 | " iou_list = [ious[turns] for ious in res_list_all_turns]\n",
70 | " print('mean iou: %.4f' % np.mean(iou_list))\n",
71 | " print('iou@0.3/0.5/0.7: %.4f/%.4f/%.4f' % tuple([len([i for i in iou_list if i > thres]) / len(iou_list) for thres in [0.3, 0.5, 0.7]]))\n"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": []
80 | }
81 | ],
82 | "metadata": {
83 | "kernelspec": {
84 | "display_name": "videochat",
85 | "language": "python",
86 | "name": "python3"
87 | },
88 | "language_info": {
89 | "codemirror_mode": {
90 | "name": "ipython",
91 | "version": 3
92 | },
93 | "file_extension": ".py",
94 | "mimetype": "text/x-python",
95 | "name": "python",
96 | "nbconvert_exporter": "python",
97 | "pygments_lexer": "ipython3",
98 | "version": "3.8.16"
99 | }
100 | },
101 | "nbformat": 4,
102 | "nbformat_minor": 2
103 | }
104 |
--------------------------------------------------------------------------------
/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from torch.utils.data import Dataset
5 |
6 | from dataset.utils import load_image_from_path
7 |
8 | try:
9 | from petrel_client.client import Client
10 | has_client = True
11 | except ImportError:
12 | has_client = False
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class ImageVideoBaseDataset(Dataset):
18 | """Base class that implements the image and video loading methods"""
19 |
20 | media_type = "video"
21 |
22 | def __init__(self):
23 | assert self.media_type in ["image", "video", "only_video"]
24 | self.data_root = None
25 | self.anno_list = (
26 | None # list(dict), each dict contains {"image": str, # image or video path}
27 | )
28 | self.transform = None
29 | self.video_reader = None
30 | self.num_tries = None
31 |
32 | self.client = None
33 | if has_client:
34 | self.client = Client('~/petreloss.conf')
35 |
36 | def __getitem__(self, index):
37 | raise NotImplementedError
38 |
39 | def __len__(self):
40 | raise NotImplementedError
41 |
42 | def get_anno(self, index):
43 | """obtain the annotation for one media (video or image)
44 |
45 | Args:
46 | index (int): The media index.
47 |
48 | Returns: dict.
49 | - "image": the filename, video also use "image".
50 | - "caption": The caption for this file.
51 |
52 | """
53 | anno = self.anno_list[index]
54 | if self.data_root is not None:
55 | anno["image"] = os.path.join(self.data_root, anno["image"])
56 | return anno
57 |
58 | def load_and_transform_media_data(self, index, data_path):
59 | if self.media_type == "image":
60 | return self.load_and_transform_media_data_image(index, data_path)
61 | else:
62 | return self.load_and_transform_media_data_video(index, data_path)
63 |
64 | def load_and_transform_media_data_image(self, index, data_path):
65 | image = load_image_from_path(data_path, client=self.client)
66 | image = self.transform(image)
67 | return image, index
68 |
69 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None):
70 | for _ in range(self.num_tries):
71 | try:
72 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
73 | frames, frame_indices, fps = self.video_reader(
74 | data_path, self.num_frames, self.sample_type,
75 | max_num_frames=max_num_frames, client=self.client, clip=clip, fps=getattr(self, "fps", None)
76 | )
77 | except Exception as e:
78 | logger.warning(
79 | f"Caught exception {e} when loading video {data_path}, "
80 | f"randomly sample a new video as replacement"
81 | )
82 | index = random.randint(0, len(self) - 1)
83 | ann = self.get_anno(index)
84 | data_path = ann["image"]
85 | continue
86 | # shared aug for video frames
87 | frames = self.transform(frames)
88 | if return_fps:
89 | assert fps is not None or hasattr(self, 'fps')
90 | fps = fps if fps is not None else self.fps # for some case that fps cannot be inferred from the video, we can set it manually
91 | sec = [str(round(f / fps, 1)) for f in frame_indices]
92 | return frames, index, sec
93 | else:
94 | return frames, index
95 | else:
96 | raise RuntimeError(
97 | f"Failed to fetch video after {self.num_tries} tries. "
98 | f"This might indicate that you have many corrupted videos."
99 | )
100 |
--------------------------------------------------------------------------------
/models/blip2/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import contextlib
8 | import os
9 | import logging
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | from .Qformer import BertConfig, BertLMHeadModel
15 | from .vit import build_vit
16 | from transformers import BertTokenizer
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Blip2Base(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | # @classmethod
26 | # def init_tokenizer(cls, truncation_side="right"):
27 | def init_tokenizer(self, truncation_side="right"):
28 | tokenizer = BertTokenizer.from_pretrained(self.bert_path, truncation_side=truncation_side, local_files_only=True)
29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
30 | return tokenizer
31 |
32 | @property
33 | def device(self):
34 | return list(self.parameters())[0].device
35 |
36 | def maybe_autocast(self, dtype=torch.float16):
37 | # if on cpu, don't use autocast
38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
39 | enable_autocast = self.device != torch.device("cpu")
40 |
41 | if enable_autocast:
42 | return torch.cuda.amp.autocast(dtype=dtype)
43 | else:
44 | return contextlib.nullcontext()
45 |
46 | # @classmethod
47 | def init_Qformer(
48 | # cls,
49 | self,
50 | num_query_token, num_frame_tokens, vision_width,
51 | qformer_hidden_dropout_prob=0.1,
52 | qformer_attention_probs_dropout_prob=0.1,
53 | qformer_drop_path_rate=0.,
54 | ):
55 | encoder_config = BertConfig.from_pretrained(self.bert_path, local_files_only=True)
56 | encoder_config.num_frame_tokens = num_frame_tokens
57 | encoder_config.encoder_width = vision_width
58 | # insert cross-attention layer every other block
59 | encoder_config.add_cross_attention = True
60 | encoder_config.cross_attention_freq = 2
61 | encoder_config.query_length = num_query_token
62 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
63 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
64 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
65 | logger.info(f"Drop_path:{encoder_config.drop_path_list}")
66 | logger.info(encoder_config)
67 | Qformer = BertLMHeadModel(config=encoder_config)
68 | query_tokens = nn.Parameter(
69 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
70 | )
71 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
72 | return Qformer, query_tokens
73 |
74 | @classmethod
75 | def init_vision_encoder_umt(self, config):
76 | """build vision encoder
77 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
78 |
79 | """
80 | vision_encoder = build_vit(config)
81 |
82 | if config.vision_encoder.vit_add_ln:
83 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12)
84 | else:
85 | vision_layernorm = nn.Identity()
86 |
87 | return vision_encoder, vision_layernorm
88 |
89 |
90 | def disabled_train(self, mode=True):
91 | """Overwrite model.train with this function to make sure train/eval mode
92 | does not change anymore."""
93 | return self
94 |
95 |
96 | class LayerNorm(nn.LayerNorm):
97 | """Subclass torch's LayerNorm to handle fp16."""
98 |
99 | def forward(self, x: torch.Tensor):
100 | orig_type = x.dtype
101 | ret = super().forward(x.type(torch.float32))
102 | return ret.type(orig_type)
103 |
--------------------------------------------------------------------------------
/data_preparing/anetc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import argparse
5 |
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--func', type=str, default='test_grounding')
10 | parser.add_argument('--annotation-fname', type=str, default='data/ActivityNet/captions/train.json')
11 | parser.add_argument('--annotation-fname2', type=str, default=None)
12 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/instructions.json')
13 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/questions.json')
14 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/train.json')
15 | parser.add_argument('--time-span-sent', type=str, default='From second %.1f to second %.1f.') # for frame-level and second-level
16 | parser.add_argument('--sample-ratio', type=float, default=1.0, help='random sample ratio for the test set')
17 | args = parser.parse_args()
18 | print(args)
19 |
20 | os.makedirs(os.path.dirname(args.output_fname), exist_ok=True)
21 | anno = json.load(open(args.annotation_fname))
22 | if args.annotation_fname2 is not None:
23 | anno1, anno = anno, dict()
24 | anno2 = json.load(open(args.annotation_fname2))
25 | keys = anno1.keys() | anno2.keys()
26 | for key in keys:
27 | if key in anno1:
28 | anno[key] = anno1[key]
29 | if key in anno2:
30 | anno[key]['sentences'] = anno1[key]['sentences'] + anno2[key]['sentences']
31 | anno[key]['timestamps'] = anno1[key]['timestamps'] + anno2[key]['timestamps']
32 | else:
33 | anno[key] = anno2[key]
34 |
35 | if os.path.exists(args.question):
36 | args.question = json.load(open(args.question))
37 | else:
38 | args.question = [args.question]
39 |
40 | if args.func in ['grounding', 'caption', 'choice']:
41 | instructions = json.load(open(args.instruction_fname))
42 | res = list()
43 | for video_id, video_data in anno.items():
44 | for (start_sec, end_sec), sent in zip(video_data['timestamps'], video_data['sentences']):
45 | if sent.endswith('.'):
46 | sent = sent[:-1]
47 | new_example = {'i': random.choice(instructions), 'start_sec': start_sec, 'end_sec': end_sec, 'neg_start_sec': 0, 'neg_end_sec': video_data['duration']}
48 | if args.func == 'grounding':
49 | new_example['q'] = random.choice(args.question) % sent
50 | new_example['a'] = args.time_span_sent
51 | elif args.func == 'caption':
52 | new_example['a'] = sent
53 | new_example['q'] = args.time_span_sent
54 | elif args.func == 'choice':
55 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."]
56 | random.shuffle(options)
57 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)]
58 | new_example['q'] = "Question: " + random.choice(args.question) % sent + "\nOptions:" + "".join(options)
59 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options]
60 | res.append({'video': video_id, 'QA': [new_example]})
61 |
62 | if args.func in ["test_grounding"]:
63 | res = list()
64 | for video_id, video_data in anno.items():
65 | for (start_sec, end_sec), sent in zip(video_data['timestamps'], video_data['sentences']):
66 | if random.random() > args.sample_ratio: continue
67 | sent = sent.replace('.', '').strip()
68 | new_example = {'video': video_id, 'question': sent, "answer": [start_sec, end_sec]}
69 | res.append(new_example)
70 |
71 | print('saved %d examples' % len(res))
72 | with open(args.output_fname, 'w') as f_out:
73 | json.dump(res, f_out)
74 |
--------------------------------------------------------------------------------
/tasks/shared_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from utils.optimizer import create_optimizer
11 | from utils.scheduler import create_scheduler
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_media_types(datasources):
17 | """get the media types for for all the dataloaders.
18 |
19 | Args:
20 | datasources (List): List of dataloaders or datasets.
21 |
22 | Returns: List. The media_types.
23 |
24 | """
25 | if isinstance(datasources[0], DataLoader):
26 | datasets = [dataloader.dataset for dataloader in datasources]
27 | else:
28 | datasets = datasources
29 | media_types = [
30 | dataset.datasets[0].media_type
31 | if isinstance(dataset, ConcatDataset)
32 | else dataset.media_type
33 | for dataset in datasets
34 | ]
35 |
36 | return media_types
37 |
38 |
39 | def setup_model(
40 | config, model_cls, find_unused_parameters=False
41 | ):
42 | logger.info("Creating model")
43 | config = copy.deepcopy(config)
44 |
45 | model = model_cls(config=config.model)
46 |
47 | # model = model.to(torch.device(config.device))
48 | if config.model.get('model_parallel', False):
49 | model.set_device_ids([int(os.environ["LOCAL_RANK"]), int(os.environ["LOCAL_RANK"]) + torch.distributed.get_world_size()])
50 | else:
51 | model.set_device_ids([int(os.environ["LOCAL_RANK"])])
52 |
53 | model_without_ddp = model
54 | if config.distributed:
55 | model = torch.nn.parallel.DistributedDataParallel(
56 | model,
57 | device_ids=[config.gpu] if not config.model.get('model_parallel', False) else None,
58 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
59 | )
60 |
61 | optimizer = create_optimizer(config.optimizer, model)
62 | scheduler = create_scheduler(config.scheduler, optimizer)
63 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
64 |
65 | start_epoch = 0
66 | global_step = 0
67 |
68 | # auto resume the latest checkpoint
69 | if config.get("auto_resume", False):
70 | logger.info("Auto resuming")
71 | model_latest = join(config.output_dir, "ckpt_latest.pth")
72 | model_best = join(config.output_dir, "ckpt_best.pth")
73 | large_num = -1
74 | for p in os.listdir(config.output_dir):
75 | if 'ckpt' in p:
76 | num = p.split('_')[1].split('.')[0]
77 | if str.isnumeric(num):
78 | if int(num) > large_num:
79 | large_num = int(num)
80 | if large_num != -1:
81 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
82 | if osp.isfile(model_latest):
83 | config.pretrained_path = model_latest
84 | config.resume = True
85 | elif osp.isfile(model_best):
86 | config.pretrained_path = model_best
87 | config.resume = True
88 | else:
89 | logger.info(f"Not found checkpoint in {config.output_dir}")
90 |
91 | if osp.isfile(config.pretrained_path):
92 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
93 | state_dict = checkpoint["model"] if model in checkpoint else checkpoint
94 |
95 | if config.resume:
96 | optimizer.load_state_dict(checkpoint["optimizer"])
97 | scheduler.load_state_dict(checkpoint["scheduler"])
98 | scaler.load_state_dict(checkpoint["scaler"])
99 | start_epoch = checkpoint["epoch"] + 1
100 | global_step = checkpoint["global_step"]
101 |
102 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
103 | logger.info(msg)
104 | logger.info(f"Loaded checkpoint from {config.pretrained_path}")
105 | else:
106 | logger.warning("No pretrained checkpoint provided, training from scratch")
107 |
108 | return (
109 | model,
110 | model_without_ddp,
111 | optimizer,
112 | scheduler,
113 | scaler,
114 | start_epoch,
115 | global_step,
116 | )
117 |
--------------------------------------------------------------------------------
/scripts/train/config_7b_stage3.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | # ========================= data ==========================
4 | train_corpus = "hawkeye_instruction"
5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
6 | test_file = dict()
7 | test_types = []
8 | num_workers = 6
9 |
10 | stop_key = None
11 |
12 | # ========================= input ==========================
13 | num_frames = 12
14 | num_frames_test = 12
15 | batch_size = 2
16 | max_txt_l = 512
17 |
18 | pre_text = False
19 |
20 | inputs = dict(
21 | image_res=224,
22 | video_input=dict(
23 | num_frames="${num_frames}",
24 | sample_type="rand",
25 | num_frames_test="${num_frames_test}",
26 | sample_type_test="middle",
27 | random_aug=False,
28 | ),
29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
30 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
32 | )
33 |
34 | # ========================= model ==========================
35 | model = dict(
36 | model_cls="HawkEye_it",
37 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth",
38 | llama_model_path="model/vicuna-7b-v0",
39 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth",
40 | freeze_vit=True,
41 | freeze_qformer=False,
42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
43 | # vit
44 | low_resource=False,
45 | add_temp_embed=False,
46 | vision_encoder=dict(
47 | name="vit_l14",
48 | img_size=224,
49 | patch_size=16,
50 | d_model=1024,
51 | encoder_embed_dim=1024,
52 | encoder_depth=24,
53 | encoder_num_heads=16,
54 | drop_path_rate=0.,
55 | num_frames="${num_frames}",
56 | tubelet_size=1,
57 | use_checkpoint=False,
58 | checkpoint_num=0,
59 | pretrained="",
60 | return_index=-2,
61 | vit_add_ln=True,
62 | ckpt_num_frame=4,
63 | ),
64 | # qformer
65 | num_query_token=32,
66 | qformer_hidden_dropout_prob=0.1,
67 | qformer_attention_probs_dropout_prob=0.1,
68 | qformer_drop_path_rate=0.2,
69 | extra_num_query_token=64,
70 | qformer_text_input=True,
71 | # prompt
72 | system="",
73 | start_token="",
75 | add_second_msg=True,
76 | img_start_token="",
77 | img_end_token="",
78 | random_shuffle=True,
79 | use_flash_attention=False,
80 | use_lora=True,
81 | lora_r=16,
82 | lora_alpha=32,
83 | lora_dropout=0.1,
84 | # debug=True,
85 | )
86 |
87 | optimizer = dict(
88 | opt="adamW",
89 | lr=2e-5,
90 | opt_betas=[0.9, 0.999], # default
91 | weight_decay=0.02,
92 | max_grad_norm=-1, # requires a positive float, use -1 to disable
93 | # use a different lr for some modules, e.g., larger lr for new modules
94 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
95 | )
96 |
97 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
98 |
99 | evaluate = False
100 | deep_fusion = False
101 | evaluation = dict(
102 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
103 | eval_x_only=False,
104 | k_test=128,
105 | eval_offload=True, # offload gpu tensors to cpu to save memory.
106 | )
107 |
108 | fp16 = True
109 | gradient_checkpointing = True
110 |
111 | # ========================= log ==========================
112 | tensorboard = dict(
113 | enable=True
114 | )
115 |
116 | wandb = dict(
117 | enable=False,
118 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
119 | project="project",
120 | )
121 |
122 | # ========================= others ==========================
123 | dist_url = "env://"
124 | device = "cuda"
125 | mode = "it"
126 |
127 | output_dir = None # output dir
128 | resume = False # if True, load optimizer and scheduler states as well
129 | debug = False
130 | log_freq = 100
131 | seed = 42
132 |
133 | save_latest = True
134 | auto_resume = True
135 | pretrained_path = "" # path to pretrained model weights, for resume only
136 | freeze_dataset_folder = None # save all data used in training
137 |
--------------------------------------------------------------------------------
/scripts/train/charades_sta.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | available_corpus['charades_sta_ft'] = [
4 | {
5 | "dataset_name": 'charades_sta_grounding', 'grounding_method': 'choice',
6 | 'label_file': f'{anno_root_it}/video/temporal/charades_sta_grounding-choice/train.json',
7 | 'data_root': f'data/videos/charades', 'media_type': 'video'
8 | }
9 | ]
10 |
11 | # ========================= data ==========================
12 | train_corpus = "charades_sta_ft"
13 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
14 | test_file = dict()
15 | test_types = []
16 | num_workers = 6
17 |
18 | stop_key = None
19 |
20 | # ========================= input ==========================
21 | num_frames = 12
22 | num_frames_test = 12
23 | batch_size = 1
24 | max_txt_l = 512
25 |
26 | pre_text = False
27 |
28 | inputs = dict(
29 | image_res=224,
30 | video_input=dict(
31 | num_frames="${num_frames}",
32 | sample_type="rand",
33 | num_frames_test="${num_frames_test}",
34 | sample_type_test="middle",
35 | random_aug=False,
36 | ),
37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
38 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
40 | )
41 |
42 | # ========================= model ==========================
43 | model = dict(
44 | model_cls="HawkEye_it",
45 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth",
46 | llama_model_path="model/vicuna-7b-v0",
47 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth",
48 | freeze_vit=True,
49 | freeze_qformer=False,
50 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
51 | # vit
52 | low_resource=False,
53 | add_temp_embed=False,
54 | vision_encoder=dict(
55 | name="vit_l14",
56 | img_size=224,
57 | patch_size=16,
58 | d_model=1024,
59 | encoder_embed_dim=1024,
60 | encoder_depth=24,
61 | encoder_num_heads=16,
62 | drop_path_rate=0.,
63 | num_frames="${num_frames}",
64 | tubelet_size=1,
65 | use_checkpoint=False,
66 | checkpoint_num=0,
67 | pretrained="",
68 | return_index=-2,
69 | vit_add_ln=True,
70 | ckpt_num_frame=4,
71 | ),
72 | # qformer
73 | num_query_token=32,
74 | qformer_hidden_dropout_prob=0.1,
75 | qformer_attention_probs_dropout_prob=0.1,
76 | qformer_drop_path_rate=0.2,
77 | extra_num_query_token=64,
78 | qformer_text_input=True,
79 | # prompt
80 | system="",
81 | start_token="",
83 | add_second_msg=True,
84 | img_start_token="",
85 | img_end_token="",
86 | random_shuffle=True,
87 | use_flash_attention=False,
88 | use_lora=True,
89 | lora_r=16,
90 | lora_alpha=32,
91 | lora_dropout=0.1,
92 | debug=False,
93 | )
94 |
95 | optimizer = dict(
96 | opt="adamW",
97 | lr=2e-5,
98 | opt_betas=[0.9, 0.999], # default
99 | weight_decay=0.02,
100 | max_grad_norm=-1, # requires a positive float, use -1 to disable
101 | # use a different lr for some modules, e.g., larger lr for new modules
102 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
103 | )
104 |
105 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.25, warmup_epochs=0.6)
106 |
107 | evaluate = False
108 | deep_fusion = False
109 | evaluation = dict(
110 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
111 | eval_x_only=False,
112 | k_test=128,
113 | eval_offload=True, # offload gpu tensors to cpu to save memory.
114 | )
115 |
116 | fp16 = True
117 | gradient_checkpointing = False
118 |
119 | # ========================= log ==========================
120 | tensorboard = dict(
121 | enable=True
122 | )
123 |
124 | wandb = dict(
125 | enable=False,
126 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
127 | project="project",
128 | )
129 |
130 | # ========================= others ==========================
131 | dist_url = "env://"
132 | device = "cuda"
133 | mode = "it"
134 |
135 | output_dir = None # output dir
136 | resume = False # if True, load optimizer and scheduler states as well
137 | debug = False
138 | log_freq = 50
139 | save_freq = 1000
140 | seed = 42
141 |
142 | save_latest = False
143 | auto_resume = True
144 | pretrained_path = "" # path to pretrained model weights, for resume only?
145 | freeze_dataset_folder = None
146 |
--------------------------------------------------------------------------------
/scripts/train/anetc.py:
--------------------------------------------------------------------------------
1 | from configs.instruction_data import *
2 |
3 | available_corpus['anetc_ft'] = [
4 | {
5 | "dataset_name": 'charades_sta_grounding', 'grounding_method': 'choice',
6 | 'label_file': f'{anno_root_it}/video/temporal/anetc_grounding/train.json',
7 | 'data_root': f'data/videos/activitynet', 'media_type': 'video', 'fps': 3
8 | }
9 | ]
10 |
11 | # ========================= data ==========================
12 | train_corpus = "anetc_ft"
13 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
14 | test_file = dict()
15 | test_types = []
16 | num_workers = 6
17 |
18 | stop_key = None
19 |
20 | # ========================= input ==========================
21 | num_frames = 12
22 | num_frames_test = 12
23 | batch_size = 4
24 | max_txt_l = 512
25 |
26 | pre_text = False
27 |
28 | inputs = dict(
29 | image_res=224,
30 | video_input=dict(
31 | num_frames="${num_frames}",
32 | sample_type="rand",
33 | num_frames_test="${num_frames_test}",
34 | sample_type_test="middle",
35 | random_aug=False,
36 | ),
37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
38 | batch_size=dict(image="${batch_size}", video="${batch_size}"),
39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
40 | )
41 |
42 | # ========================= model ==========================
43 | model = dict(
44 | model_cls="VideoChat2_it",
45 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth",
46 | llama_model_path="model/vicuna-7b-v0",
47 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth",
48 | freeze_vit=True,
49 | freeze_qformer=False,
50 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3
51 | # vit
52 | low_resource=False,
53 | add_temp_embed=False,
54 | vision_encoder=dict(
55 | name="vit_l14",
56 | img_size=224,
57 | patch_size=16,
58 | d_model=1024,
59 | encoder_embed_dim=1024,
60 | encoder_depth=24,
61 | encoder_num_heads=16,
62 | drop_path_rate=0.,
63 | num_frames="${num_frames}",
64 | tubelet_size=1,
65 | use_checkpoint=False,
66 | checkpoint_num=0,
67 | pretrained="",
68 | return_index=-2,
69 | vit_add_ln=True,
70 | ckpt_num_frame=4,
71 | ),
72 | # qformer
73 | num_query_token=32,
74 | qformer_hidden_dropout_prob=0.1,
75 | qformer_attention_probs_dropout_prob=0.1,
76 | qformer_drop_path_rate=0.2,
77 | extra_num_query_token=64,
78 | qformer_text_input=True,
79 | # prompt
80 | system="",
81 | start_token="",
83 | add_second_msg=True,
84 | img_start_token="",
85 | img_end_token="",
86 | random_shuffle=True,
87 | use_flash_attention=False,
88 | use_lora=True,
89 | lora_r=16,
90 | lora_alpha=32,
91 | lora_dropout=0.1,
92 | debug=False,
93 | )
94 |
95 | optimizer = dict(
96 | opt="adamW",
97 | lr=2e-5,
98 | opt_betas=[0.9, 0.999], # default
99 | weight_decay=0.02,
100 | max_grad_norm=-1, # requires a positive float, use -1 to disable
101 | # use a different lr for some modules, e.g., larger lr for new modules
102 | different_lr=dict(enable=False, module_names=[], lr=1e-3),
103 | )
104 |
105 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.25, warmup_epochs=0.6)
106 |
107 | evaluate = False
108 | deep_fusion = False
109 | evaluation = dict(
110 | eval_frame_ensemble="concat", # [concat, max, mean, lse]
111 | eval_x_only=False,
112 | k_test=128,
113 | eval_offload=True, # offload gpu tensors to cpu to save memory.
114 | )
115 |
116 | fp16 = True
117 | # gradient_checkpointing = True
118 | gradient_checkpointing = False
119 |
120 | # ========================= log ==========================
121 | tensorboard = dict(
122 | enable=True
123 | )
124 |
125 | wandb = dict(
126 | enable=False,
127 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
128 | project="project",
129 | )
130 |
131 | # ========================= others ==========================
132 | dist_url = "env://"
133 | device = "cuda"
134 | mode = "it"
135 |
136 | output_dir = None # output dir
137 | resume = False # if True, load optimizer and scheduler states as well
138 | debug = False
139 | log_freq = 50
140 | save_freq = 2000
141 | seed = 42
142 |
143 | save_latest = False
144 | auto_resume = True
145 | pretrained_path = "" # path to pretrained model weights, for resume only?
146 | freeze_dataset_folder = None
147 |
--------------------------------------------------------------------------------
/internvid_g/README.md:
--------------------------------------------------------------------------------
1 | InternVid-G: A Large-Scale Video-Text Dataset with Scene-Level Annotations for Temporal Grounding
2 | ===
3 | 
4 |
5 | InternVid-G is a dataset based on a fraction of videos from [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid). It has segment-level annotation, where each segment is annotated with:
6 | - a caption, which is related to the semantic content of this segment;
7 | - `start_sec` and `end_sec` of this segment;
8 | - a `neg_start_sec` and a `neg_end_sec` that defines a "negative span" of this segment. All other segments in the negative span are not semantically similar to the this segment (and its caption), so it can serve as the context for temporal video grounding models to retrieve this segment from.
9 |
10 | # Data Download
11 | You can download the annotations from [🤗HuggingFace](https://huggingface.co/datasets/wangyueqian/InternVid-G)
12 |
13 | You can use `code/download_videos.py` to download the videos from YouTube.
14 |
15 | # Processing (Optional)
16 | We also provide code for reproducing (downloading and processing) InternVid-G in `code/`. We hope this code will help you work with your datasets.
17 |
18 | This code is only for your reference, and we strongly suggest you read the code and become familiar with it before using it.
19 |
20 | 1. Save a list of YouTube video ids of your interest in `temp/video_ids.txt`. Save them one id per line, which looks like this:
21 | ```text
22 | UkQzZimp7Qs
23 | vJxjY1-0120
24 | 11mh9C7RCqg
25 | dcIuh7JaxWw
26 | d5i7JiBAMtc
27 | J0aVIs-eLMA
28 | ...
29 | ```
30 |
31 | 2. Download the videos
32 | ```shell
33 | python code/download_videos.py --src_file temp/video_ids.txt --target_dir videos --num_workers 5
34 | ```
35 |
36 | 3. Split the videos into scenes using [PySceneDetect](https://github.com/Breakthrough/PySceneDetect)
37 | ```shell
38 | python code/clip_sim.py --func split_scene --scene-fname temp/scenes.jsonl
39 | ```
40 | The splitted scenes are saved at `temp/scenes.jsonl`. This process may take a long time as PySceneDetect processes the video frame-by-frame.
41 |
42 | 4. Calculate the similarities of all scenes with each other, and get a scene sim matrix
43 | ```shell
44 | python code/clip_sim.py --func scene_sim --scene-fname temp/scenes.jsonl --scene-sim-fname temp/scenes_similarity.jsonl
45 | ```
46 | The similarity matrices of the segments is saved at `temp/scenes_similarity.jsonl`.
47 |
48 | 5. Merge the most similar consecutive scenes, and the merged scene sim matrix
49 | ```shell
50 | python code/clip_sim.py --func merge_scene \
51 | --scene-fname temp/scenes.jsonl --scene-sim-fname temp/scenes_similarity.jsonl \
52 | --scene-merged-fname temp/scenes-merged.jsonl --scene-merged-sim-fname temp/scenes_merged_similarity.jsonl
53 | ```
54 | The merged scenes and similarity matrices are saved at `temp/scenes-merged.jsonl` and `temp/scenes_merged_similarity.jsonl`.
55 |
56 | 6. Use a captioning model to get the caption of each scene (we find BLIP-2 works best in our case)
57 | ```shell
58 | python code/caption_clips.py --func blip2 \
59 | --scene-fname temp/scenes-merged.jsonl --blip2-fname temp/scenes-blip2.jsonl
60 | ```
61 | The segment captions are saved at `temp/scenes-blip2.jsonl`.
62 |
63 | 7. Calculate the similarity between each segment and its caption, remove the ones with low similarity
64 | ```shell
65 | python code/caption_clips.py --func filter --merge-method max \
66 | --filter-input-fname temp/scenes-blip2.jsonl \
67 | --filtered-fname temp/scenes-blip2-filtered.jsonl \
68 |
69 | python code/caption_clips.py --func merge_filtered_captions
70 | --filter-input-fname temp/scenes-blip2-filtered.jsonl \
71 | --filtered-fname temp/scenes-blip2-filtered-high_sims.jsonl \
72 | ```
73 | All segment-caption pairs with high similarities that are not filterd are saved in `temp/scenes-blip2-filtered-high_sims.jsonl`.
74 |
75 | 8. (Finally the last step!) Get a negative span for each segment.
76 | ```shell
77 | python code/ground_data_construction.py \
78 | --caption-fname temp/scenes-blip2-filtered-high_sims.jsonl \
79 | --scene-fname temp/scenes-merged.jsonl --scene-sim-fname temp/scenes_similarity-merged.jsonl \
80 | --caption-with-neg-interval-fname temp/final_dataset.jsonl
81 | ```
82 | Finally you can get all annotations in `temp/final_dataset.jsonl`.
83 |
84 | # Citation
85 | If you find this code useful in your research, please consider citing:
86 | ```bibtex
87 | @misc{wang2024hawkeye,
88 | title={HawkEye: Training Video-Text LLMs for Grounding Text in Videos},
89 | author={Yueqian Wang and Xiaojun Meng and Jianxin Liang and Yuxuan Wang and Qun Liu and Dongyan Zhao},
90 | year={2024},
91 | eprint={2403.10228},
92 | archivePrefix={arXiv},
93 | primaryClass={cs.CV}
94 | }
95 | ```
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import logging
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def setup_for_distributed(is_master):
11 | import warnings
12 |
13 | builtin_warn = warnings.warn
14 |
15 | def warn(*args, **kwargs):
16 | force = kwargs.pop("force", False)
17 | if is_master or force:
18 | builtin_warn(*args, **kwargs)
19 |
20 | # Log warnings only once
21 | warnings.warn = warn
22 | warnings.simplefilter("once", UserWarning)
23 |
24 | if not is_master:
25 | logging.disable()
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 |
36 | def get_world_size():
37 | if not is_dist_avail_and_initialized():
38 | return 1
39 | return dist.get_world_size()
40 |
41 |
42 | def get_rank():
43 | if not is_dist_avail_and_initialized():
44 | return 0
45 | return dist.get_rank()
46 |
47 |
48 | def is_main_process():
49 | return get_rank() == 0
50 |
51 |
52 | def save_on_master(*args, **kwargs):
53 | if is_main_process():
54 | torch.save(*args, **kwargs)
55 |
56 |
57 | def is_port_in_use(port):
58 | import socket
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(('localhost', port)) == 0
61 |
62 |
63 | def init_distributed_mode(args):
64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65 | # job started by torch.distributed.launch
66 | args.rank = int(os.environ["RANK"])
67 | args.world_size = int(os.environ['WORLD_SIZE'])
68 | args.gpu = int(os.environ['LOCAL_RANK'])
69 | elif 'SLURM_PROCID' in os.environ:
70 | # local rank on the current node / global rank
71 | local_rank = int(os.environ['SLURM_LOCALID'])
72 | global_rank = int(os.environ['SLURM_PROCID'])
73 | # number of processes / GPUs per node
74 | world_size = int(os.environ["SLURM_NNODES"]) * \
75 | int(os.environ["SLURM_TASKS_PER_NODE"][0])
76 |
77 | print(world_size)
78 |
79 | args.rank = global_rank
80 | args.gpu = local_rank
81 | args.world_size = world_size
82 | else:
83 | logger.info('Not using distributed mode')
84 | args.distributed = False
85 | return
86 |
87 | args.distributed = True
88 |
89 | logger.info('torch.cuda.set_device', args.gpu)
90 | torch.cuda.set_device(args.gpu)
91 | args.dist_backend = 'nccl'
92 |
93 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
94 | dist_port = int(args.dist_url.split(":")[-1])
95 | while is_port_in_use(dist_port):
96 | dist_port += 10
97 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
98 |
99 | logger.info('| distributed init (rank {}): {}'.format(
100 | args.rank, args.dist_url))
101 | if "SLURM_JOB_ID" in os.environ:
102 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
103 | torch.distributed.init_process_group(
104 | backend=args.dist_backend, init_method=args.dist_url,
105 | world_size=args.world_size, rank=args.rank)
106 | torch.distributed.barrier()
107 | setup_for_distributed(args.rank == 0)
108 |
109 |
110 | # Copyright (c) Facebook, Inc. and its affiliates.
111 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
112 | class GatherLayer(torch.autograd.Function):
113 | """
114 | Gather tensors from all workers with support for backward propagation:
115 | This implementation does not cut the gradients as torch.distributed.all_gather does.
116 | """
117 |
118 | @staticmethod
119 | def forward(ctx, x):
120 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
121 | dist.all_gather(output, x)
122 | return tuple(output)
123 |
124 | @staticmethod
125 | def backward(ctx, *grads):
126 | all_gradients = torch.stack(grads)
127 | dist.all_reduce(all_gradients)
128 | return all_gradients[dist.get_rank()]
129 |
130 |
131 | # copied from megavlt
132 | def gather_tensor_along_batch_with_backward(tensor, dim=0):
133 | world_size = get_world_size()
134 |
135 | if world_size < 2:
136 | return tensor
137 |
138 | tensor_list = GatherLayer.apply(tensor)
139 | tensor_list = torch.cat(tensor_list, dim=dim)
140 | return tensor_list
141 |
142 |
143 | @torch.no_grad()
144 | def gather_tensor_along_batch(tensor, dim=0):
145 | """
146 | Performs all_gather operation on the provided tensors.
147 | *** Warning ***: torch.distributed.all_gather has no gradient.
148 | """
149 | world_size = get_world_size()
150 |
151 | if world_size < 2:
152 | return tensor
153 |
154 | with torch.no_grad():
155 | tensor_list = []
156 |
157 | for _ in range(world_size):
158 | tensor_list.append(torch.zeros_like(tensor))
159 |
160 | dist.all_gather(tensor_list, tensor)
161 | tensor_list = torch.cat(tensor_list, dim=dim)
162 | return tensor_list
163 |
--------------------------------------------------------------------------------
/configs/instruction_data.py:
--------------------------------------------------------------------------------
1 | import os as __os # add "__" if not want to be exported
2 |
3 | anno_root_it = "data/HawkEye-IT"
4 |
5 | # ============== pretraining datasets=================
6 |
7 | available_corpus = dict(
8 | caption_textvr={
9 | "dataset_name": "caption_textvr",
10 | "label_file": f"{anno_root_it}/video/caption/textvr/train.json",
11 | "data_root": "data/videos/textvr",
12 | "media_type": "video",
13 | },
14 |
15 | caption_videochat={
16 | "dataset_name": "caption_videochat",
17 | "label_file": f"{anno_root_it}/video/caption/videochat/train.json",
18 | "data_root": "data/videos/webvid",
19 | "media_type": "video",
20 | },
21 |
22 | caption_webvid={
23 | "dataset_name": "caption_webvid",
24 | "label_file": f"{anno_root_it}/video/caption/webvid/train.json",
25 | "data_root": "data/videos/webvid",
26 | "media_type": "video",
27 | },
28 |
29 | caption_youcook2={
30 | "dataset_name": "caption_youcook2",
31 | "label_file": f"{anno_root_it}/video/caption/youcook2/train.json",
32 | "data_root": "data/videos/youcook2",
33 | "media_type": "video",
34 | },
35 |
36 | classification_k710={
37 | "dataset_name": "classification_k710",
38 | "label_file": f"{anno_root_it}/video/classification/k710/train.json",
39 | "data_root": "data/videos/kinetics",
40 | "media_type": "video",
41 | },
42 |
43 | classification_ssv2={
44 | "dataset_name": "classification_ssv2",
45 | "label_file": f"{anno_root_it}/video/classification/ssv2/train.json",
46 | "data_root": "data/videos/ssv2",
47 | "media_type": "video",
48 | },
49 |
50 | conversation_videochat1={
51 | "dataset_name": "conversation_videochat1",
52 | "label_file": f"{anno_root_it}/video/conversation/videochat1/train.json",
53 | "data_root": "data/videos/webvid",
54 | "media_type": "video",
55 | },
56 |
57 | conversation_videochatgpt={
58 | "dataset_name": "conversation_videochatgpt",
59 | "label_file": f"{anno_root_it}/video/conversation/videochatgpt/train.json",
60 | "data_root": "data/videos/activitynet",
61 | "media_type": "video",
62 | },
63 |
64 | reasoning_next_qa={
65 | "dataset_name": "reasoning_next_qa",
66 | "label_file": f"{anno_root_it}/video/reasoning/next_qa/train.json",
67 | "data_root": "data/videos/nextqa",
68 | "media_type": "video",
69 | },
70 |
71 | reasoning_clevrer_qa={
72 | "dataset_name": "reasoning_clevrer_qa",
73 | "label_file": f"{anno_root_it}/video/reasoning/clevrer_qa/train.json",
74 | "data_root": "data/videos/clevrer",
75 | "media_type": "video",
76 | },
77 |
78 | reasoning_clevrer_mc={
79 | "dataset_name": "reasoning_clevrer_mc",
80 | "label_file": f"{anno_root_it}/video/reasoning/clevrer_mc/train.json",
81 | "data_root": "data/videos/clevrer",
82 | "media_type": "video",
83 | },
84 |
85 | vqa_tgif_frame_qa={
86 | "dataset_name": "vqa_tgif_frame_qa",
87 | "label_file": f"{anno_root_it}/video/vqa/tgif_frame_qa/train.json",
88 | "data_root": "data/videos/tgif",
89 | "media_type": "video",
90 | },
91 |
92 | vqa_tgif_transition_qa={
93 | "dataset_name": "vqa_tgif_transition_qa",
94 | "label_file": f"{anno_root_it}/video/vqa/tgif_transition_qa/train.json",
95 | "data_root": "data/videos/tgif",
96 | "media_type": "video",
97 | },
98 |
99 | vqa_webvid_qa={
100 | "dataset_name": "vqa_webvid_qa",
101 | "label_file": f"{anno_root_it}/video/vqa/webvid_qa/train.json",
102 | "data_root": "data/videos/webvid",
103 | "media_type": "video",
104 | },
105 |
106 | internvid_grounding={
107 | "dataset_name": "internvid_grounding", 'grounding_method': 'choice',
108 | "label_file": f"{anno_root_it}/video/temporal/internvid_grounding/train.json",
109 | "data_root": "data/videos/internvid-g",
110 | "media_type": "video",
111 | },
112 |
113 | internvid_caption={
114 | "dataset_name": "internvid_grounding", 'grounding_method': 'choice',
115 | "label_file": f"{anno_root_it}/video/temporal/internvid_caption/train.json",
116 | "data_root": "data/videos/internvid-g",
117 | "media_type": "video",
118 | },
119 |
120 | )
121 |
122 | # select the instruction training data you have
123 | available_corpus["hawkeye_instruction"] = [
124 | available_corpus["caption_textvr"],
125 | available_corpus["caption_videochat"],
126 | available_corpus["caption_webvid"],
127 | available_corpus["caption_youcook2"],
128 | available_corpus["classification_k710"],
129 | available_corpus["classification_ssv2"],
130 | available_corpus["conversation_videochat1"],
131 | available_corpus["conversation_videochat2"],
132 | available_corpus["conversation_videochatgpt"],
133 | available_corpus["reasoning_next_qa"],
134 | available_corpus["reasoning_clevrer_qa"],
135 | available_corpus["reasoning_clevrer_mc"],
136 | available_corpus["vqa_ego_qa"],
137 | available_corpus["vqa_tgif_frame_qa"],
138 | available_corpus["vqa_tgif_transition_qa"],
139 | available_corpus["vqa_webvid_qa"],
140 | available_corpus["internvid_grounding"],
141 | available_corpus["internvid_caption"],
142 | ]
143 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import re
5 | import torch
6 | from torch import optim as optim
7 | from utils.distributed import is_main_process
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | try:
11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
12 | has_apex = True
13 | except ImportError:
14 | has_apex = False
15 |
16 |
17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
18 | named_param_tuples = []
19 | for name, param in model.named_parameters():
20 | if not param.requires_grad:
21 | continue # frozen weights
22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
23 | named_param_tuples.append([name, param, 0])
24 | elif name in no_decay_list:
25 | named_param_tuples.append([name, param, 0])
26 | else:
27 | named_param_tuples.append([name, param, weight_decay])
28 | return named_param_tuples
29 |
30 |
31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
32 | """use lr=diff_lr for modules named found in diff_lr_names,
33 | otherwise use lr=default_lr
34 |
35 | Args:
36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
37 | diff_lr_names: List(str)
38 | diff_lr: float
39 | default_lr: float
40 | Returns:
41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr])
42 | """
43 | named_param_tuples_with_lr = []
44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
45 | for name, p, wd in named_param_tuples_or_model:
46 | use_diff_lr = False
47 | for diff_name in diff_lr_names:
48 | # if diff_name in name:
49 | if re.search(diff_name, name) is not None:
50 | logger.info(f"param {name} use different_lr: {diff_lr}")
51 | use_diff_lr = True
52 | break
53 |
54 | named_param_tuples_with_lr.append(
55 | [name, p, wd, diff_lr if use_diff_lr else default_lr]
56 | )
57 |
58 | if is_main_process():
59 | for name, _, wd, diff_lr in named_param_tuples_with_lr:
60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
61 |
62 | return named_param_tuples_with_lr
63 |
64 |
65 | def create_optimizer_params_group(named_param_tuples_with_lr):
66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
67 | group = {}
68 | for name, p, wd, lr in named_param_tuples_with_lr:
69 | if wd not in group:
70 | group[wd] = {}
71 | if lr not in group[wd]:
72 | group[wd][lr] = []
73 | group[wd][lr].append(p)
74 |
75 | optimizer_params_group = []
76 | for wd, lr_groups in group.items():
77 | for lr, p in lr_groups.items():
78 | optimizer_params_group.append(dict(
79 | params=p,
80 | weight_decay=wd,
81 | lr=lr
82 | ))
83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
84 | return optimizer_params_group
85 |
86 |
87 | def create_optimizer(args, model, filter_bias_and_bn=True):
88 | opt_lower = args.opt.lower()
89 | weight_decay = args.weight_decay
90 | # check for modules that requires different lr
91 | if hasattr(args, "different_lr") and args.different_lr.enable:
92 | diff_lr_module_names = args.different_lr.module_names
93 | diff_lr = args.different_lr.lr
94 | else:
95 | diff_lr_module_names = []
96 | diff_lr = None
97 |
98 | no_decay = {}
99 | if hasattr(model, 'no_weight_decay'):
100 | no_decay = model.no_weight_decay()
101 | named_param_tuples = add_weight_decay(
102 | model, weight_decay, no_decay, filter_bias_and_bn)
103 | named_param_tuples = add_different_lr(
104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
105 | parameters = create_optimizer_params_group(named_param_tuples)
106 |
107 | if 'fused' in opt_lower:
108 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
109 |
110 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
111 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
112 | opt_args['eps'] = args.opt_eps
113 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
114 | opt_args['betas'] = args.opt_betas
115 | if hasattr(args, 'opt_args') and args.opt_args is not None:
116 | opt_args.update(args.opt_args)
117 |
118 | opt_split = opt_lower.split('_')
119 | opt_lower = opt_split[-1]
120 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
121 | opt_args.pop('eps', None)
122 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
123 | elif opt_lower == 'momentum':
124 | opt_args.pop('eps', None)
125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
126 | elif opt_lower == 'adam':
127 | optimizer = optim.Adam(parameters, **opt_args)
128 | elif opt_lower == 'adamw':
129 | optimizer = optim.AdamW(parameters, **opt_args)
130 | else:
131 | assert False and "Invalid optimizer"
132 | raise ValueError
133 | return optimizer
134 |
--------------------------------------------------------------------------------
/data_preparing/videoqa.py:
--------------------------------------------------------------------------------
1 | # convert many qa datasets to the input format
2 | import os
3 | import random
4 | import argparse
5 | import json
6 | import pandas as pd
7 | import numpy as np
8 |
9 |
10 |
11 | # nextqa, tvqa and star also have a time span label, but we omit them?
12 | def convert_nextqa_test(input_fname, output_fname, video_mapping_fname):
13 | video_mapping_dict = json.load(open(video_mapping_fname))
14 | video_mapping_dict = {key: val + '.mp4' for key, val in video_mapping_dict.items()}
15 | df = pd.read_csv(input_fname)
16 |
17 | res_list = list()
18 | for line_i, row in df.iterrows():
19 | video_fname = video_mapping_dict[str(row['video'])]
20 | res_dict = {'video': video_fname, 'question': row['question'] + '?', 'qid': row['qid'],
21 | 'candidates': [row['a%d' % i] + '.' for i in range(5)], 'answer': row['a%d' % row['answer']]}
22 | res_list.append(res_dict)
23 | json.dump(res_list, open(output_fname, 'w'))
24 |
25 |
26 | def convert_tvqa(input_fname, output_fname):
27 | tv_name_to_folder = {
28 | 'The Big Bang Theory': 'bbt_frames', 'Castle': 'castle_frames', 'How I Met You Mother': 'met_frames',
29 | "Grey's Anatomy": 'grey_frames', 'Friends': 'friends_frames', 'House M.D.': 'house_frames'
30 | }
31 |
32 | res_list = list()
33 | for line in open(input_fname):
34 | row = json.loads(line)
35 | start_sec, end_sec = row['ts'].split('-')
36 | res_dict = {'video': os.path.join(tv_name_to_folder[row['show_name']], row['vid_name']),
37 | 'question': row['q'], 'qid': row['qid'], 'candidates': [row['a%d' % i] for i in range(5)]}
38 |
39 | if not np.isnan(float(start_sec)) and not np.isnan(float(end_sec)):
40 | res_dict['start'], res_dict['end'] = float(start_sec), float(end_sec)
41 |
42 | if 'answer_idx' in row:
43 | res_dict['answer'] = row['a%d' % row['answer_idx']]
44 | res_list.append(res_dict)
45 | json.dump(res_list, open(output_fname, 'w'))
46 |
47 |
48 | def convert_star(input_fname, output_fname):
49 | res_list = list()
50 | for row in json.load(open(input_fname)):
51 | res_dict = {'video': row['video_id'] + '.mp4',
52 | 'question': row['question'], 'qid': row['question_id'],
53 | 'candidates': [c['choice'] for c in row['choices']]}
54 | if 'answer' in row:
55 | res_dict['answer'] = row['answer']
56 | res_list.append(res_dict)
57 | json.dump(res_list, open(output_fname, 'w'))
58 |
59 |
60 | def convert_star_output(input_fname, output_fname, src_fname):
61 | '''
62 | convert videochat2 output on star dataset to the submission format
63 | '''
64 | res_dict = {key: [] for key in ['Interaction', 'Sequence', 'Prediction', 'Feasibility']}
65 | src_data = json.load(open(src_fname))
66 | pred_data = [json.loads(line) for line in open(input_fname)]
67 | assert len(src_data) == len(pred_data)
68 |
69 | for example, src_example in zip(pred_data, src_data):
70 | for key in res_dict.keys():
71 | if src_example['qid'].startswith(key):
72 | if example['pred'][1] in 'ABCD':
73 | res_dict[key].append({'question_id': src_example['qid'], 'answer': 'ABCD'.index(example['pred'][1])})
74 | else:
75 | print('no choice letter found!')
76 | res_dict[key].append({'question_id': src_example['qid'], 'answer': random.choice([0, 1, 2, 3])})
77 | json.dump(res_dict, open(output_fname, 'w'))
78 |
79 |
80 | def convert_tvqa_output(input_fname, output_folder, test_src_fname, val_src_fname):
81 | res_dict = dict()
82 | src_data = [json.loads(line) for line in open(test_src_fname)]
83 | pred_data = [json.loads(line) for line in open(input_fname)]
84 | assert len(src_data) == len(pred_data)
85 |
86 | os.makedirs(output_folder, exist_ok=True)
87 | for example, src_example in zip(pred_data, src_data):
88 | pred_id = 'ABCDE'.index(example['pred'][1]) if example['pred'][1] in 'ABCDE' else random.randint(0, 4)
89 | res_dict[src_example['qid']] = pred_id
90 | json.dump(res_dict, open(os.path.join(output_folder, 'prediction_test_public.json'), 'w'))
91 |
92 | # generate a random result for the val set
93 | src_data = [json.loads(line) for line in open(val_src_fname)]
94 | res_dict = dict()
95 | for src_example in src_data:
96 | pred_id = random.randint(0, 4)
97 | res_dict[src_example['qid']] = pred_id
98 | json.dump(res_dict, open(os.path.join(output_folder, 'prediction_val.json'), 'w'))
99 |
100 | # generatethe metadata
101 | metadata = {
102 | 'model_name': '/'.join(input_fname.split('/')[-2:]), 'is_ensemble': False, 'with_ts': True, 'show_on_leaderboard': False,
103 | 'author': 'null', 'institution': 'null', 'description': 'null', 'paper_link': 'null', 'code_link': 'null'
104 | }
105 | json.dump(metadata, open(os.path.join(output_folder, 'meta.json'), 'w'))
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | parser = argparse.ArgumentParser()
111 | parser.add_argument('--func', type=str)
112 | parser.add_argument('--input', type=str)
113 | parser.add_argument('--input2', type=str)
114 | parser.add_argument('--input3', type=str)
115 | parser.add_argument('--output', type=str)
116 | args = parser.parse_args()
117 | print(args)
118 |
119 | if args.func == 'nextqa-test':
120 | convert_nextqa_test(args.input, args.output, args.input2)
121 |
122 | if args.func == 'tvqa':
123 | convert_tvqa(args.input, args.output)
124 |
125 | if args.func == 'star':
126 | convert_star(args.input, args.output)
127 |
128 | if args.func == 'star_output':
129 | convert_star_output(args.input, args.output, args.input2)
130 |
131 | if args.func == 'tvqa_output':
132 | convert_tvqa_output(args.input, args.output, args.input2, args.input3)
133 |
134 |
--------------------------------------------------------------------------------
/dataset/pt_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 | import sqlite3
5 | import random
6 | from os.path import basename
7 |
8 | import numpy as np
9 |
10 | from dataset.base_dataset import ImageVideoBaseDataset
11 | from dataset.utils import load_anno, pre_text
12 | from dataset.video_utils import VIDEO_READER_FUNCS
13 | from utils.distributed import is_main_process
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def get_anno_by_id(cur: sqlite3.Cursor, id: int):
19 | """TODO: Docstring for get_anno_by_id.
20 |
21 | Args:
22 | cur (sqlite3.Cursor): The dataset cursor.
23 | id (int): The annotation id.
24 |
25 | Returns:
26 |
27 | """
28 | pass
29 |
30 |
31 | class PTImgTrainDataset(ImageVideoBaseDataset):
32 | media_type = "image"
33 |
34 | def __init__(self, ann_file, transform, pre_text=True):
35 | super().__init__()
36 |
37 | if len(ann_file) == 3 and ann_file[2] == "video":
38 | self.media_type = "video"
39 | else:
40 | self.media_type = "image"
41 | self.label_file, self.data_root = ann_file[:2]
42 |
43 | logger.info('Load json file')
44 | with open(self.label_file, 'r') as f:
45 | self.anno = json.load(f)
46 | self.num_examples = len(self.anno)
47 |
48 | self.transform = transform
49 | self.pre_text = pre_text
50 | logger.info(f"Pre-process text: {pre_text}")
51 |
52 | def get_anno(self, index):
53 | filename = self.anno[index][self.media_type]
54 | caption = self.anno[index]["caption"]
55 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption}
56 | return anno
57 |
58 | def __len__(self):
59 | return self.num_examples
60 |
61 | def __getitem__(self, index):
62 | try:
63 | ann = self.get_anno(index)
64 | image, index = self.load_and_transform_media_data(index, ann["image"])
65 | caption = pre_text(ann["caption"], pre_text=self.pre_text)
66 | return image, caption, index
67 | except Exception as e:
68 | logger.warning(f"Caught exception {e} when loading image {ann['image']}")
69 | index = np.random.randint(0, len(self))
70 | return self.__getitem__(index)
71 |
72 |
73 | class PTVidTrainDataset(PTImgTrainDataset):
74 | media_type = "video"
75 |
76 | def __init__(
77 | self,
78 | ann_file,
79 | transform,
80 | num_frames=4,
81 | video_reader_type="decord",
82 | sample_type="rand",
83 | num_tries=3,
84 | pre_text=True
85 | ):
86 | super().__init__(ann_file, transform, pre_text=pre_text)
87 | self.num_frames = num_frames
88 | self.video_reader_type = video_reader_type
89 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
90 | self.sample_type = sample_type
91 | self.num_tries = num_tries
92 |
93 |
94 | class PTImgEvalDataset(ImageVideoBaseDataset):
95 | media_type = "image"
96 |
97 | def __init__(self, ann_file, transform, has_multi_vision_gt=False):
98 | super(PTImgEvalDataset, self).__init__()
99 | self.raw_anno_list = load_anno(ann_file)
100 | self.transform = transform
101 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth
102 |
103 | self.text = None
104 | self.image = None
105 | self.txt2img = None
106 | self.img2txt = None
107 | self.build_data()
108 |
109 | def build_data(self):
110 | self.text = []
111 | self.image = []
112 | self.txt2img = {}
113 | self.img2txt = {}
114 | if self.has_multi_vision_gt:
115 | self.build_data_multi_img_gt()
116 | else:
117 | self.build_data_multi_txt_gt()
118 | self.anno_list = [dict(image=e) for e in self.image]
119 |
120 | def build_data_multi_img_gt(self):
121 | """each text may have multiple ground_truth image, e.g., ssv2"""
122 | img_id = 0
123 | for txt_id, ann in enumerate(self.raw_anno_list):
124 | self.text.append(pre_text(ann["caption"]))
125 | self.txt2img[txt_id] = []
126 | _images = ann["image"] \
127 | if isinstance(ann["image"], list) else [ann["image"], ]
128 | for i, image in enumerate(_images):
129 | self.image.append(image)
130 | self.txt2img[txt_id].append(img_id)
131 | self.img2txt[img_id] = txt_id
132 | img_id += 1
133 |
134 | def build_data_multi_txt_gt(self):
135 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
136 | txt_id = 0
137 | for img_id, ann in enumerate(self.raw_anno_list):
138 | self.image.append(ann["image"])
139 | self.img2txt[img_id] = []
140 | _captions = ann["caption"] \
141 | if isinstance(ann["caption"], list) else [ann["caption"], ]
142 | for i, caption in enumerate(_captions):
143 | self.text.append(pre_text(caption))
144 | self.img2txt[img_id].append(txt_id)
145 | self.txt2img[txt_id] = img_id
146 | txt_id += 1
147 |
148 | def __len__(self):
149 | return len(self.anno_list)
150 |
151 | def __getitem__(self, index):
152 | ann = self.anno_list[index]
153 | image, index = self.load_and_transform_media_data(index, ann["image"])
154 | return image, index
155 |
156 |
157 | def preprocess_para_retrieval_data(anno_list):
158 | processed_anno_list = []
159 | for d in anno_list:
160 | d["caption"] = " ".join(d.pop("caption"))
161 | processed_anno_list.append(d)
162 | return processed_anno_list
163 |
164 |
165 | class PTVidEvalDataset(PTImgEvalDataset):
166 | media_type = "video"
167 |
168 | def __init__(
169 | self, ann_file, transform, num_frames=4,
170 | video_reader_type="decord", sample_type="rand", num_tries=1,
171 | is_paragraph_retrieval=False, has_multi_vision_gt=False,
172 | ):
173 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt)
174 | self.num_frames = num_frames
175 | self.video_reader_type = video_reader_type
176 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
177 | self.sample_type = sample_type
178 | self.num_tries = num_tries
179 | self.is_paragraph_retrieval = is_paragraph_retrieval
180 |
181 | if is_paragraph_retrieval:
182 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list)
183 | self.build_data()
184 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | #
HawkEye: Training Video-Text LLMs for Grounding Text in Videos
3 |
4 | [**[Paper]**](https://arxiv.org/abs/2403.10228)
5 | [**[Checkpoint]**](https://huggingface.co/wangyueqian/HawkEye)
6 | [**[Dataset]**](internvid_g/README.md)
7 |
8 |
9 | ## Updates
10 | - 2024/04/29: Update the model loading process, merged trained params of videochat2 to `hawkeye.pth`. Now only ckpts of vicuna-7b-v-0 and `hawkeye.pth` are needed to load Hawkeye.
11 |
12 | ## Introduction
13 | 
14 | Video-text Large Language Models (video-text LLMs) have shown remarkable performance in answering questions and holding conversations on simple videos.
15 | However, they perform almost the same as random on grounding text queries in long and complicated videos, having little ability to understand and reason about temporal information, which is the most fundamental difference between videos and images.
16 |
17 | We propose HawkEye, one of the first video-text LLMs that can perform temporal video grounding in a fully text-to-text manner. To collect training data that is applicable for temporal video grounding, we construct InternVid-G, a large-scale video-text corpus with segment-level captions and negative spans, with which we introduce two new time-aware training objectives to video-text LLMs. We also propose a coarse-grained method of representing segments in videos, which is more robust and easier for LLMs to learn and follow than other alternatives.
18 |
19 | ### Datasets and Models
20 | We release our HawkEye and our impl. VideoChat2 [**Model Checkpoints**](https://huggingface.co/wangyueqian/HawkEye), and [**InternVid-G Dataset**](internvid_g/README.md) at 🤗HuggingFace.
21 |
22 | ## Demo
23 | 
24 | 
25 |
26 | **Live Demo In progress**
27 |
28 | You can use `demo.ipynb` to test HawkEye on your data.
29 |
30 | ## Training
31 | ### Download model checkpoints
32 |
33 | - Create a directory `model/` for model checkpoints: `mkdir model/`
34 | - Follow [here](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat#running-usage) to prepare vicuna-7b-v0
35 | - Download the [HawkEye checkpoint](https://huggingface.co/wangyueqian/HawkEye)
36 | - (Optional) If you want to reproduce the instruction tuning process, download [umt_l16_qformer.pth](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/umt_l16_qformer.pth) and [videochat2_7b_stage2.pth](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/videochat2_7b_stage2.pth) from [VideoChat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
37 |
38 | After downloading all model checkpoints, the `model/` folder should looks like this:
39 | ```
40 | └── hawkeye.pth
41 | └── vicuna-7b-v0/
42 | └── VideoChat2/ (optional)
43 | └── umt_l16_qformer.pth
44 | └── videochat2_7b_stage2.pth
45 | ```
46 |
47 | ### Data preparation
48 |
49 | Download from [Dataset Homepage at 🤗HuggingFace](https://huggingface.co/datasets/wangyueqian/HawkEye-IT), and save in `data/HawkEye-IT/` folder. We also provide data proessing code in `data_preparing/`, you can use it for reference.
50 |
51 | Note that you also need to download the videos of each dataset from their original links, which is further explained in dataset homepage (this may take quite a while 🙃). Use soft links to link the video folder under `data/videos/`.
52 |
53 | After data preparation, the `data/` folder should looks like this:
54 | ```
55 | └── HawkEye-IT/
56 | └── image/ # inherited from VideoChat2-IT, but not used in training HawkEye
57 | └── video/
58 | └── temporal/
59 | └── internvid_grounding/, charades_sta_grounding/, anetc_grounding/
60 | └── instructions.json, questions.json, train.json
61 | └── internvid_caption/
62 | └── instructions.json, train.json
63 | └── caption/, classification/, conversation/, vqa/, reasoning/
64 | └── videos/
65 | └── internvid-g/, clevrer/, webvid/, activitynet/, tgif/,
66 | └── nextqa/, textvr/, youcook2/, kinetics/, ssv2/, charades/
67 | ```
68 | Note that `image/, caption/, classification/, conversation/, vqa/, reasoning/` folders of HawkEye-IT are identical to [VideoChat2-IT](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT).
69 |
70 | ### Run the instruction tuning process
71 | ```shell
72 | bash ./scripts/train/run_7b_stage3.sh OUTPUT_PATH
73 | ```
74 | The instruction-tuned HawkEye checkpoint will be saved in `OUTPUT_PATH/ckpt_${ckpt}.pth`, where `${ckpt}` is the number of iterations you train.
75 |
76 | Check the script to ensure the hyperparameters fit your computing device.
77 |
78 | ### Run the finetuning process
79 | We also provide the scripts to finetune on Charades-STA and ActivityNet-Captions:
80 | ```shell
81 | # IT_CKPT: the instruction-tuned HawkEye checkpoint
82 | bash ./scripts/train/charades_sta.sh OUTPUT_PATH IT_CKPT
83 | bash ./scripts/train/anetc.sh OUTPUT_PATH IT_CKPT
84 | ```
85 | Check the script to ensure the hyperparameters fit your computing device.
86 |
87 | ## Testing
88 | ### Data preparation
89 | 1. Download [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench) and save in `data/MVBench/` folder.
90 |
91 | 2. Download the annotation of other benchmarks from [Google Drive](https://drive.google.com/file/d/1WVx_UGAnCmBIp8GgCpvtxHHenpYVur5u/view?usp=sharing) and unzip to `data/test-anno/`. We also provide data proessing code in `data_preparing/`, you can use it for reference.
92 |
93 | 3. Download [TVQA videos](https://tvqa.cs.unc.edu/) and link it at `data/videos/tvqa`
94 |
95 | After downloading all benchmarks, the `data/` folder should like this:
96 | ```
97 | └── HawkEye-IT/ # instruct tuning datasets
98 | └── MVBench/
99 | └── test-anno/
100 | └── charades_sta-recursive_grounding.json, anetc-recursive_grounding.json
101 | └── nextgqa-recursive_grounding.json
102 | └── nextqa-test.json, tvqa-test.json, star-test.json
103 | └── videos/
104 | └── nextqa/, tvqa/, charades/, activitynet/, ...
105 | ```
106 |
107 | #### Test on video qa benchmarks
108 | ```
109 | bash ./scripts/test/videoqa.sh
110 | ```
111 | refer to `data_preparing/videoqa.py` to convert the model outputs to the format required by [STAR evaluation](https://eval.ai/web/challenges/challenge-page/1325/overview) and [TVQA evaluation w/ ts](https://codalab.lisn.upsaclay.fr/competitions/6978).
112 |
113 | #### Test on temporal video grounding benchmarks with recursive grounding
114 | ```
115 | bash ./scripts/test/recursive_grounding.sh
116 | ```
117 | To analyze the results of each recursive grounding step, refer to `data_preparing/check_grounding_results.ipynb`.
118 |
119 | ## Citation
120 | If you find this code useful in your research, please consider citing:
121 | ```bibtex
122 | @misc{wang2024hawkeye,
123 | title={HawkEye: Training Video-Text LLMs for Grounding Text in Videos},
124 | author={Yueqian Wang and Xiaojun Meng and Jianxin Liang and Yuxuan Wang and Qun Liu and Dongyan Zhao},
125 | year={2024},
126 | eprint={2403.10228},
127 | archivePrefix={arXiv},
128 | primaryClass={cs.CV}
129 | }
130 | ```
131 |
132 | ## Acknowledgments
133 | This project is based on [VideoChat and VideoChat2](https://github.com/OpenGVLab/Ask-Anything). Thanks for their great work!
134 |
--------------------------------------------------------------------------------
/dataset/video_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
3 | """
4 | import random
5 | import os
6 | import io
7 | import av
8 | import cv2
9 | import decord
10 | import imageio
11 | from decord import VideoReader
12 | import torch
13 | import numpy as np
14 | import math
15 | decord.bridge.set_bridge("torch")
16 |
17 | import logging
18 | logger = logging.getLogger(__name__)
19 |
20 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
21 | """
22 | Converts a present time with the given time base and start_pts offset to seconds.
23 |
24 | Returns:
25 | time_in_seconds (float): The corresponding time in seconds.
26 |
27 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
28 | """
29 | if pts == math.inf:
30 | return math.inf
31 |
32 | return int(pts - start_pts) * time_base
33 |
34 |
35 | def get_pyav_video_duration(video_reader):
36 | video_stream = video_reader.streams.video[0]
37 | video_duration = pts_to_secs(
38 | video_stream.duration,
39 | video_stream.time_base,
40 | video_stream.start_time
41 | )
42 | return float(video_duration)
43 |
44 |
45 | def get_frame_indices_by_fps():
46 | pass
47 |
48 |
49 | def get_frame_indices(num_frames, vlen, start_pos=0, end_pos=1, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
50 | assert 0 <= start_pos <= 1, "start_pos must be in [0, 1]"
51 | assert 0 <= end_pos <= 1, "end_pos must be in [0, 1]"
52 |
53 | start_frame, end_frame = round(start_pos * vlen), round(end_pos * vlen)
54 |
55 | if sample in ["rand", "middle"]: # uniform sampling
56 | acc_samples = min(num_frames, vlen)
57 | # split the video into `acc_samples` intervals, and sample from each interval.
58 | intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)
59 | ranges = []
60 | for idx, interv in enumerate(intervals[:-1]):
61 | ranges.append((interv, intervals[idx + 1] - 1))
62 | if sample == 'rand':
63 | try:
64 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
65 | except:
66 | frame_indices = np.random.permutation(vlen)[:acc_samples]
67 | frame_indices.sort()
68 | frame_indices = list(frame_indices)
69 | elif fix_start is not None:
70 | frame_indices = [x[0] + fix_start for x in ranges]
71 | elif sample == 'middle':
72 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
73 | else:
74 | raise NotImplementedError
75 |
76 | if len(frame_indices) < num_frames: # padded with last frame
77 | padded_frame_indices = [frame_indices[-1]] * num_frames
78 | padded_frame_indices[:len(frame_indices)] = frame_indices
79 | frame_indices = padded_frame_indices
80 |
81 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
82 | output_fps = float(sample[3:])
83 | duration = float(end_frame - start_frame) / input_fps
84 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
85 | frame_seconds = np.arange(start_frame + delta / 2, end_frame + delta / 2, delta)
86 | frame_indices = np.around(frame_seconds * input_fps).astype(int)
87 | frame_indices = [e for e in frame_indices if start_frame < e < end_frame]
88 | if max_num_frames > 0 and len(frame_indices) > max_num_frames:
89 | frame_indices = frame_indices[:max_num_frames]
90 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
91 | else:
92 | raise ValueError
93 | return frame_indices
94 |
95 |
96 | def read_frames_av(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None):
97 | reader = av.open(video_path)
98 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
99 | vlen = len(frames)
100 | duration_sec = get_pyav_video_duration(reader)
101 | fps = vlen / float(duration_sec)
102 |
103 | if clip is None:
104 | start_pos, end_pos = 0, 1
105 | else:
106 | video_start_sec, video_end_sec = clip
107 | start_pos, end_pos = video_start_sec / vlen, video_end_sec / vlen
108 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1))
109 |
110 | frame_indices = get_frame_indices(
111 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start,
112 | input_fps=fps, max_num_frames=max_num_frames
113 | )
114 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
115 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
116 |
117 | start_frame_index = round(start_pos * vlen)
118 | frame_indices = [i - start_frame_index for i in frame_indices]
119 | return frames, frame_indices, fps
120 |
121 |
122 | def read_frames_gif(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None):
123 | gif = imageio.get_reader(video_path)
124 | vlen = len(gif)
125 | duration_sec = vlen / fps
126 |
127 | if clip is None:
128 | start_pos, end_pos = 0, 1
129 | else:
130 | video_start_sec, video_end_sec = clip
131 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec
132 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1))
133 |
134 | frame_indices = get_frame_indices(
135 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start,
136 | max_num_frames=max_num_frames
137 | )
138 | frames = []
139 | for index, frame in enumerate(gif):
140 | # for index in frame_idxs:
141 | if index in frame_indices:
142 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
143 | frame = torch.from_numpy(frame).byte()
144 | # # (H x W x C) to (C x H x W)
145 | frame = frame.permute(2, 0, 1)
146 | frames.append(frame)
147 | frames = torch.stack(frames) # .float() / 255
148 |
149 | start_frame_index = round(start_pos * vlen)
150 | frame_indices = [i - start_frame_index for i in frame_indices]
151 | return frames, frame_indices, fps # for tgif
152 |
153 |
154 | def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None):
155 | video_reader = VideoReader(video_path, num_threads=1)
156 | vlen = len(video_reader)
157 | fps = video_reader.get_avg_fps()
158 | duration_sec = vlen / fps
159 |
160 | if clip is None:
161 | start_pos, end_pos = 0, 1
162 | else:
163 | video_start_sec, video_end_sec = clip
164 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec
165 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1))
166 |
167 | frame_indices = get_frame_indices(
168 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start,
169 | input_fps=fps, max_num_frames=max_num_frames
170 | )
171 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
172 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
173 |
174 | start_frame_index = round(start_pos * vlen)
175 | frame_indices = [i - start_frame_index for i in frame_indices]
176 | return frames, frame_indices, float(fps)
177 |
178 |
179 | def read_frames_images(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, interval_sec=None, clip=None, client=None, fps=3):
180 | frame_fnames = [os.path.join(video_path, f) for f in sorted(os.listdir(video_path))]
181 | vlen = len(frame_fnames)
182 | duration_sec = vlen / fps
183 |
184 | if clip is None:
185 | start_pos, end_pos = 0, 1
186 | else:
187 | video_start_sec, video_end_sec = clip
188 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec
189 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1))
190 |
191 | frame_indices = get_frame_indices(
192 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start,
193 | input_fps=fps, max_num_frames=max_num_frames
194 | )
195 | selected_fnames = [frame_fnames[i] for i in frame_indices]
196 | frames = np.stack([cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB) for fname in selected_fnames])
197 | # (T x H x W x C) to (T x C x H x W)
198 | frames = frames.transpose(0, 3, 1, 2)
199 | frames = torch.from_numpy(frames).to(torch.uint8)
200 |
201 | start_frame_index = round(start_pos * vlen)
202 | frame_indices = [i - start_frame_index for i in frame_indices]
203 | return frames, frame_indices, float(fps)
204 |
205 |
206 | VIDEO_READER_FUNCS = {
207 | 'av': read_frames_av,
208 | 'decord': read_frames_decord,
209 | 'gif': read_frames_gif,
210 | 'frames': read_frames_images,
211 | }
212 |
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | from utils.distributed import is_main_process, get_rank, get_world_size
2 | import logging
3 | import torch.distributed as dist
4 | import torch
5 | import io
6 | import os
7 | import json
8 | import re
9 | import numpy as np
10 | from os.path import join
11 | from tqdm import trange
12 | from PIL import Image
13 | from PIL import ImageFile
14 | from torchvision.transforms import PILToTensor
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 | Image.MAX_IMAGE_PIXELS = None
17 |
18 |
19 | def load_image_from_path(image_path, client):
20 | if image_path.startswith('s3') or image_path.startswith('p2'):
21 | value = client.Get(image_path)
22 | img_bytes = np.frombuffer(value, dtype=np.uint8)
23 | buff = io.BytesIO(img_bytes)
24 | image = Image.open(buff).convert('RGB')
25 | else:
26 | image = Image.open(image_path).convert('RGB') # PIL Image
27 | image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8
28 | return image
29 |
30 |
31 | def load_anno(ann_file_list):
32 | """[summary]
33 |
34 | Args:
35 | ann_file_list (List[List[str, str]] or List[str, str]):
36 | the latter will be automatically converted to the former.
37 | Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
38 | which specifies the data type, video or image
39 |
40 | Returns:
41 | List(dict): each dict is {
42 | image: str or List[str], # image_path,
43 | caption: str or List[str] # caption text string
44 | }
45 | """
46 | if isinstance(ann_file_list[0], str):
47 | ann_file_list = [ann_file_list]
48 |
49 | ann = []
50 | for d in ann_file_list:
51 | data_root = d[1]
52 | fp = d[0]
53 | is_video = len(d) == 3 and d[2] == "video"
54 | cur_ann = json.load(open(fp, "r"))
55 | iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
56 | if is_main_process() else range(len(cur_ann))
57 | for idx in iterator:
58 | key = "video" if is_video else "image"
59 | # unified to have the same key for data path
60 | if isinstance(cur_ann[idx][key], str):
61 | cur_ann[idx]["image"] = join(data_root, cur_ann[idx][key])
62 | else: # list
63 | cur_ann[idx]["image"] = [join(data_root, e) for e in cur_ann[idx][key]]
64 | ann += cur_ann
65 | return ann
66 |
67 |
68 | def pre_text(text, max_l=None, pre_text=True):
69 | if pre_text:
70 | text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
71 | text = text.replace('-', ' ').replace('/', ' ').replace('', 'person')
72 |
73 | text = re.sub(r"\s{2,}", ' ', text)
74 | text = text.rstrip('\n').strip(' ')
75 |
76 | if max_l: # truncate
77 | words = text.split(' ')
78 | if len(words) > max_l:
79 | text = ' '.join(words[:max_l])
80 | else:
81 | pass
82 | return text
83 |
84 |
85 | logger = logging.getLogger(__name__)
86 |
87 |
88 | def collect_result(result, result_dir, filename, is_json=True, is_list=True):
89 | if is_json:
90 | result_file = os.path.join(
91 | result_dir, '%s_rank%d.json' % (filename, get_rank()))
92 | final_result_file = os.path.join(result_dir, '%s.json' % filename)
93 | json.dump(result, open(result_file, 'w'))
94 | else:
95 | result_file = os.path.join(
96 | result_dir, '%s_rank%d.pth' % (filename, get_rank()))
97 | final_result_file = os.path.join(result_dir, '%s.pth' % filename)
98 | torch.save(result, result_file)
99 |
100 | dist.barrier()
101 |
102 | result = None
103 | if is_main_process():
104 | # combine results from all processes
105 | if is_list:
106 | result = []
107 | else:
108 | result = {}
109 | for rank in range(get_world_size()):
110 | if is_json:
111 | result_file = os.path.join(
112 | result_dir, '%s_rank%d.json' % (filename, rank))
113 | res = json.load(open(result_file, 'r'))
114 | else:
115 | result_file = os.path.join(
116 | result_dir, '%s_rank%d.pth' % (filename, rank))
117 | res = torch.load(result_file)
118 | if is_list:
119 | result += res
120 | else:
121 | result.update(res)
122 |
123 | return result
124 |
125 |
126 | def sync_save_result(result, result_dir, filename, is_json=True, is_list=True):
127 | """gather results from multiple GPUs"""
128 | if is_json:
129 | result_file = os.path.join(
130 | result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank()))
131 | final_result_file = os.path.join(result_dir, '%s.json' % filename)
132 | os.makedirs(os.path.dirname(result_file), exist_ok=True)
133 | json.dump(result, open(result_file, 'w'))
134 | else:
135 | result_file = os.path.join(
136 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank()))
137 | os.makedirs(os.path.dirname(result_file), exist_ok=True)
138 | final_result_file = os.path.join(result_dir, '%s.pth' % filename)
139 | torch.save(result, result_file)
140 |
141 | dist.barrier()
142 |
143 | if is_main_process():
144 | # combine results from all processes
145 | if is_list:
146 | result = []
147 | else:
148 | result = {}
149 | for rank in range(get_world_size()):
150 | if is_json:
151 | result_file = os.path.join(
152 | result_dir, "dist_res", '%s_rank%d.json' % (filename, rank))
153 | res = json.load(open(result_file, 'r'))
154 | else:
155 | result_file = os.path.join(
156 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank))
157 | res = torch.load(result_file)
158 | if is_list:
159 | result += res
160 | else:
161 | result.update(res)
162 | if is_json:
163 | json.dump(result, open(final_result_file, 'w'))
164 | else:
165 | torch.save(result, final_result_file)
166 |
167 | logger.info('result file saved to %s' % final_result_file)
168 | dist.barrier()
169 | return final_result_file, result
170 |
171 |
172 | def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
173 | """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
174 | into a (n+1)-d array, only allow the first dim has variable lengths.
175 | Args:
176 | sequences: list(n-d tensor or list)
177 | dtype: np.dtype or torch.dtype
178 | device:
179 | fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
180 | return will be of shape [len(sequences), fixed_length, ...]
181 | Returns:
182 | padded_seqs: ((n+1)-d tensor) padded with zeros
183 | mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
184 | 1 indicate valid, 0 otherwise
185 | Examples:
186 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
187 | >>> pad_sequences_1d(test_data_list, dtype=torch.long)
188 | >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
189 | >>> pad_sequences_1d(test_data_3d, dtype=torch.float)
190 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
191 | >>> pad_sequences_1d(test_data_list, dtype=np.float32)
192 | >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
193 | >>> pad_sequences_1d(test_data_3d, dtype=np.float32)
194 | """
195 | if isinstance(sequences[0], list):
196 | if "torch" in str(dtype):
197 | sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
198 | else:
199 | sequences = [np.asarray(s, dtype=dtype) for s in sequences]
200 |
201 | extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements
202 | lengths = [len(seq) for seq in sequences]
203 | if fixed_length is not None:
204 | max_length = fixed_length
205 | else:
206 | max_length = max(lengths)
207 | if isinstance(sequences[0], torch.Tensor):
208 | assert "torch" in str(dtype), "dtype and input type does not match"
209 | padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
210 | mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
211 | else: # np
212 | assert "numpy" in str(dtype), "dtype and input type does not match"
213 | padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
214 | mask = np.zeros((len(sequences), max_length), dtype=np.float32)
215 |
216 | for idx, seq in enumerate(sequences):
217 | end = lengths[idx]
218 | padded_seqs[idx, :end] = seq
219 | mask[idx, :end] = 1
220 | return padded_seqs, mask # , lengths
221 |
222 |
223 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import ast
5 | import json
6 | import os
7 | import os.path as osp
8 | import re
9 | import shutil
10 | import sys
11 | import tempfile
12 | from copy import deepcopy
13 | from importlib import import_module
14 |
15 | import yaml
16 |
17 | from .easydict import EasyDict
18 |
19 | __all__ = ["Config", "pretty_text"]
20 |
21 |
22 | BASE_KEY = "_base_"
23 | # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
24 | BASE_CONFIG = {}
25 |
26 | cfg = None
27 |
28 |
29 | class Config(object):
30 | """config"""
31 |
32 | @classmethod
33 | def pretty_text(cls, cfg: dict, indent=2) -> str:
34 | """format dict to a string
35 |
36 | Args:
37 | cfg (EasyDict): the params.
38 |
39 | Returns: The string to display.
40 |
41 | """
42 | msg = "{\n"
43 | for i, (k, v) in enumerate(cfg.items()):
44 | if isinstance(v, dict):
45 | v = cls.pretty_text(v, indent + 4)
46 | spaces = " " * indent
47 | msg += spaces + "{}: {}".format(k, v)
48 | if i == len(cfg) - 1:
49 | msg += " }"
50 | else:
51 | msg += "\n"
52 | return msg
53 |
54 | @classmethod
55 | def dump(cls, cfg, savepath=None):
56 | """dump cfg to `json` file.
57 |
58 | Args:
59 | cfg (dict): The dict to dump.
60 | savepath (str): The filepath to save the dumped dict.
61 |
62 | Returns: TODO
63 |
64 | """
65 | if savepath is None:
66 | savepath = osp.join(cfg.WORKSPACE, "config.json")
67 | json.dump(cfg, open(savepath, "w"), indent=2)
68 |
69 | @classmethod
70 | def get_config(cls, default_config: dict = None):
71 | """get a `Config` instance.
72 |
73 | Args:
74 | default_config (dict): The default config. `default_config` will be overrided
75 | by config file `--cfg`, `--cfg` will be overrided by commandline args.
76 |
77 | Returns: an EasyDict.
78 | """
79 | global cfg
80 | if cfg is not None:
81 | return cfg
82 |
83 | # define arg parser.
84 | parser = argparse.ArgumentParser()
85 | # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
86 | parser.add_argument(
87 | "config_file", help="the configuration file to load. support: .yaml, .json, .py"
88 | )
89 | parser.add_argument(
90 | "opts",
91 | default=None,
92 | nargs="*",
93 | help="overrided configs. List. Format: 'key1 name1 key2 name2'",
94 | )
95 | args = parser.parse_args()
96 |
97 | cfg = EasyDict(BASE_CONFIG)
98 | if osp.isfile(args.config_file):
99 | cfg_from_file = cls.from_file(args.config_file)
100 | cfg = merge_a_into_b(cfg_from_file, cfg)
101 | cfg = cls.merge_list(cfg, args.opts)
102 | cfg = eval_dict_leaf(cfg)
103 |
104 | # update some keys to make them show at the last
105 | for k in BASE_CONFIG:
106 | cfg[k] = cfg.pop(k)
107 | return cfg
108 |
109 | @classmethod
110 | def from_file(cls, filepath: str) -> EasyDict:
111 | """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
112 |
113 | Args:
114 | filepath (str): The config file path.
115 |
116 | Returns: TODO
117 |
118 | """
119 | filepath = osp.abspath(osp.expanduser(filepath))
120 | if not osp.isfile(filepath):
121 | raise IOError(f"File does not exist: {filepath}")
122 | if filepath.endswith(".py"):
123 | with tempfile.TemporaryDirectory() as temp_config_dir:
124 |
125 | shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
126 | sys.path.insert(0, temp_config_dir)
127 | mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
128 | # mod = import_module(temp_module_name)
129 | sys.path.pop(0)
130 | cfg_dict = {
131 | name: value
132 | for name, value in mod.__dict__.items()
133 | if not name.startswith("__")
134 | }
135 | for k in list(sys.modules.keys()):
136 | if "tmp_config" in k:
137 | del sys.modules[k]
138 | elif filepath.endswith((".yml", ".yaml")):
139 | cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
140 | elif filepath.endswith(".json"):
141 | cfg_dict = json.load(open(filepath, "r"))
142 | else:
143 | raise IOError("Only py/yml/yaml/json type are supported now!")
144 |
145 | cfg_text = filepath + "\n"
146 | with open(filepath, "r") as f:
147 | cfg_text += f.read()
148 |
149 | if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
150 | cfg_dir = osp.dirname(filepath)
151 | base_filename = cfg_dict.pop(BASE_KEY)
152 | base_filename = (
153 | base_filename if isinstance(base_filename, list) else [base_filename]
154 | )
155 |
156 | cfg_dict_list = list()
157 | for f in base_filename:
158 | _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
159 | cfg_dict_list.append(_cfg_dict)
160 |
161 | base_cfg_dict = dict()
162 | for c in cfg_dict_list:
163 | if len(base_cfg_dict.keys() & c.keys()) > 0:
164 | raise KeyError("Duplicate key is not allowed among bases")
165 | base_cfg_dict.update(c)
166 |
167 | cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
168 |
169 | return EasyDict(cfg_dict)
170 |
171 | @classmethod
172 | def merge_list(cls, cfg, opts: list):
173 | """merge commandline opts.
174 |
175 | Args:
176 | cfg: (dict): The config to be merged.
177 | opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
178 | The keys can be nested. For example, ["a.b", v] will be considered
179 | as `dict(a=dict(b=v))`.
180 |
181 | Returns: dict.
182 |
183 | """
184 | assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
185 | for i in range(0, len(opts), 2):
186 | full_k, v = opts[i], opts[i + 1]
187 | keys = full_k.split(".")
188 | sub_d = cfg
189 | for i, k in enumerate(keys):
190 | if not hasattr(sub_d, k):
191 | raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
192 | if i != len(keys) - 1:
193 | sub_d = sub_d[k]
194 | else:
195 | sub_d[k] = v
196 | return cfg
197 |
198 |
199 | def merge_a_into_b(a, b, inplace=False):
200 | """The values in a will override values in b.
201 |
202 | Args:
203 | a (dict): source dict.
204 | b (dict): target dict.
205 |
206 | Returns: dict. recursively merge dict a into dict b.
207 |
208 | """
209 | if not inplace:
210 | b = deepcopy(b)
211 | for key in a:
212 | if key in b:
213 | if isinstance(a[key], dict) and isinstance(b[key], dict):
214 | b[key] = merge_a_into_b(a[key], b[key], inplace=True)
215 | else:
216 | b[key] = a[key]
217 | else:
218 | b[key] = a[key]
219 | return b
220 |
221 |
222 | def eval_dict_leaf(d, orig_dict=None):
223 | """eval values of dict leaf.
224 |
225 | Args:
226 | d (dict): The dict to eval.
227 |
228 | Returns: dict.
229 |
230 | """
231 | if orig_dict is None:
232 | orig_dict = d
233 | for k, v in d.items():
234 | if not isinstance(v, dict):
235 | d[k] = eval_string(v, orig_dict)
236 | else:
237 | eval_dict_leaf(v, orig_dict)
238 | return d
239 |
240 |
241 | def eval_string(string, d):
242 | """automatically evaluate string to corresponding types.
243 |
244 | For example:
245 | not a string -> return the original input
246 | '0' -> 0
247 | '0.2' -> 0.2
248 | '[0, 1, 2]' -> [0,1,2]
249 | 'eval(1+2)' -> 3
250 | 'eval(range(5))' -> [0,1,2,3,4]
251 | '${a}' -> d.a
252 |
253 |
254 |
255 | Args:
256 | string (str): The value to evaluate.
257 | d (dict): The
258 |
259 | Returns: the corresponding type
260 |
261 | """
262 | if not isinstance(string, str):
263 | return string
264 | # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
265 | # return eval(string)
266 | if string[0:5] == "eval(":
267 | return eval(string[5:-1])
268 |
269 | s0 = string
270 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
271 | if s1 != s0:
272 | while s1 != s0:
273 | s0 = s1
274 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
275 | return eval(s1)
276 |
277 | try:
278 | v = ast.literal_eval(string)
279 | except:
280 | v = string
281 | return v
282 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import functools
5 | import logging
6 | import os
7 | import sys
8 | import time
9 | import wandb
10 | from typing import Any, Dict, Union
11 |
12 | import torch
13 | from .distributed import get_rank, is_main_process
14 | from termcolor import colored
15 |
16 |
17 | def log_dict_to_wandb(log_dict, step, prefix=""):
18 | """include a separator `/` at the end of `prefix`"""
19 | if not is_main_process():
20 | return
21 |
22 | log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
23 | wandb.log(log_dict, step)
24 |
25 |
26 | def log_dict_to_tensorboard(log_dict, step, prefix="", writer=None):
27 | for key, val in log_dict.items():
28 | writer.add_scalar(prefix + key, val, step)
29 |
30 |
31 | def setup_wandb(config):
32 | if not (config.wandb.enable and is_main_process()):
33 | return
34 |
35 | run = wandb.init(
36 | config=config,
37 | project=config.wandb.project,
38 | entity=config.wandb.entity,
39 | name=os.path.basename(config.output_dir),
40 | reinit=True
41 | )
42 | return run
43 |
44 |
45 | def setup_output_folder(save_dir: str, folder_only: bool = False):
46 | """Sets up and returns the output file where the logs will be placed
47 | based on the configuration passed. Usually "save_dir/logs/log_.txt".
48 | If env.log_dir is passed, logs will be directly saved in this folder.
49 | Args:
50 | folder_only (bool, optional): If folder should be returned and not the file.
51 | Defaults to False.
52 | Returns:
53 | str: folder or file path depending on folder_only flag
54 | """
55 | log_filename = "train_"
56 | log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
57 | log_filename += ".log"
58 |
59 | log_folder = os.path.join(save_dir, "logs")
60 |
61 | if not os.path.exists(log_folder):
62 | os.path.mkdirs(log_folder)
63 |
64 | if folder_only:
65 | return log_folder
66 |
67 | log_filename = os.path.join(log_folder, log_filename)
68 |
69 | return log_filename
70 |
71 |
72 | def setup_logger(
73 | output: str = None,
74 | color: bool = True,
75 | name: str = "mmf",
76 | disable: bool = False,
77 | clear_handlers=True,
78 | *args,
79 | **kwargs,
80 | ):
81 | """
82 | Initialize the MMF logger and set its verbosity level to "INFO".
83 | Outside libraries shouldn't call this in case they have set there
84 | own logging handlers and setup. If they do, and don't want to
85 | clear handlers, pass clear_handlers options.
86 | The initial version of this function was taken from D2 and adapted
87 | for MMF.
88 | Args:
89 | output (str): a file name or a directory to save log.
90 | If ends with ".txt" or ".log", assumed to be a file name.
91 | Default: Saved to file
92 | color (bool): If false, won't log colored logs. Default: true
93 | name (str): the root module name of this logger. Defaults to "mmf".
94 | disable: do not use
95 | clear_handlers (bool): If false, won't clear existing handlers.
96 | Returns:
97 | logging.Logger: a logger
98 | """
99 | if disable:
100 | return None
101 | logger = logging.getLogger(name)
102 | logger.propagate = False
103 |
104 | logging.captureWarnings(True)
105 | warnings_logger = logging.getLogger("py.warnings")
106 |
107 | plain_formatter = logging.Formatter(
108 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
109 | datefmt="%Y-%m-%dT%H:%M:%S",
110 | )
111 |
112 | distributed_rank = get_rank()
113 | handlers = []
114 |
115 | logging_level = logging.INFO
116 | # logging_level = logging.DEBUG
117 |
118 | if distributed_rank == 0:
119 | logger.setLevel(logging_level)
120 | ch = logging.StreamHandler(stream=sys.stdout)
121 | ch.setLevel(logging_level)
122 | if color:
123 | formatter = ColorfulFormatter(
124 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
125 | datefmt="%Y-%m-%dT%H:%M:%S",
126 | )
127 | else:
128 | formatter = plain_formatter
129 | ch.setFormatter(formatter)
130 | logger.addHandler(ch)
131 | warnings_logger.addHandler(ch)
132 | handlers.append(ch)
133 |
134 | # file logging: all workers
135 | if output is None:
136 | output = setup_output_folder()
137 |
138 | if output is not None:
139 | if output.endswith(".txt") or output.endswith(".log"):
140 | filename = output
141 | else:
142 | filename = os.path.join(output, "train.log")
143 | if distributed_rank > 0:
144 | filename = filename + f".rank{distributed_rank}"
145 | os.makedirs(os.path.dirname(filename), exist_ok=True)
146 |
147 | fh = logging.StreamHandler(_cached_log_stream(filename))
148 | fh.setLevel(logging_level)
149 | fh.setFormatter(plain_formatter)
150 | logger.addHandler(fh)
151 | warnings_logger.addHandler(fh)
152 | handlers.append(fh)
153 |
154 | # Slurm/FB output, only log the main process
155 | # save_dir = get_mmf_env(key="save_dir")
156 | if "train.log" not in filename and distributed_rank == 0:
157 | filename = os.path.join(output, "train.log")
158 | sh = logging.StreamHandler(_cached_log_stream(filename))
159 | sh.setLevel(logging_level)
160 | sh.setFormatter(plain_formatter)
161 | logger.addHandler(sh)
162 | warnings_logger.addHandler(sh)
163 | handlers.append(sh)
164 |
165 | logger.info(f"Logging to: {filename}")
166 |
167 | # Remove existing handlers to add MMF specific handlers
168 | if clear_handlers:
169 | for handler in logging.root.handlers[:]:
170 | logging.root.removeHandler(handler)
171 | # Now, add our handlers.
172 | logging.basicConfig(level=logging_level, handlers=handlers)
173 |
174 | return logger
175 |
176 |
177 | def setup_very_basic_config(color=True):
178 | plain_formatter = logging.Formatter(
179 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
180 | datefmt="%Y-%m-%dT%H:%M:%S",
181 | )
182 | ch = logging.StreamHandler(stream=sys.stdout)
183 | ch.setLevel(logging.INFO)
184 | if color:
185 | formatter = ColorfulFormatter(
186 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
187 | datefmt="%Y-%m-%dT%H:%M:%S",
188 | )
189 | else:
190 | formatter = plain_formatter
191 | ch.setFormatter(formatter)
192 | # Setup a minimal configuration for logging in case something tries to
193 | # log a message even before logging is setup by MMF.
194 | logging.basicConfig(level=logging.INFO, handlers=[ch])
195 |
196 |
197 | # cache the opened file object, so that different calls to `setup_logger`
198 | # with the same file name can safely write to the same file.
199 | @functools.lru_cache(maxsize=None)
200 | def _cached_log_stream(filename):
201 | return open(filename, "a")
202 |
203 |
204 | # ColorfulFormatter is adopted from Detectron2 and adapted for MMF
205 | class ColorfulFormatter(logging.Formatter):
206 | def __init__(self, *args, **kwargs):
207 | super().__init__(*args, **kwargs)
208 |
209 | def formatMessage(self, record):
210 | log = super().formatMessage(record)
211 | if record.levelno == logging.WARNING:
212 | prefix = colored("WARNING", "red", attrs=["blink"])
213 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
214 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
215 | else:
216 | return log
217 | return prefix + " " + log
218 |
219 |
220 | class TensorboardLogger:
221 | def __init__(self, log_folder="./logs", iteration=0):
222 | # This would handle warning of missing tensorboard
223 | from torch.utils.tensorboard import SummaryWriter
224 |
225 | self.summary_writer = None
226 | self._is_master = is_main_process()
227 | # self.timer = Timer()
228 | self.log_folder = log_folder
229 |
230 | if self._is_master:
231 | # current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
232 | current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
233 | # self.timer.get_time_hhmmss(None, format=self.time_format)
234 | tensorboard_folder = os.path.join(
235 | self.log_folder, f"tensorboard_{current_time}"
236 | )
237 | self.summary_writer = SummaryWriter(tensorboard_folder)
238 |
239 | def __del__(self):
240 | if getattr(self, "summary_writer", None) is not None:
241 | self.summary_writer.close()
242 |
243 | def _should_log_tensorboard(self):
244 | if self.summary_writer is None or not self._is_master:
245 | return False
246 | else:
247 | return True
248 |
249 | def add_scalar(self, key, value, iteration):
250 | if not self._should_log_tensorboard():
251 | return
252 |
253 | self.summary_writer.add_scalar(key, value, iteration)
254 |
255 | def add_scalars(self, scalar_dict, iteration):
256 | if not self._should_log_tensorboard():
257 | return
258 |
259 | for key, val in scalar_dict.items():
260 | self.summary_writer.add_scalar(key, val, iteration)
261 |
262 | def add_histogram_for_model(self, model, iteration):
263 | if not self._should_log_tensorboard():
264 | return
265 |
266 | for name, param in model.named_parameters():
267 | np_param = param.clone().cpu().data.numpy()
268 | self.summary_writer.add_histogram(name, np_param, iteration)
269 |
--------------------------------------------------------------------------------
/utils/basic_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import json
5 | import logging
6 | import random
7 | import time
8 | from collections import defaultdict, deque
9 | import datetime
10 | from pathlib import Path
11 | from typing import List, Union
12 |
13 | import torch
14 | import torch.distributed as dist
15 | from .distributed import is_dist_avail_and_initialized
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class SmoothedValue(object):
22 | """Track a series of values and provide access to smoothed values over a
23 | window or the global series average.
24 | """
25 |
26 | def __init__(self, window=20, fmt=None):
27 | if fmt is None:
28 | fmt = "{median:.4f} ({global_avg:.4f})"
29 | self.deque = deque(maxlen=window)
30 | self.total = 0.0
31 | self.count = 0
32 | self.fmt = fmt
33 |
34 | def update(self, value, n=1):
35 | self.deque.append(value)
36 | self.count += n
37 | self.total += value * n
38 |
39 | def synchronize_between_processes(self):
40 | """
41 | Warning: does not synchronize the deque!
42 | """
43 | if not is_dist_avail_and_initialized():
44 | return
45 | t = torch.tensor([self.count, self.total],
46 | dtype=torch.float64, device='cuda')
47 | dist.barrier()
48 | dist.all_reduce(t)
49 | t = t.tolist()
50 | self.count = int(t[0])
51 | self.total = t[1]
52 |
53 | @property
54 | def median(self):
55 | d = torch.tensor(list(self.deque))
56 | return d.median().item()
57 |
58 | @property
59 | def avg(self):
60 | d = torch.tensor(list(self.deque), dtype=torch.float32)
61 | return d.mean().item()
62 |
63 | @property
64 | def global_avg(self):
65 | return self.total / self.count
66 |
67 | @property
68 | def max(self):
69 | return max(self.deque)
70 |
71 | @property
72 | def value(self):
73 | return self.deque[-1]
74 |
75 | def __str__(self):
76 | return self.fmt.format(
77 | median=self.median,
78 | avg=self.avg,
79 | global_avg=self.global_avg,
80 | max=self.max,
81 | value=self.value)
82 |
83 |
84 | class MetricLogger(object):
85 | def __init__(self, delimiter="\t"):
86 | self.meters = defaultdict(SmoothedValue)
87 | self.delimiter = delimiter
88 |
89 | def update(self, **kwargs):
90 | for k, v in kwargs.items():
91 | if isinstance(v, torch.Tensor):
92 | v = v.item()
93 | assert isinstance(v, (float, int))
94 | self.meters[k].update(v)
95 |
96 | def __getattr__(self, attr):
97 | if attr in self.meters:
98 | return self.meters[attr]
99 | if attr in self.__dict__:
100 | return self.__dict__[attr]
101 | raise AttributeError("'{}' object has no attribute '{}'".format(
102 | type(self).__name__, attr))
103 |
104 | def __str__(self):
105 | loss_str = []
106 | for name, meter in self.meters.items():
107 | if meter.count == 0: # skip empty meter
108 | loss_str.append(
109 | "{}: {}".format(name, "No data")
110 | )
111 | else:
112 | loss_str.append(
113 | "{}: {}".format(name, str(meter))
114 | )
115 | return self.delimiter.join(loss_str)
116 |
117 | def global_avg(self):
118 | loss_str = []
119 | for name, meter in self.meters.items():
120 | if meter.count == 0:
121 | loss_str.append(
122 | "{}: {}".format(name, "No data")
123 | )
124 | else:
125 | loss_str.append(
126 | "{}: {:.4f}".format(name, meter.global_avg)
127 | )
128 | return self.delimiter.join(loss_str)
129 |
130 | def get_global_avg_dict(self, prefix=""):
131 | """include a separator (e.g., `/`, or "_") at the end of `prefix`"""
132 | d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
133 | return d
134 |
135 | def get_avg_dict(self, prefix=""):
136 | """same as above, but only get the average value in the window, not the epoch"""
137 | d = {f"{prefix}{k}": m.avg if m.count > 0 else 0. for k, m in self.meters.items()}
138 | return d
139 |
140 | def synchronize_between_processes(self):
141 | for meter in self.meters.values():
142 | meter.synchronize_between_processes()
143 |
144 | def add_meter(self, name, meter):
145 | self.meters[name] = meter
146 |
147 | def log_every(self, iterable, log_freq, header=None):
148 | i = 0
149 | if not header:
150 | header = ''
151 | start_time = time.time()
152 | end = time.time()
153 | iter_time = SmoothedValue(fmt='{avg:.4f}')
154 | data_time = SmoothedValue(fmt='{avg:.4f}')
155 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
156 | log_msg = [
157 | header,
158 | '[{0' + space_fmt + '}/{1}]',
159 | 'eta: {eta}',
160 | '{meters}',
161 | 'time: {time}',
162 | 'data: {data}'
163 | ]
164 | if torch.cuda.is_available():
165 | log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
166 | log_msg = self.delimiter.join(log_msg)
167 | MB = 1024.0 * 1024.0
168 | for obj in iterable:
169 | data_time.update(time.time() - end)
170 | yield obj
171 | iter_time.update(time.time() - end)
172 | if i % log_freq == 0 or i == len(iterable) - 1:
173 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
174 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
175 | if torch.cuda.is_available():
176 | logger.info(log_msg.format(
177 | i, len(iterable), eta=eta_string,
178 | meters=str(self),
179 | time=str(iter_time), data=str(data_time),
180 | memory=torch.cuda.max_memory_allocated() / MB,
181 | res_mem=torch.cuda.max_memory_reserved() / MB,
182 | ))
183 | else:
184 | logger.info(log_msg.format(
185 | i, len(iterable), eta=eta_string,
186 | meters=str(self),
187 | time=str(iter_time), data=str(data_time)))
188 | i += 1
189 | end = time.time()
190 | total_time = time.time() - start_time
191 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
192 | logger.info('{} Total time: {} ({:.4f} s / it)'.format(
193 | header, total_time_str, total_time / len(iterable)))
194 |
195 |
196 | class AttrDict(dict):
197 | def __init__(self, *args, **kwargs):
198 | super(AttrDict, self).__init__(*args, **kwargs)
199 | self.__dict__ = self
200 |
201 |
202 | def compute_acc(logits, label, reduction='mean'):
203 | ret = (torch.argmax(logits, dim=1) == label).float()
204 | if reduction == 'none':
205 | return ret.detach()
206 | elif reduction == 'mean':
207 | return ret.mean().item()
208 |
209 |
210 | def compute_n_params(model, return_str=True):
211 | tot = 0
212 | for p in model.parameters():
213 | w = 1
214 | for x in p.shape:
215 | w *= x
216 | tot += w
217 | if return_str:
218 | if tot >= 1e6:
219 | return '{:.1f}M'.format(tot / 1e6)
220 | else:
221 | return '{:.1f}K'.format(tot / 1e3)
222 | else:
223 | return tot
224 |
225 |
226 | def setup_seed(seed):
227 | torch.manual_seed(seed)
228 | np.random.seed(seed)
229 | random.seed(seed)
230 |
231 |
232 | def remove_files_if_exist(file_paths):
233 | for fp in file_paths:
234 | if os.path.isfile(fp):
235 | os.remove(fp)
236 |
237 |
238 | def save_json(data, filename, save_pretty=False, sort_keys=False):
239 | with open(filename, "w") as f:
240 | if save_pretty:
241 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
242 | else:
243 | json.dump(data, f)
244 |
245 |
246 | def load_json(filename):
247 | with open(filename, "r") as f:
248 | return json.load(f)
249 |
250 |
251 | def flat_list_of_lists(l):
252 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
253 | return [item for sublist in l for item in sublist]
254 |
255 |
256 | def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
257 | """
258 | Args:
259 | root: path to the directory to start search files
260 | suffix: any str as suffix, or can match multiple such strings
261 | when input is List[str].
262 | Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
263 | Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
264 | """
265 | if isinstance(suffix, str):
266 | suffix = [suffix, ]
267 | filepaths = flat_list_of_lists(
268 | [list(Path(root).rglob(f"*{e}")) for e in suffix])
269 | return filepaths
270 |
271 |
272 | def match_key_and_shape(state_dict1, state_dict2):
273 | keys1 = set(state_dict1.keys())
274 | keys2 = set(state_dict2.keys())
275 | print(f"keys1 - keys2: {keys1 - keys2}")
276 | print(f"keys2 - keys1: {keys2 - keys1}")
277 |
278 | mismatch = 0
279 | for k in list(keys1):
280 | if state_dict1[k].shape != state_dict2[k].shape:
281 | print(
282 | f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
283 | mismatch += 1
284 | print(f"mismatch {mismatch}")
285 |
286 |
287 | def merge_dicts(list_dicts):
288 | merged_dict = list_dicts[0].copy()
289 | for i in range(1, len(list_dicts)):
290 | merged_dict.update(list_dicts[i])
291 | return merged_dict
292 |
--------------------------------------------------------------------------------
/internvid_g/code/ground_data_construction.py:
--------------------------------------------------------------------------------
1 | # construct grounding data from caption, scene, and scene sim data
2 |
3 | import os
4 | import json
5 | import subprocess
6 | import argparse
7 |
8 | from tqdm import tqdm
9 | import numpy as np
10 |
11 |
12 | def parse_sec(sec_str):
13 | """
14 | Parse a string of the form '00:00:00.000' into a float
15 | """
16 | sec_str, ms_str = sec_str.split('.')
17 | h, m, s = sec_str.split(':')
18 | res = float(h) * 3600 + float(m) * 60 + float(s) + float(ms_str) / 1000
19 | return round(res, 3)
20 |
21 |
22 | def find_scene_id_start(value, arr):
23 | # 大于等于value的最小的数的下标
24 | low, high = 0, len(arr) - 1
25 | result = 0
26 | while low <= high:
27 | mid = (low + high) // 2
28 | if arr[mid] >= value:
29 | result = mid # 更新结果为当前位置
30 | high = mid - 1 # 缩小搜索范围到左半部分
31 | else:
32 | low = mid + 1 # 扩大搜索范围到右半部分
33 | return result
34 |
35 |
36 | def find_scene_id_end(value, arr):
37 | # 小于等于value的最大的数的下标
38 | low, high = 0, len(arr) - 1
39 | result = len(arr) - 1
40 | while low <= high:
41 | mid = (low + high) // 2
42 | if arr[mid] <= value:
43 | result = mid # 更新结果为当前位置
44 | low = mid + 1 # 缩小搜索范围到左半部分
45 | else:
46 | high = mid - 1 # 扩大搜索范围到右半部分
47 | return result
48 |
49 |
50 | def get_video_length(video_fname):
51 | result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_fname],
52 | stdout=subprocess.PIPE,
53 | stderr=subprocess.STDOUT)
54 | return float(result.stdout)
55 |
56 |
57 | def reformat_caption_data(caption_data):
58 | '''
59 | ret: {
60 | video_id: [{'start_sec': s1, 'end_sec': e1, 'captions': [c11, c12,. ...]}, ...],
61 | ...
62 | }
63 | '''
64 | ret = dict()
65 | for example in caption_data:
66 | key = example.get('YoutubeID', example.get('video_fname').split('.')[0])
67 | if key not in ret:
68 | ret[key] = list()
69 | if 'Start_timestamp' in example:
70 | example['start_sec'] = parse_sec(example['Start_timestamp'])
71 | del example['Start_timestamp']
72 | if 'End_timestamp' in example:
73 | example['end_sec'] = parse_sec(example['End_timestamp'])
74 | del example['End_timestamp']
75 | if 'Caption' in example:
76 | example['captions'] = [example['Caption']]
77 | del example['Caption']
78 | ret[key].append(example)
79 | return ret
80 |
81 |
82 | if __name__ == "__main__":
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument("--video-base-folder", type=str, default="videos")
85 | parser.add_argument("--caption-fname", type=str)
86 | parser.add_argument("--scene-fname", type=str, default="temp/scenes-merged.jsonl")
87 | parser.add_argument("--scene-sim-fname", type=str, default="temp/scenes_similarity-merged.jsonl")
88 | parser.add_argument("--caption-with-neg-interval-fname", type=str, default="temp/final_dataset.jsonl")
89 |
90 | parser.add_argument("--scene-boundary-shift", type=float, default=0.25)
91 | parser.add_argument("--span-same-threshold", type=float, default=0.8)
92 | parser.add_argument("--min-caption-span-secs", type=float, default=0.5) # about only 1% of the examples in caption.jsonl is shorter than 0.5 sec
93 | parser.add_argument("--max-caption-span-secs", type=float, default=10) # if the span is too long, the caption may not precise, and the neg span can be very small. so this is a bad case for constructing grounding data
94 | parser.add_argument("--max-num-scenes", type=int, default=100)
95 | args = parser.parse_args()
96 |
97 | caption_data = [json.loads(line) for line in open(args.caption_fname)]
98 | caption_data = reformat_caption_data(caption_data)
99 | scene_data = [json.loads(line) for line in open(args.scene_fname)]
100 | video_id_to_fname = {e['video_fname'].split('.')[0]: e['video_fname'] for e in scene_data}
101 | scene_data = {e['video_fname'].split('.')[0]: e['scenes'][:args.max_num_scenes] for e in scene_data}
102 | scene_sim_data = [json.loads(line) for line in open(args.scene_sim_fname)]
103 | scene_sim_data = {e['video_fname'].split('.')[0]: e['similarity_matrix'] for e in scene_sim_data}
104 |
105 | video_ids = caption_data.keys() & scene_data.keys() & scene_sim_data.keys()
106 | print(len(video_ids))
107 |
108 | f_out = open(args.caption_with_neg_interval_fname, 'w')
109 | for i, video_id in enumerate(video_ids):
110 | video_length_sec = get_video_length(os.path.join(args.video_base_folder, video_id_to_fname[video_id]))
111 | # find the scene of caption segment
112 | for caption_data_by_id in caption_data[video_id]:
113 | caption_start_sec, caption_end_sec = caption_data_by_id['start_sec'], caption_data_by_id['end_sec']
114 | if caption_end_sec - caption_start_sec < args.min_caption_span_secs or caption_end_sec - caption_start_sec > args.max_caption_span_secs:
115 | continue
116 | caption_start_scene = find_scene_id_start(caption_start_sec - args.scene_boundary_shift, [span[0] for span in scene_data[video_id]])
117 | caption_end_scene = find_scene_id_end(caption_end_sec + args.scene_boundary_shift, [span[1] for span in scene_data[video_id]])
118 |
119 | unsimilar_start_scene, unsimilar_end_scene = None, None
120 | unsimilar_scenes = []
121 | if caption_end_scene < caption_start_scene: # cannot find the scene of the caption segment
122 | caption_start_scene, caption_end_scene = None, None
123 | unsimilar_start_sec, unsimilar_end_sec = 0, video_length_sec
124 |
125 | else:
126 | # find if the neighbouring scene is also similar. if yes, merge them as the positive span
127 | new_caption_start_scene, new_caption_end_scene = None, None
128 | simiarity_matrix = np.array(scene_sim_data[video_id])
129 | for idx in range(caption_start_scene - 1, -1, -1):
130 | if simiarity_matrix[idx, caption_start_scene] > args.span_same_threshold:
131 | new_caption_start_scene = idx
132 | else:
133 | break
134 | if new_caption_start_scene is not None:
135 | caption_start_scene = new_caption_start_scene
136 | caption_start_sec = scene_data[video_id][caption_start_scene][0]
137 |
138 | for idx in range(caption_end_scene + 1, len(simiarity_matrix)):
139 | if simiarity_matrix[idx, caption_end_scene] > args.span_same_threshold:
140 | new_caption_end_scene = idx
141 | else:
142 | break
143 | if new_caption_end_scene is not None:
144 | caption_end_scene = new_caption_end_scene
145 | caption_end_sec = scene_data[video_id][caption_end_scene][1]
146 |
147 | # find the unsimilar scenes with the caption scenes
148 | scene_sims = np.max(simiarity_matrix[caption_start_scene: caption_end_scene + 1], axis=0)
149 | unsimilar_scenes = scene_sims < args.span_same_threshold
150 |
151 | if caption_start_scene == 0:
152 | unsimilar_start_sec = 0
153 | elif np.all(unsimilar_scenes[:caption_start_scene]): # no similar segment before this segment
154 | unsimilar_start_sec = 0
155 | elif not np.any(unsimilar_scenes[:caption_start_scene]): # all segments are similar segments before this segment
156 | unsimilar_start_sec = caption_start_sec
157 | else:
158 | # get the last similar segment
159 | unsimilar_start_scene = int(caption_start_scene - np.argmin(unsimilar_scenes[caption_start_scene-1::-1]) - 1)
160 | unsimilar_start_sec = scene_data[video_id][unsimilar_start_scene][1]
161 |
162 | if caption_end_scene == len(scene_data[video_id]) - 1:
163 | unsimilar_end_sec = video_length_sec
164 | elif np.all(unsimilar_scenes[caption_end_scene+1:]): # no similar segment after this segment
165 | unsimilar_end_sec = video_length_sec
166 | elif not np.any(unsimilar_scenes[caption_end_scene+1:]): # all segments are similar segments after this segment
167 | unsimilar_end_sec = caption_end_sec
168 | else:
169 | # get the first similar segment
170 | unsimilar_end_scene = int(np.argmin(unsimilar_scenes[caption_end_scene+1:]) + caption_end_scene + 1)
171 | unsimilar_end_sec = scene_data[video_id][unsimilar_end_scene][0]
172 | unsimilar_scenes = unsimilar_scenes.tolist()
173 |
174 | if unsimilar_end_scene == caption_end_scene and unsimilar_start_scene == caption_start_scene:
175 | continue # no neg interval found, do not use this example for grounding
176 | if caption_end_sec - caption_start_sec > args.max_caption_span_secs:
177 | continue
178 |
179 | for caption in caption_data_by_id['captions']:
180 | res_to_write = {'video': video_id_to_fname[video_id], 'duration': video_length_sec,
181 | 'start_sec': caption_start_sec, 'end_sec': caption_end_sec,
182 | 'neg_start_sec': unsimilar_start_sec, 'neg_end_sec': unsimilar_end_sec,
183 | 'caption': caption,
184 | 'start_scene': caption_start_scene, 'end_scene': caption_end_scene,
185 | 'neg_start_scene': unsimilar_start_scene, 'neg_end_scene': unsimilar_end_scene,}
186 |
187 | json.dump(res_to_write, f_out)
188 | f_out.write('\n')
189 | if i % 100 == 0:
190 | f_out.flush()
191 | f_out.close()
192 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import ConcatDataset, DataLoader
3 | from torchvision import transforms
4 | from torchvision.transforms import InterpolationMode
5 |
6 | from dataset.dataloader import MetaLoader
7 | from dataset.pt_dataset import PTImgTrainDataset, PTVidTrainDataset, PTImgEvalDataset, PTVidEvalDataset
8 | from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset
9 |
10 |
11 | def create_dataset(dataset_type, config):
12 | if "clip" in config.model.get("vit_model", 'vit'):
13 | mean = (0.485, 0.456, 0.406)
14 | std = (0.229, 0.224, 0.225)
15 | else:
16 | vision_enc_name = config.model.vision_encoder.name
17 | if "swin" in vision_enc_name or "vit" in vision_enc_name:
18 | mean = (0.485, 0.456, 0.406)
19 | std = (0.229, 0.224, 0.225)
20 | elif "beit" in vision_enc_name:
21 | mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning
22 | std = (0.5, 0.5, 0.5)
23 | elif "clip" in vision_enc_name:
24 | mean = (0.48145466, 0.4578275, 0.40821073)
25 | std = (0.26862954, 0.26130258, 0.27577711)
26 | else:
27 | raise ValueError
28 |
29 | normalize = transforms.Normalize(mean, std)
30 |
31 | # loaded images and videos are torch.Tensor of torch.uint8 format,
32 | # ordered as (T, 1 or 3, H, W) where T=1 for image
33 | type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
34 |
35 | if config.inputs.video_input.random_aug:
36 | aug_transform = transforms.RandAugment()
37 | else:
38 | aug_transform = transforms.Lambda(lambda x: x)
39 |
40 | train_transform = transforms.Compose(
41 | [
42 | aug_transform,
43 | transforms.RandomResizedCrop(
44 | config.inputs.image_res,
45 | scale=(0.5, 1.0),
46 | interpolation=InterpolationMode.BICUBIC,
47 | ),
48 | transforms.RandomHorizontalFlip(),
49 | type_transform,
50 | normalize,
51 | ]
52 | )
53 | test_transform = transforms.Compose(
54 | [
55 | transforms.Resize(
56 | (config.inputs.image_res, config.inputs.image_res),
57 | interpolation=InterpolationMode.BICUBIC,
58 | ),
59 | type_transform,
60 | normalize,
61 | ]
62 | )
63 |
64 | video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
65 | video_only_dataset_kwargs_train = dict(
66 | video_reader_type=video_reader_type,
67 | sample_type=config.inputs.video_input.sample_type,
68 | num_frames=config.inputs.video_input.num_frames,
69 | num_tries=3, # false tolerance
70 | )
71 | video_only_dataset_kwargs_eval = dict(
72 | video_reader_type=video_reader_type,
73 | sample_type=config.inputs.video_input.sample_type_test,
74 | num_frames=config.inputs.video_input.num_frames_test,
75 | num_tries=1, # we want to have predictions for all videos
76 | )
77 |
78 | if dataset_type == "pt_train":
79 | # convert to list of lists
80 | train_files = (
81 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file
82 | )
83 | train_media_types = sorted(list({e['media_type'] for e in train_files}))
84 |
85 | train_datasets = []
86 | for m in train_media_types:
87 | dataset_cls = PTImgTrainDataset if m == "image" else PTVidTrainDataset
88 | # dataset of the same media_type will be mixed in a single Dataset object
89 | _train_files = [e for e in train_files if e['media_type'] == m]
90 |
91 | datasets = []
92 | for train_file in _train_files:
93 | dataset_kwargs = train_file.copy()
94 | dataset_kwargs.update(dict(
95 | transform=train_transform,
96 | pre_text=config.get(
97 | "pre_text", True
98 | ),
99 | ))
100 | if m == "video":
101 | dataset_kwargs.update(video_only_dataset_kwargs_train)
102 | datasets.append(dataset_cls(**dataset_kwargs))
103 | dataset = ConcatDataset(datasets)
104 | train_datasets.append(dataset)
105 | return train_datasets
106 |
107 | elif dataset_type in ["it_train"]:
108 | # convert to list of lists
109 | train_files = (
110 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file
111 | )
112 | train_media_types = sorted(list({e['media_type'] for e in train_files}))
113 |
114 | train_datasets = []
115 | every_single_dataset = []
116 |
117 | for m in train_media_types:
118 | dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
119 | # dataset_cls = {"image" : ITImgTrainDataset, "video": MyITVidTrainDataset, "video_truncate": MyITVidTruncateTrainDataset}[m]
120 | # dataset of the same media_type will be mixed in a single Dataset object
121 | _train_files = [e for e in train_files if e['media_type'] == m]
122 |
123 | datasets = []
124 | for train_file in _train_files:
125 | dataset_kwargs = train_file.copy()
126 | dataset_kwargs.update(dict(
127 | transform=train_transform,
128 | system=config.model.get("system", ""),
129 | start_token=config.model.get("img_start_token", ""),
130 | end_token=config.model.get("img_end_token", ""),
131 | ))
132 | if "video" in m:
133 | video_only_dataset_kwargs_train.update({
134 | "start_token": config.model.get("start_token", ""),
136 | "add_second_msg": config.model.get("add_second_msg", True),
137 | "grounding_method": train_file.get("grounding_method", None),
138 | "min_gold_clips": train_file.get("min_gold_clips", 1),
139 | "max_gold_clips": train_file.get("max_gold_clips", None),
140 | "num_examples": train_file.get("num_examples", None),
141 | })
142 | dataset_kwargs.update(video_only_dataset_kwargs_train)
143 | if "tgif" in train_file['dataset_name'].lower():
144 | video_only_dataset_kwargs_train.update({
145 | "video_reader_type": "gif"
146 | })
147 | dataset_kwargs.update(video_only_dataset_kwargs_train)
148 | else:
149 | video_only_dataset_kwargs_train.update({
150 | "video_reader_type": "decord"
151 | })
152 | dataset_kwargs.update(video_only_dataset_kwargs_train)
153 | dataset = dataset_cls(**dataset_kwargs)
154 | every_single_dataset.append(dataset)
155 | datasets.append(dataset)
156 | dataset = ConcatDataset(datasets)
157 | train_datasets.append(dataset)
158 | return train_datasets, every_single_dataset
159 |
160 | elif dataset_type == "pt_eval":
161 | test_datasets = []
162 | test_dataset_names = []
163 | # multiple test datasets, all separate
164 | for name, data_cfg in config.test_file.items():
165 | media_type = get_media_type(data_cfg)
166 | test_dataset_cls = (
167 | PTImgEvalDataset if media_type == "image" else PTVidEvalDataset
168 | )
169 | test_dataset_names.append(name)
170 | dataset_kwargs = dict(
171 | ann_file=[data_cfg],
172 | transform=test_transform,
173 | has_multi_vision_gt=config.get(
174 | "has_multi_vision_gt", False
175 | ), # true for ssv2 ret
176 | )
177 | if media_type == "video":
178 | dataset_kwargs.update(video_only_dataset_kwargs_eval)
179 | test_datasets.append(test_dataset_cls(**dataset_kwargs))
180 | return test_datasets, test_dataset_names
181 |
182 |
183 |
184 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
185 | samplers = []
186 | for dataset, shuffle in zip(datasets, shuffles):
187 | sampler = torch.utils.data.DistributedSampler(
188 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
189 | )
190 | samplers.append(sampler)
191 | return samplers
192 |
193 |
194 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
195 | loaders = []
196 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
197 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns
198 | ):
199 | if is_train:
200 | shuffle = sampler is None
201 | drop_last = True
202 | else:
203 | shuffle = False
204 | drop_last = False
205 | loader = DataLoader(
206 | dataset,
207 | batch_size=bs,
208 | num_workers=n_worker,
209 | pin_memory=False,
210 | sampler=sampler,
211 | shuffle=shuffle,
212 | collate_fn=collate_fn,
213 | drop_last=drop_last,
214 | persistent_workers=True if n_worker > 0 else False,
215 | )
216 | loaders.append(loader)
217 | return loaders
218 |
219 |
220 | def iterate_dataloaders(dataloaders):
221 | """Alternatively generate data from multiple dataloaders,
222 | since we use `zip` to concat multiple dataloaders,
223 | the loop will end when the smaller dataloader runs out.
224 |
225 | Args:
226 | dataloaders List(DataLoader): can be a single or multiple dataloaders
227 | """
228 | for data_tuples in zip(*dataloaders):
229 | for idx, data in enumerate(data_tuples):
230 | yield dataloaders[idx].dataset.media_type, data
231 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from scipy import interpolate
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def _init_transformer_weights(module, initializer_range=0.02):
13 | """Initialize the weights. Copied from transformers ViT/Bert model init"""
14 | if isinstance(module, (nn.Linear, nn.Conv2d)):
15 | # Slightly different from the TF version which uses truncated_normal for initialization
16 | # cf https://github.com/pytorch/pytorch/pull/5617
17 | module.weight.data.normal_(mean=0.0, std=initializer_range)
18 | if module.bias is not None:
19 | module.bias.data.zero_()
20 | elif isinstance(module, nn.Embedding):
21 | module.weight.data.normal_(mean=0.0, std=initializer_range)
22 | if module.padding_idx is not None:
23 | module.weight.data[module.padding_idx].zero_()
24 | elif isinstance(module, nn.LayerNorm):
25 | module.bias.data.zero_()
26 | module.weight.data.fill_(1.0)
27 |
28 |
29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
30 | """
31 | Add/Remove extra temporal_embeddings as needed.
32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
33 |
34 | temp_embed_old: (1, num_frames_old, 1, d)
35 | temp_embed_new: (1, num_frames_new, 1, d)
36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings.
37 | """
38 | # TODO zero pad
39 | num_frms_new = temp_embed_new.shape[1]
40 | num_frms_old = temp_embed_old.shape[1]
41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
42 | if num_frms_new > num_frms_old:
43 | if add_zero:
44 | temp_embed_new[
45 | :, :num_frms_old
46 | ] = temp_embed_old # untrained embeddings are zeros.
47 | else:
48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
49 | elif num_frms_new < num_frms_old:
50 | temp_embed_new = temp_embed_old[:, :num_frms_new]
51 | else: # =
52 | temp_embed_new = temp_embed_old
53 | return temp_embed_new
54 |
55 |
56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
57 | """
58 | Add/Remove extra temporal_embeddings as needed.
59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
60 |
61 | temp_embed_old: (1, num_frames_old, 1, d)
62 | temp_embed_new: (1, num_frames_new, 1, d)
63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings.
64 | """
65 | # TODO zero pad
66 | num_frms_new = temp_embed_new.shape[1]
67 | num_frms_old = temp_embed_old.shape[1]
68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
69 | if num_frms_new > num_frms_old:
70 | if add_zero:
71 | temp_embed_new[
72 | :, :num_frms_old
73 | ] = temp_embed_old # untrained embeddings are zeros.
74 | else:
75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
76 | elif num_frms_new < num_frms_old:
77 | temp_embed_new = temp_embed_old[:, :num_frms_new]
78 | else: # =
79 | temp_embed_new = temp_embed_old
80 | return temp_embed_new
81 |
82 |
83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
84 | """
85 | temp_embed_old: (1, num_frames_old, 1, d)
86 | Returns:
87 | temp_embed_new: (1, num_frames_new, 1, d)
88 | """
89 | temp_embed_old = temp_embed_old.squeeze(2).permute(
90 | 0, 2, 1
91 | ) # (1, d, num_frames_old)
92 | temp_embed_new = F.interpolate(
93 | temp_embed_old, num_frames_new, mode="linear"
94 | ) # (1, d, num_frames_new)
95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
96 | 2
97 | ) # (1, num_frames_new, 1, d)
98 | return temp_embed_new
99 |
100 |
101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
102 | """
103 | Args:
104 | pos_embed_old: (1, L_old, d), pre-trained
105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
106 | num_patches_new:
107 | """
108 | # interpolate position embedding
109 | embedding_size = pos_embed_old.shape[-1]
110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
111 | # height (== width) for the checkpoint position embedding
112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
113 | # height (== width) for the new position embedding
114 | new_size = int(num_patches_new ** 0.5)
115 |
116 | if orig_size != new_size:
117 | # class_token and dist_token are kept unchanged
118 | # the extra tokens seems always at the beginning of the position embedding
119 | extra_tokens = pos_embed_old[:, :num_extra_tokens]
120 | # only the position tokens are interpolated
121 | pos_tokens = pos_embed_old[:, num_extra_tokens:]
122 | pos_tokens = pos_tokens.reshape(
123 | -1, orig_size, orig_size, embedding_size
124 | ).permute(0, 3, 1, 2)
125 | pos_tokens = torch.nn.functional.interpolate(
126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
127 | )
128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
131 | return interpolated_pos_embed
132 | else:
133 | return pos_embed_old
134 |
135 |
136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
137 | """
138 | Args:
139 | state_dict_old: loaded state dict
140 | state_dict_new: state dict for model with new image size
141 | patch_shape_new: new model patch_shape
142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
143 | """
144 | all_keys = list(state_dict_old.keys())
145 | for key in all_keys:
146 | if "relative_position_index" in key:
147 | state_dict_old.pop(key)
148 |
149 | if "relative_position_bias_table" in key:
150 | rel_pos_bias = state_dict_old[key]
151 | src_num_pos, num_attn_heads = rel_pos_bias.size()
152 | dst_num_pos, _ = state_dict_new[key].size()
153 | dst_patch_shape = patch_shape_new
154 | if dst_patch_shape[0] != dst_patch_shape[1]:
155 | raise NotImplementedError()
156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
157 | dst_patch_shape[1] * 2 - 1
158 | )
159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
161 | if src_size != dst_size:
162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
163 | # key, src_size, src_size, dst_size, dst_size))
164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
166 |
167 | def geometric_progression(a, r, n):
168 | return a * (1.0 - r ** n) / (1.0 - r)
169 |
170 | left, right = 1.01, 1.5
171 | while right - left > 1e-6:
172 | q = (left + right) / 2.0
173 | gp = geometric_progression(1, q, src_size // 2)
174 | if gp > dst_size // 2:
175 | right = q
176 | else:
177 | left = q
178 |
179 | # if q > 1.090307:
180 | # q = 1.090307
181 |
182 | dis = []
183 | cur = 1
184 | for i in range(src_size // 2):
185 | dis.append(cur)
186 | cur += q ** (i + 1)
187 |
188 | r_ids = [-_ for _ in reversed(dis)]
189 |
190 | x = r_ids + [0] + dis
191 | y = r_ids + [0] + dis
192 |
193 | t = dst_size // 2.0
194 | dx = np.arange(-t, t + 0.1, 1.0)
195 | dy = np.arange(-t, t + 0.1, 1.0)
196 |
197 | # logger.info("Original positions = %s" % str(x))
198 | # logger.info("Target positions = %s" % str(dx))
199 |
200 | all_rel_pos_bias = []
201 |
202 | for i in range(num_attn_heads):
203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
204 | f = interpolate.interp2d(x, y, z, kind="cubic")
205 | all_rel_pos_bias.append(
206 | torch.Tensor(f(dx, dy))
207 | .contiguous()
208 | .view(-1, 1)
209 | .to(rel_pos_bias.device)
210 | )
211 |
212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
213 |
214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
215 | state_dict_old[key] = new_rel_pos_bias
216 | return state_dict_old
217 |
218 |
219 | def tile(x, dim, n_tile):
220 | init_dim = x.size(dim)
221 | repeat_idx = [1] * x.dim()
222 | repeat_idx[dim] = n_tile
223 | x = x.repeat(*repeat_idx)
224 | order_index = torch.LongTensor(
225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
226 | )
227 | return torch.index_select(x, dim, order_index.to(x.device))
228 |
229 |
230 | def mask_logits(target, mask):
231 | return target * mask + (1 - mask) * (-1e10)
232 |
233 |
234 | class AllGather(torch.autograd.Function):
235 | """An autograd function that performs allgather on a tensor."""
236 |
237 | @staticmethod
238 | def forward(ctx, tensor, args):
239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)]
240 | torch.distributed.all_gather(output, tensor)
241 | ctx.rank = args.rank
242 | ctx.batch_size = tensor.shape[0]
243 | return torch.cat(output, dim=0)
244 |
245 | @staticmethod
246 | def backward(ctx, grad_output):
247 | return (
248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
249 | None,
250 | )
251 |
252 |
253 | allgather_wgrad = AllGather.apply
254 |
--------------------------------------------------------------------------------
/models/blip2/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from scipy import interpolate
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def _init_transformer_weights(module, initializer_range=0.02):
13 | """Initialize the weights. Copied from transformers ViT/Bert model init"""
14 | if isinstance(module, (nn.Linear, nn.Conv2d)):
15 | # Slightly different from the TF version which uses truncated_normal for initialization
16 | # cf https://github.com/pytorch/pytorch/pull/5617
17 | module.weight.data.normal_(mean=0.0, std=initializer_range)
18 | if module.bias is not None:
19 | module.bias.data.zero_()
20 | elif isinstance(module, nn.Embedding):
21 | module.weight.data.normal_(mean=0.0, std=initializer_range)
22 | if module.padding_idx is not None:
23 | module.weight.data[module.padding_idx].zero_()
24 | elif isinstance(module, nn.LayerNorm):
25 | module.bias.data.zero_()
26 | module.weight.data.fill_(1.0)
27 |
28 |
29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
30 | """
31 | Add/Remove extra temporal_embeddings as needed.
32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
33 |
34 | temp_embed_old: (1, num_frames_old, 1, d)
35 | temp_embed_new: (1, num_frames_new, 1, d)
36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings.
37 | """
38 | # TODO zero pad
39 | num_frms_new = temp_embed_new.shape[1]
40 | num_frms_old = temp_embed_old.shape[1]
41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
42 | if num_frms_new > num_frms_old:
43 | if add_zero:
44 | temp_embed_new[
45 | :, :num_frms_old
46 | ] = temp_embed_old # untrained embeddings are zeros.
47 | else:
48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
49 | elif num_frms_new < num_frms_old:
50 | temp_embed_new = temp_embed_old[:, :num_frms_new]
51 | else: # =
52 | temp_embed_new = temp_embed_old
53 | return temp_embed_new
54 |
55 |
56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
57 | """
58 | Add/Remove extra temporal_embeddings as needed.
59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
60 |
61 | temp_embed_old: (1, num_frames_old, 1, d)
62 | temp_embed_new: (1, num_frames_new, 1, d)
63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings.
64 | """
65 | # TODO zero pad
66 | num_frms_new = temp_embed_new.shape[1]
67 | num_frms_old = temp_embed_old.shape[1]
68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
69 | if num_frms_new > num_frms_old:
70 | if add_zero:
71 | temp_embed_new[
72 | :, :num_frms_old
73 | ] = temp_embed_old # untrained embeddings are zeros.
74 | else:
75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
76 | elif num_frms_new < num_frms_old:
77 | temp_embed_new = temp_embed_old[:, :num_frms_new]
78 | else: # =
79 | temp_embed_new = temp_embed_old
80 | return temp_embed_new
81 |
82 |
83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
84 | """
85 | temp_embed_old: (1, num_frames_old, 1, d)
86 | Returns:
87 | temp_embed_new: (1, num_frames_new, 1, d)
88 | """
89 | temp_embed_old = temp_embed_old.squeeze(2).permute(
90 | 0, 2, 1
91 | ) # (1, d, num_frames_old)
92 | temp_embed_new = F.interpolate(
93 | temp_embed_old, num_frames_new, mode="linear"
94 | ) # (1, d, num_frames_new)
95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
96 | 2
97 | ) # (1, num_frames_new, 1, d)
98 | return temp_embed_new
99 |
100 |
101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
102 | """
103 | Args:
104 | pos_embed_old: (1, L_old, d), pre-trained
105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
106 | num_patches_new:
107 | """
108 | # interpolate position embedding
109 | embedding_size = pos_embed_old.shape[-1]
110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
111 | # height (== width) for the checkpoint position embedding
112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
113 | # height (== width) for the new position embedding
114 | new_size = int(num_patches_new ** 0.5)
115 |
116 | if orig_size != new_size:
117 | # class_token and dist_token are kept unchanged
118 | # the extra tokens seems always at the beginning of the position embedding
119 | extra_tokens = pos_embed_old[:, :num_extra_tokens]
120 | # only the position tokens are interpolated
121 | pos_tokens = pos_embed_old[:, num_extra_tokens:]
122 | pos_tokens = pos_tokens.reshape(
123 | -1, orig_size, orig_size, embedding_size
124 | ).permute(0, 3, 1, 2)
125 | pos_tokens = torch.nn.functional.interpolate(
126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
127 | )
128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
131 | return interpolated_pos_embed
132 | else:
133 | return pos_embed_old
134 |
135 |
136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
137 | """
138 | Args:
139 | state_dict_old: loaded state dict
140 | state_dict_new: state dict for model with new image size
141 | patch_shape_new: new model patch_shape
142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
143 | """
144 | all_keys = list(state_dict_old.keys())
145 | for key in all_keys:
146 | if "relative_position_index" in key:
147 | state_dict_old.pop(key)
148 |
149 | if "relative_position_bias_table" in key:
150 | rel_pos_bias = state_dict_old[key]
151 | src_num_pos, num_attn_heads = rel_pos_bias.size()
152 | dst_num_pos, _ = state_dict_new[key].size()
153 | dst_patch_shape = patch_shape_new
154 | if dst_patch_shape[0] != dst_patch_shape[1]:
155 | raise NotImplementedError()
156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
157 | dst_patch_shape[1] * 2 - 1
158 | )
159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
161 | if src_size != dst_size:
162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
163 | # key, src_size, src_size, dst_size, dst_size))
164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
166 |
167 | def geometric_progression(a, r, n):
168 | return a * (1.0 - r ** n) / (1.0 - r)
169 |
170 | left, right = 1.01, 1.5
171 | while right - left > 1e-6:
172 | q = (left + right) / 2.0
173 | gp = geometric_progression(1, q, src_size // 2)
174 | if gp > dst_size // 2:
175 | right = q
176 | else:
177 | left = q
178 |
179 | # if q > 1.090307:
180 | # q = 1.090307
181 |
182 | dis = []
183 | cur = 1
184 | for i in range(src_size // 2):
185 | dis.append(cur)
186 | cur += q ** (i + 1)
187 |
188 | r_ids = [-_ for _ in reversed(dis)]
189 |
190 | x = r_ids + [0] + dis
191 | y = r_ids + [0] + dis
192 |
193 | t = dst_size // 2.0
194 | dx = np.arange(-t, t + 0.1, 1.0)
195 | dy = np.arange(-t, t + 0.1, 1.0)
196 |
197 | # logger.info("Original positions = %s" % str(x))
198 | # logger.info("Target positions = %s" % str(dx))
199 |
200 | all_rel_pos_bias = []
201 |
202 | for i in range(num_attn_heads):
203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
204 | f = interpolate.interp2d(x, y, z, kind="cubic")
205 | all_rel_pos_bias.append(
206 | torch.Tensor(f(dx, dy))
207 | .contiguous()
208 | .view(-1, 1)
209 | .to(rel_pos_bias.device)
210 | )
211 |
212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
213 |
214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
215 | state_dict_old[key] = new_rel_pos_bias
216 | return state_dict_old
217 |
218 |
219 | def tile(x, dim, n_tile):
220 | init_dim = x.size(dim)
221 | repeat_idx = [1] * x.dim()
222 | repeat_idx[dim] = n_tile
223 | x = x.repeat(*repeat_idx)
224 | order_index = torch.LongTensor(
225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
226 | )
227 | return torch.index_select(x, dim, order_index.to(x.device))
228 |
229 |
230 | def mask_logits(target, mask):
231 | return target * mask + (1 - mask) * (-1e10)
232 |
233 |
234 | class AllGather(torch.autograd.Function):
235 | """An autograd function that performs allgather on a tensor."""
236 |
237 | @staticmethod
238 | def forward(ctx, tensor, args):
239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)]
240 | torch.distributed.all_gather(output, tensor)
241 | ctx.rank = args.rank
242 | ctx.batch_size = tensor.shape[0]
243 | return torch.cat(output, dim=0)
244 |
245 | @staticmethod
246 | def backward(ctx, grad_output):
247 | return (
248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
249 | None,
250 | )
251 |
252 |
253 | allgather_wgrad = AllGather.apply
254 |
--------------------------------------------------------------------------------
/tasks/train_it.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import datetime
4 | import logging
5 | import time
6 | from os.path import join
7 |
8 | import sys
9 | sys.path.append('.')
10 |
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import wandb
15 |
16 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler
17 | from models.hawkeye_it import HawkEye_it
18 | from tasks.shared_utils import get_media_types, setup_model
19 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed)
20 | from utils.config_utils import setup_main
21 | from utils.distributed import get_rank, get_world_size, is_main_process
22 | from utils.logger import log_dict_to_wandb, setup_wandb, log_dict_to_tensorboard
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def train(
28 | model, model_without_ddp,
29 | train_loaders,
30 | optimizer,
31 | epoch,
32 | global_step,
33 | device,
34 | scheduler,
35 | scaler,
36 | config,
37 | tb_writer,
38 | ):
39 | model.train()
40 |
41 | metric_logger = MetricLogger(delimiter=" ")
42 | metric_logger.add_meter("lr", SmoothedValue(window=10, fmt="{value:.6f}"))
43 | loss_names = ["loss"]
44 |
45 | media_types = get_media_types(train_loaders)
46 |
47 | for name in loss_names:
48 | for m in media_types:
49 | metric_logger.add_meter(
50 | f"{m}-{name}", SmoothedValue(window=10, fmt="{value:.4f}")
51 | )
52 |
53 | header = f"Train Epoch: [{epoch}]"
54 | log_freq = config.log_freq
55 | save_freq = config.get('save_freq', None)
56 |
57 | if config.distributed:
58 | for d in train_loaders:
59 | d.sampler.set_epoch(epoch)
60 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
61 |
62 | iterator = metric_logger.log_every(train_loader, log_freq, header)
63 | for i, (media_type, (image, text, instruction, _)) in enumerate(iterator):
64 | image = image.to(device, non_blocking=True)
65 |
66 | with torch.cuda.amp.autocast(enabled=config.fp16):
67 | loss_dict = model(image, text, instruction)
68 | loss = sum(loss_dict.values())
69 |
70 | optimizer.zero_grad()
71 | if not torch.isnan(loss):
72 | scaler.scale(loss).backward()
73 | if config.optimizer.max_grad_norm > 0:
74 | scaler.unscale_(optimizer)
75 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
76 | scaler.step(optimizer)
77 | scaler.update()
78 | else:
79 | print('nan loss encountered at turn', global_step)
80 | scheduler.step()
81 |
82 | # logging
83 | for name in loss_names:
84 | value = loss_dict[name]
85 | value = value if isinstance(value, float) else value.item()
86 | metric_logger.update(**{f"{media_type}-{name}": value})
87 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
88 |
89 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
90 | logs = metric_logger.get_avg_dict()
91 | log_dict_to_wandb(logs, step=global_step, prefix="train/")
92 |
93 | if is_main_process() and config.get('tensorboard', {'enable': False})['enable'] and global_step % log_freq == 0:
94 | logs = metric_logger.get_avg_dict()
95 | log_dict_to_tensorboard(logs, global_step, "train/", tb_writer)
96 |
97 | # save model every x steps
98 | if is_main_process() and save_freq is not None and global_step % save_freq == 0:
99 | logger.info(f"Epoch {epoch}")
100 | param_grad_dic = {
101 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
102 | }
103 | state_dict = model_without_ddp.state_dict()
104 | for k in list(state_dict.keys()):
105 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
106 | # delete parameters that do not require gradient
107 | del state_dict[k]
108 | save_obj = {
109 | "model": state_dict,
110 | "optimizer": optimizer.state_dict(),
111 | "scheduler": scheduler.state_dict(),
112 | "scaler": scaler.state_dict(),
113 | "config": config,
114 | "epoch": epoch,
115 | "global_step": global_step,
116 | }
117 | if config.get("save_latest", False):
118 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
119 | else:
120 | torch.save(save_obj, join(config.output_dir, f"ckpt_{global_step}.pth"))
121 |
122 | global_step += 1
123 |
124 | # gather the stats from all processes
125 | metric_logger.synchronize_between_processes()
126 | logger.info(f"Averaged stats: {metric_logger.global_avg()}")
127 | return global_step
128 |
129 |
130 | def setup_dataloaders(config, mode="pt"):
131 | # train datasets, create a list of data loaders
132 | logger.info(f"Creating dataset for {mode}")
133 | train_datasets, every_single_dataset = create_dataset(f"{mode}_train", config)
134 |
135 | if config.get("freeze_dataset_folder", None) is not None:
136 | # save self.anno in current dataset to "freeze_dataset_folder"
137 | # this is used to record record the dataset when training data changes
138 | os.makedirs(config.freeze_dataset_folder, exist_ok=True)
139 | for dataset in every_single_dataset:
140 | dest_path = os.path.join(config.freeze_dataset_folder, dataset.dataset_name + '.json')
141 | json.dump(dataset.anno, open(dest_path, 'w'))
142 |
143 | media_types = get_media_types(train_datasets)
144 |
145 | if config.distributed:
146 | num_tasks = get_world_size()
147 | global_rank = get_rank()
148 | samplers = create_sampler(
149 | train_datasets, [True] * len(media_types), num_tasks, global_rank
150 | )
151 | else:
152 | samplers = [None] * len(media_types)
153 |
154 | train_loaders = create_loader(
155 | train_datasets,
156 | samplers,
157 | batch_size=[config.inputs.batch_size[k] for k in media_types],
158 | num_workers=[config.num_workers] * len(media_types),
159 | is_trains=[True] * len(media_types),
160 | collate_fns=[None] * len(media_types),
161 | ) # [0]
162 |
163 | return train_loaders, media_types
164 |
165 |
166 | def main(config):
167 | if is_main_process() and config.wandb.enable:
168 | run = setup_wandb(config)
169 |
170 | if is_main_process() and config.get('tensorboard', {'enable': False})['enable']:
171 | from torch.utils.tensorboard import SummaryWriter
172 | tb_writer = SummaryWriter(log_dir=config.tensorboard.get('log_dir', os.path.join(config.output_dir, 'tb_logs')))
173 | else:
174 | tb_writer = None
175 |
176 | logger.info(f"train_file: {config.train_file}")
177 |
178 | setup_seed(config.seed + get_rank())
179 | device = torch.device(config.device)
180 |
181 | train_loaders, train_media_types = setup_dataloaders(
182 | config, mode=config.mode
183 | )
184 | num_steps_per_epoch = sum(len(d) for d in train_loaders)
185 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
186 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
187 | # set cudnn.benchmark=True only when input size is fixed
188 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
189 | cudnn.benchmark = len(train_media_types) == 1
190 |
191 | model_cls = eval(config.model.get('model_cls', 'HawkEye_it'))
192 | (
193 | model,
194 | model_without_ddp,
195 | optimizer,
196 | scheduler,
197 | scaler,
198 | start_epoch,
199 | global_step,
200 | ) = setup_model(
201 | config,
202 | model_cls=model_cls,
203 | find_unused_parameters=True,
204 | )
205 | if is_main_process() and config.wandb.enable:
206 | wandb.watch(model)
207 |
208 | logger.info("Start training")
209 | start_time = time.time()
210 | for epoch in range(start_epoch, config.scheduler.epochs):
211 | if not config.evaluate:
212 | global_step = train(
213 | model, model_without_ddp,
214 | train_loaders,
215 | optimizer,
216 | epoch,
217 | global_step,
218 | device,
219 | scheduler,
220 | scaler,
221 | config,
222 | tb_writer
223 | )
224 |
225 | if is_main_process() and config.get('save_freq', None) is None: # does not save every x steps, so we save every epoch
226 | logger.info(f"Epoch {epoch}")
227 | param_grad_dic = {
228 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
229 | }
230 | state_dict = model_without_ddp.state_dict()
231 | for k in list(state_dict.keys()):
232 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
233 | # delete parameters that do not require gradient
234 | del state_dict[k]
235 | save_obj = {
236 | "model": state_dict,
237 | "optimizer": optimizer.state_dict(),
238 | "scheduler": scheduler.state_dict(),
239 | "scaler": scaler.state_dict(),
240 | "config": config,
241 | "epoch": epoch,
242 | "global_step": global_step,
243 | }
244 | if config.get("save_latest", False):
245 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
246 | else:
247 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
248 |
249 | if config.evaluate:
250 | break
251 |
252 | dist.barrier()
253 |
254 | total_time = time.time() - start_time
255 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
256 | logger.info(f"Training time {total_time_str}")
257 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
258 |
259 | if is_main_process() and config.wandb.enable:
260 | run.finish()
261 |
262 |
263 | if __name__ == "__main__":
264 | cfg = setup_main()
265 | main(cfg)
266 |
--------------------------------------------------------------------------------