├── utils ├── __init__.py ├── timestamp_extraction.py ├── shift_video.py ├── prompts.py └── cons_utils.py ├── timechat ├── common │ ├── __init__.py │ ├── gradcam.py │ ├── optims.py │ ├── dist_utils.py │ ├── logger.py │ └── registry.py ├── runners │ ├── test.py │ └── __init__.py ├── conversation │ └── __init__.py ├── datasets │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── laion_dataset.py │ │ ├── cc_sbu_dataset.py │ │ ├── base_dataset.py │ │ ├── caption_datasets.py │ │ ├── webvid_datasets.py │ │ ├── dataloader_utils.py │ │ └── llava_instruct_dataset.py │ ├── builders │ │ ├── video_caption_builder.py │ │ ├── __init__.py │ │ ├── image_text_pair_builder.py │ │ ├── instruct_builder.py │ │ └── base_dataset_builder.py │ └── data_utils.py ├── configs │ ├── datasets │ │ ├── cc_sbu │ │ │ ├── align.yaml │ │ │ └── defaults.yaml │ │ ├── laion │ │ │ └── defaults.yaml │ │ ├── instruct │ │ │ ├── valley72k_instruct.yaml │ │ │ ├── llava_instruct.yaml │ │ │ ├── time_instruct.yaml │ │ │ ├── charades_instruct.yaml │ │ │ ├── qvhighlights_instruct.yaml │ │ │ ├── webvid_instruct.yaml │ │ │ └── youcook2_instruct.yaml │ │ └── webvid │ │ │ └── defaults.yaml │ ├── default.yaml │ └── models │ │ ├── minigpt4.yaml │ │ └── timechat.yaml ├── tasks │ ├── image_text_pretrain.py │ ├── video_text_pretrain.py │ ├── __init__.py │ └── base_task.py ├── processors │ ├── base_processor.py │ ├── __init__.py │ ├── functional_video.py │ ├── blip_processors.py │ ├── transforms_video.py │ └── video_processor.py ├── __init__.py ├── eval_configs │ └── timechat.yaml ├── train_configs │ ├── stage2_finetune_charades.yaml │ └── stage2_finetune_activitynet.yaml ├── models │ ├── blip2_outputs.py │ ├── __init__.py │ ├── blip2.py │ └── base_model.py └── utils.py ├── .gitignore ├── scripts ├── test_predictions.sh └── run_eval.sh ├── eval └── run_eval.sh ├── run.py ├── task └── grounding.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timechat/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timechat/runners/test.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timechat/conversation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timechat/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .pth 3 | 4 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timechat/configs/datasets/cc_sbu/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_align_dataset 6 | -------------------------------------------------------------------------------- /timechat/configs/datasets/cc_sbu/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_dataset/{00000..00001}.tar 6 | -------------------------------------------------------------------------------- /timechat/configs/datasets/laion/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | storage: path/laion/laion_dataset/{00000..00001}.tar 6 | -------------------------------------------------------------------------------- /timechat/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with persistent storage 5 | cache_root: "/export/home/.cache/minigpt4" 6 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/valley72k_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | valley72k_instruct: 3 | data_type: video 4 | build_info: 5 | anno_dir: "data/valley/instruct_valley_72k.json" 6 | videos_dir: "data/" 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/llava_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | llava_instruct: 3 | data_type: image 4 | build_info: 5 | anno_dir: /path/llava_instruct_150k.json 6 | videos_dir: /path/train2014/train2014/ 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/time_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | time_instruct: 3 | data_type: video 4 | build_info: 5 | anno_dir: "data/instructions/instruct_time-sensitive_73k.json" 6 | videos_dir: "data/" 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/charades_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | charades_instruct: 3 | data_type: video 4 | build_info: 5 | anno_dir: "data/Charades/instruct_tvg_12.4k_charades.json" 6 | videos_dir: "data/" 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/webvid/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | webvid: 3 | data_type: video 4 | build_info: 5 | anno_dir: path/webvid/webvid_tain_data/annotations/ 6 | videos_dir: path//webvid/webvid_tain_data/videos/ 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/qvhighlights_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | qvhighlights_instruct: 3 | data_type: video 4 | build_info: 5 | anno_dir: "data/QVhighlights/instruct_vhd_6.9k_qvhighlights.json" 6 | videos_dir: "data/" 7 | -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/webvid_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | webvid_instruct: 3 | data_type: image 4 | build_info: 5 | anno_dir: /path/webvid_align/videochat_instruct_11k.json 6 | videos_dir: /path/webvid_align/videos/ 7 | -------------------------------------------------------------------------------- /scripts/test_predictions.sh: -------------------------------------------------------------------------------- 1 | which_python=$(which python) 2 | echo "which python: ${which_python}" 3 | export PYTHONPATH=${PYTHONPATH}:${which_python} 4 | export PYTHONPATH=${PYTHONPATH}:. 5 | echo "PYTHONPATH: ${PYTHONPATH}" 6 | 7 | python eval/eval.py \ 8 | ${@:1} -------------------------------------------------------------------------------- /timechat/configs/datasets/instruct/youcook2_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | youcook2_instruct: 3 | data_type: video 4 | build_info: 5 | anno_dir: "data/YouCook2-BB/YouCook2_asr_denseCap/instruct_dvc_1.2k_youcook2.json" 6 | videos_dir: "data/YouCook2-BB/" 7 | -------------------------------------------------------------------------------- /eval/run_eval.sh: -------------------------------------------------------------------------------- 1 | # bash eval/run_eval.sh --test_path {dir} --task consistency 2 | 3 | which_python=$(which python) 4 | echo "which python: ${which_python}" 5 | export PYTHONPATH=${PYTHONPATH}:${which_python} 6 | export PYTHONPATH=${PYTHONPATH}:. 7 | echo "PYTHONPATH: ${PYTHONPATH}" 8 | 9 | python eval/eval.py \ 10 | ${@:1} -------------------------------------------------------------------------------- /timechat/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.runners.runner_base import RunnerBase 9 | 10 | __all__ = ["RunnerBase"] 11 | -------------------------------------------------------------------------------- /timechat/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.common.registry import registry 9 | from timechat.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /timechat/tasks/video_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.common.registry import registry 9 | from timechat.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("video_text_pretrain") 13 | class VideoTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /timechat/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /timechat/configs/models/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # Vicuna 16 | llama_model: "ckpt/vicuna-13b/" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "blip2_image_train" 25 | image_size: 224 26 | eval: 27 | name: "blip2_image_eval" 28 | image_size: 224 29 | text_processor: 30 | train: 31 | name: "blip_caption" 32 | eval: 33 | name: "blip_caption" 34 | -------------------------------------------------------------------------------- /timechat/configs/models/timechat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: timechat 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # llama-2 16 | llama_model: "ckpt/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "alpro_video_train" 25 | image_size: 224 26 | n_frms: 8 27 | eval: 28 | name: "alpro_video_eval" 29 | image_size: 224 30 | n_frms: 8 31 | text_processor: 32 | train: 33 | name: "blip_caption" 34 | eval: 35 | name: "blip_caption" 36 | -------------------------------------------------------------------------------- /timechat/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /timechat/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.common.registry import registry 9 | from timechat.tasks.base_task import BaseTask 10 | from timechat.tasks.image_text_pretrain import ImageTextPretrainTask 11 | from timechat.tasks.video_text_pretrain import VideoTextPretrainTask 12 | 13 | 14 | def setup_task(cfg): 15 | assert "task" in cfg.run_cfg, "Task name must be provided." 16 | 17 | task_name = cfg.run_cfg.task 18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 19 | assert task is not None, "Task {} not properly registered.".format(task_name) 20 | 21 | return task 22 | 23 | 24 | __all__ = [ 25 | "BaseTask", 26 | "ImageTextPretrainTask", 27 | "VideoTextPretrainTask" 28 | ] 29 | -------------------------------------------------------------------------------- /scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | which_python=$(which python) 2 | echo "which python: ${which_python}" 3 | export PYTHONPATH=${PYTHONPATH}:${which_python} 4 | export PYTHONPATH=${PYTHONPATH}:. 5 | echo "PYTHONPATH: ${PYTHONPATH}" 6 | 7 | model_type=$1 8 | dset_name=$2 9 | video_root="/data/video_datasets" 10 | 11 | case ${dset_name} in 12 | tvr) 13 | if [[ ${model_type} == "Gemini" ]]; then 14 | extra_args+=(--skip) 15 | elif [[ ${model_type} == "GPT-4o" ]]; then 16 | n_frames=10 17 | extra_args+=(--skip) 18 | 19 | elif [[ ${model_type} == "Video-LLaMA" ]]; then 20 | n_frames=8 21 | 22 | elif [[ ${model_type} == "Video-ChatGPT" ]]; then 23 | n_frames=100 24 | elif [[ ${model_type} == "TimeChat" ]]; then 25 | n_frames=96 26 | elif [[ ${model_type} == "VTimeLLM" ]]; then 27 | n_frames=100 28 | fi 29 | esac 30 | 31 | python run.py \ 32 | --model_type=${model_type} \ 33 | --dset_name=${dset_name} \ 34 | --video_root=${video_root} \ 35 | --n_frames=${n_frames} \ 36 | ${extra_args[@]} \ 37 | ${@:3} -------------------------------------------------------------------------------- /timechat/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from timechat.common.registry import registry 14 | 15 | from timechat.datasets.builders import * 16 | from timechat.models import * 17 | from timechat.processors import * 18 | from timechat.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /timechat/datasets/builders/video_caption_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from timechat.common.registry import registry 6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from timechat.datasets.datasets.webvid_datasets import WebvidDataset 8 | 9 | @registry.register_builder("webvid") 10 | class WebvidBuilder(BaseDatasetBuilder): 11 | train_dataset_cls = WebvidDataset 12 | DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"} 13 | 14 | def _download_ann(self): 15 | pass 16 | 17 | def _download_vis(self): 18 | pass 19 | 20 | def build(self): 21 | self.build_processors() 22 | datasets = dict() 23 | split = "train" 24 | 25 | build_info = self.config.build_info 26 | dataset_cls = self.train_dataset_cls 27 | datasets[split] = dataset_cls( 28 | vis_processor=self.vis_processors[split], 29 | text_processor=self.text_processors[split], 30 | vis_root=build_info.videos_dir, 31 | ann_root=build_info.anno_dir 32 | ) 33 | 34 | return datasets -------------------------------------------------------------------------------- /timechat/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.processors.base_processor import BaseProcessor 9 | from timechat.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | from timechat.processors.video_processor import ( 15 | AlproVideoTrainProcessor, 16 | AlproVideoEvalProcessor 17 | ) 18 | from timechat.common.registry import registry 19 | 20 | __all__ = [ 21 | "BaseProcessor", 22 | "Blip2ImageTrainProcessor", 23 | "Blip2ImageEvalProcessor", 24 | "BlipCaptionProcessor", 25 | "AlproVideoTrainProcessor", 26 | "AlproVideoEvalProcessor", 27 | ] 28 | 29 | 30 | def load_processor(name, cfg=None): 31 | """ 32 | Example 33 | 34 | >>> processor = load_processor("alpro_video_train", cfg=None) 35 | """ 36 | processor = registry.get_processor_class(name).from_config(cfg) 37 | 38 | return processor 39 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | from timechat.datasets.datasets.base_dataset import BaseDataset 10 | 11 | 12 | class LaionDataset(BaseDataset): 13 | def __init__(self, vis_processor, text_processor, location): 14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 15 | 16 | self.inner_dataset = wds.DataPipeline( 17 | wds.ResampledShards(location), 18 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 19 | wds.shuffle(1000, handler=wds.warn_and_continue), 20 | wds.decode("pilrgb", handler=wds.warn_and_continue), 21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 23 | wds.map(self.to_dict, handler=wds.warn_and_continue), 24 | ) 25 | 26 | def to_dict(self, sample): 27 | return { 28 | "image": sample[0], 29 | "text_input": self.text_processor(sample[1]["caption"]), 30 | } 31 | 32 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trouble shootings: 3 | $ pip install pydantic==1.10.8 4 | 5 | """ 6 | 7 | import torch 8 | import random 9 | import numpy as np 10 | from utils.cons_utils import BaseOptions 11 | from task.grounding import run_grounding 12 | from task.consistency import run_consistency 13 | 14 | eval_func = { 15 | "grounding": run_grounding, 16 | "consistency": run_consistency, 17 | } 18 | 19 | OPT_FILE_NAME = "opt.json" 20 | PREDICTION_FILE_NAME = "predictions.jsonl" 21 | EVALUATION_FILE_NAME = "eval_results.json" 22 | 23 | 24 | def set_seed(seed, use_cuda=True): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | if use_cuda: 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | print("Set seed ", seed) 32 | 33 | 34 | if __name__ == "__main__": 35 | """ 36 | Usage: python run.py --model_type TimeChat --dset_name charades --task grounding --debug 37 | """ 38 | base_options = BaseOptions().parse() 39 | set_seed(base_options.seed) 40 | 41 | if base_options.model_type == "TimeChat": 42 | from timechat.utils import TimeChat, TimeChat_Options 43 | args = TimeChat_Options().parse() 44 | model = TimeChat(args) 45 | args.ckpt = model.model_config["ckpt"] 46 | 47 | # TODO: Add Custom Models 48 | 49 | else: 50 | raise NotImplementedError 51 | 52 | eval_func[args.task](model=model, args=args) 53 | 54 | -------------------------------------------------------------------------------- /timechat/eval_configs/timechat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: timechat 3 | model_type: pretrain_llama_v2 4 | freeze_vit: True 5 | freeze_qformer: True 6 | max_txt_len: 2048 7 | end_sym: "" 8 | low_resource: False 9 | 10 | frozen_llama_proj: True 11 | frozen_video_Qformer: True 12 | 13 | vit_model: "ckpt/timechat/eva_vit_g.pth" 14 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 15 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth" 16 | ckpt: "ckpt/timechat/timechat_7b_paper.pth" 17 | charades_ckpt: "ckpt/timechat/TimeChat-7B-Charades-VTune.pth" 18 | activitynet_ckpt: "ckpt/timechat/TimeChat-7B-ActivityNet-VTune.pth" 19 | fusion_head_layers: 2 20 | max_frame_pos: 96 21 | fusion_header_type: "seqTransf" 22 | 23 | use_grad_checkpoint: True 24 | lora: True 25 | lora_inference_mode: True 26 | qformer_text_input: True 27 | window_size: 32 28 | stride: 32 29 | 30 | datasets: 31 | webvid: 32 | vis_processor: 33 | train: 34 | name: "alpro_video_eval" 35 | n_frms: 96 36 | image_size: 224 37 | text_processor: 38 | train: 39 | name: "blip_caption" 40 | num_video_query_token: 32 41 | tokenizer_name: "ckpt/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 42 | model_type: "llama_v2" 43 | num_frm: 96 44 | sample_type: 'uniform' 45 | max_txt_len: 2048 46 | stride: 32 47 | 48 | run: 49 | task: video_text_pretrain 50 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/cc_sbu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import webdataset as wds 4 | from timechat.datasets.datasets.base_dataset import BaseDataset 5 | from timechat.datasets.datasets.caption_datasets import CaptionDataset 6 | 7 | 8 | class CCSBUDataset(BaseDataset): 9 | def __init__(self, vis_processor, text_processor, location): 10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 11 | 12 | self.inner_dataset = wds.DataPipeline( 13 | wds.ResampledShards(location), 14 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 15 | wds.shuffle(1000, handler=wds.warn_and_continue), 16 | wds.decode("pilrgb", handler=wds.warn_and_continue), 17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 19 | wds.map(self.to_dict, handler=wds.warn_and_continue), 20 | ) 21 | 22 | def to_dict(self, sample): 23 | return { 24 | "image": sample[0], 25 | "text_input": self.text_processor(sample[1]["caption"]), 26 | "type":'image', 27 | } 28 | 29 | 30 | class CCSBUAlignDataset(CaptionDataset): 31 | 32 | def __getitem__(self, index): 33 | 34 | # TODO this assumes image input, not general enough 35 | ann = self.annotation[index] 36 | 37 | img_file = '{}.jpg'.format(ann["image_id"]) 38 | image_path = os.path.join(self.vis_root, img_file) 39 | image = Image.open(image_path).convert("RGB") 40 | 41 | image = self.vis_processor(image) 42 | caption = ann["caption"] 43 | 44 | return { 45 | "image": image, 46 | "text_input": caption, 47 | "image_id": self.img_ids[ann["image_id"]], 48 | "type":'image', 49 | } -------------------------------------------------------------------------------- /timechat/train_configs/stage2_finetune_charades.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: timechat 3 | model_type: pretrain_llama_v2 4 | freeze_vit: True 5 | freeze_qformer: False 6 | 7 | # Q-Former 8 | num_query_token: 32 9 | 10 | vit_model: "ckpt/timechat/eva_vit_g.pth" 11 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 12 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth" 13 | ckpt: "ckpt/timechat/timechat_7b.pth" # continue fine-tuning from TimeChat-7B ckpt 14 | 15 | # only train vision branch 16 | frozen_llama_proj: False 17 | frozen_video_Qformer: False 18 | 19 | fusion_head_layers: 2 20 | max_frame_pos: 96 21 | fusion_header_type: "seqTransf" 22 | 23 | max_txt_len: 2048 24 | 25 | # for llama_2_chat: 26 | end_sym: "" 27 | prompt_path: "" 28 | prompt_template: '[INST] <>\n \n<>\n\n{} [/INST] ' 29 | 30 | use_grad_checkpoint: True 31 | lora: True 32 | lora_inference_mode: False 33 | qformer_text_input: True 34 | window_size: 32 35 | stride: 32 36 | 37 | datasets: 38 | charades_instruct: 39 | data_type: video 40 | build_info: 41 | anno_dir: "data/charades_for_VTune.json" 42 | videos_dir: "data/" 43 | vis_processor: 44 | train: 45 | name: "alpro_video_train" 46 | n_frms: 96 47 | image_size: 224 48 | text_processor: 49 | train: 50 | name: "blip_caption" 51 | num_video_query_token: 32 52 | tokenizer_name: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 53 | model_type: "llama_v2" 54 | num_frm: 96 55 | sample_type: 'rand' 56 | max_txt_len: 2048 57 | stride: 32 58 | 59 | run: 60 | task: video_text_pretrain 61 | # optimizer 62 | lr_sched: "linear_warmup_cosine_lr" 63 | init_lr: 3e-5 64 | min_lr: 1e-5 65 | warmup_lr: 1e-6 66 | 67 | weight_decay: 0.05 68 | max_epoch: 3 69 | iters_per_epoch: 24811 70 | batch_size_train: 1 71 | batch_size_eval: 4 72 | num_workers: 4 73 | warmup_steps: 14916 74 | accum_grad_iters: 8 75 | 76 | seed: 42 77 | output_dir: "ckpt/timechat/train_stage2_charades" 78 | 79 | amp: True 80 | resume_ckpt_path: null 81 | 82 | evaluate: False 83 | train_splits: ["train"] 84 | 85 | device: "cuda" 86 | world_size: 1 87 | dist_url: "env://" 88 | distributed: True -------------------------------------------------------------------------------- /timechat/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /timechat/train_configs/stage2_finetune_activitynet.yaml: -------------------------------------------------------------------------------- 1 | # Usage: torchrun --nproc_per_node=4 train.py --cfg-path train_configs/stage2_finetune_activitynet.yaml 2 | model: 3 | arch: timechat 4 | model_type: pretrain_llama_v2 5 | freeze_vit: True 6 | freeze_qformer: False 7 | 8 | 9 | # Q-Former 10 | num_query_token: 32 11 | 12 | vit_model: "ckpt/timechat/eva_vit_g.pth" 13 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 14 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth" 15 | ckpt: "ckpt/timechat_7b.pth" # continue fine-tuning from TimeChat-7B ckpt 16 | 17 | # only train vision branch 18 | frozen_llama_proj: False 19 | frozen_video_Qformer: False 20 | 21 | fusion_head_layers: 2 22 | max_frame_pos: 96 23 | fusion_header_type: "seqTransf" 24 | 25 | max_txt_len: 2048 26 | 27 | # for llama_2_chat: 28 | end_sym: "" 29 | prompt_path: "" 30 | prompt_template: '[INST] <>\n \n<>\n\n{} [/INST] ' 31 | 32 | use_grad_checkpoint: True 33 | lora: True 34 | lora_inference_mode: False 35 | qformer_text_input: True 36 | window_size: 32 37 | stride: 32 38 | 39 | datasets: 40 | charades_instruct: 41 | data_type: video 42 | build_info: 43 | anno_dir: "data/activitynet_for_VTune.json" 44 | videos_dir: "data/" 45 | vis_processor: 46 | train: 47 | name: "alpro_video_train" 48 | n_frms: 96 49 | image_size: 224 50 | text_processor: 51 | train: 52 | name: "blip_caption" 53 | num_video_query_token: 32 54 | tokenizer_name: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/" 55 | model_type: "llama_v2" 56 | num_frm: 96 57 | sample_type: 'rand' 58 | max_txt_len: 2048 59 | stride: 32 60 | 61 | run: 62 | task: video_text_pretrain 63 | # optimizer 64 | lr_sched: "linear_warmup_cosine_lr" 65 | init_lr: 3e-5 66 | min_lr: 1e-5 67 | warmup_lr: 1e-6 68 | 69 | weight_decay: 0.05 70 | max_epoch: 3 71 | iters_per_epoch: 51377 72 | batch_size_train: 1 73 | batch_size_eval: 4 74 | num_workers: 4 75 | warmup_steps: 25688 76 | accum_grad_iters: 4 77 | 78 | seed: 42 79 | output_dir: "ckpt/timechat/activitynet_vtune" 80 | 81 | amp: True 82 | resume_ckpt_path: null 83 | 84 | evaluate: False 85 | train_splits: ["train"] 86 | 87 | device: "cuda" 88 | world_size: 1 89 | dist_url: "env://" 90 | distributed: True 91 | -------------------------------------------------------------------------------- /timechat/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from timechat.datasets.builders.base_dataset_builder import load_dataset_config 9 | from timechat.datasets.builders.image_text_pair_builder import ( 10 | CCSBUBuilder, 11 | LaionBuilder, 12 | CCSBUAlignBuilder 13 | ) 14 | from timechat.datasets.builders.video_caption_builder import WebvidBuilder 15 | from timechat.common.registry import registry 16 | from timechat.datasets.builders.instruct_builder import WebvidInstruct_Builder, LlavaInstruct_Builder, \ 17 | Youcook2Instruct_Builder, TimeInstruct_Builder, Valley72kInstruct_Builder, QVhighlightsInstruct_Builder, \ 18 | CharadesInstruct_Builder 19 | __all__ = [ 20 | "CCSBUBuilder", 21 | "LaionBuilder", 22 | "CCSBUAlignBuilder", 23 | "WebvidBuilder", 24 | "LlavaInstruct_Builder", 25 | "WebvidInstruct_Builder", 26 | "Youcook2Instruct_Builder", 27 | "TimeInstruct_Builder", 28 | "Valley72kInstruct_Builder", 29 | "QVhighlightsInstruct_Builder", 30 | "CharadesInstruct_Builder", 31 | ] 32 | 33 | 34 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 35 | """ 36 | Example 37 | 38 | >>> dataset = load_dataset("coco_caption", cfg=None) 39 | >>> splits = dataset.keys() 40 | >>> print([len(dataset[split]) for split in splits]) 41 | 42 | """ 43 | if cfg_path is None: 44 | cfg = None 45 | else: 46 | cfg = load_dataset_config(cfg_path) 47 | 48 | try: 49 | builder = registry.get_builder_class(name)(cfg) 50 | except TypeError: 51 | print( 52 | f"Dataset {name} not found. Available datasets:\n" 53 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 54 | ) 55 | exit(1) 56 | 57 | if vis_path is not None: 58 | if data_type is None: 59 | # use default data type in the config 60 | data_type = builder.config.data_type 61 | 62 | assert ( 63 | data_type in builder.config.build_info 64 | ), f"Invalid data_type {data_type} for {name}." 65 | 66 | builder.config.build_info.get(data_type).storage = vis_path 67 | 68 | dataset = builder.build_datasets() 69 | return dataset 70 | 71 | 72 | class DatasetZoo: 73 | def __init__(self) -> None: 74 | self.dataset_zoo = { 75 | k: list(v.DATASET_CONFIG_DICT.keys()) 76 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 77 | } 78 | 79 | def get_names(self): 80 | return list(self.dataset_zoo.keys()) 81 | 82 | 83 | dataset_zoo = DatasetZoo() 84 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from timechat.datasets.datasets.base_dataset import BaseDataset 12 | from PIL import Image 13 | 14 | 15 | class __DisplMixin: 16 | def displ_item(self, index): 17 | sample, ann = self.__getitem__(index), self.annotation[index] 18 | 19 | return OrderedDict( 20 | { 21 | "file": ann["image"], 22 | "caption": ann["caption"], 23 | "image": sample["image"], 24 | } 25 | ) 26 | 27 | 28 | class CaptionDataset(BaseDataset, __DisplMixin): 29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 30 | """ 31 | vis_root (string): Root directory of images (e.g. coco/images/) 32 | ann_root (string): directory to store the annotation file 33 | """ 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | self.img_ids = {} 37 | n = 0 38 | for ann in self.annotation: 39 | img_id = ann["image_id"] 40 | if img_id not in self.img_ids.keys(): 41 | self.img_ids[img_id] = n 42 | n += 1 43 | 44 | def __getitem__(self, index): 45 | 46 | # TODO this assumes image input, not general enough 47 | ann = self.annotation[index] 48 | 49 | img_file = '{:0>12}.jpg'.format(ann["image_id"]) 50 | image_path = os.path.join(self.vis_root, img_file) 51 | image = Image.open(image_path).convert("RGB") 52 | 53 | image = self.vis_processor(image) 54 | caption = self.text_processor(ann["caption"]) 55 | 56 | return { 57 | "image": image, 58 | "text_input": caption, 59 | "image_id": self.img_ids[ann["image_id"]], 60 | } 61 | 62 | 63 | class CaptionEvalDataset(BaseDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | split (string): val or test 69 | """ 70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.vis_root, ann["image"]) 77 | image = Image.open(image_path).convert("RGB") 78 | 79 | image = self.vis_processor(image) 80 | 81 | return { 82 | "image": image, 83 | "image_id": ann["image_id"], 84 | "instance_id": ann["instance_id"], 85 | } 86 | -------------------------------------------------------------------------------- /utils/timestamp_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def extract_time(paragraph): 5 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower() 6 | paragraph = paragraph.lower() 7 | paragraph.replace(prompt, '') 8 | # Split text into sentences based on common delimiters 9 | sentences = re.split(r'[!?\n]', paragraph) 10 | 11 | # Keywords that might indicate the presence of time information 12 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"] 13 | # filter sentences by keywords 14 | candidates = [] 15 | for sentence in sentences: 16 | # If sentence contains one of the keywords 17 | if any(keyword in sentence for keyword in keywords): 18 | candidates.append(sentence) 19 | 20 | timestamps = [] 21 | # Check for The given query happens in m - n (seconds) 22 | patterns = [ 23 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)" 24 | ] 25 | 26 | for time_pattern in patterns: 27 | time_matches = re.findall(time_pattern, paragraph) 28 | if time_matches: 29 | timestamps = [[float(start), float(end)] for start, end in time_matches] 30 | 31 | if len(sentences) == 0: 32 | return [] 33 | # check for other formats e.g.: 34 | # 1 .Starting time: 0.8 seconds 35 | # Ending time: 1.1 seconds 36 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds. 37 | if len(timestamps) == 0: 38 | times = [] 39 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5) 40 | for sentence in candidates: 41 | time = re.findall(time_regex, sentence) 42 | if time: 43 | time_in_sec = float(time[0]) 44 | times.append(time_in_sec) 45 | times = times[:len(times) // 2 * 2] 46 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] 47 | # Check for examples like: 48 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23. 49 | if len(timestamps) == 0: 50 | times = [] 51 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05) 52 | for sentence in candidates: 53 | time = re.findall(time_regex, sentence) 54 | if time: 55 | t = time[0] 56 | else: 57 | continue 58 | # If time is in HH:MM:SS format, convert to seconds 59 | if t.count(':') == 2: 60 | h, m, s = map(int, t.split(':')) 61 | time_in_sec = h * 3600 + m * 60 + s 62 | elif t.count(':') == 1: 63 | m, s = map(int, t.split(':')) 64 | time_in_sec = m * 60 + s 65 | times.append(time_in_sec) 66 | times = times[:len(times) // 2 * 2] 67 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] 68 | 69 | results = [] 70 | for (start, end) in timestamps: 71 | if end > start: 72 | results.append([start, end]) 73 | else: 74 | results.append([end, start]) 75 | 76 | if len(results) == 0: 77 | return [0, 0] 78 | else: 79 | return results[0] -------------------------------------------------------------------------------- /timechat/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from timechat.common.registry import registry 6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from timechat.datasets.datasets.laion_dataset import LaionDataset 8 | from timechat.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset 9 | 10 | 11 | @registry.register_builder("cc_sbu") 12 | class CCSBUBuilder(BaseDatasetBuilder): 13 | train_dataset_cls = CCSBUDataset 14 | 15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} 16 | 17 | def _download_ann(self): 18 | pass 19 | 20 | def _download_vis(self): 21 | pass 22 | 23 | def build(self): 24 | self.build_processors() 25 | 26 | build_info = self.config.build_info 27 | 28 | datasets = dict() 29 | split = "train" 30 | 31 | # create datasets 32 | # [NOTE] return inner_datasets (wds.DataPipeline) 33 | dataset_cls = self.train_dataset_cls 34 | datasets[split] = dataset_cls( 35 | vis_processor=self.vis_processors[split], 36 | text_processor=self.text_processors[split], 37 | location=build_info.storage, 38 | ).inner_dataset 39 | 40 | return datasets 41 | 42 | 43 | @registry.register_builder("laion") 44 | class LaionBuilder(BaseDatasetBuilder): 45 | train_dataset_cls = LaionDataset 46 | 47 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} 48 | 49 | def _download_ann(self): 50 | pass 51 | 52 | def _download_vis(self): 53 | pass 54 | 55 | def build(self): 56 | self.build_processors() 57 | 58 | build_info = self.config.build_info 59 | 60 | datasets = dict() 61 | split = "train" 62 | 63 | # create datasets 64 | # [NOTE] return inner_datasets (wds.DataPipeline) 65 | dataset_cls = self.train_dataset_cls 66 | datasets[split] = dataset_cls( 67 | vis_processor=self.vis_processors[split], 68 | text_processor=self.text_processors[split], 69 | location=build_info.storage, 70 | ).inner_dataset 71 | 72 | return datasets 73 | 74 | 75 | @registry.register_builder("cc_sbu_align") 76 | class CCSBUAlignBuilder(BaseDatasetBuilder): 77 | train_dataset_cls = CCSBUAlignDataset 78 | 79 | DATASET_CONFIG_DICT = { 80 | "default": "configs/datasets/cc_sbu/align.yaml", 81 | } 82 | 83 | def build_datasets(self): 84 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 85 | logging.info("Building datasets...") 86 | self.build_processors() 87 | 88 | build_info = self.config.build_info 89 | storage_path = build_info.storage 90 | 91 | datasets = dict() 92 | 93 | if not os.path.exists(storage_path): 94 | warnings.warn("storage path {} does not exist.".format(storage_path)) 95 | 96 | # create datasets 97 | dataset_cls = self.train_dataset_cls 98 | datasets['train'] = dataset_cls( 99 | vis_processor=self.vis_processors["train"], 100 | text_processor=self.text_processors["train"], 101 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')], 102 | vis_root=os.path.join(storage_path, 'image'), 103 | ) 104 | 105 | return datasets 106 | 107 | -------------------------------------------------------------------------------- /timechat/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from timechat.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | iters_per_epoch, 63 | min_lr, 64 | init_lr, 65 | warmup_steps=0, 66 | warmup_start_lr=-1, 67 | **kwargs 68 | ): 69 | self.optimizer = optimizer 70 | 71 | self.max_epoch = max_epoch 72 | self.iters_per_epoch = iters_per_epoch 73 | self.min_lr = min_lr 74 | 75 | self.init_lr = init_lr 76 | self.warmup_steps = warmup_steps 77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 78 | 79 | def step(self, cur_epoch, cur_step): 80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 81 | if total_cur_step < self.warmup_steps: 82 | warmup_lr_schedule( 83 | step=cur_step, 84 | optimizer=self.optimizer, 85 | max_step=self.warmup_steps, 86 | init_lr=self.warmup_start_lr, 87 | max_lr=self.init_lr, 88 | ) 89 | else: 90 | cosine_lr_schedule( 91 | epoch=total_cur_step, 92 | optimizer=self.optimizer, 93 | max_epoch=self.max_epoch * self.iters_per_epoch, 94 | init_lr=self.init_lr, 95 | min_lr=self.min_lr, 96 | ) 97 | 98 | 99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 100 | """Decay the learning rate""" 101 | lr = (init_lr - min_lr) * 0.5 * ( 102 | 1.0 + math.cos(math.pi * epoch / max_epoch) 103 | ) + min_lr 104 | for param_group in optimizer.param_groups: 105 | param_group["lr"] = lr 106 | 107 | 108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 109 | """Warmup the learning rate""" 110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 111 | for param_group in optimizer.param_groups: 112 | param_group["lr"] = lr 113 | 114 | 115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 116 | """Decay the learning rate""" 117 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = lr 120 | -------------------------------------------------------------------------------- /task/grounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from easydict import EasyDict as edict 4 | import time 5 | import os 6 | import random 7 | import numpy as np 8 | from utils.cons_utils import (save_jsonl, save_json, load_json, get_iou, 9 | BaseOptions, load_logger, display) 10 | from eval.eval import evaluate_grounding 11 | logger = load_logger("[Grounding Evaluation]") 12 | 13 | OPT_FILE_NAME = "opt.json" 14 | PREDICTION_FILE_NAME = "predictions.jsonl" 15 | EVALUATION_FILE_NAME = "eval_results.json" 16 | 17 | 18 | def set_seed(seed, use_cuda=True): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | if use_cuda: 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | print("Set seed ", seed) 26 | 27 | 28 | def main(args, model): 29 | results = [] 30 | 31 | test_data = load_json(args.test_path) 32 | target_vid_list = [file.split(".")[0] for file in os.listdir(args.video_root)] 33 | target_vid_list = [vid for vid in target_vid_list if vid in list(test_data.keys())] 34 | print(f"Total {len(target_vid_list)} videos in {args.video_root}") 35 | 36 | path_to_predictions = f"{args.task}_{PREDICTION_FILE_NAME}" 37 | path_to_eval_results = f"{args.task}_{EVALUATION_FILE_NAME}" 38 | 39 | for n_data, (vid, data) in tqdm(enumerate(test_data.items()), total=len(target_vid_list), desc="Evaluating.."): 40 | duration = data['duration'] 41 | video_path = os.path.join(args.video_root, f"{vid}.mp4") 42 | 43 | # Load video frame features 44 | if os.path.exists(video_path): 45 | video_features, msg = model.load_video_features(video_path) 46 | else: 47 | print(f"Video {vid} not found") 48 | continue 49 | 50 | for i, (query, gt_moment) in enumerate(zip(data['sentences'], data['timestamps'])): 51 | gt_moment = [min(gt_moment[0], duration), min(gt_moment[1], duration)] 52 | pred_moment = model.run(task="grounding", video_features=video_features, query=query, 53 | duration=duration, msg=msg) 54 | 55 | # Save the results. 56 | result = edict( 57 | meta=edict( 58 | vid=vid, 59 | sentence=data['sentences'][i], 60 | timestamp=data['timestamps'][i], 61 | duration=data['duration'] 62 | ), 63 | prediction=edict( 64 | qa=pred_moment, 65 | iou=get_iou(gt_moment, pred_moment["t"]), 66 | ), 67 | ) 68 | results.append(result) 69 | 70 | if args.debug and n_data == 1: 71 | break 72 | 73 | if n_data % 50 == 0: 74 | logger.info(f"{len(results)} results are saved") 75 | save_jsonl(results, os.path.join(args.output_dir, f"{path_to_predictions}")) 76 | 77 | logger.info(f"{len(results)} predictions will be saved at {path_to_predictions}") 78 | save_jsonl(results, os.path.join(args.output_dir, path_to_predictions)) 79 | 80 | logger.info("============ Save Performance ============") 81 | grounding_results = evaluate_grounding(results, verbos=True) 82 | save_json(grounding_results, os.path.join(args.output_dir, path_to_eval_results), save_pretty=True) 83 | logger.info("Done.") 84 | 85 | 86 | def run_grounding(model, args): 87 | logger.info("Measuring Grounding") 88 | cur_time = time.strftime("%Y_%m_%d_%H_%M_%S") 89 | if args.exp_id is None: 90 | args.exp_id = args.model_type 91 | 92 | if args.debug: 93 | output_dir = f"debug_results/{args.model_type}/{args.exp_id}_{args.dset_name}_{cur_time}" 94 | else: 95 | if args.fine_tuned: 96 | output_dir = f"grounding_results/{args.model_type}/fine_tuned_{args.exp_id}_{args.dset_name}_{cur_time}" 97 | else: 98 | output_dir = f"grounding_results/{args.model_type}/{args.exp_id}_{args.dset_name}_{cur_time}" 99 | 100 | os.makedirs(output_dir, exist_ok=True) 101 | args.output_dir = output_dir 102 | args.test_path = f"data/{args.dset_name}_test.json" 103 | 104 | # display and save results 105 | display(args) 106 | save_json(vars(args), os.path.join(args.output_dir, OPT_FILE_NAME), save_pretty=True) 107 | 108 | main(args, model) 109 | 110 | -------------------------------------------------------------------------------- /timechat/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /timechat/processors/functional_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import warnings 9 | 10 | import torch 11 | 12 | 13 | def _is_tensor_video_clip(clip): 14 | if not torch.is_tensor(clip): 15 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 16 | 17 | if not clip.ndimension() == 4: 18 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 19 | 20 | return True 21 | 22 | 23 | def crop(clip, i, j, h, w): 24 | """ 25 | Args: 26 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 27 | """ 28 | if len(clip.size()) != 4: 29 | raise ValueError("clip should be a 4D tensor") 30 | return clip[..., i : i + h, j : j + w] 31 | 32 | 33 | def resize(clip, target_size, interpolation_mode): 34 | if len(target_size) != 2: 35 | raise ValueError( 36 | f"target size should be tuple (height, width), instead got {target_size}" 37 | ) 38 | return torch.nn.functional.interpolate( 39 | clip, size=target_size, mode=interpolation_mode, align_corners=False 40 | ) 41 | 42 | 43 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 44 | """ 45 | Do spatial cropping and resizing to the video clip 46 | Args: 47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 48 | i (int): i in (i,j) i.e coordinates of the upper left corner. 49 | j (int): j in (i,j) i.e coordinates of the upper left corner. 50 | h (int): Height of the cropped region. 51 | w (int): Width of the cropped region. 52 | size (tuple(int, int)): height and width of resized clip 53 | Returns: 54 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) 55 | """ 56 | if not _is_tensor_video_clip(clip): 57 | raise ValueError("clip should be a 4D torch.tensor") 58 | clip = crop(clip, i, j, h, w) 59 | clip = resize(clip, size, interpolation_mode) 60 | return clip 61 | 62 | 63 | def center_crop(clip, crop_size): 64 | if not _is_tensor_video_clip(clip): 65 | raise ValueError("clip should be a 4D torch.tensor") 66 | h, w = clip.size(-2), clip.size(-1) 67 | th, tw = crop_size 68 | if h < th or w < tw: 69 | raise ValueError("height and width must be no smaller than crop_size") 70 | 71 | i = int(round((h - th) / 2.0)) 72 | j = int(round((w - tw) / 2.0)) 73 | return crop(clip, i, j, th, tw) 74 | 75 | 76 | def to_tensor(clip): 77 | """ 78 | Convert tensor data type from uint8 to float, divide value by 255.0 and 79 | permute the dimensions of clip tensor 80 | Args: 81 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 82 | Return: 83 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 84 | """ 85 | _is_tensor_video_clip(clip) 86 | if not clip.dtype == torch.uint8: 87 | raise TypeError( 88 | "clip tensor should have data type uint8. Got %s" % str(clip.dtype) 89 | ) 90 | return clip.float().permute(3, 0, 1, 2) / 255.0 91 | 92 | 93 | def normalize(clip, mean, std, inplace=False): 94 | """ 95 | Args: 96 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 97 | mean (tuple): pixel RGB mean. Size is (3) 98 | std (tuple): pixel standard deviation. Size is (3) 99 | Returns: 100 | normalized clip (torch.tensor): Size is (C, T, H, W) 101 | """ 102 | if not _is_tensor_video_clip(clip): 103 | raise ValueError("clip should be a 4D torch.tensor") 104 | if not inplace: 105 | clip = clip.clone() 106 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 107 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 109 | return clip 110 | 111 | 112 | def hflip(clip): 113 | """ 114 | Args: 115 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 116 | Returns: 117 | flipped clip (torch.tensor): Size is (C, T, H, W) 118 | """ 119 | if not _is_tensor_video_clip(clip): 120 | raise ValueError("clip should be a 4D torch.tensor") 121 | return clip.flip(-1) 122 | -------------------------------------------------------------------------------- /utils/shift_video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | from utils.cons_utils import load_json 6 | import copy 7 | 8 | 9 | def save_video_with_fps(frames, output_path, frame_size, original_fps, target_fps=None): 10 | if target_fps is None: 11 | target_fps = original_fps 12 | """Save frames into a video file at the specified target frame rate (1 FPS).""" 13 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Change to XVID for better compatibility 14 | out = cv2.VideoWriter(output_path, fourcc, target_fps, frame_size) 15 | 16 | # Calculate the step to downsample frames from original FPS to target FPS 17 | frame_step = max(1, int(round(original_fps / target_fps))) # Safeguard against rounding issues 18 | 19 | # Write only frames sampled at the correct step interval 20 | for i in range(0, len(frames), frame_step): 21 | out.write(np.uint8(frames[i])) # Ensure frames are of type np.uint8 22 | 23 | out.release() 24 | 25 | 26 | def swap_frames(video_path, timestamps, shifted_timestamps, output_dir, vid): 27 | """Swap frames between original and shifted timestamps and save both original and swapped videos at 1 FPS.""" 28 | 29 | # Open the video 30 | cap = cv2.VideoCapture(video_path) 31 | fps = cap.get(cv2.CAP_PROP_FPS) 32 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 33 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 34 | frame_size = (frame_width, frame_height) 35 | 36 | # Buffer to store frames 37 | original_frames = [] 38 | while cap.isOpened(): 39 | ret, frame = cap.read() 40 | if not ret: 41 | break 42 | 43 | # Store the frame in the buffer 44 | original_frames.append(frame) 45 | 46 | # # Save the original video (before swapping) with 1 FPS 47 | # original_output_path = os.path.join(output_dir, f"{vid}.mp4") 48 | # save_video_with_fps(original_frames, original_output_path, frame_size, original_fps=fps) 49 | 50 | # Swap frames as per timestamps 51 | for i, (timestamp, shifted_timestamp) in enumerate(zip(timestamps, shifted_timestamps)): 52 | swapped_output_path = os.path.join(output_dir, f"{vid}_{i}.mp4") 53 | frames = copy.deepcopy(original_frames) 54 | start1, end1 = [int(timestamp[0] * fps), int(timestamp[1] * fps)] 55 | start2, end2 = [int(shifted_timestamp[0] * fps), int(shifted_timestamp[1] * fps)] 56 | 57 | # Ensure the ranges have the same length by extending or truncating 58 | length1 = end1 - start1 59 | length2 = end2 - start2 60 | if length1 > length2: 61 | # Extend second range by repeating the last frame 62 | extra_frames = [frames[end2 - 1]] * (length1 - length2) 63 | frames2 = np.concatenate([frames[start2:end2], extra_frames], axis=0) 64 | elif length2 > length1: 65 | # Truncate second range to match the first range's length 66 | frames2 = frames[start2:start2 + length1] 67 | else: 68 | frames2 = frames[start2:end2] 69 | 70 | # Swap frames between the two ranges 71 | temp = np.copy(frames[start1:end1]) # Copy frames from first range 72 | frames[start1:end1] = frames2 # Place frames from second range to first 73 | frames[start2:start2 + len(temp)] = temp # Place copied frames into second range 74 | 75 | # Save the swapped video with 1 FPS 76 | save_video_with_fps(frames, swapped_output_path, frame_size, original_fps=fps) 77 | cap.release() # Release outside the loop 78 | 79 | 80 | if __name__ == "__main__": 81 | # Load and process each video from annotations 82 | dset_name = "activitynet" # or charades 83 | video_root_path = "" # set the right path to your video files 84 | annotations = load_json(f"data/{dset_name}_consistency_test.json") 85 | output_dir = f'{video_root_path}/{dset_name}/shifted_videos/' 86 | 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | print("Output dir:", output_dir) 91 | 92 | for vid, details in tqdm(annotations.items()): 93 | video_path = f"{video_root_path}/{dset_name}/{vid}.mp4" # Assuming the videos are in a "data" directory 94 | if os.path.exists(video_path): 95 | original_timestamps = details["timestamps"] 96 | shifted_timestamps = details["shifted_timestamps"] 97 | swap_frames(video_path, original_timestamps, shifted_timestamps, output_dir, vid) 98 | 99 | else: 100 | print(f"Video file {video_path} not found.") 101 | -------------------------------------------------------------------------------- /timechat/models/blip2_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import torch 13 | from transformers.modeling_outputs import ( 14 | ModelOutput, 15 | BaseModelOutputWithPoolingAndCrossAttentions, 16 | CausalLMOutputWithCrossAttentions, 17 | ) 18 | 19 | 20 | @dataclass 21 | class BlipSimilarity(ModelOutput): 22 | sim_i2t: torch.FloatTensor = None 23 | sim_t2i: torch.FloatTensor = None 24 | 25 | sim_i2t_m: Optional[torch.FloatTensor] = None 26 | sim_t2i_m: Optional[torch.FloatTensor] = None 27 | 28 | sim_i2t_targets: Optional[torch.FloatTensor] = None 29 | sim_t2i_targets: Optional[torch.FloatTensor] = None 30 | 31 | 32 | @dataclass 33 | class BlipIntermediateOutput(ModelOutput): 34 | """ 35 | Data class for intermediate outputs of BLIP models. 36 | 37 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 38 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 39 | 40 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 41 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 42 | 43 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 44 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 45 | 46 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 47 | decoder_labels (torch.LongTensor): labels for the captioning loss. 48 | 49 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 50 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 51 | 52 | """ 53 | 54 | # uni-modal features 55 | image_embeds: torch.FloatTensor = None 56 | text_embeds: Optional[torch.FloatTensor] = None 57 | 58 | image_embeds_m: Optional[torch.FloatTensor] = None 59 | text_embeds_m: Optional[torch.FloatTensor] = None 60 | 61 | # intermediate outputs of multimodal encoder 62 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 64 | 65 | itm_logits: Optional[torch.FloatTensor] = None 66 | itm_labels: Optional[torch.LongTensor] = None 67 | 68 | # intermediate outputs of multimodal decoder 69 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 70 | decoder_labels: Optional[torch.LongTensor] = None 71 | 72 | 73 | @dataclass 74 | class BlipOutput(ModelOutput): 75 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 76 | sims: Optional[BlipSimilarity] = None 77 | 78 | intermediate_output: BlipIntermediateOutput = None 79 | 80 | loss: Optional[torch.FloatTensor] = None 81 | 82 | loss_itc: Optional[torch.FloatTensor] = None 83 | 84 | loss_itm: Optional[torch.FloatTensor] = None 85 | 86 | loss_lm: Optional[torch.FloatTensor] = None 87 | 88 | 89 | @dataclass 90 | class BlipOutputFeatures(ModelOutput): 91 | """ 92 | Data class of features from BlipFeatureExtractor. 93 | 94 | Args: 95 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 96 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 97 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 98 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 99 | 100 | The first embedding or feature is for the [CLS] token. 101 | 102 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 103 | """ 104 | 105 | image_embeds: Optional[torch.FloatTensor] = None 106 | image_embeds_proj: Optional[torch.FloatTensor] = None 107 | 108 | text_embeds: Optional[torch.FloatTensor] = None 109 | text_embeds_proj: Optional[torch.FloatTensor] = None 110 | 111 | multimodal_embeds: Optional[torch.FloatTensor] = None 112 | -------------------------------------------------------------------------------- /timechat/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from timechat.common.registry import registry 11 | from timechat.processors.base_processor import BaseProcessor 12 | from timechat.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("blip_caption") 29 | class BlipCaptionProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg=None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption 69 | 70 | 71 | @registry.register_processor("blip2_image_train") 72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 74 | super().__init__(mean=mean, std=std) 75 | 76 | self.transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop( 79 | image_size, 80 | scale=(min_scale, max_scale), 81 | interpolation=InterpolationMode.BICUBIC, 82 | ), 83 | transforms.ToTensor(), 84 | self.normalize, 85 | ] 86 | ) 87 | 88 | def __call__(self, item): 89 | return self.transform(item) 90 | 91 | @classmethod 92 | def from_config(cls, cfg=None): 93 | if cfg is None: 94 | cfg = OmegaConf.create() 95 | 96 | image_size = cfg.get("image_size", 224) 97 | 98 | mean = cfg.get("mean", None) 99 | std = cfg.get("std", None) 100 | 101 | min_scale = cfg.get("min_scale", 0.5) 102 | max_scale = cfg.get("max_scale", 1.0) 103 | 104 | return cls( 105 | image_size=image_size, 106 | mean=mean, 107 | std=std, 108 | min_scale=min_scale, 109 | max_scale=max_scale, 110 | ) 111 | 112 | 113 | @registry.register_processor("blip2_image_eval") 114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor): 115 | def __init__(self, image_size=224, mean=None, std=None): 116 | super().__init__(mean=mean, std=std) 117 | 118 | self.transform = transforms.Compose( 119 | [ 120 | transforms.Resize( 121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 122 | ), 123 | transforms.ToTensor(), 124 | self.normalize, 125 | ] 126 | ) 127 | 128 | def __call__(self, item): 129 | return self.transform(item) 130 | 131 | @classmethod 132 | def from_config(cls, cfg=None): 133 | if cfg is None: 134 | cfg = OmegaConf.create() 135 | 136 | image_size = cfg.get("image_size", 224) 137 | 138 | mean = cfg.get("mean", None) 139 | std = cfg.get("std", None) 140 | 141 | return cls(image_size=image_size, mean=mean, std=std) 142 | 143 | -------------------------------------------------------------------------------- /utils/prompts.py: -------------------------------------------------------------------------------- 1 | neg_occurrence = [ 2 | "Is the event '{event}' absent from {st} to {ed} seconds in the video?", 3 | "Is the event '{event}' not present from {st} to {ed} seconds in the video?", 4 | "Does the event '{event}' not happen from {st} to {ed} seconds in the video?", 5 | "Is the event '{event}' missing from {st} to {ed} seconds in the video?" 6 | ] 7 | 8 | pos_occurrence = [ 9 | "Is the event '{event}' present from {st} to {ed} seconds in the video?", 10 | "Is the event '{event}' occurring from {st} to {ed} seconds in the video?", 11 | "Does the event '{event}' happen from {st} to {ed} seconds in the video?", 12 | "Is the event '{event}' included from {st} to {ed} seconds in the video?" 13 | ] 14 | 15 | grounding_prompts = [ 16 | "When does the event '{event}' happen in the video? Please only return its start time and end time.", 17 | "Please find the visual contents in the video described by a given event, determining its starting and ending times. Now I will give you the event: '{event}'. Please only return its start time and end time.", 18 | "Please answer when the event '{event}' occurs in the video. The output format should be: start - end seconds’. Please return its start time and end time." 19 | ] 20 | 21 | # default grounding prompt is third one in the list 'grounding_prompts'. 22 | prompt = { 23 | "grounding": "Please answer when the event '{event}' occurs in the video. The output format should be: 'start - end seconds'. Please return its start time and end time.", 24 | "pos": pos_occurrence, 25 | "neg": neg_occurrence, 26 | "add_detail": "Please answer with 'Yes' or 'No'.", 27 | "description": "Please describe the given video in detail.", 28 | "compositional": "{question} from {st} to {ed} seconds in the video?", 29 | } 30 | 31 | cot = { 32 | "grounding": """Your task is to predict the start and end times of an action or event described by a query sentence based on the visual content of the video. Use Chain-of-Thought reasoning to break down the query, analyze key moments, and accurately identify the time range where the action occurs. 33 | ### Chain-of-Thought Reasoning: 34 | 1. **Step 1: Parse the Query**: Break down the query sentence to understand the key action or event that you need to locate. 35 | 2. **Step 2: Analyze the Video Features**: Examine the sequence of video frames to detect patterns that match the key action described in the query. 36 | 3. **Step 3: Identify the Temporal Boundaries**: Use temporal reasoning to find the start and end frames of the action based on the video features. 37 | 4. **Step 4: Predict Start and End Times**: Map the identified frames to timestamps in the video, making sure the start and end times align with the query. 38 | 5. **Step 5: Verify the Answer**: Check if the predicted time range accurately captures the action described in the query. 39 | """, 40 | "occurrence": """You are a model designed to predict when specific events occur in a video based on a query sentence. Your task is to verify whether the event described in the query occurs in the given moment of the video. 41 | ### Chain-of-Thought Reasoning: 42 | 1. **Step 1: Verify the Event in the Predicted Time Range**: Analyze the video features from the predicted start time to the end time. Determine if the event described in the query occurs within this time range. 43 | - Example: For the query "The person is cooking," check for visual patterns such as a stove or kitchen utensils during the predicted moment. 44 | 2. **Step 2: Answer the Verification Question**: Respond to the question: 45 | - **"Is the event '{event}' present from {start_time} to {end_time} seconds in the video?"** 46 | - Example: "Is the event 'The person is cooking' present from 30.0 to 40.0 seconds in the video?" 47 | - If find the event in the given moment, your answer should be "Yes.", if it does happen in the given moment, your answer should be "No.". 48 | """, 49 | "compositional": """You are a model designed to analyze the compositional elements of an event in a video. Your task is to verify whether each compositional element occurs during the given moment in the video based on the specific question you receive. Instead of analyzing the entire event at once, you will answer questions about individual components of the scene. 50 | ### Chain-of-Thought Reasoning: 51 | 1. **Step 1: Analyze the Video Features for the Specific Element**: Analyze the video features from the start time to the end time. Look for visual evidence of the specific compositional element described in the question. 52 | 2. **Step 2: Answer the Compositional Question**: Respond to the question: 53 | - Example: "Is there a young girl from 0.0 to 5.0 seconds in the video?" 54 | - If you find a young girl in the given video moment, your answer should be "Yes.". If it is not present, your answer should be "No.". 55 | """, 56 | } -------------------------------------------------------------------------------- /timechat/datasets/datasets/webvid_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from timechat.datasets.datasets.base_dataset import BaseDataset 10 | from timechat.datasets.datasets.caption_datasets import CaptionDataset 11 | import pandas as pd 12 | import decord 13 | from decord import VideoReader 14 | import random 15 | import torch 16 | from torch.utils.data.dataloader import default_collate 17 | class WebvidDataset(BaseDataset): 18 | def __init__(self, vis_processor, text_processor, vis_root, ann_root): 19 | """ 20 | vis_root (string): Root directory of video (e.g. webvid_eval/video/) 21 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) 22 | split (string): val or test 23 | """ 24 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 25 | 26 | 27 | # 读取一个路径下所有的 28 | 29 | ts_df = [] 30 | for file_name in os.listdir(ann_root): 31 | if file_name.endswith('.csv'): 32 | df = pd.read_csv(os.path.join(ann_root, file_name)) 33 | ts_df.append(df) 34 | 35 | merged_df = pd.concat(ts_df) 36 | self.annotation = merged_df 37 | self.vis_root = vis_root 38 | self.resize_size = 224 39 | self.num_frm = 8 40 | self.frm_sampling_strategy = 'headtail' 41 | 42 | def _get_video_path(self, sample): 43 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 44 | full_video_fp = os.path.join(self.vis_root, rel_video_fp) 45 | return full_video_fp 46 | 47 | def __getitem__(self, index): 48 | num_retries = 10 # skip error videos 49 | for _ in range(num_retries): 50 | sample = self.annotation.iloc[index] 51 | sample_dict = sample.to_dict() 52 | video_id = sample_dict['videoid'] 53 | 54 | if 'name' in sample_dict.keys(): 55 | text = sample_dict['name'].strip() 56 | else: 57 | raise NotImplementedError("Un-supported text annotation format.") 58 | 59 | # fetch video 60 | video_path = self._get_video_path(sample_dict) 61 | # if os.path.exists(video_path): 62 | try: 63 | video = self.vis_processor(video_path) 64 | except: 65 | print(f"Failed to load examples with video: {video_path}. " 66 | f"Will randomly sample an example as a replacement.") 67 | index = random.randint(0, len(self) - 1) 68 | continue 69 | caption = self.text_processor(text) 70 | 71 | # print(video.size()) 72 | if video is None or caption is None \ 73 | or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]): 74 | print(f"Failed to load examples with video: {video_path}. " 75 | f"Will randomly sample an example as a replacement.") 76 | index = random.randint(0, len(self) - 1) 77 | continue 78 | else: 79 | break 80 | else: 81 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") 82 | # "image_id" is kept to stay compatible with the COCO evaluation format 83 | return { 84 | "image": video, 85 | "text_input": caption, 86 | "type":'video', 87 | } 88 | 89 | def __len__(self): 90 | return len(self.annotation) 91 | 92 | # def collater(self, samples): 93 | # new_result = {} 94 | # new_result['image'] = default_collate( [sample["image"] for sample in samples]) 95 | # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples]) 96 | # return new_result 97 | 98 | class WebvidDatasetEvalDataset(BaseDataset): 99 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 100 | """ 101 | vis_root (string): Root directory of images (e.g. coco/images/) 102 | ann_root (string): directory to store the annotation file 103 | split (string): val or test 104 | """ 105 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 106 | 107 | def __getitem__(self, index): 108 | 109 | ann = self.annotation[index] 110 | 111 | vname = ann["video"] 112 | video_path = os.path.join(self.vis_root, vname) 113 | 114 | video = self.vis_processor(video_path) 115 | 116 | return { 117 | "video": video, 118 | "image_id": ann["image_id"], 119 | "instance_id": ann["instance_id"], 120 | } 121 | 122 | 123 | -------------------------------------------------------------------------------- /timechat/processors/transforms_video.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | 10 | import numbers 11 | import random 12 | 13 | from torchvision.transforms import ( 14 | RandomCrop, 15 | RandomResizedCrop, 16 | ) 17 | 18 | import timechat.processors.functional_video as F 19 | 20 | 21 | __all__ = [ 22 | "RandomCropVideo", 23 | "RandomResizedCropVideo", 24 | "CenterCropVideo", 25 | "NormalizeVideo", 26 | "ToTensorVideo", 27 | "RandomHorizontalFlipVideo", 28 | ] 29 | 30 | 31 | class RandomCropVideo(RandomCrop): 32 | def __init__(self, size): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | 38 | def __call__(self, clip): 39 | """ 40 | Args: 41 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 42 | Returns: 43 | torch.tensor: randomly cropped/resized video clip. 44 | size is (C, T, OH, OW) 45 | """ 46 | i, j, h, w = self.get_params(clip, self.size) 47 | return F.crop(clip, i, j, h, w) 48 | 49 | def __repr__(self) -> str: 50 | return f"{self.__class__.__name__}(size={self.size})" 51 | 52 | 53 | class RandomResizedCropVideo(RandomResizedCrop): 54 | def __init__( 55 | self, 56 | size, 57 | scale=(0.08, 1.0), 58 | ratio=(3.0 / 4.0, 4.0 / 3.0), 59 | interpolation_mode="bilinear", 60 | ): 61 | if isinstance(size, tuple): 62 | if len(size) != 2: 63 | raise ValueError( 64 | f"size should be tuple (height, width), instead got {size}" 65 | ) 66 | self.size = size 67 | else: 68 | self.size = (size, size) 69 | 70 | self.interpolation_mode = interpolation_mode 71 | self.scale = scale 72 | self.ratio = ratio 73 | 74 | def __call__(self, clip): 75 | """ 76 | Args: 77 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 78 | Returns: 79 | torch.tensor: randomly cropped/resized video clip. 80 | size is (C, T, H, W) 81 | """ 82 | i, j, h, w = self.get_params(clip, self.scale, self.ratio) 83 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) 84 | 85 | def __repr__(self) -> str: 86 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})" 87 | 88 | 89 | class CenterCropVideo: 90 | def __init__(self, crop_size): 91 | if isinstance(crop_size, numbers.Number): 92 | self.crop_size = (int(crop_size), int(crop_size)) 93 | else: 94 | self.crop_size = crop_size 95 | 96 | def __call__(self, clip): 97 | """ 98 | Args: 99 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 100 | Returns: 101 | torch.tensor: central cropping of video clip. Size is 102 | (C, T, crop_size, crop_size) 103 | """ 104 | return F.center_crop(clip, self.crop_size) 105 | 106 | def __repr__(self) -> str: 107 | return f"{self.__class__.__name__}(crop_size={self.crop_size})" 108 | 109 | 110 | class NormalizeVideo: 111 | """ 112 | Normalize the video clip by mean subtraction and division by standard deviation 113 | Args: 114 | mean (3-tuple): pixel RGB mean 115 | std (3-tuple): pixel RGB standard deviation 116 | inplace (boolean): whether do in-place normalization 117 | """ 118 | 119 | def __init__(self, mean, std, inplace=False): 120 | self.mean = mean 121 | self.std = std 122 | self.inplace = inplace 123 | 124 | def __call__(self, clip): 125 | """ 126 | Args: 127 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) 128 | """ 129 | return F.normalize(clip, self.mean, self.std, self.inplace) 130 | 131 | def __repr__(self) -> str: 132 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 133 | 134 | 135 | class ToTensorVideo: 136 | """ 137 | Convert tensor data type from uint8 to float, divide value by 255.0 and 138 | permute the dimensions of clip tensor 139 | """ 140 | 141 | def __init__(self): 142 | pass 143 | 144 | def __call__(self, clip): 145 | """ 146 | Args: 147 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 148 | Return: 149 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 150 | """ 151 | return F.to_tensor(clip) 152 | 153 | def __repr__(self) -> str: 154 | return self.__class__.__name__ 155 | 156 | 157 | class RandomHorizontalFlipVideo: 158 | """ 159 | Flip the video clip along the horizonal direction with a given probability 160 | Args: 161 | p (float): probability of the clip being flipped. Default value is 0.5 162 | """ 163 | 164 | def __init__(self, p=0.5): 165 | self.p = p 166 | 167 | def __call__(self, clip): 168 | """ 169 | Args: 170 | clip (torch.tensor): Size is (C, T, H, W) 171 | Return: 172 | clip (torch.tensor): Size is (C, T, H, W) 173 | """ 174 | if random.random() < self.p: 175 | clip = F.hflip(clip) 176 | return clip 177 | 178 | def __repr__(self) -> str: 179 | return f"{self.__class__.__name__}(p={self.p})" 180 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from timechat.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class MultiIterLoader: 16 | """ 17 | A simple wrapper for iterating over multiple iterators. 18 | 19 | Args: 20 | loaders (List[Loader]): List of Iterator loaders. 21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | """ 23 | 24 | def __init__(self, loaders, ratios=None): 25 | # assert all loaders has __next__ method 26 | for loader in loaders: 27 | assert hasattr( 28 | loader, "__next__" 29 | ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | if ratios is None: 32 | ratios = [1.0] * len(loaders) 33 | else: 34 | assert len(ratios) == len(loaders) 35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | self.loaders = loaders 38 | self.ratios = ratios 39 | 40 | def __next__(self): 41 | # random sample from each loader by ratio 42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | return next(self.loaders[loader_idx]) 44 | 45 | 46 | class PrefetchLoader(object): 47 | """ 48 | Modified from https://github.com/ChenRocks/UNITER. 49 | 50 | overlap compute and cuda data transfer 51 | (copied and then modified from nvidia apex) 52 | """ 53 | 54 | def __init__(self, loader): 55 | self.loader = loader 56 | self.stream = torch.cuda.Stream() 57 | 58 | def __iter__(self): 59 | loader_it = iter(self.loader) 60 | self.preload(loader_it) 61 | batch = self.next(loader_it) 62 | while batch is not None: 63 | is_tuple = isinstance(batch, tuple) 64 | if is_tuple: 65 | task, batch = batch 66 | 67 | if is_tuple: 68 | yield task, batch 69 | else: 70 | yield batch 71 | batch = self.next(loader_it) 72 | 73 | def __len__(self): 74 | return len(self.loader) 75 | 76 | def preload(self, it): 77 | try: 78 | self.batch = next(it) 79 | except StopIteration: 80 | self.batch = None 81 | return 82 | # if record_stream() doesn't work, another option is to make sure 83 | # device inputs are created on the main stream. 84 | # self.next_input_gpu = torch.empty_like(self.next_input, 85 | # device='cuda') 86 | # self.next_target_gpu = torch.empty_like(self.next_target, 87 | # device='cuda') 88 | # Need to make sure the memory allocated for next_* is not still in use 89 | # by the main stream at the time we start copying to next_*: 90 | # self.stream.wait_stream(torch.cuda.current_stream()) 91 | with torch.cuda.stream(self.stream): 92 | self.batch = move_to_cuda(self.batch) 93 | # more code for the alternative if record_stream() doesn't work: 94 | # copy_ will record the use of the pinned source tensor in this 95 | # side stream. 96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 98 | # self.next_input = self.next_input_gpu 99 | # self.next_target = self.next_target_gpu 100 | 101 | def next(self, it): 102 | torch.cuda.current_stream().wait_stream(self.stream) 103 | batch = self.batch 104 | if batch is not None: 105 | record_cuda_stream(batch) 106 | self.preload(it) 107 | return batch 108 | 109 | def __getattr__(self, name): 110 | method = self.loader.__getattribute__(name) 111 | return method 112 | 113 | 114 | def record_cuda_stream(batch): 115 | if isinstance(batch, torch.Tensor): 116 | batch.record_stream(torch.cuda.current_stream()) 117 | elif isinstance(batch, list) or isinstance(batch, tuple): 118 | for t in batch: 119 | record_cuda_stream(t) 120 | elif isinstance(batch, dict): 121 | for t in batch.values(): 122 | record_cuda_stream(t) 123 | else: 124 | pass 125 | 126 | 127 | class IterLoader: 128 | """ 129 | A wrapper to convert DataLoader as an infinite iterator. 130 | 131 | Modified from: 132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 133 | """ 134 | 135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 136 | self._dataloader = dataloader 137 | self.iter_loader = iter(self._dataloader) 138 | self._use_distributed = use_distributed 139 | self._epoch = 0 140 | 141 | @property 142 | def epoch(self) -> int: 143 | return self._epoch 144 | 145 | def __next__(self): 146 | try: 147 | data = next(self.iter_loader) 148 | except StopIteration: 149 | self._epoch += 1 150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 151 | self._dataloader.sampler.set_epoch(self._epoch) 152 | time.sleep(2) # Prevent possible deadlock during epoch transition 153 | self.iter_loader = iter(self._dataloader) 154 | data = next(self.iter_loader) 155 | 156 | return data 157 | 158 | def __iter__(self): 159 | return self 160 | 161 | def __len__(self): 162 | return len(self._dataloader) 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Consistency of Video Large Language Models in Temporal Comprehension 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2411.12951-b31b1b.svg)](https://arxiv.org/abs/2411.12951) 4 | 5 | 6 | 7 | ## News 8 | - [x] **[2025.03.25]** Evaluation Codes have been released. 9 | - [x] **[2025.02.27]** Our paper has been accepted by CVPR 2025! 🎉 10 | - [x] **[2025.01.15]** We are excited to share that our evaluation datasets, Charades-CON and ActivityNet-CON, are now available on Hugging Face! 🎉 Additionally, the training annotations for VTune have also been released. 11 | - [x] **[2025.01.14]** We have released our four checkpoints using VTune: [VideoLLaMA-7B-Charades-VTune](https://huggingface.co/mjjung/VideoLLaMA-7B-Charades-VTune), [VideoLLaMA-7B-ActvityNet-VTune](https://huggingface.co/mjjung/VideoLLaMA-7B-ActivityNet-VTune), [TimeChat-7B-Charades-VTune](https://huggingface.co/mjjung/TimeChat-7B-Charades-VTune), [TimeChat-7B-ActvityNet-VTune](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-VTune). Additionally, checkpoints with naive fine-tuning: [VideoLLaMA-7B-Charades-FT](https://huggingface.co/mjjung/VideoLLAMA-7B-Charades-FT), [VideoLLaMA-7B-ActvityNet-FT](https://huggingface.co/mjjung/VideoLLaMA-7B-ActivityNet-FT), [TimeChat-7B-ActivityNet-FT](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-FT) have been released. 12 | - [x] **[2024.11.20]** Our paper has been released on arXiv. 13 | 14 | ## Introduction 15 | ![image](https://github.com/user-attachments/assets/cc7ba1a6-a7b5-4c87-88b5-471632fabbd1) 16 | - We study the model’s consistency in temporal comprehension by assessing whether its responses align with the initial grounding, using dedicated probes and datasets. We specifically focus on video temporal grounding, where the task involves identifying timestamps in a video that correspond to language queries. 17 | 18 | ## Download 19 | You can download the complete annotations for consistency evaluation from [Hugging Face](https://huggingface.co/datasets/mjjung/Consistency-Evaluation-for-Video-LLMs). The source videos are available via the following links: 20 | 21 | - [Charades-STA](https://prior.allenai.org/projects/charades) 22 | - [ActivityNet-Captions](https://cs.stanford.edu/people/ranjaykrishna/densevid/) 23 | 24 | ## Evaluation 25 | Before starting the evaluation, make sure you have prepared the annotations and videos. You should also check the configuration of the Video-LLMs. Install the necessary dependencies using conda and pip for your model. Additionally, you may run `utils/shift_video.py` with the right paths to prepare shifted videos. 26 | Here, we provide an example with the model `TimeChat`. We will include additional baseline models in the future. 27 | 28 | To run the evaluation, use the following command: 29 | 30 | ``` 31 | python run.py --model_type TimeChat --dset_name activitynet --task consistency 32 | ```` 33 | `dset_name` refers to the test dataset, which can be either `charades` or `activitynet`. `task` refers to the evaluation task: either `consistency` or `grounding`. If set to `grounding`, the evaluation will be performed on the original test set. 34 | You can also use the `--debug` flag before performing the actual evaluation to verify your configuration settings. 35 | 36 | Once the evaluation is complete, the performance will be reported in `consistency_eval_results.json`, and you can check the model's output in `consistency_predictions.jsonl`. 37 | 38 | 39 | ## Training 40 | 41 | For training, please download the training annotations for each dataset from Hugging Face. 42 | 43 | ### ⚠️ Important Note 44 | 45 | The previously uploaded **VTune** dataset file `charades_for_VTune.json` partially includes videos from the Charades-Con test split. 46 | The updated file `charades_train_v2.json` and `charades_for_VTune_v2.json` excludes these overlapping videos. 47 | The corresponding hyperparameters should follow the table below. Note that neither dataset includes test videos from Charades-STA (the original one). 48 | We apologize for any inconvenience caused. 49 | 50 | | Dataset Name | iters_per_epochs | warmup_steps | 51 | |--------------------------|------------------|--------------| 52 | | `charades_for_VTune` | 24,811 | 14,916 | 53 | | `charades_for_VTune2` | 22,311 | 13,386 | 54 | 55 | The performance of TimeChat trained with `charades_for_VTune2`: 56 | 57 | | **Method** | **Ground** | **R-Ground** | **S-Ground** | **H-Verify** | **C-Verify** | 58 | |:-----------|:----------:|:------------:|:------------:|:------------:|:------------:| 59 | | SFT | 47.2 | 43.4 (91.8) | 15.0 (31.9) | 24.3 (51.5) | 24.0 (50.9) | 60 | | VTune | 52.0 | 47.4 (91.2) | 23.5 (45.2) | 31.5 (60.5) | 27.5 (52.9) | 61 | 62 | 63 | ## Checkpoints 64 | 65 | We provide the checkpoints for each dataset using the links below: 66 | - [Charades-STA](https://huggingface.co/mjjung/TimeChat-7B-Charades-VTune) 67 | - [ActivityNet-Captions](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-VTune) 68 | 69 | Then, use the following command: 70 | ``` 71 | python run.py --model_type TimeChat --dset_name activitynet --fine_tuned --task consistency 72 | ``` 73 | In the above example command, including the `fine_tuned` option will automatically switch the checkpoint path `ckpt` to `activitynet_ckpt` in `timechat/eval_configs/timechat.yaml`. 74 | 75 | ## Citation 76 | If you find our paper useful, please consider citing our paper. 77 | ```BibTeX 78 | @inproceedings{jung2025consistency, 79 | title={On the consistency of video large language models in temporal comprehension}, 80 | author={Jung, Minjoon and Xiao, Junbin and Zhang, Byoung-Tak and Yao, Angela}, 81 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 82 | pages={13713--13722}, 83 | year={2025} 84 | } 85 | ``` 86 | ## Acknowledgement 87 | We appreciate for the following awesome Video-LLMs: 88 | - [TimeChat](https://github.com/RenShuhuai-Andy/TimeChat) 89 | - [VTimeLLM](https://github.com/huangb23/VTimeLLM) 90 | - [VTG-LLM](https://github.com/gyxxyg/VTG-LLM) 91 | - [Video-LLaMA](https://github.com/DAMO-NLP-SG/Video-LLaMA) 92 | - [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2) 93 | - [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA) 94 | - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT) 95 | - [VideoChat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) 96 | -------------------------------------------------------------------------------- /timechat/datasets/builders/instruct_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from timechat.common.registry import registry 6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from timechat.datasets.datasets.laion_dataset import LaionDataset 8 | from timechat.datasets.datasets.llava_instruct_dataset import Instruct_Dataset 9 | from timechat.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset 10 | 11 | 12 | @registry.register_builder("image_instruct") 13 | class Image_Instruct_Builder(BaseDatasetBuilder): 14 | train_dataset_cls = Instruct_Dataset 15 | 16 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"} 17 | 18 | def _download_ann(self): 19 | pass 20 | 21 | def _download_vis(self): 22 | pass 23 | 24 | def build(self): 25 | self.build_processors() 26 | datasets = dict() 27 | split = "train" 28 | 29 | build_info = self.config.build_info 30 | dataset_cls = self.train_dataset_cls 31 | if self.config.num_video_query_token: 32 | num_video_query_token = self.config.num_video_query_token 33 | else: 34 | num_video_query_token = 32 35 | 36 | if self.config.tokenizer_name: 37 | tokenizer_name = self.config.tokenizer_name 38 | else: 39 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/' 40 | 41 | model_type = self.config.model_type if self.config.model_type else 'vicuna' 42 | 43 | datasets[split] = dataset_cls( 44 | vis_processor=self.vis_processors[split], 45 | text_processor=self.text_processors[split], 46 | vis_root=build_info.videos_dir, 47 | ann_root=build_info.anno_dir, 48 | num_video_query_token=num_video_query_token, 49 | tokenizer_name=tokenizer_name, 50 | data_type=self.config.data_type, 51 | model_type=model_type, 52 | ) 53 | 54 | return datasets 55 | 56 | 57 | @registry.register_builder("video_instruct") 58 | class Video_Instruct_Builder(BaseDatasetBuilder): 59 | train_dataset_cls = Video_Instruct_Dataset 60 | 61 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"} 62 | 63 | def _download_ann(self): 64 | pass 65 | 66 | def _download_vis(self): 67 | pass 68 | 69 | def build(self): 70 | self.build_processors() 71 | datasets = dict() 72 | split = "train" 73 | 74 | build_info = self.config.build_info 75 | dataset_cls = self.train_dataset_cls 76 | if self.config.num_video_query_token: 77 | num_video_query_token = self.config.num_video_query_token 78 | else: 79 | num_video_query_token = 32 80 | 81 | if self.config.tokenizer_name: 82 | tokenizer_name = self.config.tokenizer_name 83 | else: 84 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/' 85 | 86 | model_type = self.config.model_type if self.config.model_type else 'vicuna' 87 | num_frm = self.config.num_frm if self.config.num_frm else 8 88 | sample_type = self.config.sample_type if self.config.sample_type else 'uniform' 89 | max_txt_len = self.config.max_txt_len if self.config.max_txt_len else 512 90 | stride = self.config.stride if self.config.stride else 0 91 | 92 | datasets[split] = dataset_cls( 93 | vis_processor=self.vis_processors[split], 94 | text_processor=self.text_processors[split], 95 | vis_root=build_info.videos_dir, 96 | ann_root=build_info.anno_dir, 97 | num_video_query_token=num_video_query_token, 98 | tokenizer_name=tokenizer_name, 99 | data_type=self.config.data_type, 100 | model_type=model_type, 101 | num_frm=num_frm, 102 | sample_type=sample_type, 103 | max_txt_len=max_txt_len, 104 | stride=stride 105 | ) 106 | 107 | return datasets 108 | 109 | 110 | @registry.register_builder("webvid_instruct") 111 | class WebvidInstruct_Builder(Video_Instruct_Builder): 112 | train_dataset_cls = Video_Instruct_Dataset 113 | 114 | DATASET_CONFIG_DICT = { 115 | "default": "configs/datasets/instruct/webvid_instruct.yaml", 116 | } 117 | 118 | 119 | @registry.register_builder("webvid_instruct_zh") 120 | class WebvidInstruct_zh_Builder(Video_Instruct_Builder): 121 | train_dataset_cls = Video_Instruct_Dataset 122 | 123 | DATASET_CONFIG_DICT = { 124 | "default": "configs/datasets/instruct/webvid_instruct.yaml", 125 | } 126 | 127 | 128 | @registry.register_builder("llava_instruct") 129 | class LlavaInstruct_Builder(Image_Instruct_Builder): 130 | train_dataset_cls = Instruct_Dataset 131 | 132 | DATASET_CONFIG_DICT = { 133 | "default": "configs/datasets/instruct/llava_instruct.yaml", 134 | } 135 | 136 | 137 | @registry.register_builder("youcook2_instruct") 138 | class Youcook2Instruct_Builder(Video_Instruct_Builder): 139 | train_dataset_cls = Video_Instruct_Dataset 140 | 141 | DATASET_CONFIG_DICT = { 142 | "default": "configs/datasets/instruct/youcook2_instruct.yaml", 143 | } 144 | 145 | 146 | @registry.register_builder("time_instruct") 147 | class TimeInstruct_Builder(Video_Instruct_Builder): 148 | train_dataset_cls = Video_Instruct_Dataset 149 | 150 | DATASET_CONFIG_DICT = { 151 | "default": "configs/datasets/instruct/time_instruct.yaml", 152 | } 153 | 154 | 155 | @registry.register_builder("valley72k_instruct") 156 | class Valley72kInstruct_Builder(Video_Instruct_Builder): 157 | train_dataset_cls = Video_Instruct_Dataset 158 | 159 | DATASET_CONFIG_DICT = { 160 | "default": "configs/datasets/instruct/valley72k_instruct.yaml", 161 | } 162 | 163 | 164 | @registry.register_builder("qvhighlights_instruct") 165 | class QVhighlightsInstruct_Builder(Video_Instruct_Builder): 166 | train_dataset_cls = Video_Instruct_Dataset 167 | 168 | DATASET_CONFIG_DICT = { 169 | "default": "configs/datasets/instruct/qvhighlights_instruct.yaml", 170 | } 171 | 172 | 173 | @registry.register_builder("charades_instruct") 174 | class CharadesInstruct_Builder(Video_Instruct_Builder): 175 | train_dataset_cls = Video_Instruct_Dataset 176 | 177 | DATASET_CONFIG_DICT = { 178 | "default": "configs/datasets/instruct/charades_instruct.yaml", 179 | } 180 | -------------------------------------------------------------------------------- /timechat/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import torch 11 | from omegaconf import OmegaConf 12 | 13 | from timechat.common.registry import registry 14 | from timechat.models.base_model import BaseModel 15 | from timechat.models.blip2 import Blip2Base 16 | from timechat.models.timechat import TimeChat 17 | from timechat.processors.base_processor import BaseProcessor 18 | 19 | 20 | __all__ = [ 21 | "load_model", 22 | "BaseModel", 23 | "Blip2Base", 24 | "TimeChat" 25 | ] 26 | 27 | 28 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 29 | """ 30 | Load supported models. 31 | 32 | To list all available models and types in registry: 33 | >>> from timechat.models import model_zoo 34 | >>> print(model_zoo) 35 | 36 | Args: 37 | name (str): name of the model. 38 | model_type (str): type of the model. 39 | is_eval (bool): whether the model is in eval mode. Default: False. 40 | device (str): device to use. Default: "cpu". 41 | checkpoint (str): path or to checkpoint. Default: None. 42 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 43 | 44 | Returns: 45 | model (torch.nn.Module): model. 46 | """ 47 | 48 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 49 | 50 | if checkpoint is not None: 51 | model.load_checkpoint(checkpoint) 52 | 53 | if is_eval: 54 | model.eval() 55 | 56 | if device == "cpu": 57 | model = model.float() 58 | 59 | return model.to(device) 60 | 61 | 62 | def load_preprocess(config): 63 | """ 64 | Load preprocessor configs and construct preprocessors. 65 | 66 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 67 | 68 | Args: 69 | config (dict): preprocessor configs. 70 | 71 | Returns: 72 | vis_processors (dict): preprocessors for visual inputs. 73 | txt_processors (dict): preprocessors for text inputs. 74 | 75 | Key is "train" or "eval" for processors used in training and evaluation respectively. 76 | """ 77 | 78 | def _build_proc_from_cfg(cfg): 79 | return ( 80 | registry.get_processor_class(cfg.name).from_config(cfg) 81 | if cfg is not None 82 | else BaseProcessor() 83 | ) 84 | 85 | vis_processors = dict() 86 | txt_processors = dict() 87 | 88 | vis_proc_cfg = config.get("vis_processor") 89 | txt_proc_cfg = config.get("text_processor") 90 | 91 | if vis_proc_cfg is not None: 92 | vis_train_cfg = vis_proc_cfg.get("train") 93 | vis_eval_cfg = vis_proc_cfg.get("eval") 94 | else: 95 | vis_train_cfg = None 96 | vis_eval_cfg = None 97 | 98 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 99 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 100 | 101 | if txt_proc_cfg is not None: 102 | txt_train_cfg = txt_proc_cfg.get("train") 103 | txt_eval_cfg = txt_proc_cfg.get("eval") 104 | else: 105 | txt_train_cfg = None 106 | txt_eval_cfg = None 107 | 108 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 109 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 110 | 111 | return vis_processors, txt_processors 112 | 113 | 114 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 115 | """ 116 | Load model and its related preprocessors. 117 | 118 | List all available models and types in registry: 119 | >>> from timechat.models import model_zoo 120 | >>> print(model_zoo) 121 | 122 | Args: 123 | name (str): name of the model. 124 | model_type (str): type of the model. 125 | is_eval (bool): whether the model is in eval mode. Default: False. 126 | device (str): device to use. Default: "cpu". 127 | 128 | Returns: 129 | model (torch.nn.Module): model. 130 | vis_processors (dict): preprocessors for visual inputs. 131 | txt_processors (dict): preprocessors for text inputs. 132 | """ 133 | model_cls = registry.get_model_class(name) 134 | 135 | # load model 136 | model = model_cls.from_pretrained(model_type=model_type) 137 | 138 | if is_eval: 139 | model.eval() 140 | 141 | # load preprocess 142 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 143 | if cfg is not None: 144 | preprocess_cfg = cfg.preprocess 145 | 146 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 147 | else: 148 | vis_processors, txt_processors = None, None 149 | logging.info( 150 | f"""No default preprocess for model {name} ({model_type}). 151 | This can happen if the model is not finetuned on downstream datasets, 152 | or it is not intended for direct use without finetuning. 153 | """ 154 | ) 155 | 156 | if device == "cpu" or device == torch.device("cpu"): 157 | model = model.float() 158 | 159 | return model.to(device), vis_processors, txt_processors 160 | 161 | 162 | class ModelZoo: 163 | """ 164 | A utility class to create string representation of available model architectures and types. 165 | 166 | >>> from timechat.models import model_zoo 167 | >>> # list all available models 168 | >>> print(model_zoo) 169 | >>> # show total number of models 170 | >>> print(len(model_zoo)) 171 | """ 172 | 173 | def __init__(self) -> None: 174 | self.model_zoo = { 175 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 176 | for k, v in registry.mapping["model_name_mapping"].items() 177 | } 178 | 179 | def __str__(self) -> str: 180 | return ( 181 | "=" * 50 182 | + "\n" 183 | + f"{'Architectures':<30} {'Types'}\n" 184 | + "=" * 50 185 | + "\n" 186 | + "\n".join( 187 | [ 188 | f"{name:<30} {', '.join(types)}" 189 | for name, types in self.model_zoo.items() 190 | ] 191 | ) 192 | ) 193 | 194 | def __iter__(self): 195 | return iter(self.model_zoo.items()) 196 | 197 | def __len__(self): 198 | return sum([len(v) for v in self.model_zoo.values()]) 199 | 200 | 201 | model_zoo = ModelZoo() 202 | -------------------------------------------------------------------------------- /timechat/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from timechat.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def global_avg(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = "" 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt="{avg:.4f}") 129 | data_time = SmoothedValue(fmt="{avg:.4f}") 130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 131 | log_msg = [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append("max mem: {memory:.0f}") 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print( 152 | log_msg.format( 153 | i, 154 | len(iterable), 155 | eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), 158 | data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB, 160 | ) 161 | ) 162 | else: 163 | print( 164 | log_msg.format( 165 | i, 166 | len(iterable), 167 | eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), 170 | data=str(data_time), 171 | ) 172 | ) 173 | i += 1 174 | end = time.time() 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print( 178 | "{} Total time: {} ({:.4f} s / it)".format( 179 | header, total_time_str, total_time / len(iterable) 180 | ) 181 | ) 182 | 183 | 184 | class AttrDict(dict): 185 | def __init__(self, *args, **kwargs): 186 | super(AttrDict, self).__init__(*args, **kwargs) 187 | self.__dict__ = self 188 | 189 | 190 | def setup_logger(): 191 | logging.basicConfig( 192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 193 | format="%(asctime)s [%(levelname)s] %(message)s", 194 | handlers=[logging.StreamHandler()], 195 | ) 196 | -------------------------------------------------------------------------------- /timechat/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from timechat.common.registry import registry 26 | from timechat.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | 114 | # reorganize by split 115 | for _, dataset in datasets.items(): 116 | for split_name, dataset_split in dataset.items(): 117 | if split_name not in reorg_datasets: 118 | reorg_datasets[split_name] = [dataset_split] 119 | else: 120 | reorg_datasets[split_name].append(dataset_split) 121 | 122 | return reorg_datasets 123 | 124 | 125 | def concat_datasets(datasets): 126 | """ 127 | Concatenates multiple datasets into a single dataset. 128 | 129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 130 | generic IterableDataset because it requires creating separate samplers. 131 | 132 | Now only supports conctenating training datasets and assuming validation and testing 133 | have only a single dataset. This is because metrics should not be computed on the concatenated 134 | datasets. 135 | 136 | Args: 137 | datasets: dict of torch.utils.data.Dataset objects by split. 138 | 139 | Returns: 140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 141 | "val" and "test" remain the same. 142 | 143 | If the input training datasets contain both map-style and DataPipeline datasets, returns 144 | a tuple, where the first element is a concatenated map-style dataset and the second 145 | element is a chained DataPipeline dataset. 146 | 147 | """ 148 | # concatenate datasets in the same split 149 | for split_name in datasets: 150 | if split_name != "train": 151 | assert ( 152 | len(datasets[split_name]) == 1 153 | ), "Do not support multiple {} datasets.".format(split_name) 154 | datasets[split_name] = datasets[split_name][0] 155 | else: 156 | iterable_datasets, map_datasets = [], [] 157 | for dataset in datasets[split_name]: 158 | if isinstance(dataset, wds.DataPipeline): 159 | logging.info( 160 | "Dataset {} is IterableDataset, can't be concatenated.".format( 161 | dataset 162 | ) 163 | ) 164 | iterable_datasets.append(dataset) 165 | elif isinstance(dataset, IterableDataset): 166 | raise NotImplementedError( 167 | "Do not support concatenation of generic IterableDataset." 168 | ) 169 | else: 170 | map_datasets.append(dataset) 171 | 172 | # if len(iterable_datasets) > 0: 173 | # concatenate map-style datasets and iterable-style datasets separately 174 | if len(iterable_datasets) > 1: 175 | chained_datasets = ( 176 | ChainDataset(iterable_datasets) 177 | ) 178 | elif len(iterable_datasets) == 1: 179 | chained_datasets = iterable_datasets[0] 180 | else: 181 | chained_datasets = None 182 | 183 | concat_datasets = ( 184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 185 | ) 186 | 187 | train_datasets = concat_datasets, chained_datasets 188 | train_datasets = tuple([x for x in train_datasets if x is not None]) 189 | train_datasets = ( 190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 191 | ) 192 | 193 | datasets[split_name] = train_datasets 194 | 195 | return datasets 196 | 197 | -------------------------------------------------------------------------------- /timechat/models/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2023, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | import contextlib 9 | import logging 10 | import os 11 | import time 12 | import datetime 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.distributed as dist 17 | import torch.nn.functional as F 18 | 19 | import timechat.common.dist_utils as dist_utils 20 | from timechat.common.dist_utils import download_cached_file 21 | from timechat.common.utils import is_url 22 | from timechat.common.logger import MetricLogger 23 | from timechat.models.base_model import BaseModel 24 | from timechat.models.Qformer import BertConfig, BertLMHeadModel 25 | from timechat.models.eva_vit import create_eva_vit_g 26 | from transformers import BertTokenizer 27 | 28 | 29 | class Blip2Base(BaseModel): 30 | @classmethod 31 | def init_tokenizer(cls): 32 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 33 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 34 | return tokenizer 35 | 36 | def maybe_autocast(self, dtype=torch.float16): 37 | # if on cpu, don't use autocast 38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 39 | enable_autocast = self.device != torch.device("cpu") 40 | 41 | if enable_autocast: 42 | return torch.cuda.amp.autocast(dtype=dtype) 43 | else: 44 | return contextlib.nullcontext() 45 | 46 | @classmethod 47 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 48 | encoder_config = BertConfig.from_pretrained("bert-base-uncased") 49 | encoder_config.encoder_width = vision_width 50 | # insert cross-attention layer every other block 51 | encoder_config.add_cross_attention = True 52 | encoder_config.cross_attention_freq = cross_attention_freq 53 | encoder_config.query_length = num_query_token 54 | Qformer = BertLMHeadModel(config=encoder_config) 55 | query_tokens = nn.Parameter( 56 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 57 | ) 58 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 59 | return Qformer, query_tokens 60 | 61 | @classmethod 62 | def init_vision_encoder( 63 | cls, url_or_filename, img_size, drop_path_rate, use_grad_checkpoint, precision 64 | ): 65 | assert "eva_vit_g" in url_or_filename, "vit model must be eva_vit_g for current version of MiniGPT-4" 66 | visual_encoder = create_eva_vit_g( 67 | url_or_filename, img_size, drop_path_rate, use_grad_checkpoint, precision 68 | ) 69 | 70 | ln_vision = LayerNorm(visual_encoder.num_features) 71 | return visual_encoder, ln_vision 72 | 73 | def load_from_pretrained(self, url_or_filename): 74 | if is_url(url_or_filename): 75 | cached_file = download_cached_file( 76 | url_or_filename, check_hash=False, progress=True 77 | ) 78 | checkpoint = torch.load(cached_file, map_location="cpu") 79 | elif os.path.isfile(url_or_filename): 80 | checkpoint = torch.load(url_or_filename, map_location="cpu") 81 | else: 82 | raise RuntimeError("checkpoint url or path is invalid") 83 | 84 | state_dict = checkpoint["model"] 85 | 86 | msg = self.load_state_dict(state_dict, strict=False) 87 | 88 | # logging.info("Missing keys {}".format(msg.missing_keys)) 89 | logging.info("load checkpoint from %s" % url_or_filename) 90 | 91 | return msg 92 | 93 | 94 | def disabled_train(self, mode=True): 95 | """Overwrite model.train with this function to make sure train/eval mode 96 | does not change anymore.""" 97 | return self 98 | 99 | 100 | class LayerNorm(nn.LayerNorm): 101 | """Subclass torch's LayerNorm to handle fp16.""" 102 | 103 | def forward(self, x: torch.Tensor): 104 | orig_type = x.dtype 105 | ret = super().forward(x.type(torch.float32)) 106 | return ret.type(orig_type) 107 | 108 | 109 | def compute_sim_matrix(model, data_loader, **kwargs): 110 | k_test = kwargs.pop("k_test") 111 | 112 | metric_logger = MetricLogger(delimiter=" ") 113 | header = "Evaluation:" 114 | 115 | logging.info("Computing features for evaluation...") 116 | start_time = time.time() 117 | 118 | texts = data_loader.dataset.text 119 | num_text = len(texts) 120 | text_bs = 256 121 | text_ids = [] 122 | text_embeds = [] 123 | text_atts = [] 124 | for i in range(0, num_text, text_bs): 125 | text = texts[i : min(num_text, i + text_bs)] 126 | text_input = model.tokenizer( 127 | text, 128 | padding="max_length", 129 | truncation=True, 130 | max_length=35, 131 | return_tensors="pt", 132 | ).to(model.device) 133 | text_feat = model.forward_text(text_input) 134 | text_embed = F.normalize(model.text_proj(text_feat)) 135 | text_embeds.append(text_embed) 136 | text_ids.append(text_input.input_ids) 137 | text_atts.append(text_input.attention_mask) 138 | 139 | text_embeds = torch.cat(text_embeds, dim=0) 140 | text_ids = torch.cat(text_ids, dim=0) 141 | text_atts = torch.cat(text_atts, dim=0) 142 | 143 | vit_feats = [] 144 | image_embeds = [] 145 | for samples in data_loader: 146 | image = samples["image"] 147 | 148 | image = image.to(model.device) 149 | image_feat, vit_feat = model.forward_image(image) 150 | image_embed = model.vision_proj(image_feat) 151 | image_embed = F.normalize(image_embed, dim=-1) 152 | 153 | vit_feats.append(vit_feat.cpu()) 154 | image_embeds.append(image_embed) 155 | 156 | vit_feats = torch.cat(vit_feats, dim=0) 157 | image_embeds = torch.cat(image_embeds, dim=0) 158 | 159 | sims_matrix = [] 160 | for image_embed in image_embeds: 161 | sim_q2t = image_embed @ text_embeds.t() 162 | sim_i2t, _ = sim_q2t.max(0) 163 | sims_matrix.append(sim_i2t) 164 | sims_matrix = torch.stack(sims_matrix, dim=0) 165 | 166 | score_matrix_i2t = torch.full( 167 | (len(data_loader.dataset.image), len(texts)), -100.0 168 | ).to(model.device) 169 | 170 | num_tasks = dist_utils.get_world_size() 171 | rank = dist_utils.get_rank() 172 | step = sims_matrix.size(0) // num_tasks + 1 173 | start = rank * step 174 | end = min(sims_matrix.size(0), start + step) 175 | 176 | for i, sims in enumerate( 177 | metric_logger.log_every(sims_matrix[start:end], 50, header) 178 | ): 179 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 180 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 181 | score = model.compute_itm( 182 | image_inputs=image_inputs, 183 | text_ids=text_ids[topk_idx], 184 | text_atts=text_atts[topk_idx], 185 | ).float() 186 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 187 | 188 | sims_matrix = sims_matrix.t() 189 | score_matrix_t2i = torch.full( 190 | (len(texts), len(data_loader.dataset.image)), -100.0 191 | ).to(model.device) 192 | 193 | step = sims_matrix.size(0) // num_tasks + 1 194 | start = rank * step 195 | end = min(sims_matrix.size(0), start + step) 196 | 197 | for i, sims in enumerate( 198 | metric_logger.log_every(sims_matrix[start:end], 50, header) 199 | ): 200 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 201 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 202 | score = model.compute_itm( 203 | image_inputs=image_inputs, 204 | text_ids=text_ids[start + i].repeat(k_test, 1), 205 | text_atts=text_atts[start + i].repeat(k_test, 1), 206 | ).float() 207 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 208 | 209 | if dist_utils.is_dist_avail_and_initialized(): 210 | dist.barrier() 211 | torch.distributed.all_reduce( 212 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 213 | ) 214 | torch.distributed.all_reduce( 215 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 216 | ) 217 | 218 | total_time = time.time() - start_time 219 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 220 | logging.info("Evaluation time {}".format(total_time_str)) 221 | 222 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 223 | -------------------------------------------------------------------------------- /timechat/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from timechat.common.dist_utils import download_cached_file, is_dist_avail_and_initialized 16 | from timechat.common.utils import get_abs_path, is_url 17 | from omegaconf import OmegaConf 18 | 19 | 20 | class BaseModel(nn.Module): 21 | """Base class for models.""" 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | @property 27 | def device(self): 28 | return list(self.parameters())[0].device 29 | 30 | def load_checkpoint(self, url_or_filename): 31 | """ 32 | Load from a finetuned checkpoint. 33 | 34 | This should expect no mismatch in the model keys and the checkpoint keys. 35 | """ 36 | 37 | if is_url(url_or_filename): 38 | cached_file = download_cached_file( 39 | url_or_filename, check_hash=False, progress=True 40 | ) 41 | checkpoint = torch.load(cached_file, map_location="cpu") 42 | elif os.path.isfile(url_or_filename): 43 | checkpoint = torch.load(url_or_filename, map_location="cpu") 44 | else: 45 | raise RuntimeError("checkpoint url or path is invalid") 46 | 47 | if "model" in checkpoint.keys(): 48 | state_dict = checkpoint["model"] 49 | else: 50 | state_dict = checkpoint 51 | 52 | msg = self.load_state_dict(state_dict, strict=False) 53 | 54 | logging.info("Missing keys {}".format(msg.missing_keys)) 55 | logging.info("load checkpoint from %s" % url_or_filename) 56 | 57 | return msg 58 | 59 | @classmethod 60 | def from_pretrained(cls, model_type): 61 | """ 62 | Build a pretrained model from default configuration file, specified by model_type. 63 | 64 | Args: 65 | - model_type (str): model type, specifying architecture and checkpoints. 66 | 67 | Returns: 68 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 69 | """ 70 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 71 | model = cls.from_config(model_cfg) 72 | 73 | return model 74 | 75 | @classmethod 76 | def default_config_path(cls, model_type): 77 | assert ( 78 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 79 | ), "Unknown model type {}".format(model_type) 80 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 81 | 82 | def load_checkpoint_from_config(self, cfg, **kwargs): 83 | """ 84 | Load checkpoint as specified in the config file. 85 | 86 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 87 | When loading the pretrained model, each task-specific architecture may define their 88 | own load_from_pretrained() method. 89 | """ 90 | load_finetuned = cfg.get("load_finetuned", True) 91 | if load_finetuned: 92 | finetune_path = cfg.get("finetuned", None) 93 | assert ( 94 | finetune_path is not None 95 | ), "Found load_finetuned is True, but finetune_path is None." 96 | self.load_checkpoint(url_or_filename=finetune_path) 97 | else: 98 | # load pre-trained weights 99 | pretrain_path = cfg.get("pretrained", None) 100 | assert "Found load_finetuned is False, but pretrain_path is None." 101 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 102 | 103 | def before_evaluation(self, **kwargs): 104 | pass 105 | 106 | def show_n_params(self, return_str=True): 107 | tot = 0 108 | for p in self.parameters(): 109 | w = 1 110 | for x in p.shape: 111 | w *= x 112 | tot += w 113 | if return_str: 114 | if tot >= 1e6: 115 | return "{:.1f}M".format(tot / 1e6) 116 | else: 117 | return "{:.1f}K".format(tot / 1e3) 118 | else: 119 | return tot 120 | 121 | 122 | class BaseEncoder(nn.Module): 123 | """ 124 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 125 | """ 126 | 127 | def __init__(self): 128 | super().__init__() 129 | 130 | def forward_features(self, samples, **kwargs): 131 | raise NotImplementedError 132 | 133 | @property 134 | def device(self): 135 | return list(self.parameters())[0].device 136 | 137 | 138 | class SharedQueueMixin: 139 | @torch.no_grad() 140 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 141 | # gather keys before updating queue 142 | image_feats = concat_all_gather(image_feat) 143 | text_feats = concat_all_gather(text_feat) 144 | 145 | batch_size = image_feats.shape[0] 146 | 147 | ptr = int(self.queue_ptr) 148 | assert self.queue_size % batch_size == 0 # for simplicity 149 | 150 | # replace the keys at ptr (dequeue and enqueue) 151 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 152 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 153 | 154 | if idxs is not None: 155 | idxs = concat_all_gather(idxs) 156 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 157 | 158 | ptr = (ptr + batch_size) % self.queue_size # move pointer 159 | self.queue_ptr[0] = ptr 160 | 161 | 162 | class MomentumDistilationMixin: 163 | @torch.no_grad() 164 | def copy_params(self): 165 | for model_pair in self.model_pairs: 166 | for param, param_m in zip( 167 | model_pair[0].parameters(), model_pair[1].parameters() 168 | ): 169 | param_m.data.copy_(param.data) # initialize 170 | param_m.requires_grad = False # not update by gradient 171 | 172 | @torch.no_grad() 173 | def _momentum_update(self): 174 | for model_pair in self.model_pairs: 175 | for param, param_m in zip( 176 | model_pair[0].parameters(), model_pair[1].parameters() 177 | ): 178 | param_m.data = param_m.data * self.momentum + param.data * ( 179 | 1.0 - self.momentum 180 | ) 181 | 182 | 183 | class GatherLayer(torch.autograd.Function): 184 | """ 185 | Gather tensors from all workers with support for backward propagation: 186 | This implementation does not cut the gradients as torch.distributed.all_gather does. 187 | """ 188 | 189 | @staticmethod 190 | def forward(ctx, x): 191 | output = [ 192 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 193 | ] 194 | torch.distributed.all_gather(output, x) 195 | return tuple(output) 196 | 197 | @staticmethod 198 | def backward(ctx, *grads): 199 | all_gradients = torch.stack(grads) 200 | torch.distributed.all_reduce(all_gradients) 201 | return all_gradients[torch.distributed.get_rank()] 202 | 203 | 204 | def all_gather_with_grad(tensors): 205 | """ 206 | Performs all_gather operation on the provided tensors. 207 | Graph remains connected for backward grad computation. 208 | """ 209 | # Queue the gathered tensors 210 | world_size = torch.distributed.get_world_size() 211 | # There is no need for reduction in the single-proc case 212 | if world_size == 1: 213 | return tensors 214 | 215 | # tensor_all = GatherLayer.apply(tensors) 216 | tensor_all = GatherLayer.apply(tensors) 217 | 218 | return torch.cat(tensor_all, dim=0) 219 | 220 | 221 | @torch.no_grad() 222 | def concat_all_gather(tensor): 223 | """ 224 | Performs all_gather operation on the provided tensors. 225 | *** Warning ***: torch.distributed.all_gather has no gradient. 226 | """ 227 | # if use distributed training 228 | if not is_dist_avail_and_initialized(): 229 | return tensor 230 | 231 | tensors_gather = [ 232 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 233 | ] 234 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 235 | 236 | output = torch.cat(tensors_gather, dim=0) 237 | return output 238 | 239 | 240 | def tile(x, dim, n_tile): 241 | init_dim = x.size(dim) 242 | repeat_idx = [1] * x.dim() 243 | repeat_idx[dim] = n_tile 244 | x = x.repeat(*(repeat_idx)) 245 | order_index = torch.LongTensor( 246 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 247 | ) 248 | return torch.index_select(x, dim, order_index.to(x.device)) 249 | -------------------------------------------------------------------------------- /timechat/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import timechat.common.utils as utils 19 | from timechat.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from timechat.common.registry import registry 21 | from timechat.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class BaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | 31 | if cfg is None: 32 | # help to create datasets from default config. 33 | self.config = load_dataset_config(self.default_config_path()) 34 | elif isinstance(cfg, str): 35 | self.config = load_dataset_config(cfg) 36 | else: 37 | # when called from task.build_dataset() 38 | self.config = cfg 39 | 40 | self.data_type = self.config.data_type 41 | 42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 44 | 45 | def build_datasets(self): 46 | # download, split, etc... 47 | # only called on 1 GPU/TPU in distributed 48 | 49 | if is_main_process(): 50 | self._download_data() 51 | 52 | if is_dist_avail_and_initialized(): 53 | dist.barrier() 54 | 55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 56 | logging.info("Building datasets...") 57 | datasets = self.build() # dataset['train'/'val'/'test'] 58 | 59 | return datasets 60 | 61 | def build_processors(self): 62 | vis_proc_cfg = self.config.get("vis_processor") 63 | txt_proc_cfg = self.config.get("text_processor") 64 | 65 | if vis_proc_cfg is not None: 66 | vis_train_cfg = vis_proc_cfg.get("train") 67 | vis_eval_cfg = vis_proc_cfg.get("eval") 68 | 69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 71 | 72 | if txt_proc_cfg is not None: 73 | txt_train_cfg = txt_proc_cfg.get("train") 74 | txt_eval_cfg = txt_proc_cfg.get("eval") 75 | 76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 78 | 79 | @staticmethod 80 | def _build_proc_from_cfg(cfg): 81 | return ( 82 | registry.get_processor_class(cfg.name).from_config(cfg) 83 | if cfg is not None 84 | else None 85 | ) 86 | 87 | @classmethod 88 | def default_config_path(cls, type="default"): 89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 90 | 91 | def _download_data(self): 92 | self._download_ann() 93 | self._download_vis() 94 | 95 | def _download_ann(self): 96 | """ 97 | Download annotation files if necessary. 98 | All the vision-language datasets should have annotations of unified format. 99 | 100 | storage_path can be: 101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 103 | 104 | Local annotation paths should be relative. 105 | """ 106 | anns = self.config.build_info.annotations 107 | 108 | splits = anns.keys() 109 | 110 | cache_root = registry.get_path("cache_root") 111 | 112 | for split in splits: 113 | info = anns[split] 114 | 115 | urls, storage_paths = info.get("url", None), info.storage 116 | 117 | if isinstance(urls, str): 118 | urls = [urls] 119 | if isinstance(storage_paths, str): 120 | storage_paths = [storage_paths] 121 | 122 | assert len(urls) == len(storage_paths) 123 | 124 | for url_or_filename, storage_path in zip(urls, storage_paths): 125 | # if storage_path is relative, make it full by prefixing with cache_root. 126 | if not os.path.isabs(storage_path): 127 | storage_path = os.path.join(cache_root, storage_path) 128 | 129 | dirname = os.path.dirname(storage_path) 130 | if not os.path.exists(dirname): 131 | os.makedirs(dirname) 132 | 133 | if os.path.isfile(url_or_filename): 134 | src, dst = url_or_filename, storage_path 135 | if not os.path.exists(dst): 136 | shutil.copyfile(src=src, dst=dst) 137 | else: 138 | logging.info("Using existing file {}.".format(dst)) 139 | else: 140 | if os.path.isdir(storage_path): 141 | # if only dirname is provided, suffix with basename of URL. 142 | raise ValueError( 143 | "Expecting storage_path to be a file path, got directory {}".format( 144 | storage_path 145 | ) 146 | ) 147 | else: 148 | filename = os.path.basename(storage_path) 149 | 150 | download_url(url=url_or_filename, root=dirname, filename=filename) 151 | 152 | def _download_vis(self): 153 | 154 | storage_path = self.config.build_info.get(self.data_type).storage 155 | storage_path = utils.get_cache_path(storage_path) 156 | 157 | if not os.path.exists(storage_path): 158 | warnings.warn( 159 | f""" 160 | The specified path {storage_path} for visual inputs does not exist. 161 | Please provide a correct path to the visual inputs or 162 | refer to datasets/download_scripts/README.md for downloading instructions. 163 | """ 164 | ) 165 | 166 | def build(self): 167 | """ 168 | Create by split datasets inheriting torch.utils.data.Datasets. 169 | 170 | # build() can be dataset-specific. Overwrite to customize. 171 | """ 172 | self.build_processors() 173 | 174 | build_info = self.config.build_info 175 | 176 | ann_info = build_info.annotations 177 | vis_info = build_info.get(self.data_type) 178 | 179 | datasets = dict() 180 | for split in ann_info.keys(): 181 | if split not in ["train", "val", "test"]: 182 | continue 183 | 184 | is_train = split == "train" 185 | 186 | # processors 187 | vis_processor = ( 188 | self.vis_processors["train"] 189 | if is_train 190 | else self.vis_processors["eval"] 191 | ) 192 | text_processor = ( 193 | self.text_processors["train"] 194 | if is_train 195 | else self.text_processors["eval"] 196 | ) 197 | 198 | # annotation path 199 | ann_paths = ann_info.get(split).storage 200 | if isinstance(ann_paths, str): 201 | ann_paths = [ann_paths] 202 | 203 | abs_ann_paths = [] 204 | for ann_path in ann_paths: 205 | if not os.path.isabs(ann_path): 206 | ann_path = utils.get_cache_path(ann_path) 207 | abs_ann_paths.append(ann_path) 208 | ann_paths = abs_ann_paths 209 | 210 | # visual data storage path 211 | vis_path = os.path.join(vis_info.storage, split) 212 | 213 | if not os.path.isabs(vis_path): 214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 215 | vis_path = utils.get_cache_path(vis_path) 216 | 217 | if not os.path.exists(vis_path): 218 | warnings.warn("storage path {} does not exist.".format(vis_path)) 219 | 220 | # create datasets 221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 222 | datasets[split] = dataset_cls( 223 | vis_processor=vis_processor, 224 | text_processor=text_processor, 225 | ann_paths=ann_paths, 226 | vis_root=vis_path, 227 | ) 228 | 229 | return datasets 230 | 231 | 232 | def load_dataset_config(cfg_path): 233 | cfg = OmegaConf.load(cfg_path).datasets 234 | cfg = cfg[list(cfg.keys())[0]] 235 | 236 | return cfg 237 | -------------------------------------------------------------------------------- /timechat/processors/video_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | from timechat.common.registry import registry 10 | from decord import VideoReader 11 | import decord 12 | import numpy as np 13 | from timechat.processors import transforms_video 14 | from timechat.processors.base_processor import BaseProcessor 15 | from timechat.processors.randaugment import VideoRandomAugment 16 | from timechat.processors import functional_video as F 17 | from omegaconf import OmegaConf 18 | from torchvision import transforms 19 | import random as rnd 20 | 21 | 22 | MAX_INT = registry.get("MAX_INT") 23 | decord.bridge.set_bridge("torch") 24 | 25 | 26 | def interpolate_frame_pos_embed(frame_pos_embed_ckpt, new_n_frm=96): 27 | # interpolate frame position embedding 28 | # frame_pos_embed_ckpt: (old_n_frm, dim) 29 | frame_pos_embed_ckpt = frame_pos_embed_ckpt.unsqueeze(0).transpose(1, 2) # (1, dim, old_n_frm) 30 | new_frame_pos_embed_ckpt = torch.nn.functional.interpolate(frame_pos_embed_ckpt, size=(new_n_frm), 31 | mode='nearest') # (1, dim, new_n_frm) 32 | new_frame_pos_embed_ckpt = new_frame_pos_embed_ckpt.transpose(1, 2).squeeze(0) # (new_n_frm, dim) 33 | return new_frame_pos_embed_ckpt 34 | 35 | 36 | def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False): 37 | decord.bridge.set_bridge("torch") 38 | vr = VideoReader(uri=video_path, height=height, width=width) 39 | 40 | vlen = len(vr) 41 | start, end = 0, vlen 42 | acc_samples = min(n_frms, vlen) 43 | n_frms = min(n_frms, vlen) 44 | 45 | if sampling == "uniform": 46 | indices = np.arange(start, end, vlen / n_frms).astype(int).tolist() 47 | elif sampling == "headtail": 48 | indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) 49 | indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) 50 | indices = indices_h + indices_t 51 | elif sampling == 'rand': 52 | # split the video into `acc_samples` intervals, and sample from each interval. 53 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 54 | ranges = [] 55 | for idx, interv in enumerate(intervals[:-1]): 56 | ranges.append((interv, intervals[idx + 1] - 1)) 57 | try: 58 | indices = [rnd.choice(range(x[0], x[1])) for x in ranges] 59 | except: 60 | indices = np.random.permutation(vlen)[:acc_samples] 61 | indices.sort() 62 | indices = list(indices) 63 | else: 64 | raise NotImplementedError 65 | 66 | # get_batch -> T, H, W, C 67 | temp_frms = vr.get_batch(indices) 68 | # print(type(temp_frms)) 69 | tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms 70 | frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W) 71 | 72 | if not return_msg: 73 | return frms 74 | 75 | fps = float(vr.get_avg_fps()) 76 | sec = ", ".join([str(round(f / fps, 1)) for f in indices]) 77 | # " " should be added in the start and end 78 | msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. " 79 | return frms, msg 80 | 81 | 82 | class AlproVideoBaseProcessor(BaseProcessor): 83 | def __init__(self, mean=None, std=None, n_frms=MAX_INT): 84 | if mean is None: 85 | mean = (0.48145466, 0.4578275, 0.40821073) 86 | if std is None: 87 | std = (0.26862954, 0.26130258, 0.27577711) 88 | 89 | self.normalize = transforms_video.NormalizeVideo(mean, std) 90 | 91 | self.n_frms = n_frms 92 | 93 | 94 | class ToUint8(object): 95 | def __init__(self): 96 | pass 97 | 98 | def __call__(self, tensor): 99 | return tensor.to(torch.uint8) 100 | 101 | def __repr__(self): 102 | return self.__class__.__name__ 103 | 104 | 105 | class ToTHWC(object): 106 | """ 107 | Args: 108 | clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) 109 | Return: 110 | clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) 111 | """ 112 | 113 | def __init__(self): 114 | pass 115 | 116 | def __call__(self, tensor): 117 | return tensor.permute(1, 2, 3, 0) 118 | 119 | def __repr__(self): 120 | return self.__class__.__name__ 121 | 122 | 123 | class ResizeVideo(object): 124 | def __init__(self, target_size, interpolation_mode="bilinear"): 125 | self.target_size = target_size 126 | self.interpolation_mode = interpolation_mode 127 | 128 | def __call__(self, clip): 129 | """ 130 | Args: 131 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 132 | Returns: 133 | torch.tensor: central cropping of video clip. Size is 134 | (C, T, crop_size, crop_size) 135 | """ 136 | return F.resize(clip, self.target_size, self.interpolation_mode) 137 | 138 | def __repr__(self): 139 | return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) 140 | 141 | 142 | @registry.register_processor("alpro_video_train") 143 | class AlproVideoTrainProcessor(AlproVideoBaseProcessor): 144 | def __init__( 145 | self, 146 | image_size=384, 147 | mean=None, 148 | std=None, 149 | min_scale=0.5, 150 | max_scale=1.0, 151 | n_frms=MAX_INT, 152 | ): 153 | super().__init__(mean=mean, std=std, n_frms=n_frms) 154 | 155 | self.image_size = image_size 156 | 157 | self.transform = transforms.Compose( 158 | [ 159 | # Video size is (C, T, H, W) 160 | transforms_video.RandomResizedCropVideo( 161 | image_size, 162 | scale=(min_scale, max_scale), 163 | interpolation_mode="bicubic", 164 | ), 165 | ToTHWC(), # C, T, H, W -> T, H, W, C 166 | ToUint8(), 167 | transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W 168 | self.normalize, 169 | ] 170 | ) 171 | 172 | def __call__(self, vpath): 173 | """ 174 | Args: 175 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 176 | Returns: 177 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 178 | """ 179 | clip = load_video( 180 | video_path=vpath, 181 | n_frms=self.n_frms, 182 | height=self.image_size, 183 | width=self.image_size, 184 | sampling="headtail", 185 | ) 186 | 187 | return self.transform(clip) 188 | 189 | @classmethod 190 | def from_config(cls, cfg=None): 191 | if cfg is None: 192 | cfg = OmegaConf.create() 193 | 194 | image_size = cfg.get("image_size", 256) 195 | 196 | mean = cfg.get("mean", None) 197 | std = cfg.get("std", None) 198 | 199 | min_scale = cfg.get("min_scale", 0.5) 200 | max_scale = cfg.get("max_scale", 1.0) 201 | 202 | n_frms = cfg.get("n_frms", MAX_INT) 203 | 204 | return cls( 205 | image_size=image_size, 206 | mean=mean, 207 | std=std, 208 | min_scale=min_scale, 209 | max_scale=max_scale, 210 | n_frms=n_frms, 211 | ) 212 | 213 | 214 | @registry.register_processor("alpro_video_eval") 215 | class AlproVideoEvalProcessor(AlproVideoBaseProcessor): 216 | def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): 217 | super().__init__(mean=mean, std=std, n_frms=n_frms) 218 | 219 | self.image_size = image_size 220 | 221 | # Input video size is (C, T, H, W) 222 | self.transform = transforms.Compose( 223 | [ 224 | # frames will be resized during decord loading. 225 | ToUint8(), # C, T, H, W 226 | ToTHWC(), # T, H, W, C 227 | transforms_video.ToTensorVideo(), # C, T, H, W 228 | self.normalize, # C, T, H, W 229 | ] 230 | ) 231 | 232 | def __call__(self, vpath): 233 | """ 234 | Args: 235 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 236 | Returns: 237 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 238 | """ 239 | clip = load_video( 240 | video_path=vpath, 241 | n_frms=self.n_frms, 242 | height=self.image_size, 243 | width=self.image_size, 244 | ) 245 | 246 | return self.transform(clip) 247 | 248 | @classmethod 249 | def from_config(cls, cfg=None): 250 | if cfg is None: 251 | cfg = OmegaConf.create() 252 | 253 | image_size = cfg.get("image_size", 256) 254 | 255 | mean = cfg.get("mean", None) 256 | std = cfg.get("std", None) 257 | 258 | n_frms = cfg.get("n_frms", MAX_INT) 259 | 260 | return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) 261 | -------------------------------------------------------------------------------- /utils/cons_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import re 4 | import numpy as np 5 | import argparse 6 | import pandas as pd 7 | import random 8 | import torch 9 | import os 10 | import logging 11 | 12 | 13 | # ANSI escape codes for colors 14 | class Formatter(logging.Formatter): 15 | COLOR_CODES = { 16 | 'DEBUG': '\033[94m', # Blue 17 | 'INFO': '\033[92m', # Green 18 | 'WARNING': '\033[91m', # Red 19 | 'ERROR': '\033[93m', # Yellow 20 | 'CRITICAL': '\033[95m', # Magenta 21 | } 22 | RESET_CODE = '\033[0m' # Reset color 23 | 24 | def format(self, record): 25 | log_color = self.COLOR_CODES.get(record.levelname, self.RESET_CODE) 26 | message = super().format(record) 27 | return f"{log_color}{message}{self.RESET_CODE}" 28 | 29 | 30 | def load_logger(name): 31 | # Custom logger setup 32 | logger = logging.getLogger(name) 33 | logger.setLevel(logging.DEBUG) 34 | 35 | # Console handler with color formatter 36 | console_handler = logging.StreamHandler() 37 | console_handler.setLevel(logging.DEBUG) 38 | 39 | # Define formatter with color 40 | formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 41 | console_handler.setFormatter(formatter) 42 | 43 | # Add handler to logger 44 | logger.addHandler(console_handler) 45 | return logger 46 | 47 | 48 | class BaseOptions(object): 49 | def __init__(self): 50 | self.parser = None 51 | self.initialized = False 52 | self.opt = None 53 | 54 | def initialize(self): 55 | self.initialized = True 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--model_type", type=str, default='TimeChat', 58 | choices=['Video-ChatGPT', 'Video-LLaMA', 'Video-LLaMA2', 'Video-LLaVA', 'VTimeLLM', 'TimeChat', 'VTG-LLM', 'VideoChat2', 'GPT4', 'Gemini'], 59 | help="A list of Video-LLMs.") 60 | parser.add_argument("--dset_name", type=str, default="charades", choices=['activitynet', 'charades'], 61 | help="Dataset name.") 62 | parser.add_argument("--task", type=str, default="consistency", choices=['grounding', 'consistency'], 63 | help="Type of task.") 64 | parser.add_argument("--grounding_prompt", type=int, default=None) 65 | parser.add_argument('--description', action="store_true", 66 | help="Prompt the model to generate a video description before performing target tasks.") 67 | parser.add_argument('--CoT', action="store_true", help="Utilizes Chain-of-Thought Reasoning.") 68 | parser.add_argument('--fine_tuned', action="store_true") 69 | parser.add_argument('--iou_thd', type=float, default=0.5) 70 | parser.add_argument('--no_skip', action="store_true", 71 | help="Test the probes even if the initial prediction is in accurate.") 72 | parser.add_argument('--overwrite', action="store_true") 73 | parser.add_argument("--video_root", type=str, default="/data/video_datasets/", help="path to video files") 74 | parser.add_argument("--output_dir", type=str, default=None, help="path to output files") 75 | parser.add_argument("--exp_id", type=str, default=None, help="ID of this run.") 76 | parser.add_argument("--seed", type=int, default=1000) 77 | parser.add_argument('--debug', action="store_true", help="Debug mode.") 78 | self.parser = parser 79 | 80 | def parse(self): 81 | if not self.initialized: 82 | self.initialize() 83 | 84 | opt = self.parser.parse_args() 85 | 86 | if not opt.video_root: 87 | opt.video_root = f"/data/video_datasets/{opt.dset_name}" 88 | else: 89 | opt.video_root = os.path.join(opt.video_root, opt.dset_name) 90 | 91 | opt.test_path = f"data/{opt.dset_name}_consistency_test.json" 92 | 93 | self.opt = opt 94 | return opt 95 | 96 | 97 | def generate_question(task, prompt, query, duration, st=None, ed=None): 98 | choice = random.choice(["pos", "neg"]) 99 | if st and ed: 100 | st, ed = min(st, duration), min(ed, duration) 101 | 102 | add_detail = prompt["add_detail"] 103 | if task in ["grounding"]: 104 | question = prompt[task].format(event=query) 105 | add_detail = None 106 | 107 | elif task in ["description"]: 108 | question = prompt[task] 109 | add_detail = None 110 | 111 | elif task in ["occurrence"]: 112 | question = random.choice(prompt[choice]).format(event=query, st=st, ed=ed) 113 | 114 | elif task in ["compositional"]: 115 | query = query.replace("?", "") 116 | question = prompt[task].format(question=query, st=st, ed=ed) 117 | 118 | else: 119 | raise NotImplementedError(f"Not implemented task: {task}") 120 | 121 | return question, add_detail, choice 122 | 123 | 124 | def load_jsonl(filename): 125 | with open(filename, "r") as f: 126 | return [json.loads(l.replace("'","").strip("\n")) for l in f.readlines()] 127 | 128 | 129 | def save_jsonl(data, filename): 130 | """data is a list""" 131 | with open(filename, "w") as f: 132 | f.write("\n".join([json.dumps(e) for e in data])) 133 | 134 | 135 | def save_json(data, filename, save_pretty=False, sort_keys=False): 136 | with open(filename, "w") as f: 137 | if save_pretty: 138 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 139 | else: 140 | json.dump(data, f) 141 | 142 | 143 | def load_json(filename): 144 | with open(filename, "r") as f: 145 | return json.load(f) 146 | 147 | 148 | def get_iou(A, B): 149 | try: 150 | max0 = max((A[0]), (B[0])) 151 | min0 = min((A[0]), (B[0])) 152 | max1 = max((A[1]), (B[1])) 153 | min1 = min((A[1]), (B[1])) 154 | 155 | return round(max(min1 - max0, 0) / (max1 - min0), 2) 156 | 157 | except: 158 | return 0 159 | 160 | 161 | def shifting_video_moment(video_features, org_timestamp, new_timestamp, duration): 162 | """ 163 | Shifts frames of a video between the original timestamp and new timestamp. 164 | 165 | video_features: The input video features (either list or torch.Tensor) 166 | org_timestamp: The original start and end time (in seconds) for the part of the video to be moved. 167 | new_timestamp: The new start and end time (in seconds) where the original frames should be shifted. 168 | duration: Total duration of the video (in seconds). 169 | 170 | The format of video_features in Video-LLaVA and TimeChat is list, containing tensor features. 171 | """ 172 | if isinstance(video_features, list): 173 | # Handle list of frames 174 | n_frames = len(video_features) 175 | if not isinstance(video_features[0], torch.Tensor): 176 | _img_embeds = copy.deepcopy(video_features) 177 | org_frame = second_to_frame(n_frames, org_timestamp, duration) 178 | new_frame = second_to_frame(n_frames, new_timestamp, duration) 179 | 180 | # Perform the shift 181 | _img_embeds[org_frame[0]: org_frame[1] + 1] = video_features[new_frame[0]: new_frame[1] + 1] 182 | _img_embeds[new_frame[0]: new_frame[1] + 1] = video_features[org_frame[0]: org_frame[1] + 1] 183 | # print("to", video_features[0].shape, len(video_features[0])) 184 | return _img_embeds 185 | else: 186 | img_embes = video_features[0] 187 | _img_embeds = img_embes.clone() 188 | 189 | org_frame = second_to_frame(n_frames, org_timestamp, duration) 190 | new_frame = second_to_frame(n_frames, new_timestamp, duration) 191 | 192 | # Perform the shift 193 | _img_embeds[org_frame[0]: org_frame[1]+1] = img_embes[new_frame[0]: new_frame[1]+1] 194 | _img_embeds[new_frame[0]: new_frame[1]+1] = img_embes[org_frame[0]: org_frame[1]+1] 195 | # print("to", video_features[0].shape, len(video_features[0])) 196 | return [_img_embeds] 197 | 198 | elif isinstance(video_features, torch.Tensor): 199 | n_frames = video_features.shape[0] 200 | org_frame = second_to_frame(n_frames, org_timestamp, duration) 201 | new_frame = second_to_frame(n_frames, new_timestamp, duration) 202 | 203 | # Calculate the number of frames in each range 204 | org_length = org_frame[1] - org_frame[0] 205 | new_length = new_frame[1] - new_frame[0] 206 | 207 | # Find the minimum length to avoid shape mismatch 208 | min_length = min(org_length, new_length) 209 | 210 | # Extract the original and new frame segments with adjusted lengths 211 | org_frame_feat = video_features[org_frame[0]: org_frame[0] + min_length, :] 212 | new_frame_feat = video_features[new_frame[0]: new_frame[0] + min_length, :] 213 | 214 | # Clone the tensor to avoid in-place overwriting 215 | shifted_video_features = video_features.clone() 216 | 217 | # Perform the swap, making sure both ranges are equal in length 218 | shifted_video_features[org_frame[0]: org_frame[0] + min_length, :] = new_frame_feat 219 | shifted_video_features[new_frame[0]: new_frame[0] + min_length, :] = org_frame_feat 220 | 221 | return shifted_video_features 222 | 223 | 224 | def dict_to_markdown(d, max_str_len=120): 225 | # convert list into its str representation 226 | d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()} 227 | # truncate string that is longer than max_str_len 228 | if max_str_len is not None: 229 | d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()} 230 | return pd.DataFrame(d, index=[0]).transpose().to_markdown() 231 | 232 | 233 | def display(args): 234 | opt = args if isinstance(args, dict) else vars(args) 235 | # opt = vars(args) 236 | print(dict_to_markdown(opt, max_str_len=120)) 237 | 238 | 239 | def second_to_frame(n_frame, seconds, duration): 240 | return [int(seconds[0] / duration * n_frame), int(seconds[1] / duration * n_frame)] -------------------------------------------------------------------------------- /timechat/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | import wandb 11 | import torch 12 | import torch.distributed as dist 13 | from timechat.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 14 | from timechat.common.logger import MetricLogger, SmoothedValue 15 | from timechat.common.registry import registry 16 | from timechat.datasets.data_utils import prepare_sample 17 | 18 | 19 | class BaseTask: 20 | def __init__(self, **kwargs): 21 | super().__init__() 22 | 23 | self.inst_id_key = "instance_id" 24 | 25 | @classmethod 26 | def setup_task(cls, **kwargs): 27 | return cls() 28 | 29 | def build_model(self, cfg): 30 | model_config = cfg.model_cfg 31 | 32 | model_cls = registry.get_model_class(model_config.arch) 33 | return model_cls.from_config(model_config) 34 | 35 | def build_datasets(self, cfg): 36 | """ 37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. 38 | Download dataset and annotations automatically if not exist. 39 | 40 | Args: 41 | cfg (common.config.Config): _description_ 42 | 43 | Returns: 44 | dict: Dictionary of torch.utils.data.Dataset objects by split. 45 | """ 46 | 47 | datasets = dict() 48 | 49 | datasets_config = cfg.datasets_cfg 50 | 51 | assert len(datasets_config) > 0, "At least one dataset has to be specified." 52 | 53 | for name in datasets_config: 54 | dataset_config = datasets_config[name] 55 | 56 | builder = registry.get_builder_class(name)(dataset_config) 57 | dataset = builder.build_datasets() 58 | 59 | dataset['train'].name = name 60 | if 'sample_ratio' in dataset_config: 61 | dataset['train'].sample_ratio = dataset_config.sample_ratio 62 | 63 | datasets[name] = dataset 64 | 65 | return datasets 66 | 67 | def train_step(self, model, samples): 68 | loss = model(samples)["loss"] 69 | return loss 70 | 71 | def valid_step(self, model, samples): 72 | raise NotImplementedError 73 | 74 | def before_evaluation(self, model, dataset, **kwargs): 75 | model.before_evaluation(dataset=dataset, task_type=type(self)) 76 | 77 | def after_evaluation(self, **kwargs): 78 | pass 79 | 80 | def inference_step(self): 81 | raise NotImplementedError 82 | 83 | def evaluation(self, model, data_loader, cuda_enabled=True): 84 | metric_logger = MetricLogger(delimiter=" ") 85 | header = "Evaluation" 86 | # TODO make it configurable 87 | print_freq = 10 88 | 89 | results = [] 90 | 91 | for samples in metric_logger.log_every(data_loader, print_freq, header): 92 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 93 | 94 | eval_output = self.valid_step(model=model, samples=samples) 95 | results.extend(eval_output) 96 | 97 | if is_dist_avail_and_initialized(): 98 | dist.barrier() 99 | 100 | return results 101 | 102 | def train_epoch( 103 | self, 104 | epoch, 105 | model, 106 | data_loader, 107 | optimizer, 108 | lr_scheduler, 109 | scaler=None, 110 | cuda_enabled=False, 111 | log_freq=50, 112 | accum_grad_iters=1, 113 | ): 114 | return self._train_inner_loop( 115 | epoch=epoch, 116 | iters_per_epoch=lr_scheduler.iters_per_epoch, 117 | model=model, 118 | data_loader=data_loader, 119 | optimizer=optimizer, 120 | scaler=scaler, 121 | lr_scheduler=lr_scheduler, 122 | log_freq=log_freq, 123 | cuda_enabled=cuda_enabled, 124 | accum_grad_iters=accum_grad_iters, 125 | ) 126 | 127 | def train_iters( 128 | self, 129 | epoch, 130 | start_iters, 131 | iters_per_inner_epoch, 132 | model, 133 | data_loader, 134 | optimizer, 135 | lr_scheduler, 136 | scaler=None, 137 | cuda_enabled=False, 138 | log_freq=50, 139 | accum_grad_iters=1, 140 | ): 141 | return self._train_inner_loop( 142 | epoch=epoch, 143 | start_iters=start_iters, 144 | iters_per_epoch=iters_per_inner_epoch, 145 | model=model, 146 | data_loader=data_loader, 147 | optimizer=optimizer, 148 | scaler=scaler, 149 | lr_scheduler=lr_scheduler, 150 | log_freq=log_freq, 151 | cuda_enabled=cuda_enabled, 152 | accum_grad_iters=accum_grad_iters, 153 | ) 154 | 155 | def _train_inner_loop( 156 | self, 157 | epoch, 158 | iters_per_epoch, 159 | model, 160 | data_loader, 161 | optimizer, 162 | lr_scheduler, 163 | scaler=None, 164 | start_iters=None, 165 | log_freq=50, 166 | cuda_enabled=False, 167 | accum_grad_iters=1, 168 | ): 169 | """ 170 | An inner training loop compatible with both epoch-based and iter-based training. 171 | 172 | When using epoch-based, training stops after one epoch; when using iter-based, 173 | training stops after #iters_per_epoch iterations. 174 | """ 175 | use_amp = scaler is not None 176 | 177 | if not hasattr(data_loader, "__next__"): 178 | # convert to iterator if not already 179 | data_loader = iter(data_loader) 180 | 181 | metric_logger = MetricLogger(delimiter=" ") 182 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 183 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 184 | 185 | # if iter-based runner, schedule lr based on inner epoch. 186 | logging.info( 187 | "Start training epoch {}, {} iters per inner epoch.".format( 188 | epoch, iters_per_epoch 189 | ) 190 | ) 191 | header = "Train: data epoch: [{}]".format(epoch) 192 | if start_iters is None: 193 | # epoch-based runner 194 | inner_epoch = epoch 195 | else: 196 | # In iter-based runner, we schedule the learning rate based on iterations. 197 | inner_epoch = start_iters // iters_per_epoch 198 | header = header + "; inner epoch [{}]".format(inner_epoch) 199 | 200 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 201 | # if using iter-based runner, we stop after iters_per_epoch iterations. 202 | if i >= iters_per_epoch: 203 | break 204 | 205 | samples = next(data_loader) 206 | 207 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 208 | samples.update( 209 | { 210 | "epoch": inner_epoch, 211 | "num_iters_per_epoch": iters_per_epoch, 212 | "iters": i, 213 | } 214 | ) 215 | 216 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 217 | 218 | with torch.cuda.amp.autocast(enabled=use_amp): 219 | loss = self.train_step(model=model, samples=samples) 220 | 221 | # after_train_step() 222 | if use_amp: 223 | scaler.scale(loss).backward() 224 | else: 225 | loss.backward() 226 | 227 | # update gradients every accum_grad_iters iterations 228 | if (i + 1) % accum_grad_iters == 0: 229 | if use_amp: 230 | scaler.step(optimizer) 231 | scaler.update() 232 | else: 233 | optimizer.step() 234 | optimizer.zero_grad() 235 | 236 | metric_logger.update(loss=loss.item()) 237 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 238 | 239 | if is_main_process() and wandb.run is not None: 240 | wandb.log({'train/loss': loss.item(), 241 | 'train/lr': optimizer.param_groups[0]["lr"]}, 242 | step=epoch * iters_per_epoch + i) 243 | 244 | # after train_epoch() 245 | # gather the stats from all processes 246 | metric_logger.synchronize_between_processes() 247 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 248 | return { 249 | k: "{:.3f}".format(meter.global_avg) 250 | for k, meter in metric_logger.meters.items() 251 | } 252 | 253 | @staticmethod 254 | def save_result(result, result_dir, filename, remove_duplicate=""): 255 | import json 256 | 257 | result_file = os.path.join( 258 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 259 | ) 260 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 261 | 262 | json.dump(result, open(result_file, "w")) 263 | 264 | if is_dist_avail_and_initialized(): 265 | dist.barrier() 266 | 267 | if is_main_process(): 268 | logging.warning("rank %d starts merging results." % get_rank()) 269 | # combine results from all processes 270 | result = [] 271 | 272 | for rank in range(get_world_size()): 273 | result_file = os.path.join( 274 | result_dir, "%s_rank%d.json" % (filename, rank) 275 | ) 276 | res = json.load(open(result_file, "r")) 277 | result += res 278 | 279 | if remove_duplicate: 280 | result_new = [] 281 | id_list = [] 282 | for res in result: 283 | if res[remove_duplicate] not in id_list: 284 | id_list.append(res[remove_duplicate]) 285 | result_new.append(res) 286 | result = result_new 287 | 288 | json.dump(result, open(final_result_file, "w")) 289 | print("result file saved to %s" % final_result_file) 290 | 291 | return final_result_file 292 | -------------------------------------------------------------------------------- /timechat/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from timechat.common.registry import registry 31 | from timechat.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | assert issubclass( 38 | builder_cls, BaseDatasetBuilder 39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | builder_cls 41 | ) 42 | if name in cls.mapping["builder_name_mapping"]: 43 | raise KeyError( 44 | "Name '{}' already registered for {}.".format( 45 | name, cls.mapping["builder_name_mapping"][name] 46 | ) 47 | ) 48 | cls.mapping["builder_name_mapping"][name] = builder_cls 49 | return builder_cls 50 | 51 | return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from timechat.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from timechat.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from timechat.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from timechat.models import BaseModel 96 | 97 | assert issubclass( 98 | model_cls, BaseModel 99 | ), "All models must inherit BaseModel class" 100 | if name in cls.mapping["model_name_mapping"]: 101 | raise KeyError( 102 | "Name '{}' already registered for {}.".format( 103 | name, cls.mapping["model_name_mapping"][name] 104 | ) 105 | ) 106 | cls.mapping["model_name_mapping"][name] = model_cls 107 | return model_cls 108 | 109 | return wrap 110 | 111 | @classmethod 112 | def register_processor(cls, name): 113 | r"""Register a processor to registry with key 'name' 114 | 115 | Args: 116 | name: Key with which the task will be registered. 117 | 118 | Usage: 119 | 120 | from timechat.common.registry import registry 121 | """ 122 | 123 | def wrap(processor_cls): 124 | from timechat.processors import BaseProcessor 125 | 126 | assert issubclass( 127 | processor_cls, BaseProcessor 128 | ), "All processors must inherit BaseProcessor class" 129 | if name in cls.mapping["processor_name_mapping"]: 130 | raise KeyError( 131 | "Name '{}' already registered for {}.".format( 132 | name, cls.mapping["processor_name_mapping"][name] 133 | ) 134 | ) 135 | cls.mapping["processor_name_mapping"][name] = processor_cls 136 | return processor_cls 137 | 138 | return wrap 139 | 140 | @classmethod 141 | def register_lr_scheduler(cls, name): 142 | r"""Register a model to registry with key 'name' 143 | 144 | Args: 145 | name: Key with which the task will be registered. 146 | 147 | Usage: 148 | 149 | from timechat.common.registry import registry 150 | """ 151 | 152 | def wrap(lr_sched_cls): 153 | if name in cls.mapping["lr_scheduler_name_mapping"]: 154 | raise KeyError( 155 | "Name '{}' already registered for {}.".format( 156 | name, cls.mapping["lr_scheduler_name_mapping"][name] 157 | ) 158 | ) 159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 160 | return lr_sched_cls 161 | 162 | return wrap 163 | 164 | @classmethod 165 | def register_runner(cls, name): 166 | r"""Register a model to registry with key 'name' 167 | 168 | Args: 169 | name: Key with which the task will be registered. 170 | 171 | Usage: 172 | 173 | from timechat.common.registry import registry 174 | """ 175 | 176 | def wrap(runner_cls): 177 | if name in cls.mapping["runner_name_mapping"]: 178 | raise KeyError( 179 | "Name '{}' already registered for {}.".format( 180 | name, cls.mapping["runner_name_mapping"][name] 181 | ) 182 | ) 183 | cls.mapping["runner_name_mapping"][name] = runner_cls 184 | return runner_cls 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def register_path(cls, name, path): 190 | r"""Register a path to registry with key 'name' 191 | 192 | Args: 193 | name: Key with which the path will be registered. 194 | 195 | Usage: 196 | 197 | from timechat.common.registry import registry 198 | """ 199 | assert isinstance(path, str), "All path must be str." 200 | if name in cls.mapping["paths"]: 201 | raise KeyError("Name '{}' already registered.".format(name)) 202 | cls.mapping["paths"][name] = path 203 | 204 | @classmethod 205 | def register(cls, name, obj): 206 | r"""Register an item to registry with key 'name' 207 | 208 | Args: 209 | name: Key with which the item will be registered. 210 | 211 | Usage:: 212 | 213 | from timechat.common.registry import registry 214 | 215 | registry.register("config", {}) 216 | """ 217 | path = name.split(".") 218 | current = cls.mapping["state"] 219 | 220 | for part in path[:-1]: 221 | if part not in current: 222 | current[part] = {} 223 | current = current[part] 224 | 225 | current[path[-1]] = obj 226 | 227 | # @classmethod 228 | # def get_trainer_class(cls, name): 229 | # return cls.mapping["trainer_name_mapping"].get(name, None) 230 | 231 | @classmethod 232 | def get_builder_class(cls, name): 233 | return cls.mapping["builder_name_mapping"].get(name, None) 234 | 235 | @classmethod 236 | def get_model_class(cls, name): 237 | return cls.mapping["model_name_mapping"].get(name, None) 238 | 239 | @classmethod 240 | def get_task_class(cls, name): 241 | return cls.mapping["task_name_mapping"].get(name, None) 242 | 243 | @classmethod 244 | def get_processor_class(cls, name): 245 | return cls.mapping["processor_name_mapping"].get(name, None) 246 | 247 | @classmethod 248 | def get_lr_scheduler_class(cls, name): 249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 250 | 251 | @classmethod 252 | def get_runner_class(cls, name): 253 | return cls.mapping["runner_name_mapping"].get(name, None) 254 | 255 | @classmethod 256 | def list_runners(cls): 257 | return sorted(cls.mapping["runner_name_mapping"].keys()) 258 | 259 | @classmethod 260 | def list_models(cls): 261 | return sorted(cls.mapping["model_name_mapping"].keys()) 262 | 263 | @classmethod 264 | def list_tasks(cls): 265 | return sorted(cls.mapping["task_name_mapping"].keys()) 266 | 267 | @classmethod 268 | def list_processors(cls): 269 | return sorted(cls.mapping["processor_name_mapping"].keys()) 270 | 271 | @classmethod 272 | def list_lr_schedulers(cls): 273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 274 | 275 | @classmethod 276 | def list_datasets(cls): 277 | return sorted(cls.mapping["builder_name_mapping"].keys()) 278 | 279 | @classmethod 280 | def get_path(cls, name): 281 | return cls.mapping["paths"].get(name, None) 282 | 283 | @classmethod 284 | def get(cls, name, default=None, no_warning=False): 285 | r"""Get an item from registry with key 'name' 286 | 287 | Args: 288 | name (string): Key whose value needs to be retrieved. 289 | default: If passed and key is not in registry, default value will 290 | be returned with a warning. Default: None 291 | no_warning (bool): If passed as True, warning when key doesn't exist 292 | will not be generated. Useful for MMF's 293 | internal operations. Default: False 294 | """ 295 | original_name = name 296 | name = name.split(".") 297 | value = cls.mapping["state"] 298 | for subname in name: 299 | value = value.get(subname, default) 300 | if value is default: 301 | break 302 | 303 | if ( 304 | "writer" in cls.mapping["state"] 305 | and value == default 306 | and no_warning is False 307 | ): 308 | cls.mapping["state"]["writer"].warning( 309 | "Key {} is not present in registry, returning default value " 310 | "of {}".format(original_name, default) 311 | ) 312 | return value 313 | 314 | @classmethod 315 | def unregister(cls, name): 316 | r"""Remove an item from registry with key 'name' 317 | 318 | Args: 319 | name: Key which needs to be removed. 320 | Usage:: 321 | 322 | from mmf.common.registry import registry 323 | 324 | config = registry.unregister("config") 325 | """ 326 | return cls.mapping["state"].pop(name, None) 327 | 328 | 329 | registry = Registry() 330 | -------------------------------------------------------------------------------- /timechat/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import json 5 | import numpy as np 6 | import torch 7 | import decord 8 | import time 9 | import subprocess 10 | import re 11 | from utils.cons_utils import BaseOptions, generate_question, load_logger 12 | from utils.prompts import prompt, cot 13 | 14 | from timechat.common.config import Config 15 | from timechat.common.registry import registry 16 | from timechat.conversation.conversation_video import Chat, conv_llava_llama_2 17 | from timechat.processors.video_processor import load_video 18 | decord.bridge.set_bridge('torch') 19 | logger = load_logger("TimeChat") 20 | 21 | 22 | class TimeChat_Options(BaseOptions): 23 | def initialize(self): 24 | BaseOptions.initialize(self) 25 | self.parser.add_argument("--cfg-path", default='timechat/eval_configs/timechat.yaml', help="path to configuration file.") 26 | self.parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") 27 | self.parser.add_argument("--options", nargs="+", 28 | help="override some settings in the used config, " 29 | "the key-value pair in xxx=yyy format will be merged into config file (deprecate), " 30 | "change to --cfg-options instead.", 31 | ) 32 | 33 | 34 | class TimeChat: 35 | def __init__(self, args): 36 | """ 37 | Follow the official grounding prompt at https://github.com/RenShuhuai-Andy/TimeChat 38 | """ 39 | 40 | cfg = Config(args) 41 | if args.fine_tuned: 42 | prompt["grounding"] = "Localize the visual content described by the given textual query '{event}' in the video, and output the start and end timestamps in seconds." 43 | cfg.model_cfg.ckpt = cfg.model_cfg.charades_ckpt if args.dset_name == "charades" else cfg.model_cfg.activitynet_ckpt 44 | if not os.path.exists(cfg.model_cfg.ckpt): 45 | raise FileNotFoundError(f"Check the checkpoint path: {cfg.model_cfg.ckpt}") 46 | else: 47 | prompt["grounding"] = "Please find the visual event described by a sentence in the video, determining its starting and ending times. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seconds. Now I will give you the textual sentence: {event}. Please return its start time and end time." 48 | logger.info(f"Load checkpoints from '{cfg.model_cfg.ckpt}'") 49 | 50 | self.model_config = cfg.model_cfg 51 | self.gpu_id = args.gpu_id 52 | self.model, self.vis_processor = self.load_model(cfg) 53 | self.chat = Chat(self.model, self.vis_processor, device='cuda:{}'.format(args.gpu_id)) 54 | self.debug = args.debug 55 | self.n_frames = 96 56 | self.CoT = args.CoT 57 | 58 | def load_model(self, cfg): 59 | model_config = cfg.model_cfg 60 | model_config.device_8bit = self.gpu_id 61 | model_cls = registry.get_model_class(model_config.arch) 62 | model = model_cls.from_config(model_config).to('cuda:{}'.format(self.gpu_id)) 63 | model.eval() 64 | 65 | vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train 66 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 67 | 68 | return model, vis_processor 69 | 70 | def load_video_features(self, video_path): 71 | video_features = [] 72 | video, msg = load_video( 73 | video_path=video_path, 74 | n_frms=self.n_frames, 75 | height=224, 76 | width=224, 77 | sampling="uniform", 78 | return_msg=True 79 | ) 80 | video = self.vis_processor.transform(video) 81 | video = video.unsqueeze(0).to(self.gpu_id) 82 | 83 | if self.model.qformer_text_input: 84 | # timestamp 85 | timestamps = msg.split('at')[1].replace('seconds.', '').strip().split(',') # extract timestamps from msg 86 | timestamps = [f'This frame is sampled at {t.strip()} second.' for t in timestamps] 87 | timestamps = self.model.tokenizer( 88 | timestamps, 89 | return_tensors="pt", 90 | padding="longest", 91 | max_length=32, 92 | truncation=True, 93 | ) 94 | 95 | if self.model.qformer_text_input: 96 | image_emb, _ = self.model.encode_videoQformer_visual(video, timestamp=timestamps) 97 | else: 98 | image_emb, _ = self.model.encode_videoQformer_visual(video) 99 | video_features.append(image_emb) 100 | 101 | return video_features, msg 102 | 103 | def inference(self, chat, chat_state, video_features): 104 | llm_message = chat.answer(conv=chat_state, 105 | img_list=video_features, 106 | num_beams=1, 107 | do_sample=False, 108 | temperature=0.05, 109 | max_new_tokens=300, 110 | max_length=2000)[0] 111 | chat_state.messages[-1][-1] = llm_message 112 | 113 | return llm_message 114 | 115 | def extract_time(self, paragraph): 116 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower() 117 | paragraph = paragraph.lower() 118 | paragraph.replace(prompt, '') 119 | # Split text into sentences based on common delimiters 120 | sentences = re.split(r'[!?\n]', paragraph) 121 | 122 | # Keywords that might indicate the presence of time information 123 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"] 124 | # filter sentences by keywords 125 | candidates = [] 126 | for sentence in sentences: 127 | # If sentence contains one of the keywords 128 | if any(keyword in sentence for keyword in keywords): 129 | candidates.append(sentence) 130 | 131 | timestamps = [] 132 | # Check for The given query happens in m - n (seconds) 133 | patterns = [ 134 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)" 135 | ] 136 | 137 | for time_pattern in patterns: 138 | time_matches = re.findall(time_pattern, paragraph) 139 | if time_matches: 140 | timestamps = [[float(start), float(end)] for start, end in time_matches] 141 | 142 | if len(sentences) == 0: 143 | return [] 144 | # check for other formats e.g.: 145 | # 1 .Starting time: 0.8 seconds 146 | # Ending time: 1.1 seconds 147 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds. 148 | if len(timestamps) == 0: 149 | times = [] 150 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5) 151 | for sentence in candidates: 152 | time = re.findall(time_regex, sentence) 153 | if time: 154 | time_in_sec = float(time[0]) 155 | times.append(time_in_sec) 156 | times = times[:len(times) // 2 * 2] 157 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] 158 | # Check for examples like: 159 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23. 160 | if len(timestamps) == 0: 161 | times = [] 162 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05) 163 | for sentence in candidates: 164 | time = re.findall(time_regex, sentence) 165 | if time: 166 | t = time[0] 167 | else: 168 | continue 169 | # If time is in HH:MM:SS format, convert to seconds 170 | if t.count(':') == 2: 171 | h, m, s = map(int, t.split(':')) 172 | time_in_sec = h * 3600 + m * 60 + s 173 | elif t.count(':') == 1: 174 | m, s = map(int, t.split(':')) 175 | time_in_sec = m * 60 + s 176 | times.append(time_in_sec) 177 | times = times[:len(times) // 2 * 2] 178 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] 179 | 180 | results = [] 181 | for (start, end) in timestamps: 182 | if end > start: 183 | results.append([start, end]) 184 | else: 185 | results.append([end, start]) 186 | 187 | if len(results) == 0: 188 | return [0, 0] 189 | else: 190 | return results[0] 191 | 192 | def extract_time2(self, paragraph): 193 | pattern = r'(?:(?:from\s+|in\s+|between\s+|at\s+|—|to\s+|and\s+|–)\s*)?(\d+(?:\.\d+)?)\s*(?:to|and|–|—)\s*(?:(?:from\s+|in\s+|between\s+|at\s+|—|to\s+|and\s+|–)\s*)?(\d+(?:\.\d+)?)?\s*seconds?' 194 | match = re.search(pattern, paragraph) 195 | 196 | if match: 197 | start_timestamp = float(match.group(1)) 198 | end_timestamp = float(match.group(2)) 199 | 200 | # If only the start timestamp is found 201 | if end_timestamp is None: 202 | return [0, start_timestamp] 203 | else: 204 | return [float(start_timestamp), float(end_timestamp)] 205 | 206 | else: 207 | pattern = r'\d+\.\d+|\d+' 208 | 209 | # Find all matches in the input string 210 | matches = re.findall(pattern, paragraph) 211 | 212 | # Convert matched strings to float 213 | float_numbers = [float(match) for match in matches] 214 | if len(float_numbers) == 0: 215 | return [0, 0] 216 | elif len(float_numbers) == 1: 217 | return [0, float_numbers[0]] 218 | else: 219 | return float_numbers[:2] 220 | 221 | def initialize_chat(self, task, msg, add_detail=None): 222 | chat_state = conv_llava_llama_2.copy() 223 | chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail." 224 | 225 | if add_detail: 226 | chat_state.system += (" " + add_detail) 227 | if self.CoT: 228 | chat_state.system = cot[task] if task in cot else chat_state.system 229 | 230 | chat_state.append_message(chat_state.roles[0], " " + msg) 231 | 232 | return chat_state 233 | 234 | def run(self, task, video_features, query, duration, chat_state=None, st=None, ed=None, msg=None, return_chat_state=False): 235 | question, add_detail, choice = generate_question(task, prompt, query, duration, st, ed) 236 | if not chat_state: 237 | chat_state = self.initialize_chat(task, msg, add_detail=add_detail) 238 | else: 239 | chat_state = chat_state.copy() 240 | 241 | question = " ".join([question, add_detail]) if add_detail else question 242 | self.chat.ask(question, chat_state) 243 | answer = self.inference(self.chat, chat_state, video_features) 244 | 245 | if self.debug: 246 | print(f"[{task}] Question: {question}") 247 | print(f"Answer: {answer}\n") 248 | 249 | if task in ["grounding"]: 250 | timestamps = self.extract_time(answer) 251 | if timestamps == [0,0]: 252 | timestamps = self.extract_time2(answer) 253 | return {"q": question, "a": answer, "t": timestamps} 254 | 255 | elif task in ["occurrence", "compositional"]: 256 | if task == "compositional": 257 | choice = "pos" 258 | return {"c": choice, "q": question, "a": answer} 259 | 260 | elif task in ["description"]: 261 | if return_chat_state: 262 | return answer, chat_state 263 | return answer 264 | else: 265 | raise NotImplementedError(f"Task {task} is not yet implemented.") 266 | -------------------------------------------------------------------------------- /timechat/datasets/datasets/llava_instruct_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from timechat.datasets.datasets.base_dataset import BaseDataset 3 | from timechat.datasets.datasets.caption_datasets import CaptionDataset 4 | import pandas as pd 5 | import decord 6 | from decord import VideoReader 7 | import random 8 | import torch 9 | from torch.utils.data.dataloader import default_collate 10 | from PIL import Image 11 | from typing import Dict, Optional, Sequence 12 | import transformers 13 | import pathlib 14 | import json 15 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 16 | from timechat.conversation.conversation_video import Conversation,SeparatorStyle 17 | DEFAULT_IMAGE_PATCH_TOKEN = '' 18 | DEFAULT_IMAGE_TOKEN = "" 19 | import copy 20 | from timechat.processors import transforms_video,AlproVideoTrainProcessor 21 | IGNORE_INDEX = -100 22 | image_conversation = Conversation( 23 | system="", 24 | roles=("Human", "Assistant"), 25 | messages=[], 26 | offset=0, 27 | sep_style=SeparatorStyle.SINGLE, 28 | sep="###", 29 | ) 30 | llama_v2_image_conversation = Conversation( 31 | system=" ", 32 | roles=("USER", "ASSISTANT"), 33 | messages=(), 34 | offset=0, 35 | sep_style=SeparatorStyle.LLAMA_2, 36 | sep="", 37 | sep2="", 38 | ) 39 | IGNORE_INDEX = -100 40 | 41 | class Instruct_Dataset(BaseDataset): 42 | def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'): 43 | """ 44 | vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/) 45 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) 46 | split (string): val or test 47 | """ 48 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 49 | 50 | data_path = pathlib.Path(ann_root) 51 | with data_path.open(encoding='utf-8') as f: 52 | self.annotation = json.load(f) 53 | 54 | self.vis_root = vis_root 55 | self.resize_size = 224 56 | self.num_frm = 8 57 | self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False) 58 | self.tokenizer.pad_token = self.tokenizer.unk_token 59 | self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 60 | self.num_video_query_token = num_video_query_token 61 | self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN] 62 | 63 | self.transform = AlproVideoTrainProcessor( 64 | image_size=self.resize_size, n_frms = self.num_frm 65 | ).transform 66 | self.data_type = data_type 67 | self.model_type = model_type 68 | 69 | def _get_image_path(self, sample): 70 | rel_video_fp ='COCO_train2014_' + sample['image'] 71 | full_video_fp = os.path.join(self.vis_root, rel_video_fp) 72 | return full_video_fp 73 | 74 | def __getitem__(self, index): 75 | num_retries = 10 # skip error videos 76 | for _ in range(num_retries): 77 | try: 78 | sample = self.annotation[index] 79 | 80 | image_path = self._get_image_path(sample) 81 | conversation_list = sample['conversations'] 82 | image = Image.open(image_path).convert("RGB") 83 | 84 | image = self.vis_processor(image) 85 | # text = self.text_processor(text) 86 | sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token) 87 | if self.model_type =='vicuna': 88 | data_dict = preprocess( 89 | sources, 90 | self.tokenizer) 91 | elif self.model_type =='llama_v2': 92 | data_dict = preprocess_for_llama_v2( 93 | sources, 94 | self.tokenizer) 95 | else: 96 | print('not support') 97 | raise('not support') 98 | data_dict = dict(input_ids=data_dict["input_ids"][0], 99 | labels=data_dict["labels"][0]) 100 | 101 | # image exist in the data 102 | data_dict['image'] = image 103 | except: 104 | print(f"Failed to load examples with image: {image_path}. " 105 | f"Will randomly sample an example as a replacement.") 106 | index = random.randint(0, len(self) - 1) 107 | continue 108 | break 109 | else: 110 | raise RuntimeError(f"Failed to fetch image after {num_retries} retries.") 111 | # "image_id" is kept to stay compatible with the COCO evaluation format 112 | return { 113 | "image": image, 114 | "text_input": data_dict["input_ids"], 115 | "labels": data_dict["labels"], 116 | "type":'image', 117 | } 118 | 119 | def __len__(self): 120 | return len(self.annotation) 121 | 122 | def collater(self, instances): 123 | input_ids, labels = tuple([instance[key] for instance in instances] 124 | for key in ("text_input", "labels")) 125 | input_ids = torch.nn.utils.rnn.pad_sequence( 126 | input_ids, 127 | batch_first=True, 128 | padding_value=self.tokenizer.pad_token_id) 129 | labels = torch.nn.utils.rnn.pad_sequence(labels, 130 | batch_first=True, 131 | padding_value=IGNORE_INDEX) 132 | batch = dict( 133 | input_ids=input_ids, 134 | labels=labels, 135 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 136 | ) 137 | 138 | if 'image' in instances[0]: 139 | images = [instance['image'] for instance in instances] 140 | if all(x is not None and x.shape == images[0].shape for x in images): 141 | batch['images'] = torch.stack(images) 142 | else: 143 | batch['images'] = images 144 | batch['conv_type'] = 'multi' 145 | return batch 146 | 147 | 148 | def preprocess_multimodal( 149 | conversation_list: Sequence[str], 150 | multimodal_cfg: dict, 151 | cur_token_len: int, 152 | ) -> Dict: 153 | # 将conversational list中 154 | is_multimodal = True 155 | # image_token_len = multimodal_cfg['image_token_len'] 156 | image_token_len = cur_token_len 157 | 158 | for sentence in conversation_list: 159 | replace_token = ''+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'' 160 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 161 | 162 | return [conversation_list] 163 | 164 | def _add_speaker_and_signal(header, source, get_conversation=True): 165 | """Add speaker and start/end signal on each round.""" 166 | BEGIN_SIGNAL = "###" 167 | END_SIGNAL = "\n" 168 | conversation = header 169 | for sentence in source: 170 | from_str = sentence["from"] 171 | if from_str.lower() == "human": 172 | from_str = image_conversation.roles[0] 173 | elif from_str.lower() == "gpt": 174 | from_str = image_conversation.roles[1] 175 | else: 176 | from_str = 'unknown' 177 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + 178 | sentence["value"] + END_SIGNAL) 179 | if get_conversation: 180 | conversation += sentence["value"] 181 | conversation += BEGIN_SIGNAL 182 | return conversation 183 | 184 | def _tokenize_fn(strings: Sequence[str], 185 | tokenizer: transformers.PreTrainedTokenizer) -> Dict: 186 | """Tokenize a list of strings.""" 187 | tokenized_list = [ 188 | tokenizer( 189 | text, 190 | return_tensors="pt", 191 | padding="longest", 192 | max_length=512, 193 | truncation=True, 194 | ) for text in strings 195 | ] 196 | input_ids = labels = [ 197 | tokenized.input_ids[0] for tokenized in tokenized_list 198 | ] 199 | input_ids_lens = labels_lens = [ 200 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 201 | for tokenized in tokenized_list 202 | ] 203 | return dict( 204 | input_ids=input_ids, 205 | labels=labels, 206 | input_ids_lens=input_ids_lens, 207 | labels_lens=labels_lens, 208 | ) 209 | 210 | def preprocess( 211 | sources: Sequence[str], 212 | tokenizer: transformers.PreTrainedTokenizer, 213 | ) -> Dict: 214 | """ 215 | Given a list of sources, each is a conversation list. This transform: 216 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 217 | 2. Concatenate conversations together; 218 | 3. Tokenize the concatenated conversation; 219 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 220 | """ 221 | # add end signal and concatenate together 222 | conversations = [] 223 | for source in sources: 224 | header = f"{image_conversation.system}\n\n" 225 | conversation = _add_speaker_and_signal(header, source) 226 | conversations.append(conversation) 227 | # tokenize conversations 228 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) 229 | input_ids = conversations_tokenized["input_ids"] 230 | targets = copy.deepcopy(input_ids) 231 | for target, source in zip(targets, sources): 232 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], 233 | tokenizer)["input_ids_lens"] 234 | speakers = [sentence["from"] for sentence in source] 235 | _mask_targets(target, tokenized_lens, speakers) 236 | 237 | return dict(input_ids=input_ids, labels=targets) 238 | 239 | def preprocess_for_llama_v2( 240 | sources: Sequence[str], 241 | tokenizer: transformers.PreTrainedTokenizer, 242 | ) -> Dict: 243 | """ 244 | Given a list of sources, each is a conversation list. This transform: 245 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 246 | 2. Concatenate conversations together; 247 | 3. Tokenize the concatenated conversation; 248 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 249 | """ 250 | # add end signal and concatenate together 251 | conversations = [] 252 | conv = copy.deepcopy(llama_v2_image_conversation.copy()) 253 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 254 | for source in sources: 255 | # [INST] <>\n{system_prompt}\n<>\n\n 256 | header = f"[INST] <>\n{conv.system}\n>\n\n" 257 | 258 | if roles[source[0]["from"]] != conv.roles[0]: 259 | # Skip the first one if it is not from human 260 | source = source[1:] 261 | conv.messages = [] 262 | for j, sentence in enumerate(source): 263 | role = roles[sentence["from"]] 264 | assert role == conv.roles[j % 2] 265 | conv.append_message(role, sentence["value"]) 266 | conversations.append(conv.get_prompt()) 267 | 268 | input_ids = tokenizer( 269 | conversations, 270 | return_tensors="pt", 271 | padding="longest", 272 | max_length=512, 273 | truncation=True, 274 | ).input_ids 275 | targets = copy.deepcopy(input_ids) 276 | 277 | 278 | sep = "[/INST] " 279 | for conversation, target in zip(conversations, targets): 280 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) 281 | rounds = conversation.split(conv.sep2) 282 | cur_len = 1 283 | target[:cur_len] = IGNORE_INDEX 284 | for i, rou in enumerate(rounds): 285 | if rou == "": 286 | break 287 | 288 | parts = rou.split(sep) 289 | if len(parts) != 2: 290 | break 291 | parts[0] += sep 292 | 293 | 294 | round_len = len(tokenizer(rou).input_ids) 295 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目 296 | 297 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 298 | 299 | cur_len += round_len 300 | target[cur_len:] = IGNORE_INDEX 301 | 302 | return dict(input_ids=input_ids, labels=targets) 303 | 304 | def _mask_targets(target, tokenized_lens, speakers): 305 | # cur_idx = 0 306 | cur_idx = tokenized_lens[0] 307 | tokenized_lens = tokenized_lens[1:] 308 | target[:cur_idx] = IGNORE_INDEX 309 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 310 | if speaker == "human": 311 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 312 | cur_idx += tokenized_len 313 | --------------------------------------------------------------------------------