├── models ├── __init__.py ├── bert │ ├── __init__.py │ └── builder.py ├── blip2 │ ├── __init__.py │ ├── builder.py │ ├── blip2.py │ └── utils.py └── utils.py ├── assets ├── hawk.png ├── dataset.jpg ├── example1.jpg ├── example4.jpg └── performance.jpg ├── .gitignore ├── internvid_g ├── scripts │ ├── run_1121 │ │ ├── check_videos.sh │ │ ├── blip2.sh │ │ ├── text2tag.sh │ │ ├── llama2.sh │ │ ├── ground_data_construction.sh │ │ ├── config_7b.json │ │ ├── videochat.sh │ │ └── clip_filter.sh │ ├── extract_of.sh │ ├── run_llama2.sh │ └── run.sh ├── code │ ├── download_videos.py │ └── ground_data_construction.py └── README.md ├── scripts ├── train │ ├── run_7b_stage3.sh │ ├── anetc.sh │ ├── charades_sta.sh │ ├── config_7b_stage3.py │ ├── charades_sta.py │ └── anetc.py └── test │ ├── videoqa.sh │ └── recursive_grounding.sh ├── requirements.txt ├── configs ├── config.json ├── dataset_utils.py └── instruction_data.py ├── data_preparing ├── nextgqa.py ├── internvid.py ├── charades.py ├── check_grounding_results.ipynb ├── anetc.py └── videoqa.py ├── dataset ├── dataloader.py ├── base_dataset.py ├── pt_dataset.py ├── video_utils.py ├── utils.py └── __init__.py ├── utils ├── config_utils.py ├── scheduler.py ├── easydict.py ├── distributed.py ├── optimizer.py ├── config.py ├── logger.py └── basic_utils.py ├── tasks ├── shared_utils.py └── train_it.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/blip2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/hawk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/hawk.png -------------------------------------------------------------------------------- /assets/dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/dataset.jpg -------------------------------------------------------------------------------- /assets/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/example1.jpg -------------------------------------------------------------------------------- /assets/example4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/example4.jpg -------------------------------------------------------------------------------- /assets/performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/HawkEye/HEAD/assets/performance.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model/* 2 | data/* 3 | outputs/* 4 | 5 | .ipynb_checkpoints/ 6 | __pycache__ 7 | 8 | *.pyc 9 | *.pth 10 | *.log 11 | *.out 12 | test.json 13 | debug.py 14 | debug.ipynb 15 | -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/check_videos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv0b 4 | #SBATCH -p HGX 5 | #SBATCH --time=24:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:0 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | python code/check_videos.py -------------------------------------------------------------------------------- /scripts/train/run_7b_stage3.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | NNODE=1 3 | NUM_GPUS=8 4 | 5 | OUTPUT_DIR=$1 6 | mkdir -vp $OUTPUT_DIR 7 | 8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \ 9 | tasks/train_it.py scripts/train/config_7b_stage3.py \ 10 | output_dir ${OUTPUT_DIR} \ 11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \ 12 | > ${OUTPUT_DIR}/train.log 2>&1 & 13 | 14 | wait 15 | -------------------------------------------------------------------------------- /scripts/train/anetc.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | NNODE=1 3 | NUM_GPUS=2 4 | 5 | OUTPUT_DIR=$1 6 | PRETRAINED_PATH=$2 7 | mkdir -vp ${OUTPUT_DIR} 8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \ 9 | tasks/train_it.py scripts/finetune/anetc.py \ 10 | output_dir ${OUTPUT_DIR} \ 11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \ 12 | pretrained_path ${PRETRAINED_PATH} \ 13 | > ${OUTPUT_DIR}/train.log 2>&1 & 14 | wait -------------------------------------------------------------------------------- /scripts/train/charades_sta.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | NNODE=1 3 | NUM_GPUS=2 4 | 5 | OUTPUT_DIR=$1 6 | PRETRAINED_PATH=$2 7 | mkdir -vp ${OUTPUT_DIR} 8 | torchrun --nnodes=${NNODE} --nproc_per_node=${NUM_GPUS} --master_port=${MASTER_PORT} \ 9 | tasks/train_it.py scripts/train/charades_sta.py \ 10 | output_dir ${OUTPUT_DIR} \ 11 | freeze_dataset_folder ${OUTPUT_DIR}/training_data \ 12 | pretrained_path ${PRETRAINED_PATH} \ 13 | > ${OUTPUT_DIR}/train.log 2>&1 & 14 | wait 15 | -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/blip2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv0b 4 | #SBATCH -p HGX 5 | #SBATCH --time=24:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | for i in 0 1 2 3 12 | do 13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func blip2 \ 14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 15 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \ 16 | > code/nohup/blip2.log.${i} 2>&1 & 17 | done 18 | wait -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/text2tag.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv0b 4 | #SBATCH -p HGX 5 | #SBATCH --time=24:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | for i in 0 1 2 3 12 | do 13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func tag2text \ 14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 15 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \ 16 | > code/nohup/tag2text.log.${i} 2>&1 & 17 | done 18 | wait 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | apex==0.1 3 | arrow==1.3.0 4 | av==10.0.0 5 | datasets==2.14.6 6 | decorator==5.1.1 7 | decord==0.6.0 8 | einops==0.6.1 9 | einops-exts==0.0.4 10 | flash-attn==2.3.6 11 | gradio==3.23.0 12 | gradio_client==0.2.6 13 | h5py==3.8.0 14 | huggingface-hub==0.19.4 15 | imageio==2.27.0 16 | jupyter==1.0.0 17 | matplotlib==3.7.1 18 | notebook==7.0.6 19 | opencv-python==4.7.0.72 20 | pandas==1.5.3 21 | PyYAML==6.0 22 | scipy==1.10.1 23 | spacy==3.7.2 24 | tensorboard==2.12.3 25 | transformers==4.31.0 26 | torch==1.13.1 27 | torchvision==0.14.1+cu116 28 | tqdm==4.64.1 29 | -e git+https://github.com/facebookresearch/xformers.git@e153e4b4f5d0d821d707696029d84faed11a92bf#egg=xformers -------------------------------------------------------------------------------- /internvid_g/scripts/extract_of.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=of 3 | #SBATCH --qos=lv4 4 | #SBATCH --nodes=1 5 | #SBATCH --time=10:00:00 6 | #SBATCH --cpus-per-gpu=4 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --output=./code/nohup/webdvd_of.out 9 | #SBATCH --error=./code/nohup/webvid_of.out 10 | 11 | # extract optical flow 12 | 13 | cd /home/wangyuxuan1/codes/video_features 14 | 15 | python main.py \ 16 | feature_type=raft \ 17 | device="cuda:0" \ 18 | on_extraction=save_numpy \ 19 | batch_size=4 \ 20 | side_size=224 \ 21 | file_with_video_paths=/scratch2/nlp/wangyueqian/InternVid/code/scripts/webvid_of_test_path.txt \ 22 | output_path=/scratch2/nlp/wangyueqian/InternVid/code/webvid_of \ 23 | # extraction_fps=2 -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/llama2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv0b 4 | #SBATCH -p HGX 5 | #SBATCH --time=24:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | for i in 0 1 2 3 12 | do 13 | CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func llama2 \ 14 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 15 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \ 16 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \ 17 | --llama2-fname temp/1121/scene_captions_llama2.jsonl.${i} \ 18 | > code/nohup/llama2.log.${i} 2>&1 & 19 | done 20 | wait -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/ground_data_construction.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv4 4 | #SBATCH -p HGX 5 | #SBATCH --time=10:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:0 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | for i in 0 1 2 3 12 | do 13 | python code/ground_data_construction.py \ 14 | --video-base-folder videos-lowres \ 15 | --caption-fname temp/1121/scene_captions_tag2text_high_sim.jsonl.$i \ 16 | --scene-fname temp/1121/scenes_merged.jsonl.0$i \ 17 | --scene-sim-fname temp/1121/scenes_merged_similarity.jsonl \ 18 | --caption-with-neg-interval-fname temp/1121/scene_captions_tag2text_high_sim-with_neg.jsonl.$i & 19 | done 20 | wait -------------------------------------------------------------------------------- /scripts/test/videoqa.sh: -------------------------------------------------------------------------------- 1 | # test on mvbench 2 | python tasks/test.py --tasks "MVBench" --config configs/config.json \ 3 | --ckpt model/hawkeye.pth --data-dir data \ 4 | --save-path outputs/MVBench.jsonl \ 5 | > outputs/MVBench.log 2>&1 & 6 | wait 7 | 8 | # test on NExT-QA 9 | python tasks/test.py --tasks "NExTQA" --config configs/config.json \ 10 | --ckpt model/hawkeye.pth --data-dir data \ 11 | --save-path outputs/NExTQA.jsonl \ 12 | > outputs/NExTQA.log 2>&1 & 13 | wait 14 | 15 | 16 | # test on TVQA 17 | python tasks/test.py --tasks "TVQA" --config configs/config.json \ 18 | --ckpt model/hawkeye.pth --data-dir data \ 19 | --save-path outputs/TVQA.jsonl \ 20 | > outputs/TVQA.log 2>&1 & 21 | wait 22 | 23 | 24 | # test on STAR 25 | python tasks/test.py --tasks "STAR" --config configs/config.json \ 26 | --ckpt model/hawkeye.pth --data-dir data \ 27 | --save-path outputs/STAR.jsonl \ 28 | > outputs/STAR.log 2>&1 & 29 | wait -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/config_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "vit_model": "eva_clip_g", 4 | "vit_model_path": "code/video_chat/model/eva_vit_g.pth", 5 | "q_former_model_path": "code/video_chat/model/blip2_pretrained_flant5xxl.pth", 6 | "llama_model_path": "code/video_chat/model/vicuna-7b-v0", 7 | "videochat_model_path": "code/video_chat/model/VideoChat/videochat_7b_stage1.pth", 8 | "img_size": 224, 9 | "num_query_token": 32, 10 | "drop_path_rate": 0.0, 11 | "use_grad_checkpoint": false, 12 | "vit_precision": "fp32", 13 | "freeze_vit": true, 14 | "freeze_mhra": false, 15 | "freeze_qformer": true, 16 | "low_resource": false, 17 | "max_txt_len": 320, 18 | "temporal_downsample": false, 19 | "no_lmhra": true, 20 | "double_lmhra": false, 21 | "lmhra_reduction": 2.0, 22 | "gmhra_layers": 8, 23 | "gmhra_drop_path_rate": 0.0, 24 | "gmhra_dropout": 0.5, 25 | "extra_num_query_token": 64 26 | }, 27 | "device": "cuda" 28 | } 29 | -------------------------------------------------------------------------------- /scripts/test/recursive_grounding.sh: -------------------------------------------------------------------------------- 1 | max_turns=4 2 | 3 | # test on charades-sta 4 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \ 5 | --ckpt model/hawkeye.pth \ 6 | --video-path data/videos/charades \ 7 | --data-path data/test-anno/charades_sta-recursive_grounding.json \ 8 | --save-path outputs/charades_sta-recursive_grounding-${max_turns}_turns.jsonl \ 9 | > outputs/charades_sta-recursive_grounding-${max_turns}_turns.log 2>&1 & 10 | wait 11 | 12 | # test on anet-captions 13 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \ 14 | --ckpt model/hawkeye.pth \ 15 | --video-path data/videos/activitynet \ 16 | --data-path data/test-anno/anetc-recursive_grounding.json \ 17 | --save-path outputs/anetc-recursive_grounding-${max_turns}_turns.jsonl \ 18 | > outputs/anetc-recursive_grounding-${max_turns}_turns.log 2>&1 & 19 | wait 20 | 21 | # test on nextgqa 22 | python tasks/test_recursive_grounding.py --max-turns ${max_turns} --config configs/config.json \ 23 | --ckpt model/hawkeye.pth \ 24 | --video-path data/videos/nextqa \ 25 | --data-path data/test-anno/nextgqa-recursive_grounding.json \ 26 | --save-path outputs/nextgqa-recursive_grounding-${max_turns}_turns.jsonl \ 27 | > outputs/nextgqa-recursive_grounding-${max_turns}_turns.log 2>&1 & 28 | wait 29 | -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "model_cls": "HawkEye_it", 4 | "bert_path": "bert-base-uncased", 5 | "llama_model_path": "model/vicuna-7b-v0", 6 | "freeze_vit": false, 7 | "freeze_qformer": false, 8 | "max_txt_len": 512, 9 | "low_resource": false, 10 | "vision_encoder": { 11 | "name": "vit_l14", 12 | "img_size": 224, 13 | "patch_size": 16, 14 | "d_model": 1024, 15 | "encoder_embed_dim": 1024, 16 | "encoder_depth": 24, 17 | "encoder_num_heads": 16, 18 | "drop_path_rate": 0.0, 19 | "num_frames": 32, 20 | "tubelet_size": 1, 21 | "use_checkpoint": false, 22 | "checkpoint_num": 0, 23 | "pretrained": "", 24 | "return_index": -2, 25 | "vit_add_ln": true, 26 | "ckpt_num_frame": 4 27 | }, 28 | "num_query_token": 32, 29 | "qformer_hidden_dropout_prob": 0.1, 30 | "qformer_attention_probs_dropout_prob": 0.1, 31 | "qformer_drop_path_rate": 0.2, 32 | "extra_num_query_token": 64, 33 | "qformer_text_input": true, 34 | "system": "", 35 | "start_token": "", 37 | "img_start_token": "", 38 | "img_end_token": "", 39 | "random_shuffle": true, 40 | "use_lora": true, 41 | "lora_r": 16, 42 | "lora_alpha": 32, 43 | "lora_dropout": 0.1 44 | }, 45 | "device": "cuda" 46 | } 47 | -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/videochat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=i3-blip2-text2tag 3 | #SBATCH --qos=lv4 4 | #SBATCH -p HGX 5 | #SBATCH --time=10:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:2 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | # for i in 0 # 1 2 3 12 | # do 13 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func videochat \ 14 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 15 | # --videochat-fname temp/1121/scene_captions_videochat.jsonl.${i} \ 16 | # --videochat-config-fname code/scripts/run_1121/config_7b.json \ 17 | # # > code/nohup/videochat.log.${i} 2>&1 & 18 | # done 19 | # wait 20 | # 上面这个程序标注了一些内容,感觉经过finetune的videochat的特点就是会生成冗长但经常有幻觉的caption. 21 | # 接下来还尝试了只经过pre-train的videochat,发现它生成的东西有很明显的webvid noise,即总是加上时间地点等,基本也不能用. 22 | 23 | i=3 24 | CUDA_VISIBLE_DEVICES=0 python code/caption_clips.py --func blip2 \ 25 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 26 | --blip2-fname temp/1121/scene_captions_blip2.jsonl.${i} \ 27 | >> code/nohup/blip2.log.${i} 2>&1 & 28 | 29 | CUDA_VISIBLE_DEVICES=1 python code/caption_clips.py --func tag2text \ 30 | --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 31 | --tag2text-fname temp/1121/scene_captions_tag2text.jsonl.${i} \ 32 | >> code/nohup/tag2text.log.${i} 2>&1 & 33 | wait -------------------------------------------------------------------------------- /internvid_g/scripts/run_1121/clip_filter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=filter-internvid 3 | #SBATCH --qos=lv0b 4 | #SBATCH -p HGX 5 | #SBATCH --time=24:00:00 6 | #SBATCH --cpus-per-gpu=6 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | 12 | # 对llama2 summary之后的结果计算相似度 13 | # for i in 0 1 2 3 14 | # do 15 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func filter \ 16 | # --batch-size 32 \ 17 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 18 | # --llama2-fname temp/1121/scene_captions_llama2.jsonl.${i} \ 19 | # --filtered-fname temp/1121/scene_captions_filtered.jsonl.${i} \ 20 | # > code/nohup/clip_filter.log.${i} 2>&1 & 21 | # done 22 | # wait 23 | 24 | # 对tag2text标注出的每个caption计算相似度 25 | # for i in 0 1 2 3 26 | # do 27 | # CUDA_VISIBLE_DEVICES=$i python code/caption_clips.py --func filter \ 28 | # --batch-size 32 \ 29 | # --video-folder videos-lowres --scene-fname temp/1121/scenes_merged.jsonl.0${i} \ 30 | # --filter-input-fname temp/1121/scene_captions_tag2text.jsonl.${i} --merge-method max \ 31 | # --filtered-fname temp/1121/scene_captions_tag2text_clip_sim.jsonl.${i} \ 32 | # > code/nohup/tag2text_clip_sim.log.${i} 2>&1 & 33 | # done 34 | # wait 35 | 36 | # select tag2text captions with clip_sim score above median 37 | for i in 0 1 2 3 38 | do 39 | python code/caption_clips.py --func select_filtered_tag2text_captions \ 40 | --filter-input-fname temp/1121/scene_captions_tag2text_clip_sim.jsonl.$i \ 41 | --filtered-fname temp/1121/scene_captions_tag2text_high_sim.jsonl.$i 42 | done 43 | -------------------------------------------------------------------------------- /configs/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # some logit for loading data 2 | 3 | import os 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class WebvidPathMapping: 9 | def __init__(self, video_ids_fname) -> None: 10 | self.video_ids = set() 11 | self.video_ids_fname = video_ids_fname 12 | 13 | def __call__(self, input_fname): 14 | if not self.video_ids: 15 | self.video_ids = set([line.strip() for line in open(self.video_ids_fname)]) 16 | logger.info("In WebvidPathMapping, there are %d available videos in folder %s" % (len(self.video_ids), self.video_ids_fname)) 17 | fname = input_fname.split("/")[-1].split('.')[0] 18 | if fname not in self.video_ids: 19 | return None 20 | return os.path.join(fname[:3], fname) 21 | 22 | 23 | def anet_path_mapping(input_fname): 24 | fname = input_fname.split("/")[-1].split('.')[0] 25 | return fname 26 | 27 | 28 | def clevrer_path_mapping(input_fname): 29 | ''' 30 | video_02238.mp4 to video_02000-03000/video_02238.mp4 31 | ''' 32 | interval = int(input_fname.split('.')[0].split('_')[-1]) // 1000 33 | folder = 'video_%05d-%05d' % (interval * 1000, interval * 1000 + 1000) 34 | return os.path.join(folder, input_fname) 35 | 36 | 37 | webvid_path_mapping = WebvidPathMapping('data/WebVid/video_ids.txt') 38 | 39 | VIDEO_PATH_MAPPING = { 40 | 'caption_webvid': webvid_path_mapping, 41 | 'caption_videochat': webvid_path_mapping, 42 | 'conversation_videochat1': webvid_path_mapping, 43 | 'vqa_webvid_qa': webvid_path_mapping, 44 | 'conversation_videochatgpt': anet_path_mapping, 45 | 'reasoning_clevrer_qa': clevrer_path_mapping, 46 | 'reasoning_clevrer_mc': clevrer_path_mapping, 47 | } -------------------------------------------------------------------------------- /data_preparing/nextgqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import pandas as pd 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--func', type=str, default='grounding_qa') 9 | parser.add_argument('--annotation-fname', type=str, default='data/NExTQA/nextgqa/test.csv') 10 | parser.add_argument('--grounding-fname', type=str, default='data/NExTQA/nextgqa/gsub_test.json') 11 | parser.add_argument('--pred-span-fname', type=str) 12 | parser.add_argument('--video-mapping-fname', type=str, default='data/NExTQA/nextgqa/map_vid_vidorID.json') 13 | parser.add_argument('--output-fname', type=str, default='data/MVBench/json/nextgqa.json') 14 | args = parser.parse_args() 15 | print(args) 16 | 17 | video_mapping_dict = json.load(open(args.video_mapping_fname)) 18 | video_mapping_dict = {key: val + '.mp4' for key, val in video_mapping_dict.items()} 19 | df = pd.read_csv(args.annotation_fname) 20 | 21 | grounding_dict, video_lengths = dict(), dict() 22 | for video_key, data in json.load(open(args.grounding_fname)).items(): 23 | video_fname = video_mapping_dict[video_key] 24 | grounding_dict[video_fname] = dict() 25 | video_lengths[video_fname] = data['duration'] 26 | for question_key, spans in data['location'].items(): 27 | grounding_dict[video_fname][question_key] = spans 28 | 29 | if args.func in ["test_grounding"]: 30 | res_list = list() 31 | for line_i, row in df.iterrows(): 32 | video_fname = video_mapping_dict[str(row['video_id'])] 33 | res_dict = {'video': video_fname, 'question': row['question'] + '?', 'duration': video_lengths[video_fname], 'qid': row['qid'], 'answer': grounding_dict[video_fname][str(row['qid'])]} 34 | res_list.append(res_dict) 35 | json.dump(res_list, open(args.output_fname, 'w')) -------------------------------------------------------------------------------- /models/blip2/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | 6 | from .Qformer import BertConfig, BertLMHeadModel 7 | from models.utils import load_temp_embed_with_mismatch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def build_qformer(num_query_token, vision_width, 13 | qformer_hidden_dropout_prob=0.1, 14 | qformer_attention_probs_dropout_prob=0.1, 15 | drop_path_rate=0., 16 | ): 17 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True) 18 | encoder_config.encoder_width = vision_width 19 | # insert cross-attention layer every other block 20 | encoder_config.add_cross_attention = True 21 | encoder_config.cross_attention_freq = 2 22 | encoder_config.query_length = num_query_token 23 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 24 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 25 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)] 26 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 27 | logger.info(encoder_config) 28 | Qformer = BertLMHeadModel.from_pretrained( 29 | "bert-base-uncased", config=encoder_config, local_files_only=True 30 | ) 31 | query_tokens = nn.Parameter( 32 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 33 | ) 34 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 35 | return Qformer, query_tokens 36 | 37 | def interpolate_pos_embed_blip(state_dict, new_model): 38 | if "vision_temp_embed" in state_dict: 39 | vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"] 40 | state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch( 41 | state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False 42 | ) 43 | return state_dict 44 | -------------------------------------------------------------------------------- /internvid_g/scripts/run_llama2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv4 4 | #SBATCH -p HGX 5 | #SBATCH --time=10:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | # for i in 0 1 2 3 4 12 | # do 13 | # python -u code/clip_sim.py --func split_scene \ 14 | # --video-sample-fname temp/1121/video_ids.txt \ 15 | # --scene-fname temp/1121/scenes.jsonl.$i \ 16 | # --start-idx $((i*1800)) --end-idx $((i*1800+1800)) \ 17 | # > ./code/nohup/split_scene.log.$i 2>&1 & 18 | # done 19 | # wait 20 | 21 | # cat temp/1121/scenes.jsonl.* > temp/1121/scenes.jsonl 22 | # GPUS=(0 0 1 1 2 2) 23 | # for i in 0 1 2 3 4 5 24 | # do 25 | # CUDA_VISIBLE_DEVICES=${GPUS[i]} \ 26 | # python code/clip_sim.py --func scene_sim \ 27 | # --scene-fname temp/1121/scenes.jsonl \ 28 | # --scene-sim-fname temp/1121/scenes_similarity.jsonl.$i \ 29 | # --start-idx $((i*1500)) --end-idx $((i*1500+1500)) \ 30 | # > ./code/nohup/scene_sim.log.$i 2>&1 & 31 | # done 32 | # wait 33 | 34 | # cat scenes_similarity.jsonl.* > scenes_similarity.jsonl 35 | # python code/ground_data_construction.py \ 36 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \ 37 | # --caption-with-neg-interval-fname temp/1121/caption_with_neg_interval.jsonl 38 | 39 | 40 | # merge scene 41 | # python code/clip_sim.py --func merge_scene \ 42 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \ 43 | # --scene-merged-fname temp/1121/scenes_merged.jsonl --scene-merged-sim-fname temp/1121/scenes_merged_similarity.jsonl \ 44 | 45 | # test scene captioning with blip2 46 | CUDA_VISIBLE_DEVICES=0 \ 47 | python code/caption_clips.py --func llama2 \ 48 | --blip2-fname temp/scene_captions_blip2.jsonl \ 49 | --tag2text-fname temp/scene_captions_tag2text.jsonl \ 50 | --llama2-fname temp/scene_captions_llama2.jsonl 51 | -------------------------------------------------------------------------------- /internvid_g/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=internvid 3 | #SBATCH --qos=lv4 4 | #SBATCH -p HGX 5 | #SBATCH --time=10:00:00 6 | #SBATCH --cpus-per-task=6 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --output=./code/nohup/internvid.log 9 | #SBATCH --error=./code/nohup/internvid.log 10 | 11 | # for i in 0 1 2 3 4 12 | # do 13 | # python -u code/clip_sim.py --func split_scene \ 14 | # --video-sample-fname temp/1121/video_ids.txt \ 15 | # --scene-fname temp/1121/scenes.jsonl.$i \ 16 | # --start-idx $((i*1800)) --end-idx $((i*1800+1800)) \ 17 | # > ./code/nohup/split_scene.log.$i 2>&1 & 18 | # done 19 | # wait 20 | 21 | # cat temp/1121/scenes.jsonl.* > temp/1121/scenes.jsonl 22 | # GPUS=(0 0 1 1 2 2) 23 | # for i in 0 1 2 3 4 5 24 | # do 25 | # CUDA_VISIBLE_DEVICES=${GPUS[i]} \ 26 | # python code/clip_sim.py --func scene_sim \ 27 | # --scene-fname temp/1121/scenes.jsonl \ 28 | # --scene-sim-fname temp/1121/scenes_similarity.jsonl.$i \ 29 | # --start-idx $((i*1500)) --end-idx $((i*1500+1500)) \ 30 | # > ./code/nohup/scene_sim.log.$i 2>&1 & 31 | # done 32 | # wait 33 | 34 | # cat scenes_similarity.jsonl.* > scenes_similarity.jsonl 35 | # python code/ground_data_construction.py \ 36 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \ 37 | # --caption-with-neg-interval-fname temp/1121/caption_with_neg_interval.jsonl 38 | 39 | 40 | # merge scene 41 | # python code/clip_sim.py --func merge_scene \ 42 | # --scene-fname temp/1121/scenes.jsonl --scene-sim-fname temp/1121/scenes_similarity.jsonl \ 43 | # --scene-merged-fname temp/1121/scenes_merged.jsonl --scene-merged-sim-fname temp/1121/scenes_merged_similarity.jsonl \ 44 | 45 | # scene captioning with blip2 46 | # CUDA_VISIBLE_DEVICES=0 \ 47 | # python code/caption_clips.py --func blip2 \ 48 | # --blip2-fname temp/scene_captions_blip2.jsonl 49 | 50 | # filter similarity, though the summary now does not include blip2 51 | # CUDA_VISIBLE_DEVICES=0 \ 52 | # python code/caption_clips.py --func filter \ 53 | # --llama2-fname temp/scene_captions_llama2.jsonl \ 54 | # --filtered-fname temp/scene_captions_blip_filtered.jsonl \ 55 | # > code/nohup/scene_captions_filter.log 2>&1 & 56 | 57 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, name2loader): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.name2loader = name2loader 20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 22 | index2name = {v: k for k, v in name2index.items()} 23 | 24 | iter_order = [] 25 | for n, l in name2loader.items(): 26 | iter_order.extend([name2index[n]]*len(l)) 27 | 28 | random.shuffle(iter_order) 29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 30 | 31 | # sync 32 | if is_dist_avail_and_initialized(): 33 | # make sure all processes have the same order so that 34 | # each step they will have data from the same loader 35 | dist.broadcast(iter_order, src=0) 36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 37 | 38 | logger.info(str(self)) 39 | 40 | def __str__(self): 41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 42 | for idx, (name, loader) in enumerate(self.name2loader.items()): 43 | output.append( 44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " 45 | ) 46 | return "\n".join(output) 47 | 48 | def __len__(self): 49 | return len(self.iter_order) 50 | 51 | def __iter__(self): 52 | """ this iterator will run indefinitely """ 53 | for name in self.iter_order: 54 | _iter = self.name2iter[name] 55 | batch = next(_iter) 56 | yield name, batch 57 | -------------------------------------------------------------------------------- /models/bert/builder.py: -------------------------------------------------------------------------------- 1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel 2 | 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | def build_bert(model_config, pretrain, checkpoint): 7 | """build text encoder. 8 | 9 | Args: 10 | model_config (dict): model config. 11 | pretrain (bool): Whether to do pretrain or finetuning. 12 | checkpoint (bool): whether to do gradient_checkpointing. 13 | 14 | Returns: TODO 15 | 16 | """ 17 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 18 | bert_config.encoder_width = model_config.vision_encoder.d_model 19 | bert_config.gradient_checkpointing = checkpoint 20 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer 21 | 22 | if not model_config.multimodal.enable: 23 | bert_config.fusion_layer = bert_config.num_hidden_layers 24 | 25 | if pretrain: 26 | text_encoder, loading_info = BertForMaskedLM.from_pretrained( 27 | model_config.text_encoder.pretrained, 28 | config=bert_config, 29 | output_loading_info=True, 30 | ) 31 | else: 32 | text_encoder, loading_info = BertModel.from_pretrained( 33 | model_config.text_encoder.pretrained, 34 | config=bert_config, 35 | add_pooling_layer=False, 36 | output_loading_info=True, 37 | ) 38 | 39 | return text_encoder 40 | 41 | 42 | def build_bert_decoder(model_config, checkpoint): 43 | """build text decoder the same as the multimodal encoder. 44 | 45 | Args: 46 | model_config (dict): model config. 47 | pretrain (bool): Whether to do pretrain or finetuning. 48 | checkpoint (bool): whether to do gradient_checkpointing. 49 | 50 | Returns: TODO 51 | 52 | """ 53 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 54 | bert_config.encoder_width = model_config.vision_encoder.d_model 55 | bert_config.gradient_checkpointing = checkpoint 56 | 57 | bert_config.fusion_layer = 0 58 | bert_config.num_hidden_layers = ( 59 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer 60 | ) 61 | 62 | text_decoder, loading_info = BertLMHeadModel.from_pretrained( 63 | model_config.text_encoder.pretrained, 64 | config=bert_config, 65 | output_loading_info=True, 66 | ) 67 | 68 | return text_decoder 69 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from os.path import dirname, join 5 | 6 | from utils.config import Config 7 | from utils.distributed import init_distributed_mode, is_main_process 8 | from utils.logger import setup_logger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_config(): 14 | """Conbine yaml config and command line config with OmegaConf. 15 | Also converts types, e.g., `'None'` (str) --> `None` (None) 16 | """ 17 | config = Config.get_config() 18 | if config.debug: 19 | config.wandb.enable = False 20 | return config 21 | 22 | 23 | def setup_evaluate_config(config): 24 | """setup evaluation default settings, e.g., disable wandb""" 25 | assert config.evaluate 26 | config.wandb.enable = False 27 | if config.output_dir is None: 28 | config.output_dir = join(dirname(config.pretrained_path), "eval") 29 | return config 30 | 31 | 32 | def setup_output_dir(output_dir, excludes=["code"]): 33 | """ensure not overwritting an exisiting/non-empty output dir""" 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir, exist_ok=False) 36 | else: 37 | existing_dirs_files = os.listdir(output_dir) # list 38 | remaining = set(existing_dirs_files) - set(excludes) 39 | remaining = [e for e in remaining if "slurm" not in e] 40 | remaining = [e for e in remaining if ".out" not in e] 41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 42 | logger.warn(f"remaining dirs or files: {remaining}") 43 | 44 | 45 | def setup_main(): 46 | """ 47 | Setup config, logger, output_dir, etc. 48 | Shared for pretrain and all downstream tasks. 49 | """ 50 | config = setup_config() 51 | 52 | config.model.num_frame_tokens = 0 53 | for dataset_cfg in config.train_file: 54 | if dataset_cfg.get("grounding_method", None) == "frame_token": 55 | logger.info("dataset %s is using frame_token as grounding method, add %d special frame tokens to the model" % (dataset_cfg['dataset_name'], config.num_frames)) 56 | config.model.num_frame_tokens = config.num_frames 57 | 58 | if hasattr(config, "evaluate") and config.evaluate: 59 | config = setup_evaluate_config(config) 60 | init_distributed_mode(config) 61 | 62 | if is_main_process(): 63 | setup_output_dir(config.output_dir, excludes=["code"]) 64 | setup_logger(output=config.output_dir, color=True, name="vindlu") 65 | logger.info(f"config: {Config.pretty_text(config)}") 66 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 67 | return config 68 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | -------------------------------------------------------------------------------- /data_preparing/internvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import random 5 | import argparse 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--func', type=str, default='grounding', choices=['grounding', 'caption', 'choice', 'choice_caption']) 11 | parser.add_argument('--annotation-fname', type=str, default='data/InternVid-G/train.jsonl') 12 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/instructions.json') 13 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/questions.json') 14 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/internvid_grounding/train.json') 15 | parser.add_argument('--time-span-sent', type=str, default='From second %.1f to second %.1f.') # used for frame-level / second-level rep. 16 | args = parser.parse_args() 17 | print(args) 18 | 19 | res = list() 20 | 21 | instructions = json.load(open(args.instruction_fname)) 22 | input_fnames = glob.glob(args.annotation_fname) 23 | 24 | if os.path.exists(args.question): 25 | args.question = json.load(open(args.question)) 26 | else: 27 | args.question = [args.question] 28 | 29 | for input_fname in input_fnames: 30 | print('loading data from', input_fname) 31 | for line in open(input_fname).readlines(): 32 | example = json.loads(line) 33 | 34 | example['caption'] = example['caption'][0].upper() + example['caption'][1:] 35 | if example['caption'].endswith('.'): 36 | example['caption'] = example['caption'][:-1] 37 | 38 | new_example = {k: example[k] for k in ['start_sec', 'end_sec', 'neg_start_sec', 'neg_end_sec']} 39 | new_example['i'] = random.choice(instructions) 40 | if args.func == 'grounding': 41 | new_example['q'] = random.choice(args.question) % example['caption'] 42 | new_example['a'] = args.time_span_sent 43 | elif args.func == 'caption': 44 | new_example['a'] = example['caption'] 45 | new_example['q'] = args.time_span_sent 46 | elif args.func == 'choice': 47 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."] 48 | random.shuffle(options) 49 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)] 50 | new_example['q'] = "Question: " + random.choice(args.question) % example['caption'] + "\nOptions:" + "".join(options) 51 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options] 52 | elif args.func == 'choice_caption': 53 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."] 54 | new_example['q'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options] 55 | new_example['a'] = example['caption'] 56 | res.append({'video': example['video'], 'QA': [new_example]}) 57 | 58 | with open(args.output_fname, 'w') as f_out: 59 | json.dump(res, f_out) 60 | -------------------------------------------------------------------------------- /internvid_g/code/download_videos.py: -------------------------------------------------------------------------------- 1 | # download videos from youtube 2 | 3 | import time 4 | import os 5 | import json 6 | from multiprocessing import Pool 7 | import argparse 8 | 9 | 10 | def download_video(command): 11 | # time.sleep(2) 12 | video_id, target_dir = command 13 | # download videos with resolution nearest to 720p 14 | to_execute = f'yt-dlp "http://youtu.be/{video_id}" -o "{target_dir}/{video_id}.%(ext)s" -S "+res:720,fps" -q --no-check-certificate' 15 | res_code = os.system(to_execute) 16 | if res_code: 17 | error_video_fout.write(video_id + '\n') # though this is not multi-threaded safe, but this is less important 18 | error_video_fout.flush() 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--old_dirs", type=str, nargs="+", default=["videos"], help="folders with videos already downloaded, so we dont need to download them again") 24 | parser.add_argument("--target_dir", type=str, default='videos', help="folder to save downloaded videos") 25 | parser.add_argument("--src_file", type=str, default='temp/video_ids.txt', help="folder that contains video ids to download") 26 | parser.add_argument("--num_workers", type=int, default=5) 27 | parser.add_argument("--error_file", type=str, default='temp/error_video_ids.txt') 28 | args = parser.parse_args() 29 | 30 | video_ids = set() 31 | if args.src_file.endswith('jsonl'): # download from InternVid annotations or InternVid-G annotations 32 | for line in open(args.src_file): 33 | example = json.loads(line) 34 | if 'YoutubeID' in example: 35 | video_ids.add(example['YoutubeID']) 36 | if 'video_fname' in example: 37 | video_ids.add(example['video_fname'].split('.')[0]) 38 | 39 | elif args.src_file.endswith('json'): # download from InternVid annotations or InternVid-G annotations 40 | data = json.load(open(args.src_file)) 41 | for example in data: 42 | video_ids.add(example['video'][:11]) # the first 11 letters is the youtube id 43 | 44 | elif args.src_file.endswith('txt'): # download from InternVid annotations or InternVid-G annotations 45 | for line in open(args.src_file): 46 | video_ids.add(line.strip()) 47 | 48 | else: # change annotation loading code as you wish 49 | raise NotImplementedError 50 | 51 | error_video_ids = set([line.strip() for line in open(args.error_file)]) if os.path.exists(args.error_file) else set() 52 | videos_to_download = video_ids - error_video_ids 53 | 54 | for dir in args.old_dirs + [args.target_dir]: 55 | if not os.path.exists(dir): 56 | print('path does not exist, skipping stat video in this folder', dir) 57 | continue 58 | videos_downloaded = set(['.'.join(fname.split('.')[:-1]) for fname in os.listdir(dir)]) 59 | videos_to_download = videos_to_download - videos_downloaded 60 | 61 | videos_to_download = list(videos_to_download) 62 | videos_to_download.sort() 63 | 64 | print(f'{len(video_ids)} videos in total') 65 | print(f'{len(error_video_ids)} videos error') 66 | print(f'{len(videos_downloaded)} videos already downloaded') 67 | print(f'{len(videos_to_download)} videos to download') 68 | 69 | error_video_fout = open(args.error_file, 'a') 70 | commands = [(video_id, args.target_dir) for video_id in videos_downloaded] 71 | with Pool(args.num_workers) as p: 72 | p.map(download_video, commands) 73 | -------------------------------------------------------------------------------- /data_preparing/charades.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | 6 | 7 | def load_data(fname): 8 | data_list = list() 9 | for line in open(fname): 10 | info, sent = line.split('##') 11 | vid, start, end = info.split(' ') 12 | start, end = float(start), float(end) 13 | data_list.append({'video': vid + '.mp4', 'start': start, 'end': end, 'caption': sent.strip()}) 14 | return data_list 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--func', type=str, default='test_grounding') 20 | parser.add_argument('--video-lengths-fname', type=str, default='data/Charades-STA/video_lengths.json') 21 | parser.add_argument('--annotation-fname', type=str, default='data/Charades-STA/charades_sta_train.txt') 22 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/temporal/charades_sta_grounding-choice/instructions.json') 23 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/temporal/charades_sta_grounding-choice/questions.json') 24 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/charades_sta_grounding-choice/train.json') 25 | parser.add_argument('--time-span-sent', type=str, default='From frame %d to frame %d.') 26 | args = parser.parse_args() 27 | print(args) 28 | 29 | anno = load_data(args.annotation_fname) 30 | if os.path.exists(args.question): 31 | args.question = json.load(open(args.question)) 32 | else: 33 | args.question = [args.question] 34 | 35 | if args.func in ['grounding', 'caption', 'choice']: 36 | video_lengths = json.load(open(args.video_lengths_fname)) 37 | instructions = json.load(open(args.instruction_fname)) 38 | res = list() 39 | for example in anno: 40 | sent = example['caption'].replace('.', '') 41 | new_example = {'i': random.choice(instructions), 'start_sec': example['start'], 'end_sec': example['end'], 'neg_start_sec': 0, 'neg_end_sec': video_lengths[example['video'][:5]]} 42 | if args.func == 'grounding': 43 | new_example['q'] = random.choice(args.question) % sent 44 | new_example['a'] = args.time_span_sent 45 | elif args.func == 'caption': 46 | new_example['a'] = sent 47 | new_example['q'] = args.time_span_sent 48 | elif args.func == 'choice': 49 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."] 50 | random.shuffle(options) 51 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)] 52 | new_example['q'] = "Question: " + random.choice(args.question) % sent + "\nOptions:" + "".join(options) 53 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options] 54 | res.append({'video': example['video'], 'QA': [new_example]}) 55 | 56 | if args.func in ["test_grounding"]: 57 | res = list() 58 | for example in anno: 59 | sent = example['caption'].replace('.', '') 60 | new_example = {'video': example['video'], 'question': random.choice(args.question) % sent, 'answer': "%.1f-%.1f" % (example['start'], example['end'])} 61 | new_example = {'video': example['video'], 'question': sent, 'answer': [example['start'], example['end']]} 62 | res.append(new_example) 63 | 64 | with open(args.output_fname, 'w') as f_out: 65 | json.dump(res, f_out) 66 | -------------------------------------------------------------------------------- /utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /data_preparing/check_grounding_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import json\n", 11 | "\n", 12 | "\n", 13 | "def calculate_iou(pred_span, gold_span):\n", 14 | " gold_start, gold_end, pred_start, pred_end = gold_span[0], gold_span[1], pred_span[0], pred_span[1]\n", 15 | " intersection = max(0, min(gold_end, pred_end) - max(gold_start, pred_start))\n", 16 | " union = max(0, max(gold_end, pred_end) - min(gold_start, pred_start))\n", 17 | " if union <= 0 or intersection <= 0:\n", 18 | " return 0\n", 19 | " return intersection / union\n", 20 | "\n", 21 | "\n", 22 | "def check_ans(pred_span, gold_spans):\n", 23 | " if not isinstance(gold_spans[0], (list, tuple)):\n", 24 | " gold_spans = [gold_spans]\n", 25 | " return max([calculate_iou(pred_span, gold_span) for gold_span in gold_spans])\n", 26 | "\n", 27 | "\n", 28 | "def get_iou_at_different_turns(example, max_turns=4):\n", 29 | " pred_answer_list = example['pred_answer_list'] + ['throughout'] * (max_turns - len(example['pred_answer_list']))\n", 30 | " start, end = 0, example['duration']\n", 31 | " \n", 32 | " res_list = list()\n", 33 | " for pred in pred_answer_list:\n", 34 | " interval = (end - start) / 4\n", 35 | " if pred == 'beginning':\n", 36 | " end = end - 2*interval\n", 37 | " elif pred == 'middle':\n", 38 | " start, end = start + interval, end - interval\n", 39 | " elif pred == 'end':\n", 40 | " start = start + 2 * interval\n", 41 | " res_list.append(check_ans((start, end), example['gt_span']))\n", 42 | " return res_list\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# change output_fname at your need\n", 52 | "output_fname = '../outputs/charades_sta-recursive_grounding-4_turns.jsonl'\n", 53 | "gold_fname = '../data/test-anno/charades_sta-recursive_grounding.json'\n", 54 | "gold_data = json.load(open(gold_fname))\n", 55 | "max_turns = 4\n", 56 | "\n", 57 | "res_list_all_turns = list()\n", 58 | "\n", 59 | "for line, gold in zip(open(output_fname), gold_data):\n", 60 | " example = json.loads(line)\n", 61 | " if example['duration'] is None: continue\n", 62 | " example['gt_span'] = gold['answer']\n", 63 | " res_list = get_iou_at_different_turns(example, max_turns)\n", 64 | " res_list_all_turns.append(res_list)\n", 65 | "\n", 66 | "print(output_fname, 'num examples:', len(res_list_all_turns))\n", 67 | "for turns in range(max_turns):\n", 68 | " print('turns: %d' % (turns + 1))\n", 69 | " iou_list = [ious[turns] for ious in res_list_all_turns]\n", 70 | " print('mean iou: %.4f' % np.mean(iou_list))\n", 71 | " print('iou@0.3/0.5/0.7: %.4f/%.4f/%.4f' % tuple([len([i for i in iou_list if i > thres]) / len(iou_list) for thres in [0.3, 0.5, 0.7]]))\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [] 80 | } 81 | ], 82 | "metadata": { 83 | "kernelspec": { 84 | "display_name": "videochat", 85 | "language": "python", 86 | "name": "python3" 87 | }, 88 | "language_info": { 89 | "codemirror_mode": { 90 | "name": "ipython", 91 | "version": 3 92 | }, 93 | "file_extension": ".py", 94 | "mimetype": "text/x-python", 95 | "name": "python", 96 | "nbconvert_exporter": "python", 97 | "pygments_lexer": "ipython3", 98 | "version": "3.8.16" 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 2 103 | } 104 | -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | 6 | from dataset.utils import load_image_from_path 7 | 8 | try: 9 | from petrel_client.client import Client 10 | has_client = True 11 | except ImportError: 12 | has_client = False 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ImageVideoBaseDataset(Dataset): 18 | """Base class that implements the image and video loading methods""" 19 | 20 | media_type = "video" 21 | 22 | def __init__(self): 23 | assert self.media_type in ["image", "video", "only_video"] 24 | self.data_root = None 25 | self.anno_list = ( 26 | None # list(dict), each dict contains {"image": str, # image or video path} 27 | ) 28 | self.transform = None 29 | self.video_reader = None 30 | self.num_tries = None 31 | 32 | self.client = None 33 | if has_client: 34 | self.client = Client('~/petreloss.conf') 35 | 36 | def __getitem__(self, index): 37 | raise NotImplementedError 38 | 39 | def __len__(self): 40 | raise NotImplementedError 41 | 42 | def get_anno(self, index): 43 | """obtain the annotation for one media (video or image) 44 | 45 | Args: 46 | index (int): The media index. 47 | 48 | Returns: dict. 49 | - "image": the filename, video also use "image". 50 | - "caption": The caption for this file. 51 | 52 | """ 53 | anno = self.anno_list[index] 54 | if self.data_root is not None: 55 | anno["image"] = os.path.join(self.data_root, anno["image"]) 56 | return anno 57 | 58 | def load_and_transform_media_data(self, index, data_path): 59 | if self.media_type == "image": 60 | return self.load_and_transform_media_data_image(index, data_path) 61 | else: 62 | return self.load_and_transform_media_data_video(index, data_path) 63 | 64 | def load_and_transform_media_data_image(self, index, data_path): 65 | image = load_image_from_path(data_path, client=self.client) 66 | image = self.transform(image) 67 | return image, index 68 | 69 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None): 70 | for _ in range(self.num_tries): 71 | try: 72 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 73 | frames, frame_indices, fps = self.video_reader( 74 | data_path, self.num_frames, self.sample_type, 75 | max_num_frames=max_num_frames, client=self.client, clip=clip, fps=getattr(self, "fps", None) 76 | ) 77 | except Exception as e: 78 | logger.warning( 79 | f"Caught exception {e} when loading video {data_path}, " 80 | f"randomly sample a new video as replacement" 81 | ) 82 | index = random.randint(0, len(self) - 1) 83 | ann = self.get_anno(index) 84 | data_path = ann["image"] 85 | continue 86 | # shared aug for video frames 87 | frames = self.transform(frames) 88 | if return_fps: 89 | assert fps is not None or hasattr(self, 'fps') 90 | fps = fps if fps is not None else self.fps # for some case that fps cannot be inferred from the video, we can set it manually 91 | sec = [str(round(f / fps, 1)) for f in frame_indices] 92 | return frames, index, sec 93 | else: 94 | return frames, index 95 | else: 96 | raise RuntimeError( 97 | f"Failed to fetch video after {self.num_tries} tries. " 98 | f"This might indicate that you have many corrupted videos." 99 | ) 100 | -------------------------------------------------------------------------------- /models/blip2/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import os 9 | import logging 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .Qformer import BertConfig, BertLMHeadModel 15 | from .vit import build_vit 16 | from transformers import BertTokenizer 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Blip2Base(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | # @classmethod 26 | # def init_tokenizer(cls, truncation_side="right"): 27 | def init_tokenizer(self, truncation_side="right"): 28 | tokenizer = BertTokenizer.from_pretrained(self.bert_path, truncation_side=truncation_side, local_files_only=True) 29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 30 | return tokenizer 31 | 32 | @property 33 | def device(self): 34 | return list(self.parameters())[0].device 35 | 36 | def maybe_autocast(self, dtype=torch.float16): 37 | # if on cpu, don't use autocast 38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 39 | enable_autocast = self.device != torch.device("cpu") 40 | 41 | if enable_autocast: 42 | return torch.cuda.amp.autocast(dtype=dtype) 43 | else: 44 | return contextlib.nullcontext() 45 | 46 | # @classmethod 47 | def init_Qformer( 48 | # cls, 49 | self, 50 | num_query_token, num_frame_tokens, vision_width, 51 | qformer_hidden_dropout_prob=0.1, 52 | qformer_attention_probs_dropout_prob=0.1, 53 | qformer_drop_path_rate=0., 54 | ): 55 | encoder_config = BertConfig.from_pretrained(self.bert_path, local_files_only=True) 56 | encoder_config.num_frame_tokens = num_frame_tokens 57 | encoder_config.encoder_width = vision_width 58 | # insert cross-attention layer every other block 59 | encoder_config.add_cross_attention = True 60 | encoder_config.cross_attention_freq = 2 61 | encoder_config.query_length = num_query_token 62 | encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob 63 | encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob 64 | encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)] 65 | logger.info(f"Drop_path:{encoder_config.drop_path_list}") 66 | logger.info(encoder_config) 67 | Qformer = BertLMHeadModel(config=encoder_config) 68 | query_tokens = nn.Parameter( 69 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 70 | ) 71 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 72 | return Qformer, query_tokens 73 | 74 | @classmethod 75 | def init_vision_encoder_umt(self, config): 76 | """build vision encoder 77 | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. 78 | 79 | """ 80 | vision_encoder = build_vit(config) 81 | 82 | if config.vision_encoder.vit_add_ln: 83 | vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12) 84 | else: 85 | vision_layernorm = nn.Identity() 86 | 87 | return vision_encoder, vision_layernorm 88 | 89 | 90 | def disabled_train(self, mode=True): 91 | """Overwrite model.train with this function to make sure train/eval mode 92 | does not change anymore.""" 93 | return self 94 | 95 | 96 | class LayerNorm(nn.LayerNorm): 97 | """Subclass torch's LayerNorm to handle fp16.""" 98 | 99 | def forward(self, x: torch.Tensor): 100 | orig_type = x.dtype 101 | ret = super().forward(x.type(torch.float32)) 102 | return ret.type(orig_type) 103 | -------------------------------------------------------------------------------- /data_preparing/anetc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--func', type=str, default='test_grounding') 10 | parser.add_argument('--annotation-fname', type=str, default='data/ActivityNet/captions/train.json') 11 | parser.add_argument('--annotation-fname2', type=str, default=None) 12 | parser.add_argument('--instruction-fname', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/instructions.json') 13 | parser.add_argument('--question', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/questions.json') 14 | parser.add_argument('--output-fname', type=str, default='data/VideoChat2-IT/video/temporal/anetc_grounding-choice/train.json') 15 | parser.add_argument('--time-span-sent', type=str, default='From second %.1f to second %.1f.') # for frame-level and second-level 16 | parser.add_argument('--sample-ratio', type=float, default=1.0, help='random sample ratio for the test set') 17 | args = parser.parse_args() 18 | print(args) 19 | 20 | os.makedirs(os.path.dirname(args.output_fname), exist_ok=True) 21 | anno = json.load(open(args.annotation_fname)) 22 | if args.annotation_fname2 is not None: 23 | anno1, anno = anno, dict() 24 | anno2 = json.load(open(args.annotation_fname2)) 25 | keys = anno1.keys() | anno2.keys() 26 | for key in keys: 27 | if key in anno1: 28 | anno[key] = anno1[key] 29 | if key in anno2: 30 | anno[key]['sentences'] = anno1[key]['sentences'] + anno2[key]['sentences'] 31 | anno[key]['timestamps'] = anno1[key]['timestamps'] + anno2[key]['timestamps'] 32 | else: 33 | anno[key] = anno2[key] 34 | 35 | if os.path.exists(args.question): 36 | args.question = json.load(open(args.question)) 37 | else: 38 | args.question = [args.question] 39 | 40 | if args.func in ['grounding', 'caption', 'choice']: 41 | instructions = json.load(open(args.instruction_fname)) 42 | res = list() 43 | for video_id, video_data in anno.items(): 44 | for (start_sec, end_sec), sent in zip(video_data['timestamps'], video_data['sentences']): 45 | if sent.endswith('.'): 46 | sent = sent[:-1] 47 | new_example = {'i': random.choice(instructions), 'start_sec': start_sec, 'end_sec': end_sec, 'neg_start_sec': 0, 'neg_end_sec': video_data['duration']} 48 | if args.func == 'grounding': 49 | new_example['q'] = random.choice(args.question) % sent 50 | new_example['a'] = args.time_span_sent 51 | elif args.func == 'caption': 52 | new_example['a'] = sent 53 | new_example['q'] = args.time_span_sent 54 | elif args.func == 'choice': 55 | options = ["In the middle of the video.", "At the end of the video.", "Throughout the entire video.", "At the beginning of the video."] 56 | random.shuffle(options) 57 | options = ["\n(%s) %s" % ("ABCD"[i], opt) for i, opt in enumerate(options)] 58 | new_example['q'] = "Question: " + random.choice(args.question) % sent + "\nOptions:" + "".join(options) 59 | new_example['a'] = [([i for i in ["middle", "end", "throughout", "beginning"] if i in opt.lower()][0], opt.strip()) for opt in options] 60 | res.append({'video': video_id, 'QA': [new_example]}) 61 | 62 | if args.func in ["test_grounding"]: 63 | res = list() 64 | for video_id, video_data in anno.items(): 65 | for (start_sec, end_sec), sent in zip(video_data['timestamps'], video_data['sentences']): 66 | if random.random() > args.sample_ratio: continue 67 | sent = sent.replace('.', '').strip() 68 | new_example = {'video': video_id, 'question': sent, "answer": [start_sec, end_sec]} 69 | res.append(new_example) 70 | 71 | print('saved %d examples' % len(res)) 72 | with open(args.output_fname, 'w') as f_out: 73 | json.dump(res, f_out) 74 | -------------------------------------------------------------------------------- /tasks/shared_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | from torch.utils.data import ConcatDataset, DataLoader 9 | 10 | from utils.optimizer import create_optimizer 11 | from utils.scheduler import create_scheduler 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_media_types(datasources): 17 | """get the media types for for all the dataloaders. 18 | 19 | Args: 20 | datasources (List): List of dataloaders or datasets. 21 | 22 | Returns: List. The media_types. 23 | 24 | """ 25 | if isinstance(datasources[0], DataLoader): 26 | datasets = [dataloader.dataset for dataloader in datasources] 27 | else: 28 | datasets = datasources 29 | media_types = [ 30 | dataset.datasets[0].media_type 31 | if isinstance(dataset, ConcatDataset) 32 | else dataset.media_type 33 | for dataset in datasets 34 | ] 35 | 36 | return media_types 37 | 38 | 39 | def setup_model( 40 | config, model_cls, find_unused_parameters=False 41 | ): 42 | logger.info("Creating model") 43 | config = copy.deepcopy(config) 44 | 45 | model = model_cls(config=config.model) 46 | 47 | # model = model.to(torch.device(config.device)) 48 | if config.model.get('model_parallel', False): 49 | model.set_device_ids([int(os.environ["LOCAL_RANK"]), int(os.environ["LOCAL_RANK"]) + torch.distributed.get_world_size()]) 50 | else: 51 | model.set_device_ids([int(os.environ["LOCAL_RANK"])]) 52 | 53 | model_without_ddp = model 54 | if config.distributed: 55 | model = torch.nn.parallel.DistributedDataParallel( 56 | model, 57 | device_ids=[config.gpu] if not config.model.get('model_parallel', False) else None, 58 | find_unused_parameters=find_unused_parameters, # `False` for image-only task 59 | ) 60 | 61 | optimizer = create_optimizer(config.optimizer, model) 62 | scheduler = create_scheduler(config.scheduler, optimizer) 63 | scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) 64 | 65 | start_epoch = 0 66 | global_step = 0 67 | 68 | # auto resume the latest checkpoint 69 | if config.get("auto_resume", False): 70 | logger.info("Auto resuming") 71 | model_latest = join(config.output_dir, "ckpt_latest.pth") 72 | model_best = join(config.output_dir, "ckpt_best.pth") 73 | large_num = -1 74 | for p in os.listdir(config.output_dir): 75 | if 'ckpt' in p: 76 | num = p.split('_')[1].split('.')[0] 77 | if str.isnumeric(num): 78 | if int(num) > large_num: 79 | large_num = int(num) 80 | if large_num != -1: 81 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") 82 | if osp.isfile(model_latest): 83 | config.pretrained_path = model_latest 84 | config.resume = True 85 | elif osp.isfile(model_best): 86 | config.pretrained_path = model_best 87 | config.resume = True 88 | else: 89 | logger.info(f"Not found checkpoint in {config.output_dir}") 90 | 91 | if osp.isfile(config.pretrained_path): 92 | checkpoint = torch.load(config.pretrained_path, map_location="cpu") 93 | state_dict = checkpoint["model"] if model in checkpoint else checkpoint 94 | 95 | if config.resume: 96 | optimizer.load_state_dict(checkpoint["optimizer"]) 97 | scheduler.load_state_dict(checkpoint["scheduler"]) 98 | scaler.load_state_dict(checkpoint["scaler"]) 99 | start_epoch = checkpoint["epoch"] + 1 100 | global_step = checkpoint["global_step"] 101 | 102 | msg = model_without_ddp.load_state_dict(state_dict, strict=False) 103 | logger.info(msg) 104 | logger.info(f"Loaded checkpoint from {config.pretrained_path}") 105 | else: 106 | logger.warning("No pretrained checkpoint provided, training from scratch") 107 | 108 | return ( 109 | model, 110 | model_without_ddp, 111 | optimizer, 112 | scheduler, 113 | scaler, 114 | start_epoch, 115 | global_step, 116 | ) 117 | -------------------------------------------------------------------------------- /scripts/train/config_7b_stage3.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | train_corpus = "hawkeye_instruction" 5 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 6 | test_file = dict() 7 | test_types = [] 8 | num_workers = 6 9 | 10 | stop_key = None 11 | 12 | # ========================= input ========================== 13 | num_frames = 12 14 | num_frames_test = 12 15 | batch_size = 2 16 | max_txt_l = 512 17 | 18 | pre_text = False 19 | 20 | inputs = dict( 21 | image_res=224, 22 | video_input=dict( 23 | num_frames="${num_frames}", 24 | sample_type="rand", 25 | num_frames_test="${num_frames_test}", 26 | sample_type_test="middle", 27 | random_aug=False, 28 | ), 29 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 30 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 31 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 32 | ) 33 | 34 | # ========================= model ========================== 35 | model = dict( 36 | model_cls="HawkEye_it", 37 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth", 38 | llama_model_path="model/vicuna-7b-v0", 39 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth", 40 | freeze_vit=True, 41 | freeze_qformer=False, 42 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 43 | # vit 44 | low_resource=False, 45 | add_temp_embed=False, 46 | vision_encoder=dict( 47 | name="vit_l14", 48 | img_size=224, 49 | patch_size=16, 50 | d_model=1024, 51 | encoder_embed_dim=1024, 52 | encoder_depth=24, 53 | encoder_num_heads=16, 54 | drop_path_rate=0., 55 | num_frames="${num_frames}", 56 | tubelet_size=1, 57 | use_checkpoint=False, 58 | checkpoint_num=0, 59 | pretrained="", 60 | return_index=-2, 61 | vit_add_ln=True, 62 | ckpt_num_frame=4, 63 | ), 64 | # qformer 65 | num_query_token=32, 66 | qformer_hidden_dropout_prob=0.1, 67 | qformer_attention_probs_dropout_prob=0.1, 68 | qformer_drop_path_rate=0.2, 69 | extra_num_query_token=64, 70 | qformer_text_input=True, 71 | # prompt 72 | system="", 73 | start_token="", 75 | add_second_msg=True, 76 | img_start_token="", 77 | img_end_token="", 78 | random_shuffle=True, 79 | use_flash_attention=False, 80 | use_lora=True, 81 | lora_r=16, 82 | lora_alpha=32, 83 | lora_dropout=0.1, 84 | # debug=True, 85 | ) 86 | 87 | optimizer = dict( 88 | opt="adamW", 89 | lr=2e-5, 90 | opt_betas=[0.9, 0.999], # default 91 | weight_decay=0.02, 92 | max_grad_norm=-1, # requires a positive float, use -1 to disable 93 | # use a different lr for some modules, e.g., larger lr for new modules 94 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 95 | ) 96 | 97 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 98 | 99 | evaluate = False 100 | deep_fusion = False 101 | evaluation = dict( 102 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 103 | eval_x_only=False, 104 | k_test=128, 105 | eval_offload=True, # offload gpu tensors to cpu to save memory. 106 | ) 107 | 108 | fp16 = True 109 | gradient_checkpointing = True 110 | 111 | # ========================= log ========================== 112 | tensorboard = dict( 113 | enable=True 114 | ) 115 | 116 | wandb = dict( 117 | enable=False, 118 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 119 | project="project", 120 | ) 121 | 122 | # ========================= others ========================== 123 | dist_url = "env://" 124 | device = "cuda" 125 | mode = "it" 126 | 127 | output_dir = None # output dir 128 | resume = False # if True, load optimizer and scheduler states as well 129 | debug = False 130 | log_freq = 100 131 | seed = 42 132 | 133 | save_latest = True 134 | auto_resume = True 135 | pretrained_path = "" # path to pretrained model weights, for resume only 136 | freeze_dataset_folder = None # save all data used in training 137 | -------------------------------------------------------------------------------- /scripts/train/charades_sta.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | available_corpus['charades_sta_ft'] = [ 4 | { 5 | "dataset_name": 'charades_sta_grounding', 'grounding_method': 'choice', 6 | 'label_file': f'{anno_root_it}/video/temporal/charades_sta_grounding-choice/train.json', 7 | 'data_root': f'data/videos/charades', 'media_type': 'video' 8 | } 9 | ] 10 | 11 | # ========================= data ========================== 12 | train_corpus = "charades_sta_ft" 13 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 14 | test_file = dict() 15 | test_types = [] 16 | num_workers = 6 17 | 18 | stop_key = None 19 | 20 | # ========================= input ========================== 21 | num_frames = 12 22 | num_frames_test = 12 23 | batch_size = 1 24 | max_txt_l = 512 25 | 26 | pre_text = False 27 | 28 | inputs = dict( 29 | image_res=224, 30 | video_input=dict( 31 | num_frames="${num_frames}", 32 | sample_type="rand", 33 | num_frames_test="${num_frames_test}", 34 | sample_type_test="middle", 35 | random_aug=False, 36 | ), 37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 38 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 40 | ) 41 | 42 | # ========================= model ========================== 43 | model = dict( 44 | model_cls="HawkEye_it", 45 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth", 46 | llama_model_path="model/vicuna-7b-v0", 47 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth", 48 | freeze_vit=True, 49 | freeze_qformer=False, 50 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 51 | # vit 52 | low_resource=False, 53 | add_temp_embed=False, 54 | vision_encoder=dict( 55 | name="vit_l14", 56 | img_size=224, 57 | patch_size=16, 58 | d_model=1024, 59 | encoder_embed_dim=1024, 60 | encoder_depth=24, 61 | encoder_num_heads=16, 62 | drop_path_rate=0., 63 | num_frames="${num_frames}", 64 | tubelet_size=1, 65 | use_checkpoint=False, 66 | checkpoint_num=0, 67 | pretrained="", 68 | return_index=-2, 69 | vit_add_ln=True, 70 | ckpt_num_frame=4, 71 | ), 72 | # qformer 73 | num_query_token=32, 74 | qformer_hidden_dropout_prob=0.1, 75 | qformer_attention_probs_dropout_prob=0.1, 76 | qformer_drop_path_rate=0.2, 77 | extra_num_query_token=64, 78 | qformer_text_input=True, 79 | # prompt 80 | system="", 81 | start_token="", 83 | add_second_msg=True, 84 | img_start_token="", 85 | img_end_token="", 86 | random_shuffle=True, 87 | use_flash_attention=False, 88 | use_lora=True, 89 | lora_r=16, 90 | lora_alpha=32, 91 | lora_dropout=0.1, 92 | debug=False, 93 | ) 94 | 95 | optimizer = dict( 96 | opt="adamW", 97 | lr=2e-5, 98 | opt_betas=[0.9, 0.999], # default 99 | weight_decay=0.02, 100 | max_grad_norm=-1, # requires a positive float, use -1 to disable 101 | # use a different lr for some modules, e.g., larger lr for new modules 102 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 103 | ) 104 | 105 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.25, warmup_epochs=0.6) 106 | 107 | evaluate = False 108 | deep_fusion = False 109 | evaluation = dict( 110 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 111 | eval_x_only=False, 112 | k_test=128, 113 | eval_offload=True, # offload gpu tensors to cpu to save memory. 114 | ) 115 | 116 | fp16 = True 117 | gradient_checkpointing = False 118 | 119 | # ========================= log ========================== 120 | tensorboard = dict( 121 | enable=True 122 | ) 123 | 124 | wandb = dict( 125 | enable=False, 126 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 127 | project="project", 128 | ) 129 | 130 | # ========================= others ========================== 131 | dist_url = "env://" 132 | device = "cuda" 133 | mode = "it" 134 | 135 | output_dir = None # output dir 136 | resume = False # if True, load optimizer and scheduler states as well 137 | debug = False 138 | log_freq = 50 139 | save_freq = 1000 140 | seed = 42 141 | 142 | save_latest = False 143 | auto_resume = True 144 | pretrained_path = "" # path to pretrained model weights, for resume only? 145 | freeze_dataset_folder = None 146 | -------------------------------------------------------------------------------- /scripts/train/anetc.py: -------------------------------------------------------------------------------- 1 | from configs.instruction_data import * 2 | 3 | available_corpus['anetc_ft'] = [ 4 | { 5 | "dataset_name": 'charades_sta_grounding', 'grounding_method': 'choice', 6 | 'label_file': f'{anno_root_it}/video/temporal/anetc_grounding/train.json', 7 | 'data_root': f'data/videos/activitynet', 'media_type': 'video', 'fps': 3 8 | } 9 | ] 10 | 11 | # ========================= data ========================== 12 | train_corpus = "anetc_ft" 13 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 14 | test_file = dict() 15 | test_types = [] 16 | num_workers = 6 17 | 18 | stop_key = None 19 | 20 | # ========================= input ========================== 21 | num_frames = 12 22 | num_frames_test = 12 23 | batch_size = 4 24 | max_txt_l = 512 25 | 26 | pre_text = False 27 | 28 | inputs = dict( 29 | image_res=224, 30 | video_input=dict( 31 | num_frames="${num_frames}", 32 | sample_type="rand", 33 | num_frames_test="${num_frames_test}", 34 | sample_type_test="middle", 35 | random_aug=False, 36 | ), 37 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 38 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 39 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 40 | ) 41 | 42 | # ========================= model ========================== 43 | model = dict( 44 | model_cls="VideoChat2_it", 45 | vit_blip_model_path="model/VideoChat2/umt_l16_qformer.pth", 46 | llama_model_path="model/vicuna-7b-v0", 47 | videochat2_model_path="model/VideoChat2/videochat2_7b_stage2.pth", 48 | freeze_vit=True, 49 | freeze_qformer=False, 50 | max_txt_len="${max_txt_l}", # use large max_txt_len on stage3 51 | # vit 52 | low_resource=False, 53 | add_temp_embed=False, 54 | vision_encoder=dict( 55 | name="vit_l14", 56 | img_size=224, 57 | patch_size=16, 58 | d_model=1024, 59 | encoder_embed_dim=1024, 60 | encoder_depth=24, 61 | encoder_num_heads=16, 62 | drop_path_rate=0., 63 | num_frames="${num_frames}", 64 | tubelet_size=1, 65 | use_checkpoint=False, 66 | checkpoint_num=0, 67 | pretrained="", 68 | return_index=-2, 69 | vit_add_ln=True, 70 | ckpt_num_frame=4, 71 | ), 72 | # qformer 73 | num_query_token=32, 74 | qformer_hidden_dropout_prob=0.1, 75 | qformer_attention_probs_dropout_prob=0.1, 76 | qformer_drop_path_rate=0.2, 77 | extra_num_query_token=64, 78 | qformer_text_input=True, 79 | # prompt 80 | system="", 81 | start_token="", 83 | add_second_msg=True, 84 | img_start_token="", 85 | img_end_token="", 86 | random_shuffle=True, 87 | use_flash_attention=False, 88 | use_lora=True, 89 | lora_r=16, 90 | lora_alpha=32, 91 | lora_dropout=0.1, 92 | debug=False, 93 | ) 94 | 95 | optimizer = dict( 96 | opt="adamW", 97 | lr=2e-5, 98 | opt_betas=[0.9, 0.999], # default 99 | weight_decay=0.02, 100 | max_grad_norm=-1, # requires a positive float, use -1 to disable 101 | # use a different lr for some modules, e.g., larger lr for new modules 102 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 103 | ) 104 | 105 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.25, warmup_epochs=0.6) 106 | 107 | evaluate = False 108 | deep_fusion = False 109 | evaluation = dict( 110 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 111 | eval_x_only=False, 112 | k_test=128, 113 | eval_offload=True, # offload gpu tensors to cpu to save memory. 114 | ) 115 | 116 | fp16 = True 117 | # gradient_checkpointing = True 118 | gradient_checkpointing = False 119 | 120 | # ========================= log ========================== 121 | tensorboard = dict( 122 | enable=True 123 | ) 124 | 125 | wandb = dict( 126 | enable=False, 127 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 128 | project="project", 129 | ) 130 | 131 | # ========================= others ========================== 132 | dist_url = "env://" 133 | device = "cuda" 134 | mode = "it" 135 | 136 | output_dir = None # output dir 137 | resume = False # if True, load optimizer and scheduler states as well 138 | debug = False 139 | log_freq = 50 140 | save_freq = 2000 141 | seed = 42 142 | 143 | save_latest = False 144 | auto_resume = True 145 | pretrained_path = "" # path to pretrained model weights, for resume only? 146 | freeze_dataset_folder = None 147 | -------------------------------------------------------------------------------- /internvid_g/README.md: -------------------------------------------------------------------------------- 1 | InternVid-G: A Large-Scale Video-Text Dataset with Scene-Level Annotations for Temporal Grounding 2 | === 3 | ![dataset overview](../assets/dataset.jpg) 4 | 5 | InternVid-G is a dataset based on a fraction of videos from [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid). It has segment-level annotation, where each segment is annotated with: 6 | - a caption, which is related to the semantic content of this segment; 7 | - `start_sec` and `end_sec` of this segment; 8 | - a `neg_start_sec` and a `neg_end_sec` that defines a "negative span" of this segment. All other segments in the negative span are not semantically similar to the this segment (and its caption), so it can serve as the context for temporal video grounding models to retrieve this segment from. 9 | 10 | # Data Download 11 | You can download the annotations from [🤗HuggingFace](https://huggingface.co/datasets/wangyueqian/InternVid-G) 12 | 13 | You can use `code/download_videos.py` to download the videos from YouTube. 14 | 15 | # Processing (Optional) 16 | We also provide code for reproducing (downloading and processing) InternVid-G in `code/`. We hope this code will help you work with your datasets. 17 | 18 | This code is only for your reference, and we strongly suggest you read the code and become familiar with it before using it. 19 | 20 | 1. Save a list of YouTube video ids of your interest in `temp/video_ids.txt`. Save them one id per line, which looks like this: 21 | ```text 22 | UkQzZimp7Qs 23 | vJxjY1-0120 24 | 11mh9C7RCqg 25 | dcIuh7JaxWw 26 | d5i7JiBAMtc 27 | J0aVIs-eLMA 28 | ... 29 | ``` 30 | 31 | 2. Download the videos 32 | ```shell 33 | python code/download_videos.py --src_file temp/video_ids.txt --target_dir videos --num_workers 5 34 | ``` 35 | 36 | 3. Split the videos into scenes using [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) 37 | ```shell 38 | python code/clip_sim.py --func split_scene --scene-fname temp/scenes.jsonl 39 | ``` 40 | The splitted scenes are saved at `temp/scenes.jsonl`. This process may take a long time as PySceneDetect processes the video frame-by-frame. 41 | 42 | 4. Calculate the similarities of all scenes with each other, and get a scene sim matrix 43 | ```shell 44 | python code/clip_sim.py --func scene_sim --scene-fname temp/scenes.jsonl --scene-sim-fname temp/scenes_similarity.jsonl 45 | ``` 46 | The similarity matrices of the segments is saved at `temp/scenes_similarity.jsonl`. 47 | 48 | 5. Merge the most similar consecutive scenes, and the merged scene sim matrix 49 | ```shell 50 | python code/clip_sim.py --func merge_scene \ 51 | --scene-fname temp/scenes.jsonl --scene-sim-fname temp/scenes_similarity.jsonl \ 52 | --scene-merged-fname temp/scenes-merged.jsonl --scene-merged-sim-fname temp/scenes_merged_similarity.jsonl 53 | ``` 54 | The merged scenes and similarity matrices are saved at `temp/scenes-merged.jsonl` and `temp/scenes_merged_similarity.jsonl`. 55 | 56 | 6. Use a captioning model to get the caption of each scene (we find BLIP-2 works best in our case) 57 | ```shell 58 | python code/caption_clips.py --func blip2 \ 59 | --scene-fname temp/scenes-merged.jsonl --blip2-fname temp/scenes-blip2.jsonl 60 | ``` 61 | The segment captions are saved at `temp/scenes-blip2.jsonl`. 62 | 63 | 7. Calculate the similarity between each segment and its caption, remove the ones with low similarity 64 | ```shell 65 | python code/caption_clips.py --func filter --merge-method max \ 66 | --filter-input-fname temp/scenes-blip2.jsonl \ 67 | --filtered-fname temp/scenes-blip2-filtered.jsonl \ 68 | 69 | python code/caption_clips.py --func merge_filtered_captions 70 | --filter-input-fname temp/scenes-blip2-filtered.jsonl \ 71 | --filtered-fname temp/scenes-blip2-filtered-high_sims.jsonl \ 72 | ``` 73 | All segment-caption pairs with high similarities that are not filterd are saved in `temp/scenes-blip2-filtered-high_sims.jsonl`. 74 | 75 | 8. (Finally the last step!) Get a negative span for each segment. 76 | ```shell 77 | python code/ground_data_construction.py \ 78 | --caption-fname temp/scenes-blip2-filtered-high_sims.jsonl \ 79 | --scene-fname temp/scenes-merged.jsonl --scene-sim-fname temp/scenes_similarity-merged.jsonl \ 80 | --caption-with-neg-interval-fname temp/final_dataset.jsonl 81 | ``` 82 | Finally you can get all annotations in `temp/final_dataset.jsonl`. 83 | 84 | # Citation 85 | If you find this code useful in your research, please consider citing: 86 | ```bibtex 87 | @misc{wang2024hawkeye, 88 | title={HawkEye: Training Video-Text LLMs for Grounding Text in Videos}, 89 | author={Yueqian Wang and Xiaojun Meng and Jianxin Liang and Yuxuan Wang and Qun Liu and Dongyan Zhao}, 90 | year={2024}, 91 | eprint={2403.10228}, 92 | archivePrefix={arXiv}, 93 | primaryClass={cs.CV} 94 | } 95 | ``` -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def setup_for_distributed(is_master): 11 | import warnings 12 | 13 | builtin_warn = warnings.warn 14 | 15 | def warn(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_master or force: 18 | builtin_warn(*args, **kwargs) 19 | 20 | # Log warnings only once 21 | warnings.warn = warn 22 | warnings.simplefilter("once", UserWarning) 23 | 24 | if not is_master: 25 | logging.disable() 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def save_on_master(*args, **kwargs): 53 | if is_main_process(): 54 | torch.save(*args, **kwargs) 55 | 56 | 57 | def is_port_in_use(port): 58 | import socket 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(('localhost', port)) == 0 61 | 62 | 63 | def init_distributed_mode(args): 64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 65 | # job started by torch.distributed.launch 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | elif 'SLURM_PROCID' in os.environ: 70 | # local rank on the current node / global rank 71 | local_rank = int(os.environ['SLURM_LOCALID']) 72 | global_rank = int(os.environ['SLURM_PROCID']) 73 | # number of processes / GPUs per node 74 | world_size = int(os.environ["SLURM_NNODES"]) * \ 75 | int(os.environ["SLURM_TASKS_PER_NODE"][0]) 76 | 77 | print(world_size) 78 | 79 | args.rank = global_rank 80 | args.gpu = local_rank 81 | args.world_size = world_size 82 | else: 83 | logger.info('Not using distributed mode') 84 | args.distributed = False 85 | return 86 | 87 | args.distributed = True 88 | 89 | logger.info('torch.cuda.set_device', args.gpu) 90 | torch.cuda.set_device(args.gpu) 91 | args.dist_backend = 'nccl' 92 | 93 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node 94 | dist_port = int(args.dist_url.split(":")[-1]) 95 | while is_port_in_use(dist_port): 96 | dist_port += 10 97 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) 98 | 99 | logger.info('| distributed init (rank {}): {}'.format( 100 | args.rank, args.dist_url)) 101 | if "SLURM_JOB_ID" in os.environ: 102 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") 103 | torch.distributed.init_process_group( 104 | backend=args.dist_backend, init_method=args.dist_url, 105 | world_size=args.world_size, rank=args.rank) 106 | torch.distributed.barrier() 107 | setup_for_distributed(args.rank == 0) 108 | 109 | 110 | # Copyright (c) Facebook, Inc. and its affiliates. 111 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py 112 | class GatherLayer(torch.autograd.Function): 113 | """ 114 | Gather tensors from all workers with support for backward propagation: 115 | This implementation does not cut the gradients as torch.distributed.all_gather does. 116 | """ 117 | 118 | @staticmethod 119 | def forward(ctx, x): 120 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 121 | dist.all_gather(output, x) 122 | return tuple(output) 123 | 124 | @staticmethod 125 | def backward(ctx, *grads): 126 | all_gradients = torch.stack(grads) 127 | dist.all_reduce(all_gradients) 128 | return all_gradients[dist.get_rank()] 129 | 130 | 131 | # copied from megavlt 132 | def gather_tensor_along_batch_with_backward(tensor, dim=0): 133 | world_size = get_world_size() 134 | 135 | if world_size < 2: 136 | return tensor 137 | 138 | tensor_list = GatherLayer.apply(tensor) 139 | tensor_list = torch.cat(tensor_list, dim=dim) 140 | return tensor_list 141 | 142 | 143 | @torch.no_grad() 144 | def gather_tensor_along_batch(tensor, dim=0): 145 | """ 146 | Performs all_gather operation on the provided tensors. 147 | *** Warning ***: torch.distributed.all_gather has no gradient. 148 | """ 149 | world_size = get_world_size() 150 | 151 | if world_size < 2: 152 | return tensor 153 | 154 | with torch.no_grad(): 155 | tensor_list = [] 156 | 157 | for _ in range(world_size): 158 | tensor_list.append(torch.zeros_like(tensor)) 159 | 160 | dist.all_gather(tensor_list, tensor) 161 | tensor_list = torch.cat(tensor_list, dim=dim) 162 | return tensor_list 163 | -------------------------------------------------------------------------------- /configs/instruction_data.py: -------------------------------------------------------------------------------- 1 | import os as __os # add "__" if not want to be exported 2 | 3 | anno_root_it = "data/HawkEye-IT" 4 | 5 | # ============== pretraining datasets================= 6 | 7 | available_corpus = dict( 8 | caption_textvr={ 9 | "dataset_name": "caption_textvr", 10 | "label_file": f"{anno_root_it}/video/caption/textvr/train.json", 11 | "data_root": "data/videos/textvr", 12 | "media_type": "video", 13 | }, 14 | 15 | caption_videochat={ 16 | "dataset_name": "caption_videochat", 17 | "label_file": f"{anno_root_it}/video/caption/videochat/train.json", 18 | "data_root": "data/videos/webvid", 19 | "media_type": "video", 20 | }, 21 | 22 | caption_webvid={ 23 | "dataset_name": "caption_webvid", 24 | "label_file": f"{anno_root_it}/video/caption/webvid/train.json", 25 | "data_root": "data/videos/webvid", 26 | "media_type": "video", 27 | }, 28 | 29 | caption_youcook2={ 30 | "dataset_name": "caption_youcook2", 31 | "label_file": f"{anno_root_it}/video/caption/youcook2/train.json", 32 | "data_root": "data/videos/youcook2", 33 | "media_type": "video", 34 | }, 35 | 36 | classification_k710={ 37 | "dataset_name": "classification_k710", 38 | "label_file": f"{anno_root_it}/video/classification/k710/train.json", 39 | "data_root": "data/videos/kinetics", 40 | "media_type": "video", 41 | }, 42 | 43 | classification_ssv2={ 44 | "dataset_name": "classification_ssv2", 45 | "label_file": f"{anno_root_it}/video/classification/ssv2/train.json", 46 | "data_root": "data/videos/ssv2", 47 | "media_type": "video", 48 | }, 49 | 50 | conversation_videochat1={ 51 | "dataset_name": "conversation_videochat1", 52 | "label_file": f"{anno_root_it}/video/conversation/videochat1/train.json", 53 | "data_root": "data/videos/webvid", 54 | "media_type": "video", 55 | }, 56 | 57 | conversation_videochatgpt={ 58 | "dataset_name": "conversation_videochatgpt", 59 | "label_file": f"{anno_root_it}/video/conversation/videochatgpt/train.json", 60 | "data_root": "data/videos/activitynet", 61 | "media_type": "video", 62 | }, 63 | 64 | reasoning_next_qa={ 65 | "dataset_name": "reasoning_next_qa", 66 | "label_file": f"{anno_root_it}/video/reasoning/next_qa/train.json", 67 | "data_root": "data/videos/nextqa", 68 | "media_type": "video", 69 | }, 70 | 71 | reasoning_clevrer_qa={ 72 | "dataset_name": "reasoning_clevrer_qa", 73 | "label_file": f"{anno_root_it}/video/reasoning/clevrer_qa/train.json", 74 | "data_root": "data/videos/clevrer", 75 | "media_type": "video", 76 | }, 77 | 78 | reasoning_clevrer_mc={ 79 | "dataset_name": "reasoning_clevrer_mc", 80 | "label_file": f"{anno_root_it}/video/reasoning/clevrer_mc/train.json", 81 | "data_root": "data/videos/clevrer", 82 | "media_type": "video", 83 | }, 84 | 85 | vqa_tgif_frame_qa={ 86 | "dataset_name": "vqa_tgif_frame_qa", 87 | "label_file": f"{anno_root_it}/video/vqa/tgif_frame_qa/train.json", 88 | "data_root": "data/videos/tgif", 89 | "media_type": "video", 90 | }, 91 | 92 | vqa_tgif_transition_qa={ 93 | "dataset_name": "vqa_tgif_transition_qa", 94 | "label_file": f"{anno_root_it}/video/vqa/tgif_transition_qa/train.json", 95 | "data_root": "data/videos/tgif", 96 | "media_type": "video", 97 | }, 98 | 99 | vqa_webvid_qa={ 100 | "dataset_name": "vqa_webvid_qa", 101 | "label_file": f"{anno_root_it}/video/vqa/webvid_qa/train.json", 102 | "data_root": "data/videos/webvid", 103 | "media_type": "video", 104 | }, 105 | 106 | internvid_grounding={ 107 | "dataset_name": "internvid_grounding", 'grounding_method': 'choice', 108 | "label_file": f"{anno_root_it}/video/temporal/internvid_grounding/train.json", 109 | "data_root": "data/videos/internvid-g", 110 | "media_type": "video", 111 | }, 112 | 113 | internvid_caption={ 114 | "dataset_name": "internvid_grounding", 'grounding_method': 'choice', 115 | "label_file": f"{anno_root_it}/video/temporal/internvid_caption/train.json", 116 | "data_root": "data/videos/internvid-g", 117 | "media_type": "video", 118 | }, 119 | 120 | ) 121 | 122 | # select the instruction training data you have 123 | available_corpus["hawkeye_instruction"] = [ 124 | available_corpus["caption_textvr"], 125 | available_corpus["caption_videochat"], 126 | available_corpus["caption_webvid"], 127 | available_corpus["caption_youcook2"], 128 | available_corpus["classification_k710"], 129 | available_corpus["classification_ssv2"], 130 | available_corpus["conversation_videochat1"], 131 | available_corpus["conversation_videochat2"], 132 | available_corpus["conversation_videochatgpt"], 133 | available_corpus["reasoning_next_qa"], 134 | available_corpus["reasoning_clevrer_qa"], 135 | available_corpus["reasoning_clevrer_mc"], 136 | available_corpus["vqa_ego_qa"], 137 | available_corpus["vqa_tgif_frame_qa"], 138 | available_corpus["vqa_tgif_transition_qa"], 139 | available_corpus["vqa_webvid_qa"], 140 | available_corpus["internvid_grounding"], 141 | available_corpus["internvid_caption"], 142 | ] 143 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import re 5 | import torch 6 | from torch import optim as optim 7 | from utils.distributed import is_main_process 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | try: 11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 12 | has_apex = True 13 | except ImportError: 14 | has_apex = False 15 | 16 | 17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): 18 | named_param_tuples = [] 19 | for name, param in model.named_parameters(): 20 | if not param.requires_grad: 21 | continue # frozen weights 22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): 23 | named_param_tuples.append([name, param, 0]) 24 | elif name in no_decay_list: 25 | named_param_tuples.append([name, param, 0]) 26 | else: 27 | named_param_tuples.append([name, param, weight_decay]) 28 | return named_param_tuples 29 | 30 | 31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): 32 | """use lr=diff_lr for modules named found in diff_lr_names, 33 | otherwise use lr=default_lr 34 | 35 | Args: 36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module 37 | diff_lr_names: List(str) 38 | diff_lr: float 39 | default_lr: float 40 | Returns: 41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr]) 42 | """ 43 | named_param_tuples_with_lr = [] 44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") 45 | for name, p, wd in named_param_tuples_or_model: 46 | use_diff_lr = False 47 | for diff_name in diff_lr_names: 48 | # if diff_name in name: 49 | if re.search(diff_name, name) is not None: 50 | logger.info(f"param {name} use different_lr: {diff_lr}") 51 | use_diff_lr = True 52 | break 53 | 54 | named_param_tuples_with_lr.append( 55 | [name, p, wd, diff_lr if use_diff_lr else default_lr] 56 | ) 57 | 58 | if is_main_process(): 59 | for name, _, wd, diff_lr in named_param_tuples_with_lr: 60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") 61 | 62 | return named_param_tuples_with_lr 63 | 64 | 65 | def create_optimizer_params_group(named_param_tuples_with_lr): 66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" 67 | group = {} 68 | for name, p, wd, lr in named_param_tuples_with_lr: 69 | if wd not in group: 70 | group[wd] = {} 71 | if lr not in group[wd]: 72 | group[wd][lr] = [] 73 | group[wd][lr].append(p) 74 | 75 | optimizer_params_group = [] 76 | for wd, lr_groups in group.items(): 77 | for lr, p in lr_groups.items(): 78 | optimizer_params_group.append(dict( 79 | params=p, 80 | weight_decay=wd, 81 | lr=lr 82 | )) 83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") 84 | return optimizer_params_group 85 | 86 | 87 | def create_optimizer(args, model, filter_bias_and_bn=True): 88 | opt_lower = args.opt.lower() 89 | weight_decay = args.weight_decay 90 | # check for modules that requires different lr 91 | if hasattr(args, "different_lr") and args.different_lr.enable: 92 | diff_lr_module_names = args.different_lr.module_names 93 | diff_lr = args.different_lr.lr 94 | else: 95 | diff_lr_module_names = [] 96 | diff_lr = None 97 | 98 | no_decay = {} 99 | if hasattr(model, 'no_weight_decay'): 100 | no_decay = model.no_weight_decay() 101 | named_param_tuples = add_weight_decay( 102 | model, weight_decay, no_decay, filter_bias_and_bn) 103 | named_param_tuples = add_different_lr( 104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr) 105 | parameters = create_optimizer_params_group(named_param_tuples) 106 | 107 | if 'fused' in opt_lower: 108 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 109 | 110 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 111 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 112 | opt_args['eps'] = args.opt_eps 113 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 114 | opt_args['betas'] = args.opt_betas 115 | if hasattr(args, 'opt_args') and args.opt_args is not None: 116 | opt_args.update(args.opt_args) 117 | 118 | opt_split = opt_lower.split('_') 119 | opt_lower = opt_split[-1] 120 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 121 | opt_args.pop('eps', None) 122 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 123 | elif opt_lower == 'momentum': 124 | opt_args.pop('eps', None) 125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 126 | elif opt_lower == 'adam': 127 | optimizer = optim.Adam(parameters, **opt_args) 128 | elif opt_lower == 'adamw': 129 | optimizer = optim.AdamW(parameters, **opt_args) 130 | else: 131 | assert False and "Invalid optimizer" 132 | raise ValueError 133 | return optimizer 134 | -------------------------------------------------------------------------------- /data_preparing/videoqa.py: -------------------------------------------------------------------------------- 1 | # convert many qa datasets to the input format 2 | import os 3 | import random 4 | import argparse 5 | import json 6 | import pandas as pd 7 | import numpy as np 8 | 9 | 10 | 11 | # nextqa, tvqa and star also have a time span label, but we omit them? 12 | def convert_nextqa_test(input_fname, output_fname, video_mapping_fname): 13 | video_mapping_dict = json.load(open(video_mapping_fname)) 14 | video_mapping_dict = {key: val + '.mp4' for key, val in video_mapping_dict.items()} 15 | df = pd.read_csv(input_fname) 16 | 17 | res_list = list() 18 | for line_i, row in df.iterrows(): 19 | video_fname = video_mapping_dict[str(row['video'])] 20 | res_dict = {'video': video_fname, 'question': row['question'] + '?', 'qid': row['qid'], 21 | 'candidates': [row['a%d' % i] + '.' for i in range(5)], 'answer': row['a%d' % row['answer']]} 22 | res_list.append(res_dict) 23 | json.dump(res_list, open(output_fname, 'w')) 24 | 25 | 26 | def convert_tvqa(input_fname, output_fname): 27 | tv_name_to_folder = { 28 | 'The Big Bang Theory': 'bbt_frames', 'Castle': 'castle_frames', 'How I Met You Mother': 'met_frames', 29 | "Grey's Anatomy": 'grey_frames', 'Friends': 'friends_frames', 'House M.D.': 'house_frames' 30 | } 31 | 32 | res_list = list() 33 | for line in open(input_fname): 34 | row = json.loads(line) 35 | start_sec, end_sec = row['ts'].split('-') 36 | res_dict = {'video': os.path.join(tv_name_to_folder[row['show_name']], row['vid_name']), 37 | 'question': row['q'], 'qid': row['qid'], 'candidates': [row['a%d' % i] for i in range(5)]} 38 | 39 | if not np.isnan(float(start_sec)) and not np.isnan(float(end_sec)): 40 | res_dict['start'], res_dict['end'] = float(start_sec), float(end_sec) 41 | 42 | if 'answer_idx' in row: 43 | res_dict['answer'] = row['a%d' % row['answer_idx']] 44 | res_list.append(res_dict) 45 | json.dump(res_list, open(output_fname, 'w')) 46 | 47 | 48 | def convert_star(input_fname, output_fname): 49 | res_list = list() 50 | for row in json.load(open(input_fname)): 51 | res_dict = {'video': row['video_id'] + '.mp4', 52 | 'question': row['question'], 'qid': row['question_id'], 53 | 'candidates': [c['choice'] for c in row['choices']]} 54 | if 'answer' in row: 55 | res_dict['answer'] = row['answer'] 56 | res_list.append(res_dict) 57 | json.dump(res_list, open(output_fname, 'w')) 58 | 59 | 60 | def convert_star_output(input_fname, output_fname, src_fname): 61 | ''' 62 | convert videochat2 output on star dataset to the submission format 63 | ''' 64 | res_dict = {key: [] for key in ['Interaction', 'Sequence', 'Prediction', 'Feasibility']} 65 | src_data = json.load(open(src_fname)) 66 | pred_data = [json.loads(line) for line in open(input_fname)] 67 | assert len(src_data) == len(pred_data) 68 | 69 | for example, src_example in zip(pred_data, src_data): 70 | for key in res_dict.keys(): 71 | if src_example['qid'].startswith(key): 72 | if example['pred'][1] in 'ABCD': 73 | res_dict[key].append({'question_id': src_example['qid'], 'answer': 'ABCD'.index(example['pred'][1])}) 74 | else: 75 | print('no choice letter found!') 76 | res_dict[key].append({'question_id': src_example['qid'], 'answer': random.choice([0, 1, 2, 3])}) 77 | json.dump(res_dict, open(output_fname, 'w')) 78 | 79 | 80 | def convert_tvqa_output(input_fname, output_folder, test_src_fname, val_src_fname): 81 | res_dict = dict() 82 | src_data = [json.loads(line) for line in open(test_src_fname)] 83 | pred_data = [json.loads(line) for line in open(input_fname)] 84 | assert len(src_data) == len(pred_data) 85 | 86 | os.makedirs(output_folder, exist_ok=True) 87 | for example, src_example in zip(pred_data, src_data): 88 | pred_id = 'ABCDE'.index(example['pred'][1]) if example['pred'][1] in 'ABCDE' else random.randint(0, 4) 89 | res_dict[src_example['qid']] = pred_id 90 | json.dump(res_dict, open(os.path.join(output_folder, 'prediction_test_public.json'), 'w')) 91 | 92 | # generate a random result for the val set 93 | src_data = [json.loads(line) for line in open(val_src_fname)] 94 | res_dict = dict() 95 | for src_example in src_data: 96 | pred_id = random.randint(0, 4) 97 | res_dict[src_example['qid']] = pred_id 98 | json.dump(res_dict, open(os.path.join(output_folder, 'prediction_val.json'), 'w')) 99 | 100 | # generatethe metadata 101 | metadata = { 102 | 'model_name': '/'.join(input_fname.split('/')[-2:]), 'is_ensemble': False, 'with_ts': True, 'show_on_leaderboard': False, 103 | 'author': 'null', 'institution': 'null', 'description': 'null', 'paper_link': 'null', 'code_link': 'null' 104 | } 105 | json.dump(metadata, open(os.path.join(output_folder, 'meta.json'), 'w')) 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--func', type=str) 112 | parser.add_argument('--input', type=str) 113 | parser.add_argument('--input2', type=str) 114 | parser.add_argument('--input3', type=str) 115 | parser.add_argument('--output', type=str) 116 | args = parser.parse_args() 117 | print(args) 118 | 119 | if args.func == 'nextqa-test': 120 | convert_nextqa_test(args.input, args.output, args.input2) 121 | 122 | if args.func == 'tvqa': 123 | convert_tvqa(args.input, args.output) 124 | 125 | if args.func == 'star': 126 | convert_star(args.input, args.output) 127 | 128 | if args.func == 'star_output': 129 | convert_star_output(args.input, args.output, args.input2) 130 | 131 | if args.func == 'tvqa_output': 132 | convert_tvqa_output(args.input, args.output, args.input2, args.input3) 133 | 134 | -------------------------------------------------------------------------------- /dataset/pt_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import sqlite3 5 | import random 6 | from os.path import basename 7 | 8 | import numpy as np 9 | 10 | from dataset.base_dataset import ImageVideoBaseDataset 11 | from dataset.utils import load_anno, pre_text 12 | from dataset.video_utils import VIDEO_READER_FUNCS 13 | from utils.distributed import is_main_process 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def get_anno_by_id(cur: sqlite3.Cursor, id: int): 19 | """TODO: Docstring for get_anno_by_id. 20 | 21 | Args: 22 | cur (sqlite3.Cursor): The dataset cursor. 23 | id (int): The annotation id. 24 | 25 | Returns: 26 | 27 | """ 28 | pass 29 | 30 | 31 | class PTImgTrainDataset(ImageVideoBaseDataset): 32 | media_type = "image" 33 | 34 | def __init__(self, ann_file, transform, pre_text=True): 35 | super().__init__() 36 | 37 | if len(ann_file) == 3 and ann_file[2] == "video": 38 | self.media_type = "video" 39 | else: 40 | self.media_type = "image" 41 | self.label_file, self.data_root = ann_file[:2] 42 | 43 | logger.info('Load json file') 44 | with open(self.label_file, 'r') as f: 45 | self.anno = json.load(f) 46 | self.num_examples = len(self.anno) 47 | 48 | self.transform = transform 49 | self.pre_text = pre_text 50 | logger.info(f"Pre-process text: {pre_text}") 51 | 52 | def get_anno(self, index): 53 | filename = self.anno[index][self.media_type] 54 | caption = self.anno[index]["caption"] 55 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption} 56 | return anno 57 | 58 | def __len__(self): 59 | return self.num_examples 60 | 61 | def __getitem__(self, index): 62 | try: 63 | ann = self.get_anno(index) 64 | image, index = self.load_and_transform_media_data(index, ann["image"]) 65 | caption = pre_text(ann["caption"], pre_text=self.pre_text) 66 | return image, caption, index 67 | except Exception as e: 68 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 69 | index = np.random.randint(0, len(self)) 70 | return self.__getitem__(index) 71 | 72 | 73 | class PTVidTrainDataset(PTImgTrainDataset): 74 | media_type = "video" 75 | 76 | def __init__( 77 | self, 78 | ann_file, 79 | transform, 80 | num_frames=4, 81 | video_reader_type="decord", 82 | sample_type="rand", 83 | num_tries=3, 84 | pre_text=True 85 | ): 86 | super().__init__(ann_file, transform, pre_text=pre_text) 87 | self.num_frames = num_frames 88 | self.video_reader_type = video_reader_type 89 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 90 | self.sample_type = sample_type 91 | self.num_tries = num_tries 92 | 93 | 94 | class PTImgEvalDataset(ImageVideoBaseDataset): 95 | media_type = "image" 96 | 97 | def __init__(self, ann_file, transform, has_multi_vision_gt=False): 98 | super(PTImgEvalDataset, self).__init__() 99 | self.raw_anno_list = load_anno(ann_file) 100 | self.transform = transform 101 | self.has_multi_vision_gt = has_multi_vision_gt # each caption has multiple image as ground_truth 102 | 103 | self.text = None 104 | self.image = None 105 | self.txt2img = None 106 | self.img2txt = None 107 | self.build_data() 108 | 109 | def build_data(self): 110 | self.text = [] 111 | self.image = [] 112 | self.txt2img = {} 113 | self.img2txt = {} 114 | if self.has_multi_vision_gt: 115 | self.build_data_multi_img_gt() 116 | else: 117 | self.build_data_multi_txt_gt() 118 | self.anno_list = [dict(image=e) for e in self.image] 119 | 120 | def build_data_multi_img_gt(self): 121 | """each text may have multiple ground_truth image, e.g., ssv2""" 122 | img_id = 0 123 | for txt_id, ann in enumerate(self.raw_anno_list): 124 | self.text.append(pre_text(ann["caption"])) 125 | self.txt2img[txt_id] = [] 126 | _images = ann["image"] \ 127 | if isinstance(ann["image"], list) else [ann["image"], ] 128 | for i, image in enumerate(_images): 129 | self.image.append(image) 130 | self.txt2img[txt_id].append(img_id) 131 | self.img2txt[img_id] = txt_id 132 | img_id += 1 133 | 134 | def build_data_multi_txt_gt(self): 135 | """each image may have multiple ground_truth text, e.g., COCO and Flickr30K""" 136 | txt_id = 0 137 | for img_id, ann in enumerate(self.raw_anno_list): 138 | self.image.append(ann["image"]) 139 | self.img2txt[img_id] = [] 140 | _captions = ann["caption"] \ 141 | if isinstance(ann["caption"], list) else [ann["caption"], ] 142 | for i, caption in enumerate(_captions): 143 | self.text.append(pre_text(caption)) 144 | self.img2txt[img_id].append(txt_id) 145 | self.txt2img[txt_id] = img_id 146 | txt_id += 1 147 | 148 | def __len__(self): 149 | return len(self.anno_list) 150 | 151 | def __getitem__(self, index): 152 | ann = self.anno_list[index] 153 | image, index = self.load_and_transform_media_data(index, ann["image"]) 154 | return image, index 155 | 156 | 157 | def preprocess_para_retrieval_data(anno_list): 158 | processed_anno_list = [] 159 | for d in anno_list: 160 | d["caption"] = " ".join(d.pop("caption")) 161 | processed_anno_list.append(d) 162 | return processed_anno_list 163 | 164 | 165 | class PTVidEvalDataset(PTImgEvalDataset): 166 | media_type = "video" 167 | 168 | def __init__( 169 | self, ann_file, transform, num_frames=4, 170 | video_reader_type="decord", sample_type="rand", num_tries=1, 171 | is_paragraph_retrieval=False, has_multi_vision_gt=False, 172 | ): 173 | super(PTVidEvalDataset, self).__init__(ann_file, transform, has_multi_vision_gt) 174 | self.num_frames = num_frames 175 | self.video_reader_type = video_reader_type 176 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 177 | self.sample_type = sample_type 178 | self.num_tries = num_tries 179 | self.is_paragraph_retrieval = is_paragraph_retrieval 180 | 181 | if is_paragraph_retrieval: 182 | self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list) 183 | self.build_data() 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | #
logo  HawkEye: Training Video-Text LLMs for Grounding Text in Videos
3 | 4 | [**[Paper]**](https://arxiv.org/abs/2403.10228) 5 | [**[Checkpoint]**](https://huggingface.co/wangyueqian/HawkEye) 6 | [**[Dataset]**](internvid_g/README.md) 7 | 8 | 9 | ## Updates 10 | - 2024/04/29: Update the model loading process, merged trained params of videochat2 to `hawkeye.pth`. Now only ckpts of vicuna-7b-v-0 and `hawkeye.pth` are needed to load Hawkeye. 11 | 12 | ## Introduction 13 | ![performance](assets/performance.jpg) 14 | Video-text Large Language Models (video-text LLMs) have shown remarkable performance in answering questions and holding conversations on simple videos. 15 | However, they perform almost the same as random on grounding text queries in long and complicated videos, having little ability to understand and reason about temporal information, which is the most fundamental difference between videos and images. 16 | 17 | We propose HawkEye, one of the first video-text LLMs that can perform temporal video grounding in a fully text-to-text manner. To collect training data that is applicable for temporal video grounding, we construct InternVid-G, a large-scale video-text corpus with segment-level captions and negative spans, with which we introduce two new time-aware training objectives to video-text LLMs. We also propose a coarse-grained method of representing segments in videos, which is more robust and easier for LLMs to learn and follow than other alternatives. 18 | 19 | ### Datasets and Models 20 | We release our HawkEye and our impl. VideoChat2 [**Model Checkpoints**](https://huggingface.co/wangyueqian/HawkEye), and [**InternVid-G Dataset**](internvid_g/README.md) at 🤗HuggingFace. 21 | 22 | ## Demo 23 | ![example1](assets/example1.jpg) 24 | ![example4](assets/example4.jpg) 25 | 26 | **Live Demo In progress** 27 | 28 | You can use `demo.ipynb` to test HawkEye on your data. 29 | 30 | ## Training 31 | ### Download model checkpoints 32 | 33 | - Create a directory `model/` for model checkpoints: `mkdir model/` 34 | - Follow [here](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat#running-usage) to prepare vicuna-7b-v0 35 | - Download the [HawkEye checkpoint](https://huggingface.co/wangyueqian/HawkEye) 36 | - (Optional) If you want to reproduce the instruction tuning process, download [umt_l16_qformer.pth](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/umt_l16_qformer.pth) and [videochat2_7b_stage2.pth](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/videochat2_7b_stage2.pth) from [VideoChat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) 37 | 38 | After downloading all model checkpoints, the `model/` folder should looks like this: 39 | ``` 40 | └── hawkeye.pth 41 | └── vicuna-7b-v0/ 42 | └── VideoChat2/ (optional) 43 | └── umt_l16_qformer.pth 44 | └── videochat2_7b_stage2.pth 45 | ``` 46 | 47 | ### Data preparation 48 | 49 | Download from [Dataset Homepage at 🤗HuggingFace](https://huggingface.co/datasets/wangyueqian/HawkEye-IT), and save in `data/HawkEye-IT/` folder. We also provide data proessing code in `data_preparing/`, you can use it for reference. 50 | 51 | Note that you also need to download the videos of each dataset from their original links, which is further explained in dataset homepage (this may take quite a while 🙃). Use soft links to link the video folder under `data/videos/`. 52 | 53 | After data preparation, the `data/` folder should looks like this: 54 | ``` 55 | └── HawkEye-IT/ 56 | └── image/ # inherited from VideoChat2-IT, but not used in training HawkEye 57 | └── video/ 58 | └── temporal/ 59 | └── internvid_grounding/, charades_sta_grounding/, anetc_grounding/ 60 | └── instructions.json, questions.json, train.json 61 | └── internvid_caption/ 62 | └── instructions.json, train.json 63 | └── caption/, classification/, conversation/, vqa/, reasoning/ 64 | └── videos/ 65 | └── internvid-g/, clevrer/, webvid/, activitynet/, tgif/, 66 | └── nextqa/, textvr/, youcook2/, kinetics/, ssv2/, charades/ 67 | ``` 68 | Note that `image/, caption/, classification/, conversation/, vqa/, reasoning/` folders of HawkEye-IT are identical to [VideoChat2-IT](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT). 69 | 70 | ### Run the instruction tuning process 71 | ```shell 72 | bash ./scripts/train/run_7b_stage3.sh OUTPUT_PATH 73 | ``` 74 | The instruction-tuned HawkEye checkpoint will be saved in `OUTPUT_PATH/ckpt_${ckpt}.pth`, where `${ckpt}` is the number of iterations you train. 75 | 76 | Check the script to ensure the hyperparameters fit your computing device. 77 | 78 | ### Run the finetuning process 79 | We also provide the scripts to finetune on Charades-STA and ActivityNet-Captions: 80 | ```shell 81 | # IT_CKPT: the instruction-tuned HawkEye checkpoint 82 | bash ./scripts/train/charades_sta.sh OUTPUT_PATH IT_CKPT 83 | bash ./scripts/train/anetc.sh OUTPUT_PATH IT_CKPT 84 | ``` 85 | Check the script to ensure the hyperparameters fit your computing device. 86 | 87 | ## Testing 88 | ### Data preparation 89 | 1. Download [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench) and save in `data/MVBench/` folder. 90 | 91 | 2. Download the annotation of other benchmarks from [Google Drive](https://drive.google.com/file/d/1WVx_UGAnCmBIp8GgCpvtxHHenpYVur5u/view?usp=sharing) and unzip to `data/test-anno/`. We also provide data proessing code in `data_preparing/`, you can use it for reference. 92 | 93 | 3. Download [TVQA videos](https://tvqa.cs.unc.edu/) and link it at `data/videos/tvqa` 94 | 95 | After downloading all benchmarks, the `data/` folder should like this: 96 | ``` 97 | └── HawkEye-IT/ # instruct tuning datasets 98 | └── MVBench/ 99 | └── test-anno/ 100 | └── charades_sta-recursive_grounding.json, anetc-recursive_grounding.json 101 | └── nextgqa-recursive_grounding.json 102 | └── nextqa-test.json, tvqa-test.json, star-test.json 103 | └── videos/ 104 | └── nextqa/, tvqa/, charades/, activitynet/, ... 105 | ``` 106 | 107 | #### Test on video qa benchmarks 108 | ``` 109 | bash ./scripts/test/videoqa.sh 110 | ``` 111 | refer to `data_preparing/videoqa.py` to convert the model outputs to the format required by [STAR evaluation](https://eval.ai/web/challenges/challenge-page/1325/overview) and [TVQA evaluation w/ ts](https://codalab.lisn.upsaclay.fr/competitions/6978). 112 | 113 | #### Test on temporal video grounding benchmarks with recursive grounding 114 | ``` 115 | bash ./scripts/test/recursive_grounding.sh 116 | ``` 117 | To analyze the results of each recursive grounding step, refer to `data_preparing/check_grounding_results.ipynb`. 118 | 119 | ## Citation 120 | If you find this code useful in your research, please consider citing: 121 | ```bibtex 122 | @misc{wang2024hawkeye, 123 | title={HawkEye: Training Video-Text LLMs for Grounding Text in Videos}, 124 | author={Yueqian Wang and Xiaojun Meng and Jianxin Liang and Yuxuan Wang and Qun Liu and Dongyan Zhao}, 125 | year={2024}, 126 | eprint={2403.10228}, 127 | archivePrefix={arXiv}, 128 | primaryClass={cs.CV} 129 | } 130 | ``` 131 | 132 | ## Acknowledgments 133 | This project is based on [VideoChat and VideoChat2](https://github.com/OpenGVLab/Ask-Anything). Thanks for their great work! 134 | -------------------------------------------------------------------------------- /dataset/video_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py 3 | """ 4 | import random 5 | import os 6 | import io 7 | import av 8 | import cv2 9 | import decord 10 | import imageio 11 | from decord import VideoReader 12 | import torch 13 | import numpy as np 14 | import math 15 | decord.bridge.set_bridge("torch") 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | 20 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: 21 | """ 22 | Converts a present time with the given time base and start_pts offset to seconds. 23 | 24 | Returns: 25 | time_in_seconds (float): The corresponding time in seconds. 26 | 27 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 28 | """ 29 | if pts == math.inf: 30 | return math.inf 31 | 32 | return int(pts - start_pts) * time_base 33 | 34 | 35 | def get_pyav_video_duration(video_reader): 36 | video_stream = video_reader.streams.video[0] 37 | video_duration = pts_to_secs( 38 | video_stream.duration, 39 | video_stream.time_base, 40 | video_stream.start_time 41 | ) 42 | return float(video_duration) 43 | 44 | 45 | def get_frame_indices_by_fps(): 46 | pass 47 | 48 | 49 | def get_frame_indices(num_frames, vlen, start_pos=0, end_pos=1, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 50 | assert 0 <= start_pos <= 1, "start_pos must be in [0, 1]" 51 | assert 0 <= end_pos <= 1, "end_pos must be in [0, 1]" 52 | 53 | start_frame, end_frame = round(start_pos * vlen), round(end_pos * vlen) 54 | 55 | if sample in ["rand", "middle"]: # uniform sampling 56 | acc_samples = min(num_frames, vlen) 57 | # split the video into `acc_samples` intervals, and sample from each interval. 58 | intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) 59 | ranges = [] 60 | for idx, interv in enumerate(intervals[:-1]): 61 | ranges.append((interv, intervals[idx + 1] - 1)) 62 | if sample == 'rand': 63 | try: 64 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] 65 | except: 66 | frame_indices = np.random.permutation(vlen)[:acc_samples] 67 | frame_indices.sort() 68 | frame_indices = list(frame_indices) 69 | elif fix_start is not None: 70 | frame_indices = [x[0] + fix_start for x in ranges] 71 | elif sample == 'middle': 72 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] 73 | else: 74 | raise NotImplementedError 75 | 76 | if len(frame_indices) < num_frames: # padded with last frame 77 | padded_frame_indices = [frame_indices[-1]] * num_frames 78 | padded_frame_indices[:len(frame_indices)] = frame_indices 79 | frame_indices = padded_frame_indices 80 | 81 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps 82 | output_fps = float(sample[3:]) 83 | duration = float(end_frame - start_frame) / input_fps 84 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents 85 | frame_seconds = np.arange(start_frame + delta / 2, end_frame + delta / 2, delta) 86 | frame_indices = np.around(frame_seconds * input_fps).astype(int) 87 | frame_indices = [e for e in frame_indices if start_frame < e < end_frame] 88 | if max_num_frames > 0 and len(frame_indices) > max_num_frames: 89 | frame_indices = frame_indices[:max_num_frames] 90 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) 91 | else: 92 | raise ValueError 93 | return frame_indices 94 | 95 | 96 | def read_frames_av(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None): 97 | reader = av.open(video_path) 98 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] 99 | vlen = len(frames) 100 | duration_sec = get_pyav_video_duration(reader) 101 | fps = vlen / float(duration_sec) 102 | 103 | if clip is None: 104 | start_pos, end_pos = 0, 1 105 | else: 106 | video_start_sec, video_end_sec = clip 107 | start_pos, end_pos = video_start_sec / vlen, video_end_sec / vlen 108 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1)) 109 | 110 | frame_indices = get_frame_indices( 111 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start, 112 | input_fps=fps, max_num_frames=max_num_frames 113 | ) 114 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 115 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 116 | 117 | start_frame_index = round(start_pos * vlen) 118 | frame_indices = [i - start_frame_index for i in frame_indices] 119 | return frames, frame_indices, fps 120 | 121 | 122 | def read_frames_gif(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None): 123 | gif = imageio.get_reader(video_path) 124 | vlen = len(gif) 125 | duration_sec = vlen / fps 126 | 127 | if clip is None: 128 | start_pos, end_pos = 0, 1 129 | else: 130 | video_start_sec, video_end_sec = clip 131 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec 132 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1)) 133 | 134 | frame_indices = get_frame_indices( 135 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start, 136 | max_num_frames=max_num_frames 137 | ) 138 | frames = [] 139 | for index, frame in enumerate(gif): 140 | # for index in frame_idxs: 141 | if index in frame_indices: 142 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 143 | frame = torch.from_numpy(frame).byte() 144 | # # (H x W x C) to (C x H x W) 145 | frame = frame.permute(2, 0, 1) 146 | frames.append(frame) 147 | frames = torch.stack(frames) # .float() / 255 148 | 149 | start_frame_index = round(start_pos * vlen) 150 | frame_indices = [i - start_frame_index for i in frame_indices] 151 | return frames, frame_indices, fps # for tgif 152 | 153 | 154 | def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, clip=None, client=None, fps=None): 155 | video_reader = VideoReader(video_path, num_threads=1) 156 | vlen = len(video_reader) 157 | fps = video_reader.get_avg_fps() 158 | duration_sec = vlen / fps 159 | 160 | if clip is None: 161 | start_pos, end_pos = 0, 1 162 | else: 163 | video_start_sec, video_end_sec = clip 164 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec 165 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1)) 166 | 167 | frame_indices = get_frame_indices( 168 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start, 169 | input_fps=fps, max_num_frames=max_num_frames 170 | ) 171 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 172 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 173 | 174 | start_frame_index = round(start_pos * vlen) 175 | frame_indices = [i - start_frame_index for i in frame_indices] 176 | return frames, frame_indices, float(fps) 177 | 178 | 179 | def read_frames_images(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1, interval_sec=None, clip=None, client=None, fps=3): 180 | frame_fnames = [os.path.join(video_path, f) for f in sorted(os.listdir(video_path))] 181 | vlen = len(frame_fnames) 182 | duration_sec = vlen / fps 183 | 184 | if clip is None: 185 | start_pos, end_pos = 0, 1 186 | else: 187 | video_start_sec, video_end_sec = clip 188 | start_pos, end_pos = video_start_sec / duration_sec, video_end_sec / duration_sec 189 | start_pos, end_pos = max(0, min(start_pos, 1)), max(0, min(end_pos, 1)) 190 | 191 | frame_indices = get_frame_indices( 192 | num_frames, vlen, start_pos, end_pos, sample=sample, fix_start=fix_start, 193 | input_fps=fps, max_num_frames=max_num_frames 194 | ) 195 | selected_fnames = [frame_fnames[i] for i in frame_indices] 196 | frames = np.stack([cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB) for fname in selected_fnames]) 197 | # (T x H x W x C) to (T x C x H x W) 198 | frames = frames.transpose(0, 3, 1, 2) 199 | frames = torch.from_numpy(frames).to(torch.uint8) 200 | 201 | start_frame_index = round(start_pos * vlen) 202 | frame_indices = [i - start_frame_index for i in frame_indices] 203 | return frames, frame_indices, float(fps) 204 | 205 | 206 | VIDEO_READER_FUNCS = { 207 | 'av': read_frames_av, 208 | 'decord': read_frames_decord, 209 | 'gif': read_frames_gif, 210 | 'frames': read_frames_images, 211 | } 212 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | from utils.distributed import is_main_process, get_rank, get_world_size 2 | import logging 3 | import torch.distributed as dist 4 | import torch 5 | import io 6 | import os 7 | import json 8 | import re 9 | import numpy as np 10 | from os.path import join 11 | from tqdm import trange 12 | from PIL import Image 13 | from PIL import ImageFile 14 | from torchvision.transforms import PILToTensor 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | Image.MAX_IMAGE_PIXELS = None 17 | 18 | 19 | def load_image_from_path(image_path, client): 20 | if image_path.startswith('s3') or image_path.startswith('p2'): 21 | value = client.Get(image_path) 22 | img_bytes = np.frombuffer(value, dtype=np.uint8) 23 | buff = io.BytesIO(img_bytes) 24 | image = Image.open(buff).convert('RGB') 25 | else: 26 | image = Image.open(image_path).convert('RGB') # PIL Image 27 | image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 28 | return image 29 | 30 | 31 | def load_anno(ann_file_list): 32 | """[summary] 33 | 34 | Args: 35 | ann_file_list (List[List[str, str]] or List[str, str]): 36 | the latter will be automatically converted to the former. 37 | Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video']) 38 | which specifies the data type, video or image 39 | 40 | Returns: 41 | List(dict): each dict is { 42 | image: str or List[str], # image_path, 43 | caption: str or List[str] # caption text string 44 | } 45 | """ 46 | if isinstance(ann_file_list[0], str): 47 | ann_file_list = [ann_file_list] 48 | 49 | ann = [] 50 | for d in ann_file_list: 51 | data_root = d[1] 52 | fp = d[0] 53 | is_video = len(d) == 3 and d[2] == "video" 54 | cur_ann = json.load(open(fp, "r")) 55 | iterator = trange(len(cur_ann), desc=f"Loading {fp}") \ 56 | if is_main_process() else range(len(cur_ann)) 57 | for idx in iterator: 58 | key = "video" if is_video else "image" 59 | # unified to have the same key for data path 60 | if isinstance(cur_ann[idx][key], str): 61 | cur_ann[idx]["image"] = join(data_root, cur_ann[idx][key]) 62 | else: # list 63 | cur_ann[idx]["image"] = [join(data_root, e) for e in cur_ann[idx][key]] 64 | ann += cur_ann 65 | return ann 66 | 67 | 68 | def pre_text(text, max_l=None, pre_text=True): 69 | if pre_text: 70 | text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) 71 | text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') 72 | 73 | text = re.sub(r"\s{2,}", ' ', text) 74 | text = text.rstrip('\n').strip(' ') 75 | 76 | if max_l: # truncate 77 | words = text.split(' ') 78 | if len(words) > max_l: 79 | text = ' '.join(words[:max_l]) 80 | else: 81 | pass 82 | return text 83 | 84 | 85 | logger = logging.getLogger(__name__) 86 | 87 | 88 | def collect_result(result, result_dir, filename, is_json=True, is_list=True): 89 | if is_json: 90 | result_file = os.path.join( 91 | result_dir, '%s_rank%d.json' % (filename, get_rank())) 92 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 93 | json.dump(result, open(result_file, 'w')) 94 | else: 95 | result_file = os.path.join( 96 | result_dir, '%s_rank%d.pth' % (filename, get_rank())) 97 | final_result_file = os.path.join(result_dir, '%s.pth' % filename) 98 | torch.save(result, result_file) 99 | 100 | dist.barrier() 101 | 102 | result = None 103 | if is_main_process(): 104 | # combine results from all processes 105 | if is_list: 106 | result = [] 107 | else: 108 | result = {} 109 | for rank in range(get_world_size()): 110 | if is_json: 111 | result_file = os.path.join( 112 | result_dir, '%s_rank%d.json' % (filename, rank)) 113 | res = json.load(open(result_file, 'r')) 114 | else: 115 | result_file = os.path.join( 116 | result_dir, '%s_rank%d.pth' % (filename, rank)) 117 | res = torch.load(result_file) 118 | if is_list: 119 | result += res 120 | else: 121 | result.update(res) 122 | 123 | return result 124 | 125 | 126 | def sync_save_result(result, result_dir, filename, is_json=True, is_list=True): 127 | """gather results from multiple GPUs""" 128 | if is_json: 129 | result_file = os.path.join( 130 | result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank())) 131 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 132 | os.makedirs(os.path.dirname(result_file), exist_ok=True) 133 | json.dump(result, open(result_file, 'w')) 134 | else: 135 | result_file = os.path.join( 136 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank())) 137 | os.makedirs(os.path.dirname(result_file), exist_ok=True) 138 | final_result_file = os.path.join(result_dir, '%s.pth' % filename) 139 | torch.save(result, result_file) 140 | 141 | dist.barrier() 142 | 143 | if is_main_process(): 144 | # combine results from all processes 145 | if is_list: 146 | result = [] 147 | else: 148 | result = {} 149 | for rank in range(get_world_size()): 150 | if is_json: 151 | result_file = os.path.join( 152 | result_dir, "dist_res", '%s_rank%d.json' % (filename, rank)) 153 | res = json.load(open(result_file, 'r')) 154 | else: 155 | result_file = os.path.join( 156 | result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank)) 157 | res = torch.load(result_file) 158 | if is_list: 159 | result += res 160 | else: 161 | result.update(res) 162 | if is_json: 163 | json.dump(result, open(final_result_file, 'w')) 164 | else: 165 | torch.save(result, final_result_file) 166 | 167 | logger.info('result file saved to %s' % final_result_file) 168 | dist.barrier() 169 | return final_result_file, result 170 | 171 | 172 | def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None): 173 | """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray) 174 | into a (n+1)-d array, only allow the first dim has variable lengths. 175 | Args: 176 | sequences: list(n-d tensor or list) 177 | dtype: np.dtype or torch.dtype 178 | device: 179 | fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length. 180 | return will be of shape [len(sequences), fixed_length, ...] 181 | Returns: 182 | padded_seqs: ((n+1)-d tensor) padded with zeros 183 | mask: (2d tensor) of the same shape as the first two dims of padded_seqs, 184 | 1 indicate valid, 0 otherwise 185 | Examples: 186 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] 187 | >>> pad_sequences_1d(test_data_list, dtype=torch.long) 188 | >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)] 189 | >>> pad_sequences_1d(test_data_3d, dtype=torch.float) 190 | >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] 191 | >>> pad_sequences_1d(test_data_list, dtype=np.float32) 192 | >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)] 193 | >>> pad_sequences_1d(test_data_3d, dtype=np.float32) 194 | """ 195 | if isinstance(sequences[0], list): 196 | if "torch" in str(dtype): 197 | sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences] 198 | else: 199 | sequences = [np.asarray(s, dtype=dtype) for s in sequences] 200 | 201 | extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements 202 | lengths = [len(seq) for seq in sequences] 203 | if fixed_length is not None: 204 | max_length = fixed_length 205 | else: 206 | max_length = max(lengths) 207 | if isinstance(sequences[0], torch.Tensor): 208 | assert "torch" in str(dtype), "dtype and input type does not match" 209 | padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device) 210 | mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device) 211 | else: # np 212 | assert "numpy" in str(dtype), "dtype and input type does not match" 213 | padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype) 214 | mask = np.zeros((len(sequences), max_length), dtype=np.float32) 215 | 216 | for idx, seq in enumerate(sequences): 217 | end = lengths[idx] 218 | padded_seqs[idx, :end] = seq 219 | mask[idx, :end] = 1 220 | return padded_seqs, mask # , lengths 221 | 222 | 223 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import ast 5 | import json 6 | import os 7 | import os.path as osp 8 | import re 9 | import shutil 10 | import sys 11 | import tempfile 12 | from copy import deepcopy 13 | from importlib import import_module 14 | 15 | import yaml 16 | 17 | from .easydict import EasyDict 18 | 19 | __all__ = ["Config", "pretty_text"] 20 | 21 | 22 | BASE_KEY = "_base_" 23 | # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"} 24 | BASE_CONFIG = {} 25 | 26 | cfg = None 27 | 28 | 29 | class Config(object): 30 | """config""" 31 | 32 | @classmethod 33 | def pretty_text(cls, cfg: dict, indent=2) -> str: 34 | """format dict to a string 35 | 36 | Args: 37 | cfg (EasyDict): the params. 38 | 39 | Returns: The string to display. 40 | 41 | """ 42 | msg = "{\n" 43 | for i, (k, v) in enumerate(cfg.items()): 44 | if isinstance(v, dict): 45 | v = cls.pretty_text(v, indent + 4) 46 | spaces = " " * indent 47 | msg += spaces + "{}: {}".format(k, v) 48 | if i == len(cfg) - 1: 49 | msg += " }" 50 | else: 51 | msg += "\n" 52 | return msg 53 | 54 | @classmethod 55 | def dump(cls, cfg, savepath=None): 56 | """dump cfg to `json` file. 57 | 58 | Args: 59 | cfg (dict): The dict to dump. 60 | savepath (str): The filepath to save the dumped dict. 61 | 62 | Returns: TODO 63 | 64 | """ 65 | if savepath is None: 66 | savepath = osp.join(cfg.WORKSPACE, "config.json") 67 | json.dump(cfg, open(savepath, "w"), indent=2) 68 | 69 | @classmethod 70 | def get_config(cls, default_config: dict = None): 71 | """get a `Config` instance. 72 | 73 | Args: 74 | default_config (dict): The default config. `default_config` will be overrided 75 | by config file `--cfg`, `--cfg` will be overrided by commandline args. 76 | 77 | Returns: an EasyDict. 78 | """ 79 | global cfg 80 | if cfg is not None: 81 | return cfg 82 | 83 | # define arg parser. 84 | parser = argparse.ArgumentParser() 85 | # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str) 86 | parser.add_argument( 87 | "config_file", help="the configuration file to load. support: .yaml, .json, .py" 88 | ) 89 | parser.add_argument( 90 | "opts", 91 | default=None, 92 | nargs="*", 93 | help="overrided configs. List. Format: 'key1 name1 key2 name2'", 94 | ) 95 | args = parser.parse_args() 96 | 97 | cfg = EasyDict(BASE_CONFIG) 98 | if osp.isfile(args.config_file): 99 | cfg_from_file = cls.from_file(args.config_file) 100 | cfg = merge_a_into_b(cfg_from_file, cfg) 101 | cfg = cls.merge_list(cfg, args.opts) 102 | cfg = eval_dict_leaf(cfg) 103 | 104 | # update some keys to make them show at the last 105 | for k in BASE_CONFIG: 106 | cfg[k] = cfg.pop(k) 107 | return cfg 108 | 109 | @classmethod 110 | def from_file(cls, filepath: str) -> EasyDict: 111 | """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`. 112 | 113 | Args: 114 | filepath (str): The config file path. 115 | 116 | Returns: TODO 117 | 118 | """ 119 | filepath = osp.abspath(osp.expanduser(filepath)) 120 | if not osp.isfile(filepath): 121 | raise IOError(f"File does not exist: {filepath}") 122 | if filepath.endswith(".py"): 123 | with tempfile.TemporaryDirectory() as temp_config_dir: 124 | 125 | shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config")) 126 | sys.path.insert(0, temp_config_dir) 127 | mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0]) 128 | # mod = import_module(temp_module_name) 129 | sys.path.pop(0) 130 | cfg_dict = { 131 | name: value 132 | for name, value in mod.__dict__.items() 133 | if not name.startswith("__") 134 | } 135 | for k in list(sys.modules.keys()): 136 | if "tmp_config" in k: 137 | del sys.modules[k] 138 | elif filepath.endswith((".yml", ".yaml")): 139 | cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader) 140 | elif filepath.endswith(".json"): 141 | cfg_dict = json.load(open(filepath, "r")) 142 | else: 143 | raise IOError("Only py/yml/yaml/json type are supported now!") 144 | 145 | cfg_text = filepath + "\n" 146 | with open(filepath, "r") as f: 147 | cfg_text += f.read() 148 | 149 | if BASE_KEY in cfg_dict: # load configs in `BASE_KEY` 150 | cfg_dir = osp.dirname(filepath) 151 | base_filename = cfg_dict.pop(BASE_KEY) 152 | base_filename = ( 153 | base_filename if isinstance(base_filename, list) else [base_filename] 154 | ) 155 | 156 | cfg_dict_list = list() 157 | for f in base_filename: 158 | _cfg_dict = Config.from_file(osp.join(cfg_dir, f)) 159 | cfg_dict_list.append(_cfg_dict) 160 | 161 | base_cfg_dict = dict() 162 | for c in cfg_dict_list: 163 | if len(base_cfg_dict.keys() & c.keys()) > 0: 164 | raise KeyError("Duplicate key is not allowed among bases") 165 | base_cfg_dict.update(c) 166 | 167 | cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict) 168 | 169 | return EasyDict(cfg_dict) 170 | 171 | @classmethod 172 | def merge_list(cls, cfg, opts: list): 173 | """merge commandline opts. 174 | 175 | Args: 176 | cfg: (dict): The config to be merged. 177 | opts (list): The list to merge. Format: [key1, name1, key2, name2,...]. 178 | The keys can be nested. For example, ["a.b", v] will be considered 179 | as `dict(a=dict(b=v))`. 180 | 181 | Returns: dict. 182 | 183 | """ 184 | assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}" 185 | for i in range(0, len(opts), 2): 186 | full_k, v = opts[i], opts[i + 1] 187 | keys = full_k.split(".") 188 | sub_d = cfg 189 | for i, k in enumerate(keys): 190 | if not hasattr(sub_d, k): 191 | raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}") 192 | if i != len(keys) - 1: 193 | sub_d = sub_d[k] 194 | else: 195 | sub_d[k] = v 196 | return cfg 197 | 198 | 199 | def merge_a_into_b(a, b, inplace=False): 200 | """The values in a will override values in b. 201 | 202 | Args: 203 | a (dict): source dict. 204 | b (dict): target dict. 205 | 206 | Returns: dict. recursively merge dict a into dict b. 207 | 208 | """ 209 | if not inplace: 210 | b = deepcopy(b) 211 | for key in a: 212 | if key in b: 213 | if isinstance(a[key], dict) and isinstance(b[key], dict): 214 | b[key] = merge_a_into_b(a[key], b[key], inplace=True) 215 | else: 216 | b[key] = a[key] 217 | else: 218 | b[key] = a[key] 219 | return b 220 | 221 | 222 | def eval_dict_leaf(d, orig_dict=None): 223 | """eval values of dict leaf. 224 | 225 | Args: 226 | d (dict): The dict to eval. 227 | 228 | Returns: dict. 229 | 230 | """ 231 | if orig_dict is None: 232 | orig_dict = d 233 | for k, v in d.items(): 234 | if not isinstance(v, dict): 235 | d[k] = eval_string(v, orig_dict) 236 | else: 237 | eval_dict_leaf(v, orig_dict) 238 | return d 239 | 240 | 241 | def eval_string(string, d): 242 | """automatically evaluate string to corresponding types. 243 | 244 | For example: 245 | not a string -> return the original input 246 | '0' -> 0 247 | '0.2' -> 0.2 248 | '[0, 1, 2]' -> [0,1,2] 249 | 'eval(1+2)' -> 3 250 | 'eval(range(5))' -> [0,1,2,3,4] 251 | '${a}' -> d.a 252 | 253 | 254 | 255 | Args: 256 | string (str): The value to evaluate. 257 | d (dict): The 258 | 259 | Returns: the corresponding type 260 | 261 | """ 262 | if not isinstance(string, str): 263 | return string 264 | # if len(string) > 1 and string[0] == "[" and string[-1] == "]": 265 | # return eval(string) 266 | if string[0:5] == "eval(": 267 | return eval(string[5:-1]) 268 | 269 | s0 = string 270 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 271 | if s1 != s0: 272 | while s1 != s0: 273 | s0 = s1 274 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 275 | return eval(s1) 276 | 277 | try: 278 | v = ast.literal_eval(string) 279 | except: 280 | v = string 281 | return v 282 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import functools 5 | import logging 6 | import os 7 | import sys 8 | import time 9 | import wandb 10 | from typing import Any, Dict, Union 11 | 12 | import torch 13 | from .distributed import get_rank, is_main_process 14 | from termcolor import colored 15 | 16 | 17 | def log_dict_to_wandb(log_dict, step, prefix=""): 18 | """include a separator `/` at the end of `prefix`""" 19 | if not is_main_process(): 20 | return 21 | 22 | log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} 23 | wandb.log(log_dict, step) 24 | 25 | 26 | def log_dict_to_tensorboard(log_dict, step, prefix="", writer=None): 27 | for key, val in log_dict.items(): 28 | writer.add_scalar(prefix + key, val, step) 29 | 30 | 31 | def setup_wandb(config): 32 | if not (config.wandb.enable and is_main_process()): 33 | return 34 | 35 | run = wandb.init( 36 | config=config, 37 | project=config.wandb.project, 38 | entity=config.wandb.entity, 39 | name=os.path.basename(config.output_dir), 40 | reinit=True 41 | ) 42 | return run 43 | 44 | 45 | def setup_output_folder(save_dir: str, folder_only: bool = False): 46 | """Sets up and returns the output file where the logs will be placed 47 | based on the configuration passed. Usually "save_dir/logs/log_.txt". 48 | If env.log_dir is passed, logs will be directly saved in this folder. 49 | Args: 50 | folder_only (bool, optional): If folder should be returned and not the file. 51 | Defaults to False. 52 | Returns: 53 | str: folder or file path depending on folder_only flag 54 | """ 55 | log_filename = "train_" 56 | log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") 57 | log_filename += ".log" 58 | 59 | log_folder = os.path.join(save_dir, "logs") 60 | 61 | if not os.path.exists(log_folder): 62 | os.path.mkdirs(log_folder) 63 | 64 | if folder_only: 65 | return log_folder 66 | 67 | log_filename = os.path.join(log_folder, log_filename) 68 | 69 | return log_filename 70 | 71 | 72 | def setup_logger( 73 | output: str = None, 74 | color: bool = True, 75 | name: str = "mmf", 76 | disable: bool = False, 77 | clear_handlers=True, 78 | *args, 79 | **kwargs, 80 | ): 81 | """ 82 | Initialize the MMF logger and set its verbosity level to "INFO". 83 | Outside libraries shouldn't call this in case they have set there 84 | own logging handlers and setup. If they do, and don't want to 85 | clear handlers, pass clear_handlers options. 86 | The initial version of this function was taken from D2 and adapted 87 | for MMF. 88 | Args: 89 | output (str): a file name or a directory to save log. 90 | If ends with ".txt" or ".log", assumed to be a file name. 91 | Default: Saved to file 92 | color (bool): If false, won't log colored logs. Default: true 93 | name (str): the root module name of this logger. Defaults to "mmf". 94 | disable: do not use 95 | clear_handlers (bool): If false, won't clear existing handlers. 96 | Returns: 97 | logging.Logger: a logger 98 | """ 99 | if disable: 100 | return None 101 | logger = logging.getLogger(name) 102 | logger.propagate = False 103 | 104 | logging.captureWarnings(True) 105 | warnings_logger = logging.getLogger("py.warnings") 106 | 107 | plain_formatter = logging.Formatter( 108 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 109 | datefmt="%Y-%m-%dT%H:%M:%S", 110 | ) 111 | 112 | distributed_rank = get_rank() 113 | handlers = [] 114 | 115 | logging_level = logging.INFO 116 | # logging_level = logging.DEBUG 117 | 118 | if distributed_rank == 0: 119 | logger.setLevel(logging_level) 120 | ch = logging.StreamHandler(stream=sys.stdout) 121 | ch.setLevel(logging_level) 122 | if color: 123 | formatter = ColorfulFormatter( 124 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 125 | datefmt="%Y-%m-%dT%H:%M:%S", 126 | ) 127 | else: 128 | formatter = plain_formatter 129 | ch.setFormatter(formatter) 130 | logger.addHandler(ch) 131 | warnings_logger.addHandler(ch) 132 | handlers.append(ch) 133 | 134 | # file logging: all workers 135 | if output is None: 136 | output = setup_output_folder() 137 | 138 | if output is not None: 139 | if output.endswith(".txt") or output.endswith(".log"): 140 | filename = output 141 | else: 142 | filename = os.path.join(output, "train.log") 143 | if distributed_rank > 0: 144 | filename = filename + f".rank{distributed_rank}" 145 | os.makedirs(os.path.dirname(filename), exist_ok=True) 146 | 147 | fh = logging.StreamHandler(_cached_log_stream(filename)) 148 | fh.setLevel(logging_level) 149 | fh.setFormatter(plain_formatter) 150 | logger.addHandler(fh) 151 | warnings_logger.addHandler(fh) 152 | handlers.append(fh) 153 | 154 | # Slurm/FB output, only log the main process 155 | # save_dir = get_mmf_env(key="save_dir") 156 | if "train.log" not in filename and distributed_rank == 0: 157 | filename = os.path.join(output, "train.log") 158 | sh = logging.StreamHandler(_cached_log_stream(filename)) 159 | sh.setLevel(logging_level) 160 | sh.setFormatter(plain_formatter) 161 | logger.addHandler(sh) 162 | warnings_logger.addHandler(sh) 163 | handlers.append(sh) 164 | 165 | logger.info(f"Logging to: {filename}") 166 | 167 | # Remove existing handlers to add MMF specific handlers 168 | if clear_handlers: 169 | for handler in logging.root.handlers[:]: 170 | logging.root.removeHandler(handler) 171 | # Now, add our handlers. 172 | logging.basicConfig(level=logging_level, handlers=handlers) 173 | 174 | return logger 175 | 176 | 177 | def setup_very_basic_config(color=True): 178 | plain_formatter = logging.Formatter( 179 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 180 | datefmt="%Y-%m-%dT%H:%M:%S", 181 | ) 182 | ch = logging.StreamHandler(stream=sys.stdout) 183 | ch.setLevel(logging.INFO) 184 | if color: 185 | formatter = ColorfulFormatter( 186 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 187 | datefmt="%Y-%m-%dT%H:%M:%S", 188 | ) 189 | else: 190 | formatter = plain_formatter 191 | ch.setFormatter(formatter) 192 | # Setup a minimal configuration for logging in case something tries to 193 | # log a message even before logging is setup by MMF. 194 | logging.basicConfig(level=logging.INFO, handlers=[ch]) 195 | 196 | 197 | # cache the opened file object, so that different calls to `setup_logger` 198 | # with the same file name can safely write to the same file. 199 | @functools.lru_cache(maxsize=None) 200 | def _cached_log_stream(filename): 201 | return open(filename, "a") 202 | 203 | 204 | # ColorfulFormatter is adopted from Detectron2 and adapted for MMF 205 | class ColorfulFormatter(logging.Formatter): 206 | def __init__(self, *args, **kwargs): 207 | super().__init__(*args, **kwargs) 208 | 209 | def formatMessage(self, record): 210 | log = super().formatMessage(record) 211 | if record.levelno == logging.WARNING: 212 | prefix = colored("WARNING", "red", attrs=["blink"]) 213 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 214 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 215 | else: 216 | return log 217 | return prefix + " " + log 218 | 219 | 220 | class TensorboardLogger: 221 | def __init__(self, log_folder="./logs", iteration=0): 222 | # This would handle warning of missing tensorboard 223 | from torch.utils.tensorboard import SummaryWriter 224 | 225 | self.summary_writer = None 226 | self._is_master = is_main_process() 227 | # self.timer = Timer() 228 | self.log_folder = log_folder 229 | 230 | if self._is_master: 231 | # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) 232 | current_time = time.strftime("%Y-%m-%dT%H:%M:%S") 233 | # self.timer.get_time_hhmmss(None, format=self.time_format) 234 | tensorboard_folder = os.path.join( 235 | self.log_folder, f"tensorboard_{current_time}" 236 | ) 237 | self.summary_writer = SummaryWriter(tensorboard_folder) 238 | 239 | def __del__(self): 240 | if getattr(self, "summary_writer", None) is not None: 241 | self.summary_writer.close() 242 | 243 | def _should_log_tensorboard(self): 244 | if self.summary_writer is None or not self._is_master: 245 | return False 246 | else: 247 | return True 248 | 249 | def add_scalar(self, key, value, iteration): 250 | if not self._should_log_tensorboard(): 251 | return 252 | 253 | self.summary_writer.add_scalar(key, value, iteration) 254 | 255 | def add_scalars(self, scalar_dict, iteration): 256 | if not self._should_log_tensorboard(): 257 | return 258 | 259 | for key, val in scalar_dict.items(): 260 | self.summary_writer.add_scalar(key, val, iteration) 261 | 262 | def add_histogram_for_model(self, model, iteration): 263 | if not self._should_log_tensorboard(): 264 | return 265 | 266 | for name, param in model.named_parameters(): 267 | np_param = param.clone().cpu().data.numpy() 268 | self.summary_writer.add_histogram(name, np_param, iteration) 269 | -------------------------------------------------------------------------------- /utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import json 5 | import logging 6 | import random 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | from pathlib import Path 11 | from typing import List, Union 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from .distributed import is_dist_avail_and_initialized 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], 46 | dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | if meter.count == 0: # skip empty meter 108 | loss_str.append( 109 | "{}: {}".format(name, "No data") 110 | ) 111 | else: 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def global_avg(self): 118 | loss_str = [] 119 | for name, meter in self.meters.items(): 120 | if meter.count == 0: 121 | loss_str.append( 122 | "{}: {}".format(name, "No data") 123 | ) 124 | else: 125 | loss_str.append( 126 | "{}: {:.4f}".format(name, meter.global_avg) 127 | ) 128 | return self.delimiter.join(loss_str) 129 | 130 | def get_global_avg_dict(self, prefix=""): 131 | """include a separator (e.g., `/`, or "_") at the end of `prefix`""" 132 | d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()} 133 | return d 134 | 135 | def get_avg_dict(self, prefix=""): 136 | """same as above, but only get the average value in the window, not the epoch""" 137 | d = {f"{prefix}{k}": m.avg if m.count > 0 else 0. for k, m in self.meters.items()} 138 | return d 139 | 140 | def synchronize_between_processes(self): 141 | for meter in self.meters.values(): 142 | meter.synchronize_between_processes() 143 | 144 | def add_meter(self, name, meter): 145 | self.meters[name] = meter 146 | 147 | def log_every(self, iterable, log_freq, header=None): 148 | i = 0 149 | if not header: 150 | header = '' 151 | start_time = time.time() 152 | end = time.time() 153 | iter_time = SmoothedValue(fmt='{avg:.4f}') 154 | data_time = SmoothedValue(fmt='{avg:.4f}') 155 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 156 | log_msg = [ 157 | header, 158 | '[{0' + space_fmt + '}/{1}]', 159 | 'eta: {eta}', 160 | '{meters}', 161 | 'time: {time}', 162 | 'data: {data}' 163 | ] 164 | if torch.cuda.is_available(): 165 | log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}') 166 | log_msg = self.delimiter.join(log_msg) 167 | MB = 1024.0 * 1024.0 168 | for obj in iterable: 169 | data_time.update(time.time() - end) 170 | yield obj 171 | iter_time.update(time.time() - end) 172 | if i % log_freq == 0 or i == len(iterable) - 1: 173 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 174 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 175 | if torch.cuda.is_available(): 176 | logger.info(log_msg.format( 177 | i, len(iterable), eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), data=str(data_time), 180 | memory=torch.cuda.max_memory_allocated() / MB, 181 | res_mem=torch.cuda.max_memory_reserved() / MB, 182 | )) 183 | else: 184 | logger.info(log_msg.format( 185 | i, len(iterable), eta=eta_string, 186 | meters=str(self), 187 | time=str(iter_time), data=str(data_time))) 188 | i += 1 189 | end = time.time() 190 | total_time = time.time() - start_time 191 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 192 | logger.info('{} Total time: {} ({:.4f} s / it)'.format( 193 | header, total_time_str, total_time / len(iterable))) 194 | 195 | 196 | class AttrDict(dict): 197 | def __init__(self, *args, **kwargs): 198 | super(AttrDict, self).__init__(*args, **kwargs) 199 | self.__dict__ = self 200 | 201 | 202 | def compute_acc(logits, label, reduction='mean'): 203 | ret = (torch.argmax(logits, dim=1) == label).float() 204 | if reduction == 'none': 205 | return ret.detach() 206 | elif reduction == 'mean': 207 | return ret.mean().item() 208 | 209 | 210 | def compute_n_params(model, return_str=True): 211 | tot = 0 212 | for p in model.parameters(): 213 | w = 1 214 | for x in p.shape: 215 | w *= x 216 | tot += w 217 | if return_str: 218 | if tot >= 1e6: 219 | return '{:.1f}M'.format(tot / 1e6) 220 | else: 221 | return '{:.1f}K'.format(tot / 1e3) 222 | else: 223 | return tot 224 | 225 | 226 | def setup_seed(seed): 227 | torch.manual_seed(seed) 228 | np.random.seed(seed) 229 | random.seed(seed) 230 | 231 | 232 | def remove_files_if_exist(file_paths): 233 | for fp in file_paths: 234 | if os.path.isfile(fp): 235 | os.remove(fp) 236 | 237 | 238 | def save_json(data, filename, save_pretty=False, sort_keys=False): 239 | with open(filename, "w") as f: 240 | if save_pretty: 241 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 242 | else: 243 | json.dump(data, f) 244 | 245 | 246 | def load_json(filename): 247 | with open(filename, "r") as f: 248 | return json.load(f) 249 | 250 | 251 | def flat_list_of_lists(l): 252 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 253 | return [item for sublist in l for item in sublist] 254 | 255 | 256 | def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]): 257 | """ 258 | Args: 259 | root: path to the directory to start search files 260 | suffix: any str as suffix, or can match multiple such strings 261 | when input is List[str]. 262 | Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`] 263 | Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`. 264 | """ 265 | if isinstance(suffix, str): 266 | suffix = [suffix, ] 267 | filepaths = flat_list_of_lists( 268 | [list(Path(root).rglob(f"*{e}")) for e in suffix]) 269 | return filepaths 270 | 271 | 272 | def match_key_and_shape(state_dict1, state_dict2): 273 | keys1 = set(state_dict1.keys()) 274 | keys2 = set(state_dict2.keys()) 275 | print(f"keys1 - keys2: {keys1 - keys2}") 276 | print(f"keys2 - keys1: {keys2 - keys1}") 277 | 278 | mismatch = 0 279 | for k in list(keys1): 280 | if state_dict1[k].shape != state_dict2[k].shape: 281 | print( 282 | f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}") 283 | mismatch += 1 284 | print(f"mismatch {mismatch}") 285 | 286 | 287 | def merge_dicts(list_dicts): 288 | merged_dict = list_dicts[0].copy() 289 | for i in range(1, len(list_dicts)): 290 | merged_dict.update(list_dicts[i]) 291 | return merged_dict 292 | -------------------------------------------------------------------------------- /internvid_g/code/ground_data_construction.py: -------------------------------------------------------------------------------- 1 | # construct grounding data from caption, scene, and scene sim data 2 | 3 | import os 4 | import json 5 | import subprocess 6 | import argparse 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | 12 | def parse_sec(sec_str): 13 | """ 14 | Parse a string of the form '00:00:00.000' into a float 15 | """ 16 | sec_str, ms_str = sec_str.split('.') 17 | h, m, s = sec_str.split(':') 18 | res = float(h) * 3600 + float(m) * 60 + float(s) + float(ms_str) / 1000 19 | return round(res, 3) 20 | 21 | 22 | def find_scene_id_start(value, arr): 23 | # 大于等于value的最小的数的下标 24 | low, high = 0, len(arr) - 1 25 | result = 0 26 | while low <= high: 27 | mid = (low + high) // 2 28 | if arr[mid] >= value: 29 | result = mid # 更新结果为当前位置 30 | high = mid - 1 # 缩小搜索范围到左半部分 31 | else: 32 | low = mid + 1 # 扩大搜索范围到右半部分 33 | return result 34 | 35 | 36 | def find_scene_id_end(value, arr): 37 | # 小于等于value的最大的数的下标 38 | low, high = 0, len(arr) - 1 39 | result = len(arr) - 1 40 | while low <= high: 41 | mid = (low + high) // 2 42 | if arr[mid] <= value: 43 | result = mid # 更新结果为当前位置 44 | low = mid + 1 # 缩小搜索范围到左半部分 45 | else: 46 | high = mid - 1 # 扩大搜索范围到右半部分 47 | return result 48 | 49 | 50 | def get_video_length(video_fname): 51 | result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_fname], 52 | stdout=subprocess.PIPE, 53 | stderr=subprocess.STDOUT) 54 | return float(result.stdout) 55 | 56 | 57 | def reformat_caption_data(caption_data): 58 | ''' 59 | ret: { 60 | video_id: [{'start_sec': s1, 'end_sec': e1, 'captions': [c11, c12,. ...]}, ...], 61 | ... 62 | } 63 | ''' 64 | ret = dict() 65 | for example in caption_data: 66 | key = example.get('YoutubeID', example.get('video_fname').split('.')[0]) 67 | if key not in ret: 68 | ret[key] = list() 69 | if 'Start_timestamp' in example: 70 | example['start_sec'] = parse_sec(example['Start_timestamp']) 71 | del example['Start_timestamp'] 72 | if 'End_timestamp' in example: 73 | example['end_sec'] = parse_sec(example['End_timestamp']) 74 | del example['End_timestamp'] 75 | if 'Caption' in example: 76 | example['captions'] = [example['Caption']] 77 | del example['Caption'] 78 | ret[key].append(example) 79 | return ret 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("--video-base-folder", type=str, default="videos") 85 | parser.add_argument("--caption-fname", type=str) 86 | parser.add_argument("--scene-fname", type=str, default="temp/scenes-merged.jsonl") 87 | parser.add_argument("--scene-sim-fname", type=str, default="temp/scenes_similarity-merged.jsonl") 88 | parser.add_argument("--caption-with-neg-interval-fname", type=str, default="temp/final_dataset.jsonl") 89 | 90 | parser.add_argument("--scene-boundary-shift", type=float, default=0.25) 91 | parser.add_argument("--span-same-threshold", type=float, default=0.8) 92 | parser.add_argument("--min-caption-span-secs", type=float, default=0.5) # about only 1% of the examples in caption.jsonl is shorter than 0.5 sec 93 | parser.add_argument("--max-caption-span-secs", type=float, default=10) # if the span is too long, the caption may not precise, and the neg span can be very small. so this is a bad case for constructing grounding data 94 | parser.add_argument("--max-num-scenes", type=int, default=100) 95 | args = parser.parse_args() 96 | 97 | caption_data = [json.loads(line) for line in open(args.caption_fname)] 98 | caption_data = reformat_caption_data(caption_data) 99 | scene_data = [json.loads(line) for line in open(args.scene_fname)] 100 | video_id_to_fname = {e['video_fname'].split('.')[0]: e['video_fname'] for e in scene_data} 101 | scene_data = {e['video_fname'].split('.')[0]: e['scenes'][:args.max_num_scenes] for e in scene_data} 102 | scene_sim_data = [json.loads(line) for line in open(args.scene_sim_fname)] 103 | scene_sim_data = {e['video_fname'].split('.')[0]: e['similarity_matrix'] for e in scene_sim_data} 104 | 105 | video_ids = caption_data.keys() & scene_data.keys() & scene_sim_data.keys() 106 | print(len(video_ids)) 107 | 108 | f_out = open(args.caption_with_neg_interval_fname, 'w') 109 | for i, video_id in enumerate(video_ids): 110 | video_length_sec = get_video_length(os.path.join(args.video_base_folder, video_id_to_fname[video_id])) 111 | # find the scene of caption segment 112 | for caption_data_by_id in caption_data[video_id]: 113 | caption_start_sec, caption_end_sec = caption_data_by_id['start_sec'], caption_data_by_id['end_sec'] 114 | if caption_end_sec - caption_start_sec < args.min_caption_span_secs or caption_end_sec - caption_start_sec > args.max_caption_span_secs: 115 | continue 116 | caption_start_scene = find_scene_id_start(caption_start_sec - args.scene_boundary_shift, [span[0] for span in scene_data[video_id]]) 117 | caption_end_scene = find_scene_id_end(caption_end_sec + args.scene_boundary_shift, [span[1] for span in scene_data[video_id]]) 118 | 119 | unsimilar_start_scene, unsimilar_end_scene = None, None 120 | unsimilar_scenes = [] 121 | if caption_end_scene < caption_start_scene: # cannot find the scene of the caption segment 122 | caption_start_scene, caption_end_scene = None, None 123 | unsimilar_start_sec, unsimilar_end_sec = 0, video_length_sec 124 | 125 | else: 126 | # find if the neighbouring scene is also similar. if yes, merge them as the positive span 127 | new_caption_start_scene, new_caption_end_scene = None, None 128 | simiarity_matrix = np.array(scene_sim_data[video_id]) 129 | for idx in range(caption_start_scene - 1, -1, -1): 130 | if simiarity_matrix[idx, caption_start_scene] > args.span_same_threshold: 131 | new_caption_start_scene = idx 132 | else: 133 | break 134 | if new_caption_start_scene is not None: 135 | caption_start_scene = new_caption_start_scene 136 | caption_start_sec = scene_data[video_id][caption_start_scene][0] 137 | 138 | for idx in range(caption_end_scene + 1, len(simiarity_matrix)): 139 | if simiarity_matrix[idx, caption_end_scene] > args.span_same_threshold: 140 | new_caption_end_scene = idx 141 | else: 142 | break 143 | if new_caption_end_scene is not None: 144 | caption_end_scene = new_caption_end_scene 145 | caption_end_sec = scene_data[video_id][caption_end_scene][1] 146 | 147 | # find the unsimilar scenes with the caption scenes 148 | scene_sims = np.max(simiarity_matrix[caption_start_scene: caption_end_scene + 1], axis=0) 149 | unsimilar_scenes = scene_sims < args.span_same_threshold 150 | 151 | if caption_start_scene == 0: 152 | unsimilar_start_sec = 0 153 | elif np.all(unsimilar_scenes[:caption_start_scene]): # no similar segment before this segment 154 | unsimilar_start_sec = 0 155 | elif not np.any(unsimilar_scenes[:caption_start_scene]): # all segments are similar segments before this segment 156 | unsimilar_start_sec = caption_start_sec 157 | else: 158 | # get the last similar segment 159 | unsimilar_start_scene = int(caption_start_scene - np.argmin(unsimilar_scenes[caption_start_scene-1::-1]) - 1) 160 | unsimilar_start_sec = scene_data[video_id][unsimilar_start_scene][1] 161 | 162 | if caption_end_scene == len(scene_data[video_id]) - 1: 163 | unsimilar_end_sec = video_length_sec 164 | elif np.all(unsimilar_scenes[caption_end_scene+1:]): # no similar segment after this segment 165 | unsimilar_end_sec = video_length_sec 166 | elif not np.any(unsimilar_scenes[caption_end_scene+1:]): # all segments are similar segments after this segment 167 | unsimilar_end_sec = caption_end_sec 168 | else: 169 | # get the first similar segment 170 | unsimilar_end_scene = int(np.argmin(unsimilar_scenes[caption_end_scene+1:]) + caption_end_scene + 1) 171 | unsimilar_end_sec = scene_data[video_id][unsimilar_end_scene][0] 172 | unsimilar_scenes = unsimilar_scenes.tolist() 173 | 174 | if unsimilar_end_scene == caption_end_scene and unsimilar_start_scene == caption_start_scene: 175 | continue # no neg interval found, do not use this example for grounding 176 | if caption_end_sec - caption_start_sec > args.max_caption_span_secs: 177 | continue 178 | 179 | for caption in caption_data_by_id['captions']: 180 | res_to_write = {'video': video_id_to_fname[video_id], 'duration': video_length_sec, 181 | 'start_sec': caption_start_sec, 'end_sec': caption_end_sec, 182 | 'neg_start_sec': unsimilar_start_sec, 'neg_end_sec': unsimilar_end_sec, 183 | 'caption': caption, 184 | 'start_scene': caption_start_scene, 'end_scene': caption_end_scene, 185 | 'neg_start_scene': unsimilar_start_scene, 'neg_end_scene': unsimilar_end_scene,} 186 | 187 | json.dump(res_to_write, f_out) 188 | f_out.write('\n') 189 | if i % 100 == 0: 190 | f_out.flush() 191 | f_out.close() 192 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | 6 | from dataset.dataloader import MetaLoader 7 | from dataset.pt_dataset import PTImgTrainDataset, PTVidTrainDataset, PTImgEvalDataset, PTVidEvalDataset 8 | from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset 9 | 10 | 11 | def create_dataset(dataset_type, config): 12 | if "clip" in config.model.get("vit_model", 'vit'): 13 | mean = (0.485, 0.456, 0.406) 14 | std = (0.229, 0.224, 0.225) 15 | else: 16 | vision_enc_name = config.model.vision_encoder.name 17 | if "swin" in vision_enc_name or "vit" in vision_enc_name: 18 | mean = (0.485, 0.456, 0.406) 19 | std = (0.229, 0.224, 0.225) 20 | elif "beit" in vision_enc_name: 21 | mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning 22 | std = (0.5, 0.5, 0.5) 23 | elif "clip" in vision_enc_name: 24 | mean = (0.48145466, 0.4578275, 0.40821073) 25 | std = (0.26862954, 0.26130258, 0.27577711) 26 | else: 27 | raise ValueError 28 | 29 | normalize = transforms.Normalize(mean, std) 30 | 31 | # loaded images and videos are torch.Tensor of torch.uint8 format, 32 | # ordered as (T, 1 or 3, H, W) where T=1 for image 33 | type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) 34 | 35 | if config.inputs.video_input.random_aug: 36 | aug_transform = transforms.RandAugment() 37 | else: 38 | aug_transform = transforms.Lambda(lambda x: x) 39 | 40 | train_transform = transforms.Compose( 41 | [ 42 | aug_transform, 43 | transforms.RandomResizedCrop( 44 | config.inputs.image_res, 45 | scale=(0.5, 1.0), 46 | interpolation=InterpolationMode.BICUBIC, 47 | ), 48 | transforms.RandomHorizontalFlip(), 49 | type_transform, 50 | normalize, 51 | ] 52 | ) 53 | test_transform = transforms.Compose( 54 | [ 55 | transforms.Resize( 56 | (config.inputs.image_res, config.inputs.image_res), 57 | interpolation=InterpolationMode.BICUBIC, 58 | ), 59 | type_transform, 60 | normalize, 61 | ] 62 | ) 63 | 64 | video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") 65 | video_only_dataset_kwargs_train = dict( 66 | video_reader_type=video_reader_type, 67 | sample_type=config.inputs.video_input.sample_type, 68 | num_frames=config.inputs.video_input.num_frames, 69 | num_tries=3, # false tolerance 70 | ) 71 | video_only_dataset_kwargs_eval = dict( 72 | video_reader_type=video_reader_type, 73 | sample_type=config.inputs.video_input.sample_type_test, 74 | num_frames=config.inputs.video_input.num_frames_test, 75 | num_tries=1, # we want to have predictions for all videos 76 | ) 77 | 78 | if dataset_type == "pt_train": 79 | # convert to list of lists 80 | train_files = ( 81 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file 82 | ) 83 | train_media_types = sorted(list({e['media_type'] for e in train_files})) 84 | 85 | train_datasets = [] 86 | for m in train_media_types: 87 | dataset_cls = PTImgTrainDataset if m == "image" else PTVidTrainDataset 88 | # dataset of the same media_type will be mixed in a single Dataset object 89 | _train_files = [e for e in train_files if e['media_type'] == m] 90 | 91 | datasets = [] 92 | for train_file in _train_files: 93 | dataset_kwargs = train_file.copy() 94 | dataset_kwargs.update(dict( 95 | transform=train_transform, 96 | pre_text=config.get( 97 | "pre_text", True 98 | ), 99 | )) 100 | if m == "video": 101 | dataset_kwargs.update(video_only_dataset_kwargs_train) 102 | datasets.append(dataset_cls(**dataset_kwargs)) 103 | dataset = ConcatDataset(datasets) 104 | train_datasets.append(dataset) 105 | return train_datasets 106 | 107 | elif dataset_type in ["it_train"]: 108 | # convert to list of lists 109 | train_files = ( 110 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file 111 | ) 112 | train_media_types = sorted(list({e['media_type'] for e in train_files})) 113 | 114 | train_datasets = [] 115 | every_single_dataset = [] 116 | 117 | for m in train_media_types: 118 | dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset 119 | # dataset_cls = {"image" : ITImgTrainDataset, "video": MyITVidTrainDataset, "video_truncate": MyITVidTruncateTrainDataset}[m] 120 | # dataset of the same media_type will be mixed in a single Dataset object 121 | _train_files = [e for e in train_files if e['media_type'] == m] 122 | 123 | datasets = [] 124 | for train_file in _train_files: 125 | dataset_kwargs = train_file.copy() 126 | dataset_kwargs.update(dict( 127 | transform=train_transform, 128 | system=config.model.get("system", ""), 129 | start_token=config.model.get("img_start_token", ""), 130 | end_token=config.model.get("img_end_token", ""), 131 | )) 132 | if "video" in m: 133 | video_only_dataset_kwargs_train.update({ 134 | "start_token": config.model.get("start_token", ""), 136 | "add_second_msg": config.model.get("add_second_msg", True), 137 | "grounding_method": train_file.get("grounding_method", None), 138 | "min_gold_clips": train_file.get("min_gold_clips", 1), 139 | "max_gold_clips": train_file.get("max_gold_clips", None), 140 | "num_examples": train_file.get("num_examples", None), 141 | }) 142 | dataset_kwargs.update(video_only_dataset_kwargs_train) 143 | if "tgif" in train_file['dataset_name'].lower(): 144 | video_only_dataset_kwargs_train.update({ 145 | "video_reader_type": "gif" 146 | }) 147 | dataset_kwargs.update(video_only_dataset_kwargs_train) 148 | else: 149 | video_only_dataset_kwargs_train.update({ 150 | "video_reader_type": "decord" 151 | }) 152 | dataset_kwargs.update(video_only_dataset_kwargs_train) 153 | dataset = dataset_cls(**dataset_kwargs) 154 | every_single_dataset.append(dataset) 155 | datasets.append(dataset) 156 | dataset = ConcatDataset(datasets) 157 | train_datasets.append(dataset) 158 | return train_datasets, every_single_dataset 159 | 160 | elif dataset_type == "pt_eval": 161 | test_datasets = [] 162 | test_dataset_names = [] 163 | # multiple test datasets, all separate 164 | for name, data_cfg in config.test_file.items(): 165 | media_type = get_media_type(data_cfg) 166 | test_dataset_cls = ( 167 | PTImgEvalDataset if media_type == "image" else PTVidEvalDataset 168 | ) 169 | test_dataset_names.append(name) 170 | dataset_kwargs = dict( 171 | ann_file=[data_cfg], 172 | transform=test_transform, 173 | has_multi_vision_gt=config.get( 174 | "has_multi_vision_gt", False 175 | ), # true for ssv2 ret 176 | ) 177 | if media_type == "video": 178 | dataset_kwargs.update(video_only_dataset_kwargs_eval) 179 | test_datasets.append(test_dataset_cls(**dataset_kwargs)) 180 | return test_datasets, test_dataset_names 181 | 182 | 183 | 184 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 185 | samplers = [] 186 | for dataset, shuffle in zip(datasets, shuffles): 187 | sampler = torch.utils.data.DistributedSampler( 188 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle 189 | ) 190 | samplers.append(sampler) 191 | return samplers 192 | 193 | 194 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 195 | loaders = [] 196 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( 197 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns 198 | ): 199 | if is_train: 200 | shuffle = sampler is None 201 | drop_last = True 202 | else: 203 | shuffle = False 204 | drop_last = False 205 | loader = DataLoader( 206 | dataset, 207 | batch_size=bs, 208 | num_workers=n_worker, 209 | pin_memory=False, 210 | sampler=sampler, 211 | shuffle=shuffle, 212 | collate_fn=collate_fn, 213 | drop_last=drop_last, 214 | persistent_workers=True if n_worker > 0 else False, 215 | ) 216 | loaders.append(loader) 217 | return loaders 218 | 219 | 220 | def iterate_dataloaders(dataloaders): 221 | """Alternatively generate data from multiple dataloaders, 222 | since we use `zip` to concat multiple dataloaders, 223 | the loop will end when the smaller dataloader runs out. 224 | 225 | Args: 226 | dataloaders List(DataLoader): can be a single or multiple dataloaders 227 | """ 228 | for data_tuples in zip(*dataloaders): 229 | for idx, data in enumerate(data_tuples): 230 | yield dataloaders[idx].dataset.media_type, data 231 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy import interpolate 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _init_transformer_weights(module, initializer_range=0.02): 13 | """Initialize the weights. Copied from transformers ViT/Bert model init""" 14 | if isinstance(module, (nn.Linear, nn.Conv2d)): 15 | # Slightly different from the TF version which uses truncated_normal for initialization 16 | # cf https://github.com/pytorch/pytorch/pull/5617 17 | module.weight.data.normal_(mean=0.0, std=initializer_range) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | elif isinstance(module, nn.Embedding): 21 | module.weight.data.normal_(mean=0.0, std=initializer_range) 22 | if module.padding_idx is not None: 23 | module.weight.data[module.padding_idx].zero_() 24 | elif isinstance(module, nn.LayerNorm): 25 | module.bias.data.zero_() 26 | module.weight.data.fill_(1.0) 27 | 28 | 29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 30 | """ 31 | Add/Remove extra temporal_embeddings as needed. 32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 33 | 34 | temp_embed_old: (1, num_frames_old, 1, d) 35 | temp_embed_new: (1, num_frames_new, 1, d) 36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 37 | """ 38 | # TODO zero pad 39 | num_frms_new = temp_embed_new.shape[1] 40 | num_frms_old = temp_embed_old.shape[1] 41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 42 | if num_frms_new > num_frms_old: 43 | if add_zero: 44 | temp_embed_new[ 45 | :, :num_frms_old 46 | ] = temp_embed_old # untrained embeddings are zeros. 47 | else: 48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 49 | elif num_frms_new < num_frms_old: 50 | temp_embed_new = temp_embed_old[:, :num_frms_new] 51 | else: # = 52 | temp_embed_new = temp_embed_old 53 | return temp_embed_new 54 | 55 | 56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 57 | """ 58 | Add/Remove extra temporal_embeddings as needed. 59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 60 | 61 | temp_embed_old: (1, num_frames_old, 1, d) 62 | temp_embed_new: (1, num_frames_new, 1, d) 63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 64 | """ 65 | # TODO zero pad 66 | num_frms_new = temp_embed_new.shape[1] 67 | num_frms_old = temp_embed_old.shape[1] 68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 69 | if num_frms_new > num_frms_old: 70 | if add_zero: 71 | temp_embed_new[ 72 | :, :num_frms_old 73 | ] = temp_embed_old # untrained embeddings are zeros. 74 | else: 75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 76 | elif num_frms_new < num_frms_old: 77 | temp_embed_new = temp_embed_old[:, :num_frms_new] 78 | else: # = 79 | temp_embed_new = temp_embed_old 80 | return temp_embed_new 81 | 82 | 83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): 84 | """ 85 | temp_embed_old: (1, num_frames_old, 1, d) 86 | Returns: 87 | temp_embed_new: (1, num_frames_new, 1, d) 88 | """ 89 | temp_embed_old = temp_embed_old.squeeze(2).permute( 90 | 0, 2, 1 91 | ) # (1, d, num_frames_old) 92 | temp_embed_new = F.interpolate( 93 | temp_embed_old, num_frames_new, mode="linear" 94 | ) # (1, d, num_frames_new) 95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( 96 | 2 97 | ) # (1, num_frames_new, 1, d) 98 | return temp_embed_new 99 | 100 | 101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new): 102 | """ 103 | Args: 104 | pos_embed_old: (1, L_old, d), pre-trained 105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights 106 | num_patches_new: 107 | """ 108 | # interpolate position embedding 109 | embedding_size = pos_embed_old.shape[-1] 110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new 111 | # height (== width) for the checkpoint position embedding 112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5) 113 | # height (== width) for the new position embedding 114 | new_size = int(num_patches_new ** 0.5) 115 | 116 | if orig_size != new_size: 117 | # class_token and dist_token are kept unchanged 118 | # the extra tokens seems always at the beginning of the position embedding 119 | extra_tokens = pos_embed_old[:, :num_extra_tokens] 120 | # only the position tokens are interpolated 121 | pos_tokens = pos_embed_old[:, num_extra_tokens:] 122 | pos_tokens = pos_tokens.reshape( 123 | -1, orig_size, orig_size, embedding_size 124 | ).permute(0, 3, 1, 2) 125 | pos_tokens = torch.nn.functional.interpolate( 126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False 127 | ) 128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2") 131 | return interpolated_pos_embed 132 | else: 133 | return pos_embed_old 134 | 135 | 136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new): 137 | """ 138 | Args: 139 | state_dict_old: loaded state dict 140 | state_dict_new: state dict for model with new image size 141 | patch_shape_new: new model patch_shape 142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py 143 | """ 144 | all_keys = list(state_dict_old.keys()) 145 | for key in all_keys: 146 | if "relative_position_index" in key: 147 | state_dict_old.pop(key) 148 | 149 | if "relative_position_bias_table" in key: 150 | rel_pos_bias = state_dict_old[key] 151 | src_num_pos, num_attn_heads = rel_pos_bias.size() 152 | dst_num_pos, _ = state_dict_new[key].size() 153 | dst_patch_shape = patch_shape_new 154 | if dst_patch_shape[0] != dst_patch_shape[1]: 155 | raise NotImplementedError() 156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( 157 | dst_patch_shape[1] * 2 - 1 158 | ) 159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 161 | if src_size != dst_size: 162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % ( 163 | # key, src_size, src_size, dst_size, dst_size)) 164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 166 | 167 | def geometric_progression(a, r, n): 168 | return a * (1.0 - r ** n) / (1.0 - r) 169 | 170 | left, right = 1.01, 1.5 171 | while right - left > 1e-6: 172 | q = (left + right) / 2.0 173 | gp = geometric_progression(1, q, src_size // 2) 174 | if gp > dst_size // 2: 175 | right = q 176 | else: 177 | left = q 178 | 179 | # if q > 1.090307: 180 | # q = 1.090307 181 | 182 | dis = [] 183 | cur = 1 184 | for i in range(src_size // 2): 185 | dis.append(cur) 186 | cur += q ** (i + 1) 187 | 188 | r_ids = [-_ for _ in reversed(dis)] 189 | 190 | x = r_ids + [0] + dis 191 | y = r_ids + [0] + dis 192 | 193 | t = dst_size // 2.0 194 | dx = np.arange(-t, t + 0.1, 1.0) 195 | dy = np.arange(-t, t + 0.1, 1.0) 196 | 197 | # logger.info("Original positions = %s" % str(x)) 198 | # logger.info("Target positions = %s" % str(dx)) 199 | 200 | all_rel_pos_bias = [] 201 | 202 | for i in range(num_attn_heads): 203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 204 | f = interpolate.interp2d(x, y, z, kind="cubic") 205 | all_rel_pos_bias.append( 206 | torch.Tensor(f(dx, dy)) 207 | .contiguous() 208 | .view(-1, 1) 209 | .to(rel_pos_bias.device) 210 | ) 211 | 212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 213 | 214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 215 | state_dict_old[key] = new_rel_pos_bias 216 | return state_dict_old 217 | 218 | 219 | def tile(x, dim, n_tile): 220 | init_dim = x.size(dim) 221 | repeat_idx = [1] * x.dim() 222 | repeat_idx[dim] = n_tile 223 | x = x.repeat(*repeat_idx) 224 | order_index = torch.LongTensor( 225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 226 | ) 227 | return torch.index_select(x, dim, order_index.to(x.device)) 228 | 229 | 230 | def mask_logits(target, mask): 231 | return target * mask + (1 - mask) * (-1e10) 232 | 233 | 234 | class AllGather(torch.autograd.Function): 235 | """An autograd function that performs allgather on a tensor.""" 236 | 237 | @staticmethod 238 | def forward(ctx, tensor, args): 239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 240 | torch.distributed.all_gather(output, tensor) 241 | ctx.rank = args.rank 242 | ctx.batch_size = tensor.shape[0] 243 | return torch.cat(output, dim=0) 244 | 245 | @staticmethod 246 | def backward(ctx, grad_output): 247 | return ( 248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 249 | None, 250 | ) 251 | 252 | 253 | allgather_wgrad = AllGather.apply 254 | -------------------------------------------------------------------------------- /models/blip2/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy import interpolate 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _init_transformer_weights(module, initializer_range=0.02): 13 | """Initialize the weights. Copied from transformers ViT/Bert model init""" 14 | if isinstance(module, (nn.Linear, nn.Conv2d)): 15 | # Slightly different from the TF version which uses truncated_normal for initialization 16 | # cf https://github.com/pytorch/pytorch/pull/5617 17 | module.weight.data.normal_(mean=0.0, std=initializer_range) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | elif isinstance(module, nn.Embedding): 21 | module.weight.data.normal_(mean=0.0, std=initializer_range) 22 | if module.padding_idx is not None: 23 | module.weight.data[module.padding_idx].zero_() 24 | elif isinstance(module, nn.LayerNorm): 25 | module.bias.data.zero_() 26 | module.weight.data.fill_(1.0) 27 | 28 | 29 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 30 | """ 31 | Add/Remove extra temporal_embeddings as needed. 32 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 33 | 34 | temp_embed_old: (1, num_frames_old, 1, d) 35 | temp_embed_new: (1, num_frames_new, 1, d) 36 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 37 | """ 38 | # TODO zero pad 39 | num_frms_new = temp_embed_new.shape[1] 40 | num_frms_old = temp_embed_old.shape[1] 41 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 42 | if num_frms_new > num_frms_old: 43 | if add_zero: 44 | temp_embed_new[ 45 | :, :num_frms_old 46 | ] = temp_embed_old # untrained embeddings are zeros. 47 | else: 48 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 49 | elif num_frms_new < num_frms_old: 50 | temp_embed_new = temp_embed_old[:, :num_frms_new] 51 | else: # = 52 | temp_embed_new = temp_embed_old 53 | return temp_embed_new 54 | 55 | 56 | def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): 57 | """ 58 | Add/Remove extra temporal_embeddings as needed. 59 | https://arxiv.org/abs/2104.00650 shows adding zero paddings works. 60 | 61 | temp_embed_old: (1, num_frames_old, 1, d) 62 | temp_embed_new: (1, num_frames_new, 1, d) 63 | add_zero: bool, if True, add zero, else, interpolate trained embeddings. 64 | """ 65 | # TODO zero pad 66 | num_frms_new = temp_embed_new.shape[1] 67 | num_frms_old = temp_embed_old.shape[1] 68 | logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") 69 | if num_frms_new > num_frms_old: 70 | if add_zero: 71 | temp_embed_new[ 72 | :, :num_frms_old 73 | ] = temp_embed_old # untrained embeddings are zeros. 74 | else: 75 | temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) 76 | elif num_frms_new < num_frms_old: 77 | temp_embed_new = temp_embed_old[:, :num_frms_new] 78 | else: # = 79 | temp_embed_new = temp_embed_old 80 | return temp_embed_new 81 | 82 | 83 | def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): 84 | """ 85 | temp_embed_old: (1, num_frames_old, 1, d) 86 | Returns: 87 | temp_embed_new: (1, num_frames_new, 1, d) 88 | """ 89 | temp_embed_old = temp_embed_old.squeeze(2).permute( 90 | 0, 2, 1 91 | ) # (1, d, num_frames_old) 92 | temp_embed_new = F.interpolate( 93 | temp_embed_old, num_frames_new, mode="linear" 94 | ) # (1, d, num_frames_new) 95 | temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( 96 | 2 97 | ) # (1, num_frames_new, 1, d) 98 | return temp_embed_new 99 | 100 | 101 | def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new): 102 | """ 103 | Args: 104 | pos_embed_old: (1, L_old, d), pre-trained 105 | pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights 106 | num_patches_new: 107 | """ 108 | # interpolate position embedding 109 | embedding_size = pos_embed_old.shape[-1] 110 | num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new 111 | # height (== width) for the checkpoint position embedding 112 | orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5) 113 | # height (== width) for the new position embedding 114 | new_size = int(num_patches_new ** 0.5) 115 | 116 | if orig_size != new_size: 117 | # class_token and dist_token are kept unchanged 118 | # the extra tokens seems always at the beginning of the position embedding 119 | extra_tokens = pos_embed_old[:, :num_extra_tokens] 120 | # only the position tokens are interpolated 121 | pos_tokens = pos_embed_old[:, num_extra_tokens:] 122 | pos_tokens = pos_tokens.reshape( 123 | -1, orig_size, orig_size, embedding_size 124 | ).permute(0, 3, 1, 2) 125 | pos_tokens = torch.nn.functional.interpolate( 126 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False 127 | ) 128 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 129 | interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 130 | logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2") 131 | return interpolated_pos_embed 132 | else: 133 | return pos_embed_old 134 | 135 | 136 | def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new): 137 | """ 138 | Args: 139 | state_dict_old: loaded state dict 140 | state_dict_new: state dict for model with new image size 141 | patch_shape_new: new model patch_shape 142 | ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py 143 | """ 144 | all_keys = list(state_dict_old.keys()) 145 | for key in all_keys: 146 | if "relative_position_index" in key: 147 | state_dict_old.pop(key) 148 | 149 | if "relative_position_bias_table" in key: 150 | rel_pos_bias = state_dict_old[key] 151 | src_num_pos, num_attn_heads = rel_pos_bias.size() 152 | dst_num_pos, _ = state_dict_new[key].size() 153 | dst_patch_shape = patch_shape_new 154 | if dst_patch_shape[0] != dst_patch_shape[1]: 155 | raise NotImplementedError() 156 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( 157 | dst_patch_shape[1] * 2 - 1 158 | ) 159 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 160 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 161 | if src_size != dst_size: 162 | # logger.info("Position interpolate for %s from %dx%d to %dx%d" % ( 163 | # key, src_size, src_size, dst_size, dst_size)) 164 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 165 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 166 | 167 | def geometric_progression(a, r, n): 168 | return a * (1.0 - r ** n) / (1.0 - r) 169 | 170 | left, right = 1.01, 1.5 171 | while right - left > 1e-6: 172 | q = (left + right) / 2.0 173 | gp = geometric_progression(1, q, src_size // 2) 174 | if gp > dst_size // 2: 175 | right = q 176 | else: 177 | left = q 178 | 179 | # if q > 1.090307: 180 | # q = 1.090307 181 | 182 | dis = [] 183 | cur = 1 184 | for i in range(src_size // 2): 185 | dis.append(cur) 186 | cur += q ** (i + 1) 187 | 188 | r_ids = [-_ for _ in reversed(dis)] 189 | 190 | x = r_ids + [0] + dis 191 | y = r_ids + [0] + dis 192 | 193 | t = dst_size // 2.0 194 | dx = np.arange(-t, t + 0.1, 1.0) 195 | dy = np.arange(-t, t + 0.1, 1.0) 196 | 197 | # logger.info("Original positions = %s" % str(x)) 198 | # logger.info("Target positions = %s" % str(dx)) 199 | 200 | all_rel_pos_bias = [] 201 | 202 | for i in range(num_attn_heads): 203 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 204 | f = interpolate.interp2d(x, y, z, kind="cubic") 205 | all_rel_pos_bias.append( 206 | torch.Tensor(f(dx, dy)) 207 | .contiguous() 208 | .view(-1, 1) 209 | .to(rel_pos_bias.device) 210 | ) 211 | 212 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 213 | 214 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 215 | state_dict_old[key] = new_rel_pos_bias 216 | return state_dict_old 217 | 218 | 219 | def tile(x, dim, n_tile): 220 | init_dim = x.size(dim) 221 | repeat_idx = [1] * x.dim() 222 | repeat_idx[dim] = n_tile 223 | x = x.repeat(*repeat_idx) 224 | order_index = torch.LongTensor( 225 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 226 | ) 227 | return torch.index_select(x, dim, order_index.to(x.device)) 228 | 229 | 230 | def mask_logits(target, mask): 231 | return target * mask + (1 - mask) * (-1e10) 232 | 233 | 234 | class AllGather(torch.autograd.Function): 235 | """An autograd function that performs allgather on a tensor.""" 236 | 237 | @staticmethod 238 | def forward(ctx, tensor, args): 239 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 240 | torch.distributed.all_gather(output, tensor) 241 | ctx.rank = args.rank 242 | ctx.batch_size = tensor.shape[0] 243 | return torch.cat(output, dim=0) 244 | 245 | @staticmethod 246 | def backward(ctx, grad_output): 247 | return ( 248 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 249 | None, 250 | ) 251 | 252 | 253 | allgather_wgrad = AllGather.apply 254 | -------------------------------------------------------------------------------- /tasks/train_it.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datetime 4 | import logging 5 | import time 6 | from os.path import join 7 | 8 | import sys 9 | sys.path.append('.') 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import wandb 15 | 16 | from dataset import MetaLoader, create_dataset, create_loader, create_sampler 17 | from models.hawkeye_it import HawkEye_it 18 | from tasks.shared_utils import get_media_types, setup_model 19 | from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed) 20 | from utils.config_utils import setup_main 21 | from utils.distributed import get_rank, get_world_size, is_main_process 22 | from utils.logger import log_dict_to_wandb, setup_wandb, log_dict_to_tensorboard 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def train( 28 | model, model_without_ddp, 29 | train_loaders, 30 | optimizer, 31 | epoch, 32 | global_step, 33 | device, 34 | scheduler, 35 | scaler, 36 | config, 37 | tb_writer, 38 | ): 39 | model.train() 40 | 41 | metric_logger = MetricLogger(delimiter=" ") 42 | metric_logger.add_meter("lr", SmoothedValue(window=10, fmt="{value:.6f}")) 43 | loss_names = ["loss"] 44 | 45 | media_types = get_media_types(train_loaders) 46 | 47 | for name in loss_names: 48 | for m in media_types: 49 | metric_logger.add_meter( 50 | f"{m}-{name}", SmoothedValue(window=10, fmt="{value:.4f}") 51 | ) 52 | 53 | header = f"Train Epoch: [{epoch}]" 54 | log_freq = config.log_freq 55 | save_freq = config.get('save_freq', None) 56 | 57 | if config.distributed: 58 | for d in train_loaders: 59 | d.sampler.set_epoch(epoch) 60 | train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders)))) 61 | 62 | iterator = metric_logger.log_every(train_loader, log_freq, header) 63 | for i, (media_type, (image, text, instruction, _)) in enumerate(iterator): 64 | image = image.to(device, non_blocking=True) 65 | 66 | with torch.cuda.amp.autocast(enabled=config.fp16): 67 | loss_dict = model(image, text, instruction) 68 | loss = sum(loss_dict.values()) 69 | 70 | optimizer.zero_grad() 71 | if not torch.isnan(loss): 72 | scaler.scale(loss).backward() 73 | if config.optimizer.max_grad_norm > 0: 74 | scaler.unscale_(optimizer) 75 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) 76 | scaler.step(optimizer) 77 | scaler.update() 78 | else: 79 | print('nan loss encountered at turn', global_step) 80 | scheduler.step() 81 | 82 | # logging 83 | for name in loss_names: 84 | value = loss_dict[name] 85 | value = value if isinstance(value, float) else value.item() 86 | metric_logger.update(**{f"{media_type}-{name}": value}) 87 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 88 | 89 | if is_main_process() and config.wandb.enable and global_step % log_freq == 0: 90 | logs = metric_logger.get_avg_dict() 91 | log_dict_to_wandb(logs, step=global_step, prefix="train/") 92 | 93 | if is_main_process() and config.get('tensorboard', {'enable': False})['enable'] and global_step % log_freq == 0: 94 | logs = metric_logger.get_avg_dict() 95 | log_dict_to_tensorboard(logs, global_step, "train/", tb_writer) 96 | 97 | # save model every x steps 98 | if is_main_process() and save_freq is not None and global_step % save_freq == 0: 99 | logger.info(f"Epoch {epoch}") 100 | param_grad_dic = { 101 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() 102 | } 103 | state_dict = model_without_ddp.state_dict() 104 | for k in list(state_dict.keys()): 105 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 106 | # delete parameters that do not require gradient 107 | del state_dict[k] 108 | save_obj = { 109 | "model": state_dict, 110 | "optimizer": optimizer.state_dict(), 111 | "scheduler": scheduler.state_dict(), 112 | "scaler": scaler.state_dict(), 113 | "config": config, 114 | "epoch": epoch, 115 | "global_step": global_step, 116 | } 117 | if config.get("save_latest", False): 118 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) 119 | else: 120 | torch.save(save_obj, join(config.output_dir, f"ckpt_{global_step}.pth")) 121 | 122 | global_step += 1 123 | 124 | # gather the stats from all processes 125 | metric_logger.synchronize_between_processes() 126 | logger.info(f"Averaged stats: {metric_logger.global_avg()}") 127 | return global_step 128 | 129 | 130 | def setup_dataloaders(config, mode="pt"): 131 | # train datasets, create a list of data loaders 132 | logger.info(f"Creating dataset for {mode}") 133 | train_datasets, every_single_dataset = create_dataset(f"{mode}_train", config) 134 | 135 | if config.get("freeze_dataset_folder", None) is not None: 136 | # save self.anno in current dataset to "freeze_dataset_folder" 137 | # this is used to record record the dataset when training data changes 138 | os.makedirs(config.freeze_dataset_folder, exist_ok=True) 139 | for dataset in every_single_dataset: 140 | dest_path = os.path.join(config.freeze_dataset_folder, dataset.dataset_name + '.json') 141 | json.dump(dataset.anno, open(dest_path, 'w')) 142 | 143 | media_types = get_media_types(train_datasets) 144 | 145 | if config.distributed: 146 | num_tasks = get_world_size() 147 | global_rank = get_rank() 148 | samplers = create_sampler( 149 | train_datasets, [True] * len(media_types), num_tasks, global_rank 150 | ) 151 | else: 152 | samplers = [None] * len(media_types) 153 | 154 | train_loaders = create_loader( 155 | train_datasets, 156 | samplers, 157 | batch_size=[config.inputs.batch_size[k] for k in media_types], 158 | num_workers=[config.num_workers] * len(media_types), 159 | is_trains=[True] * len(media_types), 160 | collate_fns=[None] * len(media_types), 161 | ) # [0] 162 | 163 | return train_loaders, media_types 164 | 165 | 166 | def main(config): 167 | if is_main_process() and config.wandb.enable: 168 | run = setup_wandb(config) 169 | 170 | if is_main_process() and config.get('tensorboard', {'enable': False})['enable']: 171 | from torch.utils.tensorboard import SummaryWriter 172 | tb_writer = SummaryWriter(log_dir=config.tensorboard.get('log_dir', os.path.join(config.output_dir, 'tb_logs'))) 173 | else: 174 | tb_writer = None 175 | 176 | logger.info(f"train_file: {config.train_file}") 177 | 178 | setup_seed(config.seed + get_rank()) 179 | device = torch.device(config.device) 180 | 181 | train_loaders, train_media_types = setup_dataloaders( 182 | config, mode=config.mode 183 | ) 184 | num_steps_per_epoch = sum(len(d) for d in train_loaders) 185 | config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs 186 | config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs 187 | # set cudnn.benchmark=True only when input size is fixed 188 | # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 189 | cudnn.benchmark = len(train_media_types) == 1 190 | 191 | model_cls = eval(config.model.get('model_cls', 'HawkEye_it')) 192 | ( 193 | model, 194 | model_without_ddp, 195 | optimizer, 196 | scheduler, 197 | scaler, 198 | start_epoch, 199 | global_step, 200 | ) = setup_model( 201 | config, 202 | model_cls=model_cls, 203 | find_unused_parameters=True, 204 | ) 205 | if is_main_process() and config.wandb.enable: 206 | wandb.watch(model) 207 | 208 | logger.info("Start training") 209 | start_time = time.time() 210 | for epoch in range(start_epoch, config.scheduler.epochs): 211 | if not config.evaluate: 212 | global_step = train( 213 | model, model_without_ddp, 214 | train_loaders, 215 | optimizer, 216 | epoch, 217 | global_step, 218 | device, 219 | scheduler, 220 | scaler, 221 | config, 222 | tb_writer 223 | ) 224 | 225 | if is_main_process() and config.get('save_freq', None) is None: # does not save every x steps, so we save every epoch 226 | logger.info(f"Epoch {epoch}") 227 | param_grad_dic = { 228 | k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() 229 | } 230 | state_dict = model_without_ddp.state_dict() 231 | for k in list(state_dict.keys()): 232 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 233 | # delete parameters that do not require gradient 234 | del state_dict[k] 235 | save_obj = { 236 | "model": state_dict, 237 | "optimizer": optimizer.state_dict(), 238 | "scheduler": scheduler.state_dict(), 239 | "scaler": scaler.state_dict(), 240 | "config": config, 241 | "epoch": epoch, 242 | "global_step": global_step, 243 | } 244 | if config.get("save_latest", False): 245 | torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) 246 | else: 247 | torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) 248 | 249 | if config.evaluate: 250 | break 251 | 252 | dist.barrier() 253 | 254 | total_time = time.time() - start_time 255 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 256 | logger.info(f"Training time {total_time_str}") 257 | logger.info(f"Checkpoints and Logs saved at {config.output_dir}") 258 | 259 | if is_main_process() and config.wandb.enable: 260 | run.finish() 261 | 262 | 263 | if __name__ == "__main__": 264 | cfg = setup_main() 265 | main(cfg) 266 | --------------------------------------------------------------------------------