├── utils
├── __init__.py
├── timestamp_extraction.py
├── shift_video.py
├── prompts.py
└── cons_utils.py
├── timechat
├── common
│ ├── __init__.py
│ ├── gradcam.py
│ ├── optims.py
│ ├── dist_utils.py
│ ├── logger.py
│ └── registry.py
├── runners
│ ├── test.py
│ └── __init__.py
├── conversation
│ └── __init__.py
├── datasets
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── laion_dataset.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── webvid_datasets.py
│ │ ├── dataloader_utils.py
│ │ └── llava_instruct_dataset.py
│ ├── builders
│ │ ├── video_caption_builder.py
│ │ ├── __init__.py
│ │ ├── image_text_pair_builder.py
│ │ ├── instruct_builder.py
│ │ └── base_dataset_builder.py
│ └── data_utils.py
├── configs
│ ├── datasets
│ │ ├── cc_sbu
│ │ │ ├── align.yaml
│ │ │ └── defaults.yaml
│ │ ├── laion
│ │ │ └── defaults.yaml
│ │ ├── instruct
│ │ │ ├── valley72k_instruct.yaml
│ │ │ ├── llava_instruct.yaml
│ │ │ ├── time_instruct.yaml
│ │ │ ├── charades_instruct.yaml
│ │ │ ├── qvhighlights_instruct.yaml
│ │ │ ├── webvid_instruct.yaml
│ │ │ └── youcook2_instruct.yaml
│ │ └── webvid
│ │ │ └── defaults.yaml
│ ├── default.yaml
│ └── models
│ │ ├── minigpt4.yaml
│ │ └── timechat.yaml
├── tasks
│ ├── image_text_pretrain.py
│ ├── video_text_pretrain.py
│ ├── __init__.py
│ └── base_task.py
├── processors
│ ├── base_processor.py
│ ├── __init__.py
│ ├── functional_video.py
│ ├── blip_processors.py
│ ├── transforms_video.py
│ └── video_processor.py
├── __init__.py
├── eval_configs
│ └── timechat.yaml
├── train_configs
│ ├── stage2_finetune_charades.yaml
│ └── stage2_finetune_activitynet.yaml
├── models
│ ├── blip2_outputs.py
│ ├── __init__.py
│ ├── blip2.py
│ └── base_model.py
└── utils.py
├── .gitignore
├── scripts
├── test_predictions.sh
└── run_eval.sh
├── eval
└── run_eval.sh
├── run.py
├── task
└── grounding.py
└── README.md
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timechat/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timechat/runners/test.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timechat/conversation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timechat/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .pth
3 |
4 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/cc_sbu/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu_align:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_align_dataset
6 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/cc_sbu/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
6 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/laion/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | laion:
3 | data_type: images
4 | build_info:
5 | storage: path/laion/laion_dataset/{00000..00001}.tar
6 |
--------------------------------------------------------------------------------
/timechat/configs/default.yaml:
--------------------------------------------------------------------------------
1 | env:
2 | # For default users
3 | # cache_root: "cache"
4 | # For internal use with persistent storage
5 | cache_root: "/export/home/.cache/minigpt4"
6 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/valley72k_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | valley72k_instruct:
3 | data_type: video
4 | build_info:
5 | anno_dir: "data/valley/instruct_valley_72k.json"
6 | videos_dir: "data/"
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/llava_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | llava_instruct:
3 | data_type: image
4 | build_info:
5 | anno_dir: /path/llava_instruct_150k.json
6 | videos_dir: /path/train2014/train2014/
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/time_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | time_instruct:
3 | data_type: video
4 | build_info:
5 | anno_dir: "data/instructions/instruct_time-sensitive_73k.json"
6 | videos_dir: "data/"
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/charades_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | charades_instruct:
3 | data_type: video
4 | build_info:
5 | anno_dir: "data/Charades/instruct_tvg_12.4k_charades.json"
6 | videos_dir: "data/"
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/webvid/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | webvid:
3 | data_type: video
4 | build_info:
5 | anno_dir: path/webvid/webvid_tain_data/annotations/
6 | videos_dir: path//webvid/webvid_tain_data/videos/
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/qvhighlights_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | qvhighlights_instruct:
3 | data_type: video
4 | build_info:
5 | anno_dir: "data/QVhighlights/instruct_vhd_6.9k_qvhighlights.json"
6 | videos_dir: "data/"
7 |
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/webvid_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | webvid_instruct:
3 | data_type: image
4 | build_info:
5 | anno_dir: /path/webvid_align/videochat_instruct_11k.json
6 | videos_dir: /path/webvid_align/videos/
7 |
--------------------------------------------------------------------------------
/scripts/test_predictions.sh:
--------------------------------------------------------------------------------
1 | which_python=$(which python)
2 | echo "which python: ${which_python}"
3 | export PYTHONPATH=${PYTHONPATH}:${which_python}
4 | export PYTHONPATH=${PYTHONPATH}:.
5 | echo "PYTHONPATH: ${PYTHONPATH}"
6 |
7 | python eval/eval.py \
8 | ${@:1}
--------------------------------------------------------------------------------
/timechat/configs/datasets/instruct/youcook2_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | youcook2_instruct:
3 | data_type: video
4 | build_info:
5 | anno_dir: "data/YouCook2-BB/YouCook2_asr_denseCap/instruct_dvc_1.2k_youcook2.json"
6 | videos_dir: "data/YouCook2-BB/"
7 |
--------------------------------------------------------------------------------
/eval/run_eval.sh:
--------------------------------------------------------------------------------
1 | # bash eval/run_eval.sh --test_path {dir} --task consistency
2 |
3 | which_python=$(which python)
4 | echo "which python: ${which_python}"
5 | export PYTHONPATH=${PYTHONPATH}:${which_python}
6 | export PYTHONPATH=${PYTHONPATH}:.
7 | echo "PYTHONPATH: ${PYTHONPATH}"
8 |
9 | python eval/eval.py \
10 | ${@:1}
--------------------------------------------------------------------------------
/timechat/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.runners.runner_base import RunnerBase
9 |
10 | __all__ = ["RunnerBase"]
11 |
--------------------------------------------------------------------------------
/timechat/tasks/image_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.common.registry import registry
9 | from timechat.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("image_text_pretrain")
13 | class ImageTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/timechat/tasks/video_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.common.registry import registry
9 | from timechat.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("video_text_pretrain")
13 | class VideoTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/timechat/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/timechat/configs/models/minigpt4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "ckpt/vicuna-13b/"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "blip2_image_train"
25 | image_size: 224
26 | eval:
27 | name: "blip2_image_eval"
28 | image_size: 224
29 | text_processor:
30 | train:
31 | name: "blip_caption"
32 | eval:
33 | name: "blip_caption"
34 |
--------------------------------------------------------------------------------
/timechat/configs/models/timechat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: timechat
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # llama-2
16 | llama_model: "ckpt/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "alpro_video_train"
25 | image_size: 224
26 | n_frms: 8
27 | eval:
28 | name: "alpro_video_eval"
29 | image_size: 224
30 | n_frms: 8
31 | text_processor:
32 | train:
33 | name: "blip_caption"
34 | eval:
35 | name: "blip_caption"
36 |
--------------------------------------------------------------------------------
/timechat/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/timechat/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.common.registry import registry
9 | from timechat.tasks.base_task import BaseTask
10 | from timechat.tasks.image_text_pretrain import ImageTextPretrainTask
11 | from timechat.tasks.video_text_pretrain import VideoTextPretrainTask
12 |
13 |
14 | def setup_task(cfg):
15 | assert "task" in cfg.run_cfg, "Task name must be provided."
16 |
17 | task_name = cfg.run_cfg.task
18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg)
19 | assert task is not None, "Task {} not properly registered.".format(task_name)
20 |
21 | return task
22 |
23 |
24 | __all__ = [
25 | "BaseTask",
26 | "ImageTextPretrainTask",
27 | "VideoTextPretrainTask"
28 | ]
29 |
--------------------------------------------------------------------------------
/scripts/run_eval.sh:
--------------------------------------------------------------------------------
1 | which_python=$(which python)
2 | echo "which python: ${which_python}"
3 | export PYTHONPATH=${PYTHONPATH}:${which_python}
4 | export PYTHONPATH=${PYTHONPATH}:.
5 | echo "PYTHONPATH: ${PYTHONPATH}"
6 |
7 | model_type=$1
8 | dset_name=$2
9 | video_root="/data/video_datasets"
10 |
11 | case ${dset_name} in
12 | tvr)
13 | if [[ ${model_type} == "Gemini" ]]; then
14 | extra_args+=(--skip)
15 | elif [[ ${model_type} == "GPT-4o" ]]; then
16 | n_frames=10
17 | extra_args+=(--skip)
18 |
19 | elif [[ ${model_type} == "Video-LLaMA" ]]; then
20 | n_frames=8
21 |
22 | elif [[ ${model_type} == "Video-ChatGPT" ]]; then
23 | n_frames=100
24 | elif [[ ${model_type} == "TimeChat" ]]; then
25 | n_frames=96
26 | elif [[ ${model_type} == "VTimeLLM" ]]; then
27 | n_frames=100
28 | fi
29 | esac
30 |
31 | python run.py \
32 | --model_type=${model_type} \
33 | --dset_name=${dset_name} \
34 | --video_root=${video_root} \
35 | --n_frames=${n_frames} \
36 | ${extra_args[@]} \
37 | ${@:3}
--------------------------------------------------------------------------------
/timechat/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | import sys
10 |
11 | from omegaconf import OmegaConf
12 |
13 | from timechat.common.registry import registry
14 |
15 | from timechat.datasets.builders import *
16 | from timechat.models import *
17 | from timechat.processors import *
18 | from timechat.tasks import *
19 |
20 |
21 | root_dir = os.path.dirname(os.path.abspath(__file__))
22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23 |
24 | registry.register_path("library_root", root_dir)
25 | repo_root = os.path.join(root_dir, "..")
26 | registry.register_path("repo_root", repo_root)
27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28 | registry.register_path("cache_root", cache_root)
29 |
30 | registry.register("MAX_INT", sys.maxsize)
31 | registry.register("SPLIT_NAMES", ["train", "val", "test"])
32 |
--------------------------------------------------------------------------------
/timechat/datasets/builders/video_caption_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from timechat.common.registry import registry
6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from timechat.datasets.datasets.webvid_datasets import WebvidDataset
8 |
9 | @registry.register_builder("webvid")
10 | class WebvidBuilder(BaseDatasetBuilder):
11 | train_dataset_cls = WebvidDataset
12 | DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
13 |
14 | def _download_ann(self):
15 | pass
16 |
17 | def _download_vis(self):
18 | pass
19 |
20 | def build(self):
21 | self.build_processors()
22 | datasets = dict()
23 | split = "train"
24 |
25 | build_info = self.config.build_info
26 | dataset_cls = self.train_dataset_cls
27 | datasets[split] = dataset_cls(
28 | vis_processor=self.vis_processors[split],
29 | text_processor=self.text_processors[split],
30 | vis_root=build_info.videos_dir,
31 | ann_root=build_info.anno_dir
32 | )
33 |
34 | return datasets
--------------------------------------------------------------------------------
/timechat/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.processors.base_processor import BaseProcessor
9 | from timechat.processors.blip_processors import (
10 | Blip2ImageTrainProcessor,
11 | Blip2ImageEvalProcessor,
12 | BlipCaptionProcessor,
13 | )
14 | from timechat.processors.video_processor import (
15 | AlproVideoTrainProcessor,
16 | AlproVideoEvalProcessor
17 | )
18 | from timechat.common.registry import registry
19 |
20 | __all__ = [
21 | "BaseProcessor",
22 | "Blip2ImageTrainProcessor",
23 | "Blip2ImageEvalProcessor",
24 | "BlipCaptionProcessor",
25 | "AlproVideoTrainProcessor",
26 | "AlproVideoEvalProcessor",
27 | ]
28 |
29 |
30 | def load_processor(name, cfg=None):
31 | """
32 | Example
33 |
34 | >>> processor = load_processor("alpro_video_train", cfg=None)
35 | """
36 | processor = registry.get_processor_class(name).from_config(cfg)
37 |
38 | return processor
39 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import webdataset as wds
9 | from timechat.datasets.datasets.base_dataset import BaseDataset
10 |
11 |
12 | class LaionDataset(BaseDataset):
13 | def __init__(self, vis_processor, text_processor, location):
14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15 |
16 | self.inner_dataset = wds.DataPipeline(
17 | wds.ResampledShards(location),
18 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
19 | wds.shuffle(1000, handler=wds.warn_and_continue),
20 | wds.decode("pilrgb", handler=wds.warn_and_continue),
21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23 | wds.map(self.to_dict, handler=wds.warn_and_continue),
24 | )
25 |
26 | def to_dict(self, sample):
27 | return {
28 | "image": sample[0],
29 | "text_input": self.text_processor(sample[1]["caption"]),
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | """
2 | Trouble shootings:
3 | $ pip install pydantic==1.10.8
4 |
5 | """
6 |
7 | import torch
8 | import random
9 | import numpy as np
10 | from utils.cons_utils import BaseOptions
11 | from task.grounding import run_grounding
12 | from task.consistency import run_consistency
13 |
14 | eval_func = {
15 | "grounding": run_grounding,
16 | "consistency": run_consistency,
17 | }
18 |
19 | OPT_FILE_NAME = "opt.json"
20 | PREDICTION_FILE_NAME = "predictions.jsonl"
21 | EVALUATION_FILE_NAME = "eval_results.json"
22 |
23 |
24 | def set_seed(seed, use_cuda=True):
25 | random.seed(seed)
26 | np.random.seed(seed)
27 | torch.manual_seed(seed)
28 | if use_cuda:
29 | torch.cuda.manual_seed_all(seed)
30 |
31 | print("Set seed ", seed)
32 |
33 |
34 | if __name__ == "__main__":
35 | """
36 | Usage: python run.py --model_type TimeChat --dset_name charades --task grounding --debug
37 | """
38 | base_options = BaseOptions().parse()
39 | set_seed(base_options.seed)
40 |
41 | if base_options.model_type == "TimeChat":
42 | from timechat.utils import TimeChat, TimeChat_Options
43 | args = TimeChat_Options().parse()
44 | model = TimeChat(args)
45 | args.ckpt = model.model_config["ckpt"]
46 |
47 | # TODO: Add Custom Models
48 |
49 | else:
50 | raise NotImplementedError
51 |
52 | eval_func[args.task](model=model, args=args)
53 |
54 |
--------------------------------------------------------------------------------
/timechat/eval_configs/timechat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: timechat
3 | model_type: pretrain_llama_v2
4 | freeze_vit: True
5 | freeze_qformer: True
6 | max_txt_len: 2048
7 | end_sym: ""
8 | low_resource: False
9 |
10 | frozen_llama_proj: True
11 | frozen_video_Qformer: True
12 |
13 | vit_model: "ckpt/timechat/eva_vit_g.pth"
14 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
15 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth"
16 | ckpt: "ckpt/timechat/timechat_7b_paper.pth"
17 | charades_ckpt: "ckpt/timechat/TimeChat-7B-Charades-VTune.pth"
18 | activitynet_ckpt: "ckpt/timechat/TimeChat-7B-ActivityNet-VTune.pth"
19 | fusion_head_layers: 2
20 | max_frame_pos: 96
21 | fusion_header_type: "seqTransf"
22 |
23 | use_grad_checkpoint: True
24 | lora: True
25 | lora_inference_mode: True
26 | qformer_text_input: True
27 | window_size: 32
28 | stride: 32
29 |
30 | datasets:
31 | webvid:
32 | vis_processor:
33 | train:
34 | name: "alpro_video_eval"
35 | n_frms: 96
36 | image_size: 224
37 | text_processor:
38 | train:
39 | name: "blip_caption"
40 | num_video_query_token: 32
41 | tokenizer_name: "ckpt/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
42 | model_type: "llama_v2"
43 | num_frm: 96
44 | sample_type: 'uniform'
45 | max_txt_len: 2048
46 | stride: 32
47 |
48 | run:
49 | task: video_text_pretrain
50 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/cc_sbu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import webdataset as wds
4 | from timechat.datasets.datasets.base_dataset import BaseDataset
5 | from timechat.datasets.datasets.caption_datasets import CaptionDataset
6 |
7 |
8 | class CCSBUDataset(BaseDataset):
9 | def __init__(self, vis_processor, text_processor, location):
10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11 |
12 | self.inner_dataset = wds.DataPipeline(
13 | wds.ResampledShards(location),
14 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
15 | wds.shuffle(1000, handler=wds.warn_and_continue),
16 | wds.decode("pilrgb", handler=wds.warn_and_continue),
17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19 | wds.map(self.to_dict, handler=wds.warn_and_continue),
20 | )
21 |
22 | def to_dict(self, sample):
23 | return {
24 | "image": sample[0],
25 | "text_input": self.text_processor(sample[1]["caption"]),
26 | "type":'image',
27 | }
28 |
29 |
30 | class CCSBUAlignDataset(CaptionDataset):
31 |
32 | def __getitem__(self, index):
33 |
34 | # TODO this assumes image input, not general enough
35 | ann = self.annotation[index]
36 |
37 | img_file = '{}.jpg'.format(ann["image_id"])
38 | image_path = os.path.join(self.vis_root, img_file)
39 | image = Image.open(image_path).convert("RGB")
40 |
41 | image = self.vis_processor(image)
42 | caption = ann["caption"]
43 |
44 | return {
45 | "image": image,
46 | "text_input": caption,
47 | "image_id": self.img_ids[ann["image_id"]],
48 | "type":'image',
49 | }
--------------------------------------------------------------------------------
/timechat/train_configs/stage2_finetune_charades.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: timechat
3 | model_type: pretrain_llama_v2
4 | freeze_vit: True
5 | freeze_qformer: False
6 |
7 | # Q-Former
8 | num_query_token: 32
9 |
10 | vit_model: "ckpt/timechat/eva_vit_g.pth"
11 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
12 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth"
13 | ckpt: "ckpt/timechat/timechat_7b.pth" # continue fine-tuning from TimeChat-7B ckpt
14 |
15 | # only train vision branch
16 | frozen_llama_proj: False
17 | frozen_video_Qformer: False
18 |
19 | fusion_head_layers: 2
20 | max_frame_pos: 96
21 | fusion_header_type: "seqTransf"
22 |
23 | max_txt_len: 2048
24 |
25 | # for llama_2_chat:
26 | end_sym: ""
27 | prompt_path: ""
28 | prompt_template: '[INST] <>\n \n<>\n\n{} [/INST] '
29 |
30 | use_grad_checkpoint: True
31 | lora: True
32 | lora_inference_mode: False
33 | qformer_text_input: True
34 | window_size: 32
35 | stride: 32
36 |
37 | datasets:
38 | charades_instruct:
39 | data_type: video
40 | build_info:
41 | anno_dir: "data/charades_for_VTune.json"
42 | videos_dir: "data/"
43 | vis_processor:
44 | train:
45 | name: "alpro_video_train"
46 | n_frms: 96
47 | image_size: 224
48 | text_processor:
49 | train:
50 | name: "blip_caption"
51 | num_video_query_token: 32
52 | tokenizer_name: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
53 | model_type: "llama_v2"
54 | num_frm: 96
55 | sample_type: 'rand'
56 | max_txt_len: 2048
57 | stride: 32
58 |
59 | run:
60 | task: video_text_pretrain
61 | # optimizer
62 | lr_sched: "linear_warmup_cosine_lr"
63 | init_lr: 3e-5
64 | min_lr: 1e-5
65 | warmup_lr: 1e-6
66 |
67 | weight_decay: 0.05
68 | max_epoch: 3
69 | iters_per_epoch: 24811
70 | batch_size_train: 1
71 | batch_size_eval: 4
72 | num_workers: 4
73 | warmup_steps: 14916
74 | accum_grad_iters: 8
75 |
76 | seed: 42
77 | output_dir: "ckpt/timechat/train_stage2_charades"
78 |
79 | amp: True
80 | resume_ckpt_path: null
81 |
82 | evaluate: False
83 | train_splits: ["train"]
84 |
85 | device: "cuda"
86 | world_size: 1
87 | dist_url: "env://"
88 | distributed: True
--------------------------------------------------------------------------------
/timechat/datasets/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import json
9 | from typing import Iterable
10 |
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 | class BaseDataset(Dataset):
16 | def __init__(
17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18 | ):
19 | """
20 | vis_root (string): Root directory of images (e.g. coco/images/)
21 | ann_root (string): directory to store the annotation file
22 | """
23 | self.vis_root = vis_root
24 |
25 | self.annotation = []
26 | for ann_path in ann_paths:
27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28 |
29 | self.vis_processor = vis_processor
30 | self.text_processor = text_processor
31 |
32 | self._add_instance_ids()
33 |
34 | def __len__(self):
35 | return len(self.annotation)
36 |
37 | def collater(self, samples):
38 | return default_collate(samples)
39 |
40 | def set_processors(self, vis_processor, text_processor):
41 | self.vis_processor = vis_processor
42 | self.text_processor = text_processor
43 |
44 | def _add_instance_ids(self, key="instance_id"):
45 | for idx, ann in enumerate(self.annotation):
46 | ann[key] = str(idx)
47 |
48 |
49 | class ConcatDataset(ConcatDataset):
50 | def __init__(self, datasets: Iterable[Dataset]) -> None:
51 | super().__init__(datasets)
52 |
53 | def collater(self, samples):
54 | # TODO For now only supports datasets with same underlying collater implementations
55 |
56 | all_keys = set()
57 | for s in samples:
58 | all_keys.update(s)
59 |
60 | shared_keys = all_keys
61 | for s in samples:
62 | shared_keys = shared_keys & set(s.keys())
63 |
64 | samples_shared_keys = []
65 | for s in samples:
66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67 |
68 | return self.datasets[0].collater(samples_shared_keys)
69 |
--------------------------------------------------------------------------------
/timechat/train_configs/stage2_finetune_activitynet.yaml:
--------------------------------------------------------------------------------
1 | # Usage: torchrun --nproc_per_node=4 train.py --cfg-path train_configs/stage2_finetune_activitynet.yaml
2 | model:
3 | arch: timechat
4 | model_type: pretrain_llama_v2
5 | freeze_vit: True
6 | freeze_qformer: False
7 |
8 |
9 | # Q-Former
10 | num_query_token: 32
11 |
12 | vit_model: "ckpt/timechat/eva_vit_g.pth"
13 | llama_model: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
14 | q_former_model: "ckpt/timechat/instruct_blip_vicuna7b_trimmed.pth"
15 | ckpt: "ckpt/timechat_7b.pth" # continue fine-tuning from TimeChat-7B ckpt
16 |
17 | # only train vision branch
18 | frozen_llama_proj: False
19 | frozen_video_Qformer: False
20 |
21 | fusion_head_layers: 2
22 | max_frame_pos: 96
23 | fusion_header_type: "seqTransf"
24 |
25 | max_txt_len: 2048
26 |
27 | # for llama_2_chat:
28 | end_sym: ""
29 | prompt_path: ""
30 | prompt_template: '[INST] <>\n \n<>\n\n{} [/INST] '
31 |
32 | use_grad_checkpoint: True
33 | lora: True
34 | lora_inference_mode: False
35 | qformer_text_input: True
36 | window_size: 32
37 | stride: 32
38 |
39 | datasets:
40 | charades_instruct:
41 | data_type: video
42 | build_info:
43 | anno_dir: "data/activitynet_for_VTune.json"
44 | videos_dir: "data/"
45 | vis_processor:
46 | train:
47 | name: "alpro_video_train"
48 | n_frms: 96
49 | image_size: 224
50 | text_processor:
51 | train:
52 | name: "blip_caption"
53 | num_video_query_token: 32
54 | tokenizer_name: "ckpt/Video-LLaMA/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
55 | model_type: "llama_v2"
56 | num_frm: 96
57 | sample_type: 'rand'
58 | max_txt_len: 2048
59 | stride: 32
60 |
61 | run:
62 | task: video_text_pretrain
63 | # optimizer
64 | lr_sched: "linear_warmup_cosine_lr"
65 | init_lr: 3e-5
66 | min_lr: 1e-5
67 | warmup_lr: 1e-6
68 |
69 | weight_decay: 0.05
70 | max_epoch: 3
71 | iters_per_epoch: 51377
72 | batch_size_train: 1
73 | batch_size_eval: 4
74 | num_workers: 4
75 | warmup_steps: 25688
76 | accum_grad_iters: 4
77 |
78 | seed: 42
79 | output_dir: "ckpt/timechat/activitynet_vtune"
80 |
81 | amp: True
82 | resume_ckpt_path: null
83 |
84 | evaluate: False
85 | train_splits: ["train"]
86 |
87 | device: "cuda"
88 | world_size: 1
89 | dist_url: "env://"
90 | distributed: True
91 |
--------------------------------------------------------------------------------
/timechat/datasets/builders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from timechat.datasets.builders.base_dataset_builder import load_dataset_config
9 | from timechat.datasets.builders.image_text_pair_builder import (
10 | CCSBUBuilder,
11 | LaionBuilder,
12 | CCSBUAlignBuilder
13 | )
14 | from timechat.datasets.builders.video_caption_builder import WebvidBuilder
15 | from timechat.common.registry import registry
16 | from timechat.datasets.builders.instruct_builder import WebvidInstruct_Builder, LlavaInstruct_Builder, \
17 | Youcook2Instruct_Builder, TimeInstruct_Builder, Valley72kInstruct_Builder, QVhighlightsInstruct_Builder, \
18 | CharadesInstruct_Builder
19 | __all__ = [
20 | "CCSBUBuilder",
21 | "LaionBuilder",
22 | "CCSBUAlignBuilder",
23 | "WebvidBuilder",
24 | "LlavaInstruct_Builder",
25 | "WebvidInstruct_Builder",
26 | "Youcook2Instruct_Builder",
27 | "TimeInstruct_Builder",
28 | "Valley72kInstruct_Builder",
29 | "QVhighlightsInstruct_Builder",
30 | "CharadesInstruct_Builder",
31 | ]
32 |
33 |
34 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
35 | """
36 | Example
37 |
38 | >>> dataset = load_dataset("coco_caption", cfg=None)
39 | >>> splits = dataset.keys()
40 | >>> print([len(dataset[split]) for split in splits])
41 |
42 | """
43 | if cfg_path is None:
44 | cfg = None
45 | else:
46 | cfg = load_dataset_config(cfg_path)
47 |
48 | try:
49 | builder = registry.get_builder_class(name)(cfg)
50 | except TypeError:
51 | print(
52 | f"Dataset {name} not found. Available datasets:\n"
53 | + ", ".join([str(k) for k in dataset_zoo.get_names()])
54 | )
55 | exit(1)
56 |
57 | if vis_path is not None:
58 | if data_type is None:
59 | # use default data type in the config
60 | data_type = builder.config.data_type
61 |
62 | assert (
63 | data_type in builder.config.build_info
64 | ), f"Invalid data_type {data_type} for {name}."
65 |
66 | builder.config.build_info.get(data_type).storage = vis_path
67 |
68 | dataset = builder.build_datasets()
69 | return dataset
70 |
71 |
72 | class DatasetZoo:
73 | def __init__(self) -> None:
74 | self.dataset_zoo = {
75 | k: list(v.DATASET_CONFIG_DICT.keys())
76 | for k, v in sorted(registry.mapping["builder_name_mapping"].items())
77 | }
78 |
79 | def get_names(self):
80 | return list(self.dataset_zoo.keys())
81 |
82 |
83 | dataset_zoo = DatasetZoo()
84 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/caption_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from collections import OrderedDict
10 |
11 | from timechat.datasets.datasets.base_dataset import BaseDataset
12 | from PIL import Image
13 |
14 |
15 | class __DisplMixin:
16 | def displ_item(self, index):
17 | sample, ann = self.__getitem__(index), self.annotation[index]
18 |
19 | return OrderedDict(
20 | {
21 | "file": ann["image"],
22 | "caption": ann["caption"],
23 | "image": sample["image"],
24 | }
25 | )
26 |
27 |
28 | class CaptionDataset(BaseDataset, __DisplMixin):
29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30 | """
31 | vis_root (string): Root directory of images (e.g. coco/images/)
32 | ann_root (string): directory to store the annotation file
33 | """
34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35 |
36 | self.img_ids = {}
37 | n = 0
38 | for ann in self.annotation:
39 | img_id = ann["image_id"]
40 | if img_id not in self.img_ids.keys():
41 | self.img_ids[img_id] = n
42 | n += 1
43 |
44 | def __getitem__(self, index):
45 |
46 | # TODO this assumes image input, not general enough
47 | ann = self.annotation[index]
48 |
49 | img_file = '{:0>12}.jpg'.format(ann["image_id"])
50 | image_path = os.path.join(self.vis_root, img_file)
51 | image = Image.open(image_path).convert("RGB")
52 |
53 | image = self.vis_processor(image)
54 | caption = self.text_processor(ann["caption"])
55 |
56 | return {
57 | "image": image,
58 | "text_input": caption,
59 | "image_id": self.img_ids[ann["image_id"]],
60 | }
61 |
62 |
63 | class CaptionEvalDataset(BaseDataset, __DisplMixin):
64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65 | """
66 | vis_root (string): Root directory of images (e.g. coco/images/)
67 | ann_root (string): directory to store the annotation file
68 | split (string): val or test
69 | """
70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71 |
72 | def __getitem__(self, index):
73 |
74 | ann = self.annotation[index]
75 |
76 | image_path = os.path.join(self.vis_root, ann["image"])
77 | image = Image.open(image_path).convert("RGB")
78 |
79 | image = self.vis_processor(image)
80 |
81 | return {
82 | "image": image,
83 | "image_id": ann["image_id"],
84 | "instance_id": ann["instance_id"],
85 | }
86 |
--------------------------------------------------------------------------------
/utils/timestamp_extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def extract_time(paragraph):
5 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower()
6 | paragraph = paragraph.lower()
7 | paragraph.replace(prompt, '')
8 | # Split text into sentences based on common delimiters
9 | sentences = re.split(r'[!?\n]', paragraph)
10 |
11 | # Keywords that might indicate the presence of time information
12 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"]
13 | # filter sentences by keywords
14 | candidates = []
15 | for sentence in sentences:
16 | # If sentence contains one of the keywords
17 | if any(keyword in sentence for keyword in keywords):
18 | candidates.append(sentence)
19 |
20 | timestamps = []
21 | # Check for The given query happens in m - n (seconds)
22 | patterns = [
23 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)"
24 | ]
25 |
26 | for time_pattern in patterns:
27 | time_matches = re.findall(time_pattern, paragraph)
28 | if time_matches:
29 | timestamps = [[float(start), float(end)] for start, end in time_matches]
30 |
31 | if len(sentences) == 0:
32 | return []
33 | # check for other formats e.g.:
34 | # 1 .Starting time: 0.8 seconds
35 | # Ending time: 1.1 seconds
36 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds.
37 | if len(timestamps) == 0:
38 | times = []
39 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5)
40 | for sentence in candidates:
41 | time = re.findall(time_regex, sentence)
42 | if time:
43 | time_in_sec = float(time[0])
44 | times.append(time_in_sec)
45 | times = times[:len(times) // 2 * 2]
46 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)]
47 | # Check for examples like:
48 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23.
49 | if len(timestamps) == 0:
50 | times = []
51 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05)
52 | for sentence in candidates:
53 | time = re.findall(time_regex, sentence)
54 | if time:
55 | t = time[0]
56 | else:
57 | continue
58 | # If time is in HH:MM:SS format, convert to seconds
59 | if t.count(':') == 2:
60 | h, m, s = map(int, t.split(':'))
61 | time_in_sec = h * 3600 + m * 60 + s
62 | elif t.count(':') == 1:
63 | m, s = map(int, t.split(':'))
64 | time_in_sec = m * 60 + s
65 | times.append(time_in_sec)
66 | times = times[:len(times) // 2 * 2]
67 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)]
68 |
69 | results = []
70 | for (start, end) in timestamps:
71 | if end > start:
72 | results.append([start, end])
73 | else:
74 | results.append([end, start])
75 |
76 | if len(results) == 0:
77 | return [0, 0]
78 | else:
79 | return results[0]
--------------------------------------------------------------------------------
/timechat/datasets/builders/image_text_pair_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from timechat.common.registry import registry
6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from timechat.datasets.datasets.laion_dataset import LaionDataset
8 | from timechat.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9 |
10 |
11 | @registry.register_builder("cc_sbu")
12 | class CCSBUBuilder(BaseDatasetBuilder):
13 | train_dataset_cls = CCSBUDataset
14 |
15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
16 |
17 | def _download_ann(self):
18 | pass
19 |
20 | def _download_vis(self):
21 | pass
22 |
23 | def build(self):
24 | self.build_processors()
25 |
26 | build_info = self.config.build_info
27 |
28 | datasets = dict()
29 | split = "train"
30 |
31 | # create datasets
32 | # [NOTE] return inner_datasets (wds.DataPipeline)
33 | dataset_cls = self.train_dataset_cls
34 | datasets[split] = dataset_cls(
35 | vis_processor=self.vis_processors[split],
36 | text_processor=self.text_processors[split],
37 | location=build_info.storage,
38 | ).inner_dataset
39 |
40 | return datasets
41 |
42 |
43 | @registry.register_builder("laion")
44 | class LaionBuilder(BaseDatasetBuilder):
45 | train_dataset_cls = LaionDataset
46 |
47 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
48 |
49 | def _download_ann(self):
50 | pass
51 |
52 | def _download_vis(self):
53 | pass
54 |
55 | def build(self):
56 | self.build_processors()
57 |
58 | build_info = self.config.build_info
59 |
60 | datasets = dict()
61 | split = "train"
62 |
63 | # create datasets
64 | # [NOTE] return inner_datasets (wds.DataPipeline)
65 | dataset_cls = self.train_dataset_cls
66 | datasets[split] = dataset_cls(
67 | vis_processor=self.vis_processors[split],
68 | text_processor=self.text_processors[split],
69 | location=build_info.storage,
70 | ).inner_dataset
71 |
72 | return datasets
73 |
74 |
75 | @registry.register_builder("cc_sbu_align")
76 | class CCSBUAlignBuilder(BaseDatasetBuilder):
77 | train_dataset_cls = CCSBUAlignDataset
78 |
79 | DATASET_CONFIG_DICT = {
80 | "default": "configs/datasets/cc_sbu/align.yaml",
81 | }
82 |
83 | def build_datasets(self):
84 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
85 | logging.info("Building datasets...")
86 | self.build_processors()
87 |
88 | build_info = self.config.build_info
89 | storage_path = build_info.storage
90 |
91 | datasets = dict()
92 |
93 | if not os.path.exists(storage_path):
94 | warnings.warn("storage path {} does not exist.".format(storage_path))
95 |
96 | # create datasets
97 | dataset_cls = self.train_dataset_cls
98 | datasets['train'] = dataset_cls(
99 | vis_processor=self.vis_processors["train"],
100 | text_processor=self.text_processors["train"],
101 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
102 | vis_root=os.path.join(storage_path, 'image'),
103 | )
104 |
105 | return datasets
106 |
107 |
--------------------------------------------------------------------------------
/timechat/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from timechat.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | iters_per_epoch,
63 | min_lr,
64 | init_lr,
65 | warmup_steps=0,
66 | warmup_start_lr=-1,
67 | **kwargs
68 | ):
69 | self.optimizer = optimizer
70 |
71 | self.max_epoch = max_epoch
72 | self.iters_per_epoch = iters_per_epoch
73 | self.min_lr = min_lr
74 |
75 | self.init_lr = init_lr
76 | self.warmup_steps = warmup_steps
77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78 |
79 | def step(self, cur_epoch, cur_step):
80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81 | if total_cur_step < self.warmup_steps:
82 | warmup_lr_schedule(
83 | step=cur_step,
84 | optimizer=self.optimizer,
85 | max_step=self.warmup_steps,
86 | init_lr=self.warmup_start_lr,
87 | max_lr=self.init_lr,
88 | )
89 | else:
90 | cosine_lr_schedule(
91 | epoch=total_cur_step,
92 | optimizer=self.optimizer,
93 | max_epoch=self.max_epoch * self.iters_per_epoch,
94 | init_lr=self.init_lr,
95 | min_lr=self.min_lr,
96 | )
97 |
98 |
99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100 | """Decay the learning rate"""
101 | lr = (init_lr - min_lr) * 0.5 * (
102 | 1.0 + math.cos(math.pi * epoch / max_epoch)
103 | ) + min_lr
104 | for param_group in optimizer.param_groups:
105 | param_group["lr"] = lr
106 |
107 |
108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109 | """Warmup the learning rate"""
110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111 | for param_group in optimizer.param_groups:
112 | param_group["lr"] = lr
113 |
114 |
115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116 | """Decay the learning rate"""
117 | lr = max(min_lr, init_lr * (decay_rate**epoch))
118 | for param_group in optimizer.param_groups:
119 | param_group["lr"] = lr
120 |
--------------------------------------------------------------------------------
/task/grounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from easydict import EasyDict as edict
4 | import time
5 | import os
6 | import random
7 | import numpy as np
8 | from utils.cons_utils import (save_jsonl, save_json, load_json, get_iou,
9 | BaseOptions, load_logger, display)
10 | from eval.eval import evaluate_grounding
11 | logger = load_logger("[Grounding Evaluation]")
12 |
13 | OPT_FILE_NAME = "opt.json"
14 | PREDICTION_FILE_NAME = "predictions.jsonl"
15 | EVALUATION_FILE_NAME = "eval_results.json"
16 |
17 |
18 | def set_seed(seed, use_cuda=True):
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | if use_cuda:
23 | torch.cuda.manual_seed_all(seed)
24 |
25 | print("Set seed ", seed)
26 |
27 |
28 | def main(args, model):
29 | results = []
30 |
31 | test_data = load_json(args.test_path)
32 | target_vid_list = [file.split(".")[0] for file in os.listdir(args.video_root)]
33 | target_vid_list = [vid for vid in target_vid_list if vid in list(test_data.keys())]
34 | print(f"Total {len(target_vid_list)} videos in {args.video_root}")
35 |
36 | path_to_predictions = f"{args.task}_{PREDICTION_FILE_NAME}"
37 | path_to_eval_results = f"{args.task}_{EVALUATION_FILE_NAME}"
38 |
39 | for n_data, (vid, data) in tqdm(enumerate(test_data.items()), total=len(target_vid_list), desc="Evaluating.."):
40 | duration = data['duration']
41 | video_path = os.path.join(args.video_root, f"{vid}.mp4")
42 |
43 | # Load video frame features
44 | if os.path.exists(video_path):
45 | video_features, msg = model.load_video_features(video_path)
46 | else:
47 | print(f"Video {vid} not found")
48 | continue
49 |
50 | for i, (query, gt_moment) in enumerate(zip(data['sentences'], data['timestamps'])):
51 | gt_moment = [min(gt_moment[0], duration), min(gt_moment[1], duration)]
52 | pred_moment = model.run(task="grounding", video_features=video_features, query=query,
53 | duration=duration, msg=msg)
54 |
55 | # Save the results.
56 | result = edict(
57 | meta=edict(
58 | vid=vid,
59 | sentence=data['sentences'][i],
60 | timestamp=data['timestamps'][i],
61 | duration=data['duration']
62 | ),
63 | prediction=edict(
64 | qa=pred_moment,
65 | iou=get_iou(gt_moment, pred_moment["t"]),
66 | ),
67 | )
68 | results.append(result)
69 |
70 | if args.debug and n_data == 1:
71 | break
72 |
73 | if n_data % 50 == 0:
74 | logger.info(f"{len(results)} results are saved")
75 | save_jsonl(results, os.path.join(args.output_dir, f"{path_to_predictions}"))
76 |
77 | logger.info(f"{len(results)} predictions will be saved at {path_to_predictions}")
78 | save_jsonl(results, os.path.join(args.output_dir, path_to_predictions))
79 |
80 | logger.info("============ Save Performance ============")
81 | grounding_results = evaluate_grounding(results, verbos=True)
82 | save_json(grounding_results, os.path.join(args.output_dir, path_to_eval_results), save_pretty=True)
83 | logger.info("Done.")
84 |
85 |
86 | def run_grounding(model, args):
87 | logger.info("Measuring Grounding")
88 | cur_time = time.strftime("%Y_%m_%d_%H_%M_%S")
89 | if args.exp_id is None:
90 | args.exp_id = args.model_type
91 |
92 | if args.debug:
93 | output_dir = f"debug_results/{args.model_type}/{args.exp_id}_{args.dset_name}_{cur_time}"
94 | else:
95 | if args.fine_tuned:
96 | output_dir = f"grounding_results/{args.model_type}/fine_tuned_{args.exp_id}_{args.dset_name}_{cur_time}"
97 | else:
98 | output_dir = f"grounding_results/{args.model_type}/{args.exp_id}_{args.dset_name}_{cur_time}"
99 |
100 | os.makedirs(output_dir, exist_ok=True)
101 | args.output_dir = output_dir
102 | args.test_path = f"data/{args.dset_name}_test.json"
103 |
104 | # display and save results
105 | display(args)
106 | save_json(vars(args), os.path.join(args.output_dir, OPT_FILE_NAME), save_pretty=True)
107 |
108 | main(args, model)
109 |
110 |
--------------------------------------------------------------------------------
/timechat/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import functools
10 | import os
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import timm.models.hub as timm_hub
15 |
16 |
17 | def setup_for_distributed(is_master):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | import builtins as __builtin__
22 |
23 | builtin_print = __builtin__.print
24 |
25 | def print(*args, **kwargs):
26 | force = kwargs.pop("force", False)
27 | if is_master or force:
28 | builtin_print(*args, **kwargs)
29 |
30 | __builtin__.print = print
31 |
32 |
33 | def is_dist_avail_and_initialized():
34 | if not dist.is_available():
35 | return False
36 | if not dist.is_initialized():
37 | return False
38 | return True
39 |
40 |
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 |
47 | def get_rank():
48 | if not is_dist_avail_and_initialized():
49 | return 0
50 | return dist.get_rank()
51 |
52 |
53 | def is_main_process():
54 | return get_rank() == 0
55 |
56 |
57 | def init_distributed_mode(args):
58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59 | args.rank = int(os.environ["RANK"])
60 | args.world_size = int(os.environ["WORLD_SIZE"])
61 | args.gpu = int(os.environ["LOCAL_RANK"])
62 | elif "SLURM_PROCID" in os.environ:
63 | args.rank = int(os.environ["SLURM_PROCID"])
64 | args.gpu = args.rank % torch.cuda.device_count()
65 | else:
66 | print("Not using distributed mode")
67 | args.distributed = False
68 | return
69 |
70 | args.distributed = True
71 |
72 | torch.cuda.set_device(args.gpu)
73 | args.dist_backend = "nccl"
74 | print(
75 | "| distributed init (rank {}, world {}): {}".format(
76 | args.rank, args.world_size, args.dist_url
77 | ),
78 | flush=True,
79 | )
80 | torch.distributed.init_process_group(
81 | backend=args.dist_backend,
82 | init_method=args.dist_url,
83 | world_size=args.world_size,
84 | rank=args.rank,
85 | timeout=datetime.timedelta(
86 | days=365
87 | ), # allow auto-downloading and de-compressing
88 | )
89 | torch.distributed.barrier()
90 | setup_for_distributed(args.rank == 0)
91 |
92 |
93 | def get_dist_info():
94 | if torch.__version__ < "1.0":
95 | initialized = dist._initialized
96 | else:
97 | initialized = dist.is_initialized()
98 | if initialized:
99 | rank = dist.get_rank()
100 | world_size = dist.get_world_size()
101 | else: # non-distributed training
102 | rank = 0
103 | world_size = 1
104 | return rank, world_size
105 |
106 |
107 | def main_process(func):
108 | @functools.wraps(func)
109 | def wrapper(*args, **kwargs):
110 | rank, _ = get_dist_info()
111 | if rank == 0:
112 | return func(*args, **kwargs)
113 |
114 | return wrapper
115 |
116 |
117 | def download_cached_file(url, check_hash=True, progress=False):
118 | """
119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121 | """
122 |
123 | def get_cached_file_path():
124 | # a hack to sync the file path across processes
125 | parts = torch.hub.urlparse(url)
126 | filename = os.path.basename(parts.path)
127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128 |
129 | return cached_file
130 |
131 | if is_main_process():
132 | timm_hub.download_cached_file(url, check_hash, progress)
133 |
134 | if is_dist_avail_and_initialized():
135 | dist.barrier()
136 |
137 | return get_cached_file_path()
138 |
--------------------------------------------------------------------------------
/timechat/processors/functional_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import warnings
9 |
10 | import torch
11 |
12 |
13 | def _is_tensor_video_clip(clip):
14 | if not torch.is_tensor(clip):
15 | raise TypeError("clip should be Tensor. Got %s" % type(clip))
16 |
17 | if not clip.ndimension() == 4:
18 | raise ValueError("clip should be 4D. Got %dD" % clip.dim())
19 |
20 | return True
21 |
22 |
23 | def crop(clip, i, j, h, w):
24 | """
25 | Args:
26 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
27 | """
28 | if len(clip.size()) != 4:
29 | raise ValueError("clip should be a 4D tensor")
30 | return clip[..., i : i + h, j : j + w]
31 |
32 |
33 | def resize(clip, target_size, interpolation_mode):
34 | if len(target_size) != 2:
35 | raise ValueError(
36 | f"target size should be tuple (height, width), instead got {target_size}"
37 | )
38 | return torch.nn.functional.interpolate(
39 | clip, size=target_size, mode=interpolation_mode, align_corners=False
40 | )
41 |
42 |
43 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
44 | """
45 | Do spatial cropping and resizing to the video clip
46 | Args:
47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
48 | i (int): i in (i,j) i.e coordinates of the upper left corner.
49 | j (int): j in (i,j) i.e coordinates of the upper left corner.
50 | h (int): Height of the cropped region.
51 | w (int): Width of the cropped region.
52 | size (tuple(int, int)): height and width of resized clip
53 | Returns:
54 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
55 | """
56 | if not _is_tensor_video_clip(clip):
57 | raise ValueError("clip should be a 4D torch.tensor")
58 | clip = crop(clip, i, j, h, w)
59 | clip = resize(clip, size, interpolation_mode)
60 | return clip
61 |
62 |
63 | def center_crop(clip, crop_size):
64 | if not _is_tensor_video_clip(clip):
65 | raise ValueError("clip should be a 4D torch.tensor")
66 | h, w = clip.size(-2), clip.size(-1)
67 | th, tw = crop_size
68 | if h < th or w < tw:
69 | raise ValueError("height and width must be no smaller than crop_size")
70 |
71 | i = int(round((h - th) / 2.0))
72 | j = int(round((w - tw) / 2.0))
73 | return crop(clip, i, j, th, tw)
74 |
75 |
76 | def to_tensor(clip):
77 | """
78 | Convert tensor data type from uint8 to float, divide value by 255.0 and
79 | permute the dimensions of clip tensor
80 | Args:
81 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
82 | Return:
83 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
84 | """
85 | _is_tensor_video_clip(clip)
86 | if not clip.dtype == torch.uint8:
87 | raise TypeError(
88 | "clip tensor should have data type uint8. Got %s" % str(clip.dtype)
89 | )
90 | return clip.float().permute(3, 0, 1, 2) / 255.0
91 |
92 |
93 | def normalize(clip, mean, std, inplace=False):
94 | """
95 | Args:
96 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
97 | mean (tuple): pixel RGB mean. Size is (3)
98 | std (tuple): pixel standard deviation. Size is (3)
99 | Returns:
100 | normalized clip (torch.tensor): Size is (C, T, H, W)
101 | """
102 | if not _is_tensor_video_clip(clip):
103 | raise ValueError("clip should be a 4D torch.tensor")
104 | if not inplace:
105 | clip = clip.clone()
106 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
107 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
109 | return clip
110 |
111 |
112 | def hflip(clip):
113 | """
114 | Args:
115 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
116 | Returns:
117 | flipped clip (torch.tensor): Size is (C, T, H, W)
118 | """
119 | if not _is_tensor_video_clip(clip):
120 | raise ValueError("clip should be a 4D torch.tensor")
121 | return clip.flip(-1)
122 |
--------------------------------------------------------------------------------
/utils/shift_video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 | from utils.cons_utils import load_json
6 | import copy
7 |
8 |
9 | def save_video_with_fps(frames, output_path, frame_size, original_fps, target_fps=None):
10 | if target_fps is None:
11 | target_fps = original_fps
12 | """Save frames into a video file at the specified target frame rate (1 FPS)."""
13 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Change to XVID for better compatibility
14 | out = cv2.VideoWriter(output_path, fourcc, target_fps, frame_size)
15 |
16 | # Calculate the step to downsample frames from original FPS to target FPS
17 | frame_step = max(1, int(round(original_fps / target_fps))) # Safeguard against rounding issues
18 |
19 | # Write only frames sampled at the correct step interval
20 | for i in range(0, len(frames), frame_step):
21 | out.write(np.uint8(frames[i])) # Ensure frames are of type np.uint8
22 |
23 | out.release()
24 |
25 |
26 | def swap_frames(video_path, timestamps, shifted_timestamps, output_dir, vid):
27 | """Swap frames between original and shifted timestamps and save both original and swapped videos at 1 FPS."""
28 |
29 | # Open the video
30 | cap = cv2.VideoCapture(video_path)
31 | fps = cap.get(cv2.CAP_PROP_FPS)
32 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
33 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
34 | frame_size = (frame_width, frame_height)
35 |
36 | # Buffer to store frames
37 | original_frames = []
38 | while cap.isOpened():
39 | ret, frame = cap.read()
40 | if not ret:
41 | break
42 |
43 | # Store the frame in the buffer
44 | original_frames.append(frame)
45 |
46 | # # Save the original video (before swapping) with 1 FPS
47 | # original_output_path = os.path.join(output_dir, f"{vid}.mp4")
48 | # save_video_with_fps(original_frames, original_output_path, frame_size, original_fps=fps)
49 |
50 | # Swap frames as per timestamps
51 | for i, (timestamp, shifted_timestamp) in enumerate(zip(timestamps, shifted_timestamps)):
52 | swapped_output_path = os.path.join(output_dir, f"{vid}_{i}.mp4")
53 | frames = copy.deepcopy(original_frames)
54 | start1, end1 = [int(timestamp[0] * fps), int(timestamp[1] * fps)]
55 | start2, end2 = [int(shifted_timestamp[0] * fps), int(shifted_timestamp[1] * fps)]
56 |
57 | # Ensure the ranges have the same length by extending or truncating
58 | length1 = end1 - start1
59 | length2 = end2 - start2
60 | if length1 > length2:
61 | # Extend second range by repeating the last frame
62 | extra_frames = [frames[end2 - 1]] * (length1 - length2)
63 | frames2 = np.concatenate([frames[start2:end2], extra_frames], axis=0)
64 | elif length2 > length1:
65 | # Truncate second range to match the first range's length
66 | frames2 = frames[start2:start2 + length1]
67 | else:
68 | frames2 = frames[start2:end2]
69 |
70 | # Swap frames between the two ranges
71 | temp = np.copy(frames[start1:end1]) # Copy frames from first range
72 | frames[start1:end1] = frames2 # Place frames from second range to first
73 | frames[start2:start2 + len(temp)] = temp # Place copied frames into second range
74 |
75 | # Save the swapped video with 1 FPS
76 | save_video_with_fps(frames, swapped_output_path, frame_size, original_fps=fps)
77 | cap.release() # Release outside the loop
78 |
79 |
80 | if __name__ == "__main__":
81 | # Load and process each video from annotations
82 | dset_name = "activitynet" # or charades
83 | video_root_path = "" # set the right path to your video files
84 | annotations = load_json(f"data/{dset_name}_consistency_test.json")
85 | output_dir = f'{video_root_path}/{dset_name}/shifted_videos/'
86 |
87 | if not os.path.exists(output_dir):
88 | os.makedirs(output_dir)
89 |
90 | print("Output dir:", output_dir)
91 |
92 | for vid, details in tqdm(annotations.items()):
93 | video_path = f"{video_root_path}/{dset_name}/{vid}.mp4" # Assuming the videos are in a "data" directory
94 | if os.path.exists(video_path):
95 | original_timestamps = details["timestamps"]
96 | shifted_timestamps = details["shifted_timestamps"]
97 | swap_frames(video_path, original_timestamps, shifted_timestamps, output_dir, vid)
98 |
99 | else:
100 | print(f"Video file {video_path} not found.")
101 |
--------------------------------------------------------------------------------
/timechat/models/blip2_outputs.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS. Below is the original copyright:
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | from dataclasses import dataclass
10 | from typing import Optional
11 |
12 | import torch
13 | from transformers.modeling_outputs import (
14 | ModelOutput,
15 | BaseModelOutputWithPoolingAndCrossAttentions,
16 | CausalLMOutputWithCrossAttentions,
17 | )
18 |
19 |
20 | @dataclass
21 | class BlipSimilarity(ModelOutput):
22 | sim_i2t: torch.FloatTensor = None
23 | sim_t2i: torch.FloatTensor = None
24 |
25 | sim_i2t_m: Optional[torch.FloatTensor] = None
26 | sim_t2i_m: Optional[torch.FloatTensor] = None
27 |
28 | sim_i2t_targets: Optional[torch.FloatTensor] = None
29 | sim_t2i_targets: Optional[torch.FloatTensor] = None
30 |
31 |
32 | @dataclass
33 | class BlipIntermediateOutput(ModelOutput):
34 | """
35 | Data class for intermediate outputs of BLIP models.
36 |
37 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
38 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
39 |
40 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
41 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
42 |
43 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
44 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
45 |
46 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
47 | decoder_labels (torch.LongTensor): labels for the captioning loss.
48 |
49 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
50 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
51 |
52 | """
53 |
54 | # uni-modal features
55 | image_embeds: torch.FloatTensor = None
56 | text_embeds: Optional[torch.FloatTensor] = None
57 |
58 | image_embeds_m: Optional[torch.FloatTensor] = None
59 | text_embeds_m: Optional[torch.FloatTensor] = None
60 |
61 | # intermediate outputs of multimodal encoder
62 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
64 |
65 | itm_logits: Optional[torch.FloatTensor] = None
66 | itm_labels: Optional[torch.LongTensor] = None
67 |
68 | # intermediate outputs of multimodal decoder
69 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
70 | decoder_labels: Optional[torch.LongTensor] = None
71 |
72 |
73 | @dataclass
74 | class BlipOutput(ModelOutput):
75 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
76 | sims: Optional[BlipSimilarity] = None
77 |
78 | intermediate_output: BlipIntermediateOutput = None
79 |
80 | loss: Optional[torch.FloatTensor] = None
81 |
82 | loss_itc: Optional[torch.FloatTensor] = None
83 |
84 | loss_itm: Optional[torch.FloatTensor] = None
85 |
86 | loss_lm: Optional[torch.FloatTensor] = None
87 |
88 |
89 | @dataclass
90 | class BlipOutputFeatures(ModelOutput):
91 | """
92 | Data class of features from BlipFeatureExtractor.
93 |
94 | Args:
95 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
96 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
97 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
98 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
99 |
100 | The first embedding or feature is for the [CLS] token.
101 |
102 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
103 | """
104 |
105 | image_embeds: Optional[torch.FloatTensor] = None
106 | image_embeds_proj: Optional[torch.FloatTensor] = None
107 |
108 | text_embeds: Optional[torch.FloatTensor] = None
109 | text_embeds_proj: Optional[torch.FloatTensor] = None
110 |
111 | multimodal_embeds: Optional[torch.FloatTensor] = None
112 |
--------------------------------------------------------------------------------
/timechat/processors/blip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import re
9 |
10 | from timechat.common.registry import registry
11 | from timechat.processors.base_processor import BaseProcessor
12 | from timechat.processors.randaugment import RandomAugment
13 | from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 |
17 |
18 | class BlipImageBaseProcessor(BaseProcessor):
19 | def __init__(self, mean=None, std=None):
20 | if mean is None:
21 | mean = (0.48145466, 0.4578275, 0.40821073)
22 | if std is None:
23 | std = (0.26862954, 0.26130258, 0.27577711)
24 |
25 | self.normalize = transforms.Normalize(mean, std)
26 |
27 |
28 | @registry.register_processor("blip_caption")
29 | class BlipCaptionProcessor(BaseProcessor):
30 | def __init__(self, prompt="", max_words=50):
31 | self.prompt = prompt
32 | self.max_words = max_words
33 |
34 | def __call__(self, caption):
35 | caption = self.prompt + self.pre_caption(caption)
36 |
37 | return caption
38 |
39 | @classmethod
40 | def from_config(cls, cfg=None):
41 | if cfg is None:
42 | cfg = OmegaConf.create()
43 |
44 | prompt = cfg.get("prompt", "")
45 | max_words = cfg.get("max_words", 50)
46 |
47 | return cls(prompt=prompt, max_words=max_words)
48 |
49 | def pre_caption(self, caption):
50 | caption = re.sub(
51 | r"([.!\"()*#:;~])",
52 | " ",
53 | caption.lower(),
54 | )
55 | caption = re.sub(
56 | r"\s{2,}",
57 | " ",
58 | caption,
59 | )
60 | caption = caption.rstrip("\n")
61 | caption = caption.strip(" ")
62 |
63 | # truncate caption
64 | caption_words = caption.split(" ")
65 | if len(caption_words) > self.max_words:
66 | caption = " ".join(caption_words[: self.max_words])
67 |
68 | return caption
69 |
70 |
71 | @registry.register_processor("blip2_image_train")
72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74 | super().__init__(mean=mean, std=std)
75 |
76 | self.transform = transforms.Compose(
77 | [
78 | transforms.RandomResizedCrop(
79 | image_size,
80 | scale=(min_scale, max_scale),
81 | interpolation=InterpolationMode.BICUBIC,
82 | ),
83 | transforms.ToTensor(),
84 | self.normalize,
85 | ]
86 | )
87 |
88 | def __call__(self, item):
89 | return self.transform(item)
90 |
91 | @classmethod
92 | def from_config(cls, cfg=None):
93 | if cfg is None:
94 | cfg = OmegaConf.create()
95 |
96 | image_size = cfg.get("image_size", 224)
97 |
98 | mean = cfg.get("mean", None)
99 | std = cfg.get("std", None)
100 |
101 | min_scale = cfg.get("min_scale", 0.5)
102 | max_scale = cfg.get("max_scale", 1.0)
103 |
104 | return cls(
105 | image_size=image_size,
106 | mean=mean,
107 | std=std,
108 | min_scale=min_scale,
109 | max_scale=max_scale,
110 | )
111 |
112 |
113 | @registry.register_processor("blip2_image_eval")
114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115 | def __init__(self, image_size=224, mean=None, std=None):
116 | super().__init__(mean=mean, std=std)
117 |
118 | self.transform = transforms.Compose(
119 | [
120 | transforms.Resize(
121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122 | ),
123 | transforms.ToTensor(),
124 | self.normalize,
125 | ]
126 | )
127 |
128 | def __call__(self, item):
129 | return self.transform(item)
130 |
131 | @classmethod
132 | def from_config(cls, cfg=None):
133 | if cfg is None:
134 | cfg = OmegaConf.create()
135 |
136 | image_size = cfg.get("image_size", 224)
137 |
138 | mean = cfg.get("mean", None)
139 | std = cfg.get("std", None)
140 |
141 | return cls(image_size=image_size, mean=mean, std=std)
142 |
143 |
--------------------------------------------------------------------------------
/utils/prompts.py:
--------------------------------------------------------------------------------
1 | neg_occurrence = [
2 | "Is the event '{event}' absent from {st} to {ed} seconds in the video?",
3 | "Is the event '{event}' not present from {st} to {ed} seconds in the video?",
4 | "Does the event '{event}' not happen from {st} to {ed} seconds in the video?",
5 | "Is the event '{event}' missing from {st} to {ed} seconds in the video?"
6 | ]
7 |
8 | pos_occurrence = [
9 | "Is the event '{event}' present from {st} to {ed} seconds in the video?",
10 | "Is the event '{event}' occurring from {st} to {ed} seconds in the video?",
11 | "Does the event '{event}' happen from {st} to {ed} seconds in the video?",
12 | "Is the event '{event}' included from {st} to {ed} seconds in the video?"
13 | ]
14 |
15 | grounding_prompts = [
16 | "When does the event '{event}' happen in the video? Please only return its start time and end time.",
17 | "Please find the visual contents in the video described by a given event, determining its starting and ending times. Now I will give you the event: '{event}'. Please only return its start time and end time.",
18 | "Please answer when the event '{event}' occurs in the video. The output format should be: start - end seconds’. Please return its start time and end time."
19 | ]
20 |
21 | # default grounding prompt is third one in the list 'grounding_prompts'.
22 | prompt = {
23 | "grounding": "Please answer when the event '{event}' occurs in the video. The output format should be: 'start - end seconds'. Please return its start time and end time.",
24 | "pos": pos_occurrence,
25 | "neg": neg_occurrence,
26 | "add_detail": "Please answer with 'Yes' or 'No'.",
27 | "description": "Please describe the given video in detail.",
28 | "compositional": "{question} from {st} to {ed} seconds in the video?",
29 | }
30 |
31 | cot = {
32 | "grounding": """Your task is to predict the start and end times of an action or event described by a query sentence based on the visual content of the video. Use Chain-of-Thought reasoning to break down the query, analyze key moments, and accurately identify the time range where the action occurs.
33 | ### Chain-of-Thought Reasoning:
34 | 1. **Step 1: Parse the Query**: Break down the query sentence to understand the key action or event that you need to locate.
35 | 2. **Step 2: Analyze the Video Features**: Examine the sequence of video frames to detect patterns that match the key action described in the query.
36 | 3. **Step 3: Identify the Temporal Boundaries**: Use temporal reasoning to find the start and end frames of the action based on the video features.
37 | 4. **Step 4: Predict Start and End Times**: Map the identified frames to timestamps in the video, making sure the start and end times align with the query.
38 | 5. **Step 5: Verify the Answer**: Check if the predicted time range accurately captures the action described in the query.
39 | """,
40 | "occurrence": """You are a model designed to predict when specific events occur in a video based on a query sentence. Your task is to verify whether the event described in the query occurs in the given moment of the video.
41 | ### Chain-of-Thought Reasoning:
42 | 1. **Step 1: Verify the Event in the Predicted Time Range**: Analyze the video features from the predicted start time to the end time. Determine if the event described in the query occurs within this time range.
43 | - Example: For the query "The person is cooking," check for visual patterns such as a stove or kitchen utensils during the predicted moment.
44 | 2. **Step 2: Answer the Verification Question**: Respond to the question:
45 | - **"Is the event '{event}' present from {start_time} to {end_time} seconds in the video?"**
46 | - Example: "Is the event 'The person is cooking' present from 30.0 to 40.0 seconds in the video?"
47 | - If find the event in the given moment, your answer should be "Yes.", if it does happen in the given moment, your answer should be "No.".
48 | """,
49 | "compositional": """You are a model designed to analyze the compositional elements of an event in a video. Your task is to verify whether each compositional element occurs during the given moment in the video based on the specific question you receive. Instead of analyzing the entire event at once, you will answer questions about individual components of the scene.
50 | ### Chain-of-Thought Reasoning:
51 | 1. **Step 1: Analyze the Video Features for the Specific Element**: Analyze the video features from the start time to the end time. Look for visual evidence of the specific compositional element described in the question.
52 | 2. **Step 2: Answer the Compositional Question**: Respond to the question:
53 | - Example: "Is there a young girl from 0.0 to 5.0 seconds in the video?"
54 | - If you find a young girl in the given video moment, your answer should be "Yes.". If it is not present, your answer should be "No.".
55 | """,
56 | }
--------------------------------------------------------------------------------
/timechat/datasets/datasets/webvid_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from timechat.datasets.datasets.base_dataset import BaseDataset
10 | from timechat.datasets.datasets.caption_datasets import CaptionDataset
11 | import pandas as pd
12 | import decord
13 | from decord import VideoReader
14 | import random
15 | import torch
16 | from torch.utils.data.dataloader import default_collate
17 | class WebvidDataset(BaseDataset):
18 | def __init__(self, vis_processor, text_processor, vis_root, ann_root):
19 | """
20 | vis_root (string): Root directory of video (e.g. webvid_eval/video/)
21 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
22 | split (string): val or test
23 | """
24 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
25 |
26 |
27 | # 读取一个路径下所有的
28 |
29 | ts_df = []
30 | for file_name in os.listdir(ann_root):
31 | if file_name.endswith('.csv'):
32 | df = pd.read_csv(os.path.join(ann_root, file_name))
33 | ts_df.append(df)
34 |
35 | merged_df = pd.concat(ts_df)
36 | self.annotation = merged_df
37 | self.vis_root = vis_root
38 | self.resize_size = 224
39 | self.num_frm = 8
40 | self.frm_sampling_strategy = 'headtail'
41 |
42 | def _get_video_path(self, sample):
43 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
44 | full_video_fp = os.path.join(self.vis_root, rel_video_fp)
45 | return full_video_fp
46 |
47 | def __getitem__(self, index):
48 | num_retries = 10 # skip error videos
49 | for _ in range(num_retries):
50 | sample = self.annotation.iloc[index]
51 | sample_dict = sample.to_dict()
52 | video_id = sample_dict['videoid']
53 |
54 | if 'name' in sample_dict.keys():
55 | text = sample_dict['name'].strip()
56 | else:
57 | raise NotImplementedError("Un-supported text annotation format.")
58 |
59 | # fetch video
60 | video_path = self._get_video_path(sample_dict)
61 | # if os.path.exists(video_path):
62 | try:
63 | video = self.vis_processor(video_path)
64 | except:
65 | print(f"Failed to load examples with video: {video_path}. "
66 | f"Will randomly sample an example as a replacement.")
67 | index = random.randint(0, len(self) - 1)
68 | continue
69 | caption = self.text_processor(text)
70 |
71 | # print(video.size())
72 | if video is None or caption is None \
73 | or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
74 | print(f"Failed to load examples with video: {video_path}. "
75 | f"Will randomly sample an example as a replacement.")
76 | index = random.randint(0, len(self) - 1)
77 | continue
78 | else:
79 | break
80 | else:
81 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
82 | # "image_id" is kept to stay compatible with the COCO evaluation format
83 | return {
84 | "image": video,
85 | "text_input": caption,
86 | "type":'video',
87 | }
88 |
89 | def __len__(self):
90 | return len(self.annotation)
91 |
92 | # def collater(self, samples):
93 | # new_result = {}
94 | # new_result['image'] = default_collate( [sample["image"] for sample in samples])
95 | # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
96 | # return new_result
97 |
98 | class WebvidDatasetEvalDataset(BaseDataset):
99 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
100 | """
101 | vis_root (string): Root directory of images (e.g. coco/images/)
102 | ann_root (string): directory to store the annotation file
103 | split (string): val or test
104 | """
105 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
106 |
107 | def __getitem__(self, index):
108 |
109 | ann = self.annotation[index]
110 |
111 | vname = ann["video"]
112 | video_path = os.path.join(self.vis_root, vname)
113 |
114 | video = self.vis_processor(video_path)
115 |
116 | return {
117 | "video": video,
118 | "image_id": ann["image_id"],
119 | "instance_id": ann["instance_id"],
120 | }
121 |
122 |
123 |
--------------------------------------------------------------------------------
/timechat/processors/transforms_video.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 |
10 | import numbers
11 | import random
12 |
13 | from torchvision.transforms import (
14 | RandomCrop,
15 | RandomResizedCrop,
16 | )
17 |
18 | import timechat.processors.functional_video as F
19 |
20 |
21 | __all__ = [
22 | "RandomCropVideo",
23 | "RandomResizedCropVideo",
24 | "CenterCropVideo",
25 | "NormalizeVideo",
26 | "ToTensorVideo",
27 | "RandomHorizontalFlipVideo",
28 | ]
29 |
30 |
31 | class RandomCropVideo(RandomCrop):
32 | def __init__(self, size):
33 | if isinstance(size, numbers.Number):
34 | self.size = (int(size), int(size))
35 | else:
36 | self.size = size
37 |
38 | def __call__(self, clip):
39 | """
40 | Args:
41 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
42 | Returns:
43 | torch.tensor: randomly cropped/resized video clip.
44 | size is (C, T, OH, OW)
45 | """
46 | i, j, h, w = self.get_params(clip, self.size)
47 | return F.crop(clip, i, j, h, w)
48 |
49 | def __repr__(self) -> str:
50 | return f"{self.__class__.__name__}(size={self.size})"
51 |
52 |
53 | class RandomResizedCropVideo(RandomResizedCrop):
54 | def __init__(
55 | self,
56 | size,
57 | scale=(0.08, 1.0),
58 | ratio=(3.0 / 4.0, 4.0 / 3.0),
59 | interpolation_mode="bilinear",
60 | ):
61 | if isinstance(size, tuple):
62 | if len(size) != 2:
63 | raise ValueError(
64 | f"size should be tuple (height, width), instead got {size}"
65 | )
66 | self.size = size
67 | else:
68 | self.size = (size, size)
69 |
70 | self.interpolation_mode = interpolation_mode
71 | self.scale = scale
72 | self.ratio = ratio
73 |
74 | def __call__(self, clip):
75 | """
76 | Args:
77 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
78 | Returns:
79 | torch.tensor: randomly cropped/resized video clip.
80 | size is (C, T, H, W)
81 | """
82 | i, j, h, w = self.get_params(clip, self.scale, self.ratio)
83 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
84 |
85 | def __repr__(self) -> str:
86 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
87 |
88 |
89 | class CenterCropVideo:
90 | def __init__(self, crop_size):
91 | if isinstance(crop_size, numbers.Number):
92 | self.crop_size = (int(crop_size), int(crop_size))
93 | else:
94 | self.crop_size = crop_size
95 |
96 | def __call__(self, clip):
97 | """
98 | Args:
99 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
100 | Returns:
101 | torch.tensor: central cropping of video clip. Size is
102 | (C, T, crop_size, crop_size)
103 | """
104 | return F.center_crop(clip, self.crop_size)
105 |
106 | def __repr__(self) -> str:
107 | return f"{self.__class__.__name__}(crop_size={self.crop_size})"
108 |
109 |
110 | class NormalizeVideo:
111 | """
112 | Normalize the video clip by mean subtraction and division by standard deviation
113 | Args:
114 | mean (3-tuple): pixel RGB mean
115 | std (3-tuple): pixel RGB standard deviation
116 | inplace (boolean): whether do in-place normalization
117 | """
118 |
119 | def __init__(self, mean, std, inplace=False):
120 | self.mean = mean
121 | self.std = std
122 | self.inplace = inplace
123 |
124 | def __call__(self, clip):
125 | """
126 | Args:
127 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
128 | """
129 | return F.normalize(clip, self.mean, self.std, self.inplace)
130 |
131 | def __repr__(self) -> str:
132 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
133 |
134 |
135 | class ToTensorVideo:
136 | """
137 | Convert tensor data type from uint8 to float, divide value by 255.0 and
138 | permute the dimensions of clip tensor
139 | """
140 |
141 | def __init__(self):
142 | pass
143 |
144 | def __call__(self, clip):
145 | """
146 | Args:
147 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
148 | Return:
149 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
150 | """
151 | return F.to_tensor(clip)
152 |
153 | def __repr__(self) -> str:
154 | return self.__class__.__name__
155 |
156 |
157 | class RandomHorizontalFlipVideo:
158 | """
159 | Flip the video clip along the horizonal direction with a given probability
160 | Args:
161 | p (float): probability of the clip being flipped. Default value is 0.5
162 | """
163 |
164 | def __init__(self, p=0.5):
165 | self.p = p
166 |
167 | def __call__(self, clip):
168 | """
169 | Args:
170 | clip (torch.tensor): Size is (C, T, H, W)
171 | Return:
172 | clip (torch.tensor): Size is (C, T, H, W)
173 | """
174 | if random.random() < self.p:
175 | clip = F.hflip(clip)
176 | return clip
177 |
178 | def __repr__(self) -> str:
179 | return f"{self.__class__.__name__}(p={self.p})"
180 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import time
9 | import random
10 | import torch
11 | from timechat.datasets.data_utils import move_to_cuda
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class MultiIterLoader:
16 | """
17 | A simple wrapper for iterating over multiple iterators.
18 |
19 | Args:
20 | loaders (List[Loader]): List of Iterator loaders.
21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22 | """
23 |
24 | def __init__(self, loaders, ratios=None):
25 | # assert all loaders has __next__ method
26 | for loader in loaders:
27 | assert hasattr(
28 | loader, "__next__"
29 | ), "Loader {} has no __next__ method.".format(loader)
30 |
31 | if ratios is None:
32 | ratios = [1.0] * len(loaders)
33 | else:
34 | assert len(ratios) == len(loaders)
35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36 |
37 | self.loaders = loaders
38 | self.ratios = ratios
39 |
40 | def __next__(self):
41 | # random sample from each loader by ratio
42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43 | return next(self.loaders[loader_idx])
44 |
45 |
46 | class PrefetchLoader(object):
47 | """
48 | Modified from https://github.com/ChenRocks/UNITER.
49 |
50 | overlap compute and cuda data transfer
51 | (copied and then modified from nvidia apex)
52 | """
53 |
54 | def __init__(self, loader):
55 | self.loader = loader
56 | self.stream = torch.cuda.Stream()
57 |
58 | def __iter__(self):
59 | loader_it = iter(self.loader)
60 | self.preload(loader_it)
61 | batch = self.next(loader_it)
62 | while batch is not None:
63 | is_tuple = isinstance(batch, tuple)
64 | if is_tuple:
65 | task, batch = batch
66 |
67 | if is_tuple:
68 | yield task, batch
69 | else:
70 | yield batch
71 | batch = self.next(loader_it)
72 |
73 | def __len__(self):
74 | return len(self.loader)
75 |
76 | def preload(self, it):
77 | try:
78 | self.batch = next(it)
79 | except StopIteration:
80 | self.batch = None
81 | return
82 | # if record_stream() doesn't work, another option is to make sure
83 | # device inputs are created on the main stream.
84 | # self.next_input_gpu = torch.empty_like(self.next_input,
85 | # device='cuda')
86 | # self.next_target_gpu = torch.empty_like(self.next_target,
87 | # device='cuda')
88 | # Need to make sure the memory allocated for next_* is not still in use
89 | # by the main stream at the time we start copying to next_*:
90 | # self.stream.wait_stream(torch.cuda.current_stream())
91 | with torch.cuda.stream(self.stream):
92 | self.batch = move_to_cuda(self.batch)
93 | # more code for the alternative if record_stream() doesn't work:
94 | # copy_ will record the use of the pinned source tensor in this
95 | # side stream.
96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98 | # self.next_input = self.next_input_gpu
99 | # self.next_target = self.next_target_gpu
100 |
101 | def next(self, it):
102 | torch.cuda.current_stream().wait_stream(self.stream)
103 | batch = self.batch
104 | if batch is not None:
105 | record_cuda_stream(batch)
106 | self.preload(it)
107 | return batch
108 |
109 | def __getattr__(self, name):
110 | method = self.loader.__getattribute__(name)
111 | return method
112 |
113 |
114 | def record_cuda_stream(batch):
115 | if isinstance(batch, torch.Tensor):
116 | batch.record_stream(torch.cuda.current_stream())
117 | elif isinstance(batch, list) or isinstance(batch, tuple):
118 | for t in batch:
119 | record_cuda_stream(t)
120 | elif isinstance(batch, dict):
121 | for t in batch.values():
122 | record_cuda_stream(t)
123 | else:
124 | pass
125 |
126 |
127 | class IterLoader:
128 | """
129 | A wrapper to convert DataLoader as an infinite iterator.
130 |
131 | Modified from:
132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133 | """
134 |
135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136 | self._dataloader = dataloader
137 | self.iter_loader = iter(self._dataloader)
138 | self._use_distributed = use_distributed
139 | self._epoch = 0
140 |
141 | @property
142 | def epoch(self) -> int:
143 | return self._epoch
144 |
145 | def __next__(self):
146 | try:
147 | data = next(self.iter_loader)
148 | except StopIteration:
149 | self._epoch += 1
150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151 | self._dataloader.sampler.set_epoch(self._epoch)
152 | time.sleep(2) # Prevent possible deadlock during epoch transition
153 | self.iter_loader = iter(self._dataloader)
154 | data = next(self.iter_loader)
155 |
156 | return data
157 |
158 | def __iter__(self):
159 | return self
160 |
161 | def __len__(self):
162 | return len(self._dataloader)
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On the Consistency of Video Large Language Models in Temporal Comprehension
2 |
3 | [](https://arxiv.org/abs/2411.12951)
4 |
5 |
6 |
7 | ## News
8 | - [x] **[2025.03.25]** Evaluation Codes have been released.
9 | - [x] **[2025.02.27]** Our paper has been accepted by CVPR 2025! 🎉
10 | - [x] **[2025.01.15]** We are excited to share that our evaluation datasets, Charades-CON and ActivityNet-CON, are now available on Hugging Face! 🎉 Additionally, the training annotations for VTune have also been released.
11 | - [x] **[2025.01.14]** We have released our four checkpoints using VTune: [VideoLLaMA-7B-Charades-VTune](https://huggingface.co/mjjung/VideoLLaMA-7B-Charades-VTune), [VideoLLaMA-7B-ActvityNet-VTune](https://huggingface.co/mjjung/VideoLLaMA-7B-ActivityNet-VTune), [TimeChat-7B-Charades-VTune](https://huggingface.co/mjjung/TimeChat-7B-Charades-VTune), [TimeChat-7B-ActvityNet-VTune](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-VTune). Additionally, checkpoints with naive fine-tuning: [VideoLLaMA-7B-Charades-FT](https://huggingface.co/mjjung/VideoLLAMA-7B-Charades-FT), [VideoLLaMA-7B-ActvityNet-FT](https://huggingface.co/mjjung/VideoLLaMA-7B-ActivityNet-FT), [TimeChat-7B-ActivityNet-FT](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-FT) have been released.
12 | - [x] **[2024.11.20]** Our paper has been released on arXiv.
13 |
14 | ## Introduction
15 | 
16 | - We study the model’s consistency in temporal comprehension by assessing whether its responses align with the initial grounding, using dedicated probes and datasets. We specifically focus on video temporal grounding, where the task involves identifying timestamps in a video that correspond to language queries.
17 |
18 | ## Download
19 | You can download the complete annotations for consistency evaluation from [Hugging Face](https://huggingface.co/datasets/mjjung/Consistency-Evaluation-for-Video-LLMs). The source videos are available via the following links:
20 |
21 | - [Charades-STA](https://prior.allenai.org/projects/charades)
22 | - [ActivityNet-Captions](https://cs.stanford.edu/people/ranjaykrishna/densevid/)
23 |
24 | ## Evaluation
25 | Before starting the evaluation, make sure you have prepared the annotations and videos. You should also check the configuration of the Video-LLMs. Install the necessary dependencies using conda and pip for your model. Additionally, you may run `utils/shift_video.py` with the right paths to prepare shifted videos.
26 | Here, we provide an example with the model `TimeChat`. We will include additional baseline models in the future.
27 |
28 | To run the evaluation, use the following command:
29 |
30 | ```
31 | python run.py --model_type TimeChat --dset_name activitynet --task consistency
32 | ````
33 | `dset_name` refers to the test dataset, which can be either `charades` or `activitynet`. `task` refers to the evaluation task: either `consistency` or `grounding`. If set to `grounding`, the evaluation will be performed on the original test set.
34 | You can also use the `--debug` flag before performing the actual evaluation to verify your configuration settings.
35 |
36 | Once the evaluation is complete, the performance will be reported in `consistency_eval_results.json`, and you can check the model's output in `consistency_predictions.jsonl`.
37 |
38 |
39 | ## Training
40 |
41 | For training, please download the training annotations for each dataset from Hugging Face.
42 |
43 | ### ⚠️ Important Note
44 |
45 | The previously uploaded **VTune** dataset file `charades_for_VTune.json` partially includes videos from the Charades-Con test split.
46 | The updated file `charades_train_v2.json` and `charades_for_VTune_v2.json` excludes these overlapping videos.
47 | The corresponding hyperparameters should follow the table below. Note that neither dataset includes test videos from Charades-STA (the original one).
48 | We apologize for any inconvenience caused.
49 |
50 | | Dataset Name | iters_per_epochs | warmup_steps |
51 | |--------------------------|------------------|--------------|
52 | | `charades_for_VTune` | 24,811 | 14,916 |
53 | | `charades_for_VTune2` | 22,311 | 13,386 |
54 |
55 | The performance of TimeChat trained with `charades_for_VTune2`:
56 |
57 | | **Method** | **Ground** | **R-Ground** | **S-Ground** | **H-Verify** | **C-Verify** |
58 | |:-----------|:----------:|:------------:|:------------:|:------------:|:------------:|
59 | | SFT | 47.2 | 43.4 (91.8) | 15.0 (31.9) | 24.3 (51.5) | 24.0 (50.9) |
60 | | VTune | 52.0 | 47.4 (91.2) | 23.5 (45.2) | 31.5 (60.5) | 27.5 (52.9) |
61 |
62 |
63 | ## Checkpoints
64 |
65 | We provide the checkpoints for each dataset using the links below:
66 | - [Charades-STA](https://huggingface.co/mjjung/TimeChat-7B-Charades-VTune)
67 | - [ActivityNet-Captions](https://huggingface.co/mjjung/TimeChat-7B-ActivityNet-VTune)
68 |
69 | Then, use the following command:
70 | ```
71 | python run.py --model_type TimeChat --dset_name activitynet --fine_tuned --task consistency
72 | ```
73 | In the above example command, including the `fine_tuned` option will automatically switch the checkpoint path `ckpt` to `activitynet_ckpt` in `timechat/eval_configs/timechat.yaml`.
74 |
75 | ## Citation
76 | If you find our paper useful, please consider citing our paper.
77 | ```BibTeX
78 | @inproceedings{jung2025consistency,
79 | title={On the consistency of video large language models in temporal comprehension},
80 | author={Jung, Minjoon and Xiao, Junbin and Zhang, Byoung-Tak and Yao, Angela},
81 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
82 | pages={13713--13722},
83 | year={2025}
84 | }
85 | ```
86 | ## Acknowledgement
87 | We appreciate for the following awesome Video-LLMs:
88 | - [TimeChat](https://github.com/RenShuhuai-Andy/TimeChat)
89 | - [VTimeLLM](https://github.com/huangb23/VTimeLLM)
90 | - [VTG-LLM](https://github.com/gyxxyg/VTG-LLM)
91 | - [Video-LLaMA](https://github.com/DAMO-NLP-SG/Video-LLaMA)
92 | - [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2)
93 | - [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA)
94 | - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT)
95 | - [VideoChat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
96 |
--------------------------------------------------------------------------------
/timechat/datasets/builders/instruct_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from timechat.common.registry import registry
6 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from timechat.datasets.datasets.laion_dataset import LaionDataset
8 | from timechat.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
9 | from timechat.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
10 |
11 |
12 | @registry.register_builder("image_instruct")
13 | class Image_Instruct_Builder(BaseDatasetBuilder):
14 | train_dataset_cls = Instruct_Dataset
15 |
16 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
17 |
18 | def _download_ann(self):
19 | pass
20 |
21 | def _download_vis(self):
22 | pass
23 |
24 | def build(self):
25 | self.build_processors()
26 | datasets = dict()
27 | split = "train"
28 |
29 | build_info = self.config.build_info
30 | dataset_cls = self.train_dataset_cls
31 | if self.config.num_video_query_token:
32 | num_video_query_token = self.config.num_video_query_token
33 | else:
34 | num_video_query_token = 32
35 |
36 | if self.config.tokenizer_name:
37 | tokenizer_name = self.config.tokenizer_name
38 | else:
39 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
40 |
41 | model_type = self.config.model_type if self.config.model_type else 'vicuna'
42 |
43 | datasets[split] = dataset_cls(
44 | vis_processor=self.vis_processors[split],
45 | text_processor=self.text_processors[split],
46 | vis_root=build_info.videos_dir,
47 | ann_root=build_info.anno_dir,
48 | num_video_query_token=num_video_query_token,
49 | tokenizer_name=tokenizer_name,
50 | data_type=self.config.data_type,
51 | model_type=model_type,
52 | )
53 |
54 | return datasets
55 |
56 |
57 | @registry.register_builder("video_instruct")
58 | class Video_Instruct_Builder(BaseDatasetBuilder):
59 | train_dataset_cls = Video_Instruct_Dataset
60 |
61 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
62 |
63 | def _download_ann(self):
64 | pass
65 |
66 | def _download_vis(self):
67 | pass
68 |
69 | def build(self):
70 | self.build_processors()
71 | datasets = dict()
72 | split = "train"
73 |
74 | build_info = self.config.build_info
75 | dataset_cls = self.train_dataset_cls
76 | if self.config.num_video_query_token:
77 | num_video_query_token = self.config.num_video_query_token
78 | else:
79 | num_video_query_token = 32
80 |
81 | if self.config.tokenizer_name:
82 | tokenizer_name = self.config.tokenizer_name
83 | else:
84 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
85 |
86 | model_type = self.config.model_type if self.config.model_type else 'vicuna'
87 | num_frm = self.config.num_frm if self.config.num_frm else 8
88 | sample_type = self.config.sample_type if self.config.sample_type else 'uniform'
89 | max_txt_len = self.config.max_txt_len if self.config.max_txt_len else 512
90 | stride = self.config.stride if self.config.stride else 0
91 |
92 | datasets[split] = dataset_cls(
93 | vis_processor=self.vis_processors[split],
94 | text_processor=self.text_processors[split],
95 | vis_root=build_info.videos_dir,
96 | ann_root=build_info.anno_dir,
97 | num_video_query_token=num_video_query_token,
98 | tokenizer_name=tokenizer_name,
99 | data_type=self.config.data_type,
100 | model_type=model_type,
101 | num_frm=num_frm,
102 | sample_type=sample_type,
103 | max_txt_len=max_txt_len,
104 | stride=stride
105 | )
106 |
107 | return datasets
108 |
109 |
110 | @registry.register_builder("webvid_instruct")
111 | class WebvidInstruct_Builder(Video_Instruct_Builder):
112 | train_dataset_cls = Video_Instruct_Dataset
113 |
114 | DATASET_CONFIG_DICT = {
115 | "default": "configs/datasets/instruct/webvid_instruct.yaml",
116 | }
117 |
118 |
119 | @registry.register_builder("webvid_instruct_zh")
120 | class WebvidInstruct_zh_Builder(Video_Instruct_Builder):
121 | train_dataset_cls = Video_Instruct_Dataset
122 |
123 | DATASET_CONFIG_DICT = {
124 | "default": "configs/datasets/instruct/webvid_instruct.yaml",
125 | }
126 |
127 |
128 | @registry.register_builder("llava_instruct")
129 | class LlavaInstruct_Builder(Image_Instruct_Builder):
130 | train_dataset_cls = Instruct_Dataset
131 |
132 | DATASET_CONFIG_DICT = {
133 | "default": "configs/datasets/instruct/llava_instruct.yaml",
134 | }
135 |
136 |
137 | @registry.register_builder("youcook2_instruct")
138 | class Youcook2Instruct_Builder(Video_Instruct_Builder):
139 | train_dataset_cls = Video_Instruct_Dataset
140 |
141 | DATASET_CONFIG_DICT = {
142 | "default": "configs/datasets/instruct/youcook2_instruct.yaml",
143 | }
144 |
145 |
146 | @registry.register_builder("time_instruct")
147 | class TimeInstruct_Builder(Video_Instruct_Builder):
148 | train_dataset_cls = Video_Instruct_Dataset
149 |
150 | DATASET_CONFIG_DICT = {
151 | "default": "configs/datasets/instruct/time_instruct.yaml",
152 | }
153 |
154 |
155 | @registry.register_builder("valley72k_instruct")
156 | class Valley72kInstruct_Builder(Video_Instruct_Builder):
157 | train_dataset_cls = Video_Instruct_Dataset
158 |
159 | DATASET_CONFIG_DICT = {
160 | "default": "configs/datasets/instruct/valley72k_instruct.yaml",
161 | }
162 |
163 |
164 | @registry.register_builder("qvhighlights_instruct")
165 | class QVhighlightsInstruct_Builder(Video_Instruct_Builder):
166 | train_dataset_cls = Video_Instruct_Dataset
167 |
168 | DATASET_CONFIG_DICT = {
169 | "default": "configs/datasets/instruct/qvhighlights_instruct.yaml",
170 | }
171 |
172 |
173 | @registry.register_builder("charades_instruct")
174 | class CharadesInstruct_Builder(Video_Instruct_Builder):
175 | train_dataset_cls = Video_Instruct_Dataset
176 |
177 | DATASET_CONFIG_DICT = {
178 | "default": "configs/datasets/instruct/charades_instruct.yaml",
179 | }
180 |
--------------------------------------------------------------------------------
/timechat/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright:
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | import logging
10 | import torch
11 | from omegaconf import OmegaConf
12 |
13 | from timechat.common.registry import registry
14 | from timechat.models.base_model import BaseModel
15 | from timechat.models.blip2 import Blip2Base
16 | from timechat.models.timechat import TimeChat
17 | from timechat.processors.base_processor import BaseProcessor
18 |
19 |
20 | __all__ = [
21 | "load_model",
22 | "BaseModel",
23 | "Blip2Base",
24 | "TimeChat"
25 | ]
26 |
27 |
28 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
29 | """
30 | Load supported models.
31 |
32 | To list all available models and types in registry:
33 | >>> from timechat.models import model_zoo
34 | >>> print(model_zoo)
35 |
36 | Args:
37 | name (str): name of the model.
38 | model_type (str): type of the model.
39 | is_eval (bool): whether the model is in eval mode. Default: False.
40 | device (str): device to use. Default: "cpu".
41 | checkpoint (str): path or to checkpoint. Default: None.
42 | Note that expecting the checkpoint to have the same keys in state_dict as the model.
43 |
44 | Returns:
45 | model (torch.nn.Module): model.
46 | """
47 |
48 | model = registry.get_model_class(name).from_pretrained(model_type=model_type)
49 |
50 | if checkpoint is not None:
51 | model.load_checkpoint(checkpoint)
52 |
53 | if is_eval:
54 | model.eval()
55 |
56 | if device == "cpu":
57 | model = model.float()
58 |
59 | return model.to(device)
60 |
61 |
62 | def load_preprocess(config):
63 | """
64 | Load preprocessor configs and construct preprocessors.
65 |
66 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
67 |
68 | Args:
69 | config (dict): preprocessor configs.
70 |
71 | Returns:
72 | vis_processors (dict): preprocessors for visual inputs.
73 | txt_processors (dict): preprocessors for text inputs.
74 |
75 | Key is "train" or "eval" for processors used in training and evaluation respectively.
76 | """
77 |
78 | def _build_proc_from_cfg(cfg):
79 | return (
80 | registry.get_processor_class(cfg.name).from_config(cfg)
81 | if cfg is not None
82 | else BaseProcessor()
83 | )
84 |
85 | vis_processors = dict()
86 | txt_processors = dict()
87 |
88 | vis_proc_cfg = config.get("vis_processor")
89 | txt_proc_cfg = config.get("text_processor")
90 |
91 | if vis_proc_cfg is not None:
92 | vis_train_cfg = vis_proc_cfg.get("train")
93 | vis_eval_cfg = vis_proc_cfg.get("eval")
94 | else:
95 | vis_train_cfg = None
96 | vis_eval_cfg = None
97 |
98 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
99 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
100 |
101 | if txt_proc_cfg is not None:
102 | txt_train_cfg = txt_proc_cfg.get("train")
103 | txt_eval_cfg = txt_proc_cfg.get("eval")
104 | else:
105 | txt_train_cfg = None
106 | txt_eval_cfg = None
107 |
108 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
109 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
110 |
111 | return vis_processors, txt_processors
112 |
113 |
114 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
115 | """
116 | Load model and its related preprocessors.
117 |
118 | List all available models and types in registry:
119 | >>> from timechat.models import model_zoo
120 | >>> print(model_zoo)
121 |
122 | Args:
123 | name (str): name of the model.
124 | model_type (str): type of the model.
125 | is_eval (bool): whether the model is in eval mode. Default: False.
126 | device (str): device to use. Default: "cpu".
127 |
128 | Returns:
129 | model (torch.nn.Module): model.
130 | vis_processors (dict): preprocessors for visual inputs.
131 | txt_processors (dict): preprocessors for text inputs.
132 | """
133 | model_cls = registry.get_model_class(name)
134 |
135 | # load model
136 | model = model_cls.from_pretrained(model_type=model_type)
137 |
138 | if is_eval:
139 | model.eval()
140 |
141 | # load preprocess
142 | cfg = OmegaConf.load(model_cls.default_config_path(model_type))
143 | if cfg is not None:
144 | preprocess_cfg = cfg.preprocess
145 |
146 | vis_processors, txt_processors = load_preprocess(preprocess_cfg)
147 | else:
148 | vis_processors, txt_processors = None, None
149 | logging.info(
150 | f"""No default preprocess for model {name} ({model_type}).
151 | This can happen if the model is not finetuned on downstream datasets,
152 | or it is not intended for direct use without finetuning.
153 | """
154 | )
155 |
156 | if device == "cpu" or device == torch.device("cpu"):
157 | model = model.float()
158 |
159 | return model.to(device), vis_processors, txt_processors
160 |
161 |
162 | class ModelZoo:
163 | """
164 | A utility class to create string representation of available model architectures and types.
165 |
166 | >>> from timechat.models import model_zoo
167 | >>> # list all available models
168 | >>> print(model_zoo)
169 | >>> # show total number of models
170 | >>> print(len(model_zoo))
171 | """
172 |
173 | def __init__(self) -> None:
174 | self.model_zoo = {
175 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
176 | for k, v in registry.mapping["model_name_mapping"].items()
177 | }
178 |
179 | def __str__(self) -> str:
180 | return (
181 | "=" * 50
182 | + "\n"
183 | + f"{'Architectures':<30} {'Types'}\n"
184 | + "=" * 50
185 | + "\n"
186 | + "\n".join(
187 | [
188 | f"{name:<30} {', '.join(types)}"
189 | for name, types in self.model_zoo.items()
190 | ]
191 | )
192 | )
193 |
194 | def __iter__(self):
195 | return iter(self.model_zoo.items())
196 |
197 | def __len__(self):
198 | return sum([len(v) for v in self.model_zoo.values()])
199 |
200 |
201 | model_zoo = ModelZoo()
202 |
--------------------------------------------------------------------------------
/timechat/common/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import logging
10 | import time
11 | from collections import defaultdict, deque
12 |
13 | import torch
14 | import torch.distributed as dist
15 |
16 | from timechat.common import dist_utils
17 |
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not dist_utils.is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value,
79 | )
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError(
100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101 | )
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def global_avg(self):
110 | loss_str = []
111 | for name, meter in self.meters.items():
112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113 | return self.delimiter.join(loss_str)
114 |
115 | def synchronize_between_processes(self):
116 | for meter in self.meters.values():
117 | meter.synchronize_between_processes()
118 |
119 | def add_meter(self, name, meter):
120 | self.meters[name] = meter
121 |
122 | def log_every(self, iterable, print_freq, header=None):
123 | i = 0
124 | if not header:
125 | header = ""
126 | start_time = time.time()
127 | end = time.time()
128 | iter_time = SmoothedValue(fmt="{avg:.4f}")
129 | data_time = SmoothedValue(fmt="{avg:.4f}")
130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131 | log_msg = [
132 | header,
133 | "[{0" + space_fmt + "}/{1}]",
134 | "eta: {eta}",
135 | "{meters}",
136 | "time: {time}",
137 | "data: {data}",
138 | ]
139 | if torch.cuda.is_available():
140 | log_msg.append("max mem: {memory:.0f}")
141 | log_msg = self.delimiter.join(log_msg)
142 | MB = 1024.0 * 1024.0
143 | for obj in iterable:
144 | data_time.update(time.time() - end)
145 | yield obj
146 | iter_time.update(time.time() - end)
147 | if i % print_freq == 0 or i == len(iterable) - 1:
148 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150 | if torch.cuda.is_available():
151 | print(
152 | log_msg.format(
153 | i,
154 | len(iterable),
155 | eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time),
158 | data=str(data_time),
159 | memory=torch.cuda.max_memory_allocated() / MB,
160 | )
161 | )
162 | else:
163 | print(
164 | log_msg.format(
165 | i,
166 | len(iterable),
167 | eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time),
170 | data=str(data_time),
171 | )
172 | )
173 | i += 1
174 | end = time.time()
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print(
178 | "{} Total time: {} ({:.4f} s / it)".format(
179 | header, total_time_str, total_time / len(iterable)
180 | )
181 | )
182 |
183 |
184 | class AttrDict(dict):
185 | def __init__(self, *args, **kwargs):
186 | super(AttrDict, self).__init__(*args, **kwargs)
187 | self.__dict__ = self
188 |
189 |
190 | def setup_logger():
191 | logging.basicConfig(
192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193 | format="%(asctime)s [%(levelname)s] %(message)s",
194 | handlers=[logging.StreamHandler()],
195 | )
196 |
--------------------------------------------------------------------------------
/timechat/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import gzip
9 | import logging
10 | import os
11 | import random as rnd
12 | import tarfile
13 | import zipfile
14 | import random
15 | from typing import List
16 | from tqdm import tqdm
17 |
18 | import decord
19 | from decord import VideoReader
20 | import webdataset as wds
21 | import numpy as np
22 | import torch
23 | from torch.utils.data.dataset import IterableDataset
24 |
25 | from timechat.common.registry import registry
26 | from timechat.datasets.datasets.base_dataset import ConcatDataset
27 |
28 |
29 | decord.bridge.set_bridge("torch")
30 | MAX_INT = registry.get("MAX_INT")
31 |
32 |
33 | class ChainDataset(wds.DataPipeline):
34 | r"""Dataset for chaining multiple :class:`DataPipeline` s.
35 |
36 | This class is useful to assemble different existing dataset streams. The
37 | chaining operation is done on-the-fly, so concatenating large-scale
38 | datasets with this class will be efficient.
39 |
40 | Args:
41 | datasets (iterable of IterableDataset): datasets to be chained together
42 | """
43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44 | super().__init__()
45 | self.datasets = datasets
46 | self.prob = []
47 | self.names = []
48 | for dataset in self.datasets:
49 | if hasattr(dataset, 'name'):
50 | self.names.append(dataset.name)
51 | else:
52 | self.names.append('Unknown')
53 | if hasattr(dataset, 'sample_ratio'):
54 | self.prob.append(dataset.sample_ratio)
55 | else:
56 | self.prob.append(1)
57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58 |
59 | def __iter__(self):
60 | datastreams = [iter(dataset) for dataset in self.datasets]
61 | while True:
62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63 | yield next(select_datastream)
64 |
65 |
66 | def apply_to_sample(f, sample):
67 | if len(sample) == 0:
68 | return {}
69 |
70 | def _apply(x):
71 | if torch.is_tensor(x):
72 | return f(x)
73 | elif isinstance(x, dict):
74 | return {key: _apply(value) for key, value in x.items()}
75 | elif isinstance(x, list):
76 | return [_apply(x) for x in x]
77 | else:
78 | return x
79 |
80 | return _apply(sample)
81 |
82 |
83 | def move_to_cuda(sample):
84 | def _move_to_cuda(tensor):
85 | return tensor.cuda()
86 |
87 | return apply_to_sample(_move_to_cuda, sample)
88 |
89 |
90 | def prepare_sample(samples, cuda_enabled=True):
91 | if cuda_enabled:
92 | samples = move_to_cuda(samples)
93 |
94 | # TODO fp16 support
95 |
96 | return samples
97 |
98 |
99 | def reorg_datasets_by_split(datasets):
100 | """
101 | Organizes datasets by split.
102 |
103 | Args:
104 | datasets: dict of torch.utils.data.Dataset objects by name.
105 |
106 | Returns:
107 | Dict of datasets by split {split_name: List[Datasets]}.
108 | """
109 | # if len(datasets) == 1:
110 | # return datasets[list(datasets.keys())[0]]
111 | # else:
112 | reorg_datasets = dict()
113 |
114 | # reorganize by split
115 | for _, dataset in datasets.items():
116 | for split_name, dataset_split in dataset.items():
117 | if split_name not in reorg_datasets:
118 | reorg_datasets[split_name] = [dataset_split]
119 | else:
120 | reorg_datasets[split_name].append(dataset_split)
121 |
122 | return reorg_datasets
123 |
124 |
125 | def concat_datasets(datasets):
126 | """
127 | Concatenates multiple datasets into a single dataset.
128 |
129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130 | generic IterableDataset because it requires creating separate samplers.
131 |
132 | Now only supports conctenating training datasets and assuming validation and testing
133 | have only a single dataset. This is because metrics should not be computed on the concatenated
134 | datasets.
135 |
136 | Args:
137 | datasets: dict of torch.utils.data.Dataset objects by split.
138 |
139 | Returns:
140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141 | "val" and "test" remain the same.
142 |
143 | If the input training datasets contain both map-style and DataPipeline datasets, returns
144 | a tuple, where the first element is a concatenated map-style dataset and the second
145 | element is a chained DataPipeline dataset.
146 |
147 | """
148 | # concatenate datasets in the same split
149 | for split_name in datasets:
150 | if split_name != "train":
151 | assert (
152 | len(datasets[split_name]) == 1
153 | ), "Do not support multiple {} datasets.".format(split_name)
154 | datasets[split_name] = datasets[split_name][0]
155 | else:
156 | iterable_datasets, map_datasets = [], []
157 | for dataset in datasets[split_name]:
158 | if isinstance(dataset, wds.DataPipeline):
159 | logging.info(
160 | "Dataset {} is IterableDataset, can't be concatenated.".format(
161 | dataset
162 | )
163 | )
164 | iterable_datasets.append(dataset)
165 | elif isinstance(dataset, IterableDataset):
166 | raise NotImplementedError(
167 | "Do not support concatenation of generic IterableDataset."
168 | )
169 | else:
170 | map_datasets.append(dataset)
171 |
172 | # if len(iterable_datasets) > 0:
173 | # concatenate map-style datasets and iterable-style datasets separately
174 | if len(iterable_datasets) > 1:
175 | chained_datasets = (
176 | ChainDataset(iterable_datasets)
177 | )
178 | elif len(iterable_datasets) == 1:
179 | chained_datasets = iterable_datasets[0]
180 | else:
181 | chained_datasets = None
182 |
183 | concat_datasets = (
184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185 | )
186 |
187 | train_datasets = concat_datasets, chained_datasets
188 | train_datasets = tuple([x for x in train_datasets if x is not None])
189 | train_datasets = (
190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets
191 | )
192 |
193 | datasets[split_name] = train_datasets
194 |
195 | return datasets
196 |
197 |
--------------------------------------------------------------------------------
/timechat/models/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS. Below is the original copyright:
3 | Copyright (c) 2023, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 | import contextlib
9 | import logging
10 | import os
11 | import time
12 | import datetime
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.distributed as dist
17 | import torch.nn.functional as F
18 |
19 | import timechat.common.dist_utils as dist_utils
20 | from timechat.common.dist_utils import download_cached_file
21 | from timechat.common.utils import is_url
22 | from timechat.common.logger import MetricLogger
23 | from timechat.models.base_model import BaseModel
24 | from timechat.models.Qformer import BertConfig, BertLMHeadModel
25 | from timechat.models.eva_vit import create_eva_vit_g
26 | from transformers import BertTokenizer
27 |
28 |
29 | class Blip2Base(BaseModel):
30 | @classmethod
31 | def init_tokenizer(cls):
32 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
33 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
34 | return tokenizer
35 |
36 | def maybe_autocast(self, dtype=torch.float16):
37 | # if on cpu, don't use autocast
38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
39 | enable_autocast = self.device != torch.device("cpu")
40 |
41 | if enable_autocast:
42 | return torch.cuda.amp.autocast(dtype=dtype)
43 | else:
44 | return contextlib.nullcontext()
45 |
46 | @classmethod
47 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
48 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
49 | encoder_config.encoder_width = vision_width
50 | # insert cross-attention layer every other block
51 | encoder_config.add_cross_attention = True
52 | encoder_config.cross_attention_freq = cross_attention_freq
53 | encoder_config.query_length = num_query_token
54 | Qformer = BertLMHeadModel(config=encoder_config)
55 | query_tokens = nn.Parameter(
56 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
57 | )
58 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
59 | return Qformer, query_tokens
60 |
61 | @classmethod
62 | def init_vision_encoder(
63 | cls, url_or_filename, img_size, drop_path_rate, use_grad_checkpoint, precision
64 | ):
65 | assert "eva_vit_g" in url_or_filename, "vit model must be eva_vit_g for current version of MiniGPT-4"
66 | visual_encoder = create_eva_vit_g(
67 | url_or_filename, img_size, drop_path_rate, use_grad_checkpoint, precision
68 | )
69 |
70 | ln_vision = LayerNorm(visual_encoder.num_features)
71 | return visual_encoder, ln_vision
72 |
73 | def load_from_pretrained(self, url_or_filename):
74 | if is_url(url_or_filename):
75 | cached_file = download_cached_file(
76 | url_or_filename, check_hash=False, progress=True
77 | )
78 | checkpoint = torch.load(cached_file, map_location="cpu")
79 | elif os.path.isfile(url_or_filename):
80 | checkpoint = torch.load(url_or_filename, map_location="cpu")
81 | else:
82 | raise RuntimeError("checkpoint url or path is invalid")
83 |
84 | state_dict = checkpoint["model"]
85 |
86 | msg = self.load_state_dict(state_dict, strict=False)
87 |
88 | # logging.info("Missing keys {}".format(msg.missing_keys))
89 | logging.info("load checkpoint from %s" % url_or_filename)
90 |
91 | return msg
92 |
93 |
94 | def disabled_train(self, mode=True):
95 | """Overwrite model.train with this function to make sure train/eval mode
96 | does not change anymore."""
97 | return self
98 |
99 |
100 | class LayerNorm(nn.LayerNorm):
101 | """Subclass torch's LayerNorm to handle fp16."""
102 |
103 | def forward(self, x: torch.Tensor):
104 | orig_type = x.dtype
105 | ret = super().forward(x.type(torch.float32))
106 | return ret.type(orig_type)
107 |
108 |
109 | def compute_sim_matrix(model, data_loader, **kwargs):
110 | k_test = kwargs.pop("k_test")
111 |
112 | metric_logger = MetricLogger(delimiter=" ")
113 | header = "Evaluation:"
114 |
115 | logging.info("Computing features for evaluation...")
116 | start_time = time.time()
117 |
118 | texts = data_loader.dataset.text
119 | num_text = len(texts)
120 | text_bs = 256
121 | text_ids = []
122 | text_embeds = []
123 | text_atts = []
124 | for i in range(0, num_text, text_bs):
125 | text = texts[i : min(num_text, i + text_bs)]
126 | text_input = model.tokenizer(
127 | text,
128 | padding="max_length",
129 | truncation=True,
130 | max_length=35,
131 | return_tensors="pt",
132 | ).to(model.device)
133 | text_feat = model.forward_text(text_input)
134 | text_embed = F.normalize(model.text_proj(text_feat))
135 | text_embeds.append(text_embed)
136 | text_ids.append(text_input.input_ids)
137 | text_atts.append(text_input.attention_mask)
138 |
139 | text_embeds = torch.cat(text_embeds, dim=0)
140 | text_ids = torch.cat(text_ids, dim=0)
141 | text_atts = torch.cat(text_atts, dim=0)
142 |
143 | vit_feats = []
144 | image_embeds = []
145 | for samples in data_loader:
146 | image = samples["image"]
147 |
148 | image = image.to(model.device)
149 | image_feat, vit_feat = model.forward_image(image)
150 | image_embed = model.vision_proj(image_feat)
151 | image_embed = F.normalize(image_embed, dim=-1)
152 |
153 | vit_feats.append(vit_feat.cpu())
154 | image_embeds.append(image_embed)
155 |
156 | vit_feats = torch.cat(vit_feats, dim=0)
157 | image_embeds = torch.cat(image_embeds, dim=0)
158 |
159 | sims_matrix = []
160 | for image_embed in image_embeds:
161 | sim_q2t = image_embed @ text_embeds.t()
162 | sim_i2t, _ = sim_q2t.max(0)
163 | sims_matrix.append(sim_i2t)
164 | sims_matrix = torch.stack(sims_matrix, dim=0)
165 |
166 | score_matrix_i2t = torch.full(
167 | (len(data_loader.dataset.image), len(texts)), -100.0
168 | ).to(model.device)
169 |
170 | num_tasks = dist_utils.get_world_size()
171 | rank = dist_utils.get_rank()
172 | step = sims_matrix.size(0) // num_tasks + 1
173 | start = rank * step
174 | end = min(sims_matrix.size(0), start + step)
175 |
176 | for i, sims in enumerate(
177 | metric_logger.log_every(sims_matrix[start:end], 50, header)
178 | ):
179 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
180 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
181 | score = model.compute_itm(
182 | image_inputs=image_inputs,
183 | text_ids=text_ids[topk_idx],
184 | text_atts=text_atts[topk_idx],
185 | ).float()
186 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim
187 |
188 | sims_matrix = sims_matrix.t()
189 | score_matrix_t2i = torch.full(
190 | (len(texts), len(data_loader.dataset.image)), -100.0
191 | ).to(model.device)
192 |
193 | step = sims_matrix.size(0) // num_tasks + 1
194 | start = rank * step
195 | end = min(sims_matrix.size(0), start + step)
196 |
197 | for i, sims in enumerate(
198 | metric_logger.log_every(sims_matrix[start:end], 50, header)
199 | ):
200 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
201 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
202 | score = model.compute_itm(
203 | image_inputs=image_inputs,
204 | text_ids=text_ids[start + i].repeat(k_test, 1),
205 | text_atts=text_atts[start + i].repeat(k_test, 1),
206 | ).float()
207 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim
208 |
209 | if dist_utils.is_dist_avail_and_initialized():
210 | dist.barrier()
211 | torch.distributed.all_reduce(
212 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
213 | )
214 | torch.distributed.all_reduce(
215 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
216 | )
217 |
218 | total_time = time.time() - start_time
219 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
220 | logging.info("Evaluation time {}".format(total_time_str))
221 |
222 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
223 |
--------------------------------------------------------------------------------
/timechat/models/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS. Below is the original copyright:
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | import logging
10 | import os
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from timechat.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
16 | from timechat.common.utils import get_abs_path, is_url
17 | from omegaconf import OmegaConf
18 |
19 |
20 | class BaseModel(nn.Module):
21 | """Base class for models."""
22 |
23 | def __init__(self):
24 | super().__init__()
25 |
26 | @property
27 | def device(self):
28 | return list(self.parameters())[0].device
29 |
30 | def load_checkpoint(self, url_or_filename):
31 | """
32 | Load from a finetuned checkpoint.
33 |
34 | This should expect no mismatch in the model keys and the checkpoint keys.
35 | """
36 |
37 | if is_url(url_or_filename):
38 | cached_file = download_cached_file(
39 | url_or_filename, check_hash=False, progress=True
40 | )
41 | checkpoint = torch.load(cached_file, map_location="cpu")
42 | elif os.path.isfile(url_or_filename):
43 | checkpoint = torch.load(url_or_filename, map_location="cpu")
44 | else:
45 | raise RuntimeError("checkpoint url or path is invalid")
46 |
47 | if "model" in checkpoint.keys():
48 | state_dict = checkpoint["model"]
49 | else:
50 | state_dict = checkpoint
51 |
52 | msg = self.load_state_dict(state_dict, strict=False)
53 |
54 | logging.info("Missing keys {}".format(msg.missing_keys))
55 | logging.info("load checkpoint from %s" % url_or_filename)
56 |
57 | return msg
58 |
59 | @classmethod
60 | def from_pretrained(cls, model_type):
61 | """
62 | Build a pretrained model from default configuration file, specified by model_type.
63 |
64 | Args:
65 | - model_type (str): model type, specifying architecture and checkpoints.
66 |
67 | Returns:
68 | - model (nn.Module): pretrained or finetuned model, depending on the configuration.
69 | """
70 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
71 | model = cls.from_config(model_cfg)
72 |
73 | return model
74 |
75 | @classmethod
76 | def default_config_path(cls, model_type):
77 | assert (
78 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
79 | ), "Unknown model type {}".format(model_type)
80 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
81 |
82 | def load_checkpoint_from_config(self, cfg, **kwargs):
83 | """
84 | Load checkpoint as specified in the config file.
85 |
86 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
87 | When loading the pretrained model, each task-specific architecture may define their
88 | own load_from_pretrained() method.
89 | """
90 | load_finetuned = cfg.get("load_finetuned", True)
91 | if load_finetuned:
92 | finetune_path = cfg.get("finetuned", None)
93 | assert (
94 | finetune_path is not None
95 | ), "Found load_finetuned is True, but finetune_path is None."
96 | self.load_checkpoint(url_or_filename=finetune_path)
97 | else:
98 | # load pre-trained weights
99 | pretrain_path = cfg.get("pretrained", None)
100 | assert "Found load_finetuned is False, but pretrain_path is None."
101 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
102 |
103 | def before_evaluation(self, **kwargs):
104 | pass
105 |
106 | def show_n_params(self, return_str=True):
107 | tot = 0
108 | for p in self.parameters():
109 | w = 1
110 | for x in p.shape:
111 | w *= x
112 | tot += w
113 | if return_str:
114 | if tot >= 1e6:
115 | return "{:.1f}M".format(tot / 1e6)
116 | else:
117 | return "{:.1f}K".format(tot / 1e3)
118 | else:
119 | return tot
120 |
121 |
122 | class BaseEncoder(nn.Module):
123 | """
124 | Base class for primitive encoders, such as ViT, TimeSformer, etc.
125 | """
126 |
127 | def __init__(self):
128 | super().__init__()
129 |
130 | def forward_features(self, samples, **kwargs):
131 | raise NotImplementedError
132 |
133 | @property
134 | def device(self):
135 | return list(self.parameters())[0].device
136 |
137 |
138 | class SharedQueueMixin:
139 | @torch.no_grad()
140 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
141 | # gather keys before updating queue
142 | image_feats = concat_all_gather(image_feat)
143 | text_feats = concat_all_gather(text_feat)
144 |
145 | batch_size = image_feats.shape[0]
146 |
147 | ptr = int(self.queue_ptr)
148 | assert self.queue_size % batch_size == 0 # for simplicity
149 |
150 | # replace the keys at ptr (dequeue and enqueue)
151 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
152 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
153 |
154 | if idxs is not None:
155 | idxs = concat_all_gather(idxs)
156 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
157 |
158 | ptr = (ptr + batch_size) % self.queue_size # move pointer
159 | self.queue_ptr[0] = ptr
160 |
161 |
162 | class MomentumDistilationMixin:
163 | @torch.no_grad()
164 | def copy_params(self):
165 | for model_pair in self.model_pairs:
166 | for param, param_m in zip(
167 | model_pair[0].parameters(), model_pair[1].parameters()
168 | ):
169 | param_m.data.copy_(param.data) # initialize
170 | param_m.requires_grad = False # not update by gradient
171 |
172 | @torch.no_grad()
173 | def _momentum_update(self):
174 | for model_pair in self.model_pairs:
175 | for param, param_m in zip(
176 | model_pair[0].parameters(), model_pair[1].parameters()
177 | ):
178 | param_m.data = param_m.data * self.momentum + param.data * (
179 | 1.0 - self.momentum
180 | )
181 |
182 |
183 | class GatherLayer(torch.autograd.Function):
184 | """
185 | Gather tensors from all workers with support for backward propagation:
186 | This implementation does not cut the gradients as torch.distributed.all_gather does.
187 | """
188 |
189 | @staticmethod
190 | def forward(ctx, x):
191 | output = [
192 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
193 | ]
194 | torch.distributed.all_gather(output, x)
195 | return tuple(output)
196 |
197 | @staticmethod
198 | def backward(ctx, *grads):
199 | all_gradients = torch.stack(grads)
200 | torch.distributed.all_reduce(all_gradients)
201 | return all_gradients[torch.distributed.get_rank()]
202 |
203 |
204 | def all_gather_with_grad(tensors):
205 | """
206 | Performs all_gather operation on the provided tensors.
207 | Graph remains connected for backward grad computation.
208 | """
209 | # Queue the gathered tensors
210 | world_size = torch.distributed.get_world_size()
211 | # There is no need for reduction in the single-proc case
212 | if world_size == 1:
213 | return tensors
214 |
215 | # tensor_all = GatherLayer.apply(tensors)
216 | tensor_all = GatherLayer.apply(tensors)
217 |
218 | return torch.cat(tensor_all, dim=0)
219 |
220 |
221 | @torch.no_grad()
222 | def concat_all_gather(tensor):
223 | """
224 | Performs all_gather operation on the provided tensors.
225 | *** Warning ***: torch.distributed.all_gather has no gradient.
226 | """
227 | # if use distributed training
228 | if not is_dist_avail_and_initialized():
229 | return tensor
230 |
231 | tensors_gather = [
232 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
233 | ]
234 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
235 |
236 | output = torch.cat(tensors_gather, dim=0)
237 | return output
238 |
239 |
240 | def tile(x, dim, n_tile):
241 | init_dim = x.size(dim)
242 | repeat_idx = [1] * x.dim()
243 | repeat_idx[dim] = n_tile
244 | x = x.repeat(*(repeat_idx))
245 | order_index = torch.LongTensor(
246 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
247 | )
248 | return torch.index_select(x, dim, order_index.to(x.device))
249 |
--------------------------------------------------------------------------------
/timechat/datasets/builders/base_dataset_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | import logging
10 | import os
11 | import shutil
12 | import warnings
13 |
14 | from omegaconf import OmegaConf
15 | import torch.distributed as dist
16 | from torchvision.datasets.utils import download_url
17 |
18 | import timechat.common.utils as utils
19 | from timechat.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20 | from timechat.common.registry import registry
21 | from timechat.processors.base_processor import BaseProcessor
22 |
23 |
24 |
25 | class BaseDatasetBuilder:
26 | train_dataset_cls, eval_dataset_cls = None, None
27 |
28 | def __init__(self, cfg=None):
29 | super().__init__()
30 |
31 | if cfg is None:
32 | # help to create datasets from default config.
33 | self.config = load_dataset_config(self.default_config_path())
34 | elif isinstance(cfg, str):
35 | self.config = load_dataset_config(cfg)
36 | else:
37 | # when called from task.build_dataset()
38 | self.config = cfg
39 |
40 | self.data_type = self.config.data_type
41 |
42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44 |
45 | def build_datasets(self):
46 | # download, split, etc...
47 | # only called on 1 GPU/TPU in distributed
48 |
49 | if is_main_process():
50 | self._download_data()
51 |
52 | if is_dist_avail_and_initialized():
53 | dist.barrier()
54 |
55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56 | logging.info("Building datasets...")
57 | datasets = self.build() # dataset['train'/'val'/'test']
58 |
59 | return datasets
60 |
61 | def build_processors(self):
62 | vis_proc_cfg = self.config.get("vis_processor")
63 | txt_proc_cfg = self.config.get("text_processor")
64 |
65 | if vis_proc_cfg is not None:
66 | vis_train_cfg = vis_proc_cfg.get("train")
67 | vis_eval_cfg = vis_proc_cfg.get("eval")
68 |
69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71 |
72 | if txt_proc_cfg is not None:
73 | txt_train_cfg = txt_proc_cfg.get("train")
74 | txt_eval_cfg = txt_proc_cfg.get("eval")
75 |
76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78 |
79 | @staticmethod
80 | def _build_proc_from_cfg(cfg):
81 | return (
82 | registry.get_processor_class(cfg.name).from_config(cfg)
83 | if cfg is not None
84 | else None
85 | )
86 |
87 | @classmethod
88 | def default_config_path(cls, type="default"):
89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90 |
91 | def _download_data(self):
92 | self._download_ann()
93 | self._download_vis()
94 |
95 | def _download_ann(self):
96 | """
97 | Download annotation files if necessary.
98 | All the vision-language datasets should have annotations of unified format.
99 |
100 | storage_path can be:
101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103 |
104 | Local annotation paths should be relative.
105 | """
106 | anns = self.config.build_info.annotations
107 |
108 | splits = anns.keys()
109 |
110 | cache_root = registry.get_path("cache_root")
111 |
112 | for split in splits:
113 | info = anns[split]
114 |
115 | urls, storage_paths = info.get("url", None), info.storage
116 |
117 | if isinstance(urls, str):
118 | urls = [urls]
119 | if isinstance(storage_paths, str):
120 | storage_paths = [storage_paths]
121 |
122 | assert len(urls) == len(storage_paths)
123 |
124 | for url_or_filename, storage_path in zip(urls, storage_paths):
125 | # if storage_path is relative, make it full by prefixing with cache_root.
126 | if not os.path.isabs(storage_path):
127 | storage_path = os.path.join(cache_root, storage_path)
128 |
129 | dirname = os.path.dirname(storage_path)
130 | if not os.path.exists(dirname):
131 | os.makedirs(dirname)
132 |
133 | if os.path.isfile(url_or_filename):
134 | src, dst = url_or_filename, storage_path
135 | if not os.path.exists(dst):
136 | shutil.copyfile(src=src, dst=dst)
137 | else:
138 | logging.info("Using existing file {}.".format(dst))
139 | else:
140 | if os.path.isdir(storage_path):
141 | # if only dirname is provided, suffix with basename of URL.
142 | raise ValueError(
143 | "Expecting storage_path to be a file path, got directory {}".format(
144 | storage_path
145 | )
146 | )
147 | else:
148 | filename = os.path.basename(storage_path)
149 |
150 | download_url(url=url_or_filename, root=dirname, filename=filename)
151 |
152 | def _download_vis(self):
153 |
154 | storage_path = self.config.build_info.get(self.data_type).storage
155 | storage_path = utils.get_cache_path(storage_path)
156 |
157 | if not os.path.exists(storage_path):
158 | warnings.warn(
159 | f"""
160 | The specified path {storage_path} for visual inputs does not exist.
161 | Please provide a correct path to the visual inputs or
162 | refer to datasets/download_scripts/README.md for downloading instructions.
163 | """
164 | )
165 |
166 | def build(self):
167 | """
168 | Create by split datasets inheriting torch.utils.data.Datasets.
169 |
170 | # build() can be dataset-specific. Overwrite to customize.
171 | """
172 | self.build_processors()
173 |
174 | build_info = self.config.build_info
175 |
176 | ann_info = build_info.annotations
177 | vis_info = build_info.get(self.data_type)
178 |
179 | datasets = dict()
180 | for split in ann_info.keys():
181 | if split not in ["train", "val", "test"]:
182 | continue
183 |
184 | is_train = split == "train"
185 |
186 | # processors
187 | vis_processor = (
188 | self.vis_processors["train"]
189 | if is_train
190 | else self.vis_processors["eval"]
191 | )
192 | text_processor = (
193 | self.text_processors["train"]
194 | if is_train
195 | else self.text_processors["eval"]
196 | )
197 |
198 | # annotation path
199 | ann_paths = ann_info.get(split).storage
200 | if isinstance(ann_paths, str):
201 | ann_paths = [ann_paths]
202 |
203 | abs_ann_paths = []
204 | for ann_path in ann_paths:
205 | if not os.path.isabs(ann_path):
206 | ann_path = utils.get_cache_path(ann_path)
207 | abs_ann_paths.append(ann_path)
208 | ann_paths = abs_ann_paths
209 |
210 | # visual data storage path
211 | vis_path = os.path.join(vis_info.storage, split)
212 |
213 | if not os.path.isabs(vis_path):
214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215 | vis_path = utils.get_cache_path(vis_path)
216 |
217 | if not os.path.exists(vis_path):
218 | warnings.warn("storage path {} does not exist.".format(vis_path))
219 |
220 | # create datasets
221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222 | datasets[split] = dataset_cls(
223 | vis_processor=vis_processor,
224 | text_processor=text_processor,
225 | ann_paths=ann_paths,
226 | vis_root=vis_path,
227 | )
228 |
229 | return datasets
230 |
231 |
232 | def load_dataset_config(cfg_path):
233 | cfg = OmegaConf.load(cfg_path).datasets
234 | cfg = cfg[list(cfg.keys())[0]]
235 |
236 | return cfg
237 |
--------------------------------------------------------------------------------
/timechat/processors/video_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import torch
9 | from timechat.common.registry import registry
10 | from decord import VideoReader
11 | import decord
12 | import numpy as np
13 | from timechat.processors import transforms_video
14 | from timechat.processors.base_processor import BaseProcessor
15 | from timechat.processors.randaugment import VideoRandomAugment
16 | from timechat.processors import functional_video as F
17 | from omegaconf import OmegaConf
18 | from torchvision import transforms
19 | import random as rnd
20 |
21 |
22 | MAX_INT = registry.get("MAX_INT")
23 | decord.bridge.set_bridge("torch")
24 |
25 |
26 | def interpolate_frame_pos_embed(frame_pos_embed_ckpt, new_n_frm=96):
27 | # interpolate frame position embedding
28 | # frame_pos_embed_ckpt: (old_n_frm, dim)
29 | frame_pos_embed_ckpt = frame_pos_embed_ckpt.unsqueeze(0).transpose(1, 2) # (1, dim, old_n_frm)
30 | new_frame_pos_embed_ckpt = torch.nn.functional.interpolate(frame_pos_embed_ckpt, size=(new_n_frm),
31 | mode='nearest') # (1, dim, new_n_frm)
32 | new_frame_pos_embed_ckpt = new_frame_pos_embed_ckpt.transpose(1, 2).squeeze(0) # (new_n_frm, dim)
33 | return new_frame_pos_embed_ckpt
34 |
35 |
36 | def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False):
37 | decord.bridge.set_bridge("torch")
38 | vr = VideoReader(uri=video_path, height=height, width=width)
39 |
40 | vlen = len(vr)
41 | start, end = 0, vlen
42 | acc_samples = min(n_frms, vlen)
43 | n_frms = min(n_frms, vlen)
44 |
45 | if sampling == "uniform":
46 | indices = np.arange(start, end, vlen / n_frms).astype(int).tolist()
47 | elif sampling == "headtail":
48 | indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
49 | indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
50 | indices = indices_h + indices_t
51 | elif sampling == 'rand':
52 | # split the video into `acc_samples` intervals, and sample from each interval.
53 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
54 | ranges = []
55 | for idx, interv in enumerate(intervals[:-1]):
56 | ranges.append((interv, intervals[idx + 1] - 1))
57 | try:
58 | indices = [rnd.choice(range(x[0], x[1])) for x in ranges]
59 | except:
60 | indices = np.random.permutation(vlen)[:acc_samples]
61 | indices.sort()
62 | indices = list(indices)
63 | else:
64 | raise NotImplementedError
65 |
66 | # get_batch -> T, H, W, C
67 | temp_frms = vr.get_batch(indices)
68 | # print(type(temp_frms))
69 | tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
70 | frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W)
71 |
72 | if not return_msg:
73 | return frms
74 |
75 | fps = float(vr.get_avg_fps())
76 | sec = ", ".join([str(round(f / fps, 1)) for f in indices])
77 | # " " should be added in the start and end
78 | msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. "
79 | return frms, msg
80 |
81 |
82 | class AlproVideoBaseProcessor(BaseProcessor):
83 | def __init__(self, mean=None, std=None, n_frms=MAX_INT):
84 | if mean is None:
85 | mean = (0.48145466, 0.4578275, 0.40821073)
86 | if std is None:
87 | std = (0.26862954, 0.26130258, 0.27577711)
88 |
89 | self.normalize = transforms_video.NormalizeVideo(mean, std)
90 |
91 | self.n_frms = n_frms
92 |
93 |
94 | class ToUint8(object):
95 | def __init__(self):
96 | pass
97 |
98 | def __call__(self, tensor):
99 | return tensor.to(torch.uint8)
100 |
101 | def __repr__(self):
102 | return self.__class__.__name__
103 |
104 |
105 | class ToTHWC(object):
106 | """
107 | Args:
108 | clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)
109 | Return:
110 | clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)
111 | """
112 |
113 | def __init__(self):
114 | pass
115 |
116 | def __call__(self, tensor):
117 | return tensor.permute(1, 2, 3, 0)
118 |
119 | def __repr__(self):
120 | return self.__class__.__name__
121 |
122 |
123 | class ResizeVideo(object):
124 | def __init__(self, target_size, interpolation_mode="bilinear"):
125 | self.target_size = target_size
126 | self.interpolation_mode = interpolation_mode
127 |
128 | def __call__(self, clip):
129 | """
130 | Args:
131 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
132 | Returns:
133 | torch.tensor: central cropping of video clip. Size is
134 | (C, T, crop_size, crop_size)
135 | """
136 | return F.resize(clip, self.target_size, self.interpolation_mode)
137 |
138 | def __repr__(self):
139 | return self.__class__.__name__ + "(resize_size={0})".format(self.target_size)
140 |
141 |
142 | @registry.register_processor("alpro_video_train")
143 | class AlproVideoTrainProcessor(AlproVideoBaseProcessor):
144 | def __init__(
145 | self,
146 | image_size=384,
147 | mean=None,
148 | std=None,
149 | min_scale=0.5,
150 | max_scale=1.0,
151 | n_frms=MAX_INT,
152 | ):
153 | super().__init__(mean=mean, std=std, n_frms=n_frms)
154 |
155 | self.image_size = image_size
156 |
157 | self.transform = transforms.Compose(
158 | [
159 | # Video size is (C, T, H, W)
160 | transforms_video.RandomResizedCropVideo(
161 | image_size,
162 | scale=(min_scale, max_scale),
163 | interpolation_mode="bicubic",
164 | ),
165 | ToTHWC(), # C, T, H, W -> T, H, W, C
166 | ToUint8(),
167 | transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W
168 | self.normalize,
169 | ]
170 | )
171 |
172 | def __call__(self, vpath):
173 | """
174 | Args:
175 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
176 | Returns:
177 | torch.tensor: video clip after transforms. Size is (C, T, size, size).
178 | """
179 | clip = load_video(
180 | video_path=vpath,
181 | n_frms=self.n_frms,
182 | height=self.image_size,
183 | width=self.image_size,
184 | sampling="headtail",
185 | )
186 |
187 | return self.transform(clip)
188 |
189 | @classmethod
190 | def from_config(cls, cfg=None):
191 | if cfg is None:
192 | cfg = OmegaConf.create()
193 |
194 | image_size = cfg.get("image_size", 256)
195 |
196 | mean = cfg.get("mean", None)
197 | std = cfg.get("std", None)
198 |
199 | min_scale = cfg.get("min_scale", 0.5)
200 | max_scale = cfg.get("max_scale", 1.0)
201 |
202 | n_frms = cfg.get("n_frms", MAX_INT)
203 |
204 | return cls(
205 | image_size=image_size,
206 | mean=mean,
207 | std=std,
208 | min_scale=min_scale,
209 | max_scale=max_scale,
210 | n_frms=n_frms,
211 | )
212 |
213 |
214 | @registry.register_processor("alpro_video_eval")
215 | class AlproVideoEvalProcessor(AlproVideoBaseProcessor):
216 | def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT):
217 | super().__init__(mean=mean, std=std, n_frms=n_frms)
218 |
219 | self.image_size = image_size
220 |
221 | # Input video size is (C, T, H, W)
222 | self.transform = transforms.Compose(
223 | [
224 | # frames will be resized during decord loading.
225 | ToUint8(), # C, T, H, W
226 | ToTHWC(), # T, H, W, C
227 | transforms_video.ToTensorVideo(), # C, T, H, W
228 | self.normalize, # C, T, H, W
229 | ]
230 | )
231 |
232 | def __call__(self, vpath):
233 | """
234 | Args:
235 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
236 | Returns:
237 | torch.tensor: video clip after transforms. Size is (C, T, size, size).
238 | """
239 | clip = load_video(
240 | video_path=vpath,
241 | n_frms=self.n_frms,
242 | height=self.image_size,
243 | width=self.image_size,
244 | )
245 |
246 | return self.transform(clip)
247 |
248 | @classmethod
249 | def from_config(cls, cfg=None):
250 | if cfg is None:
251 | cfg = OmegaConf.create()
252 |
253 | image_size = cfg.get("image_size", 256)
254 |
255 | mean = cfg.get("mean", None)
256 | std = cfg.get("std", None)
257 |
258 | n_frms = cfg.get("n_frms", MAX_INT)
259 |
260 | return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)
261 |
--------------------------------------------------------------------------------
/utils/cons_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import re
4 | import numpy as np
5 | import argparse
6 | import pandas as pd
7 | import random
8 | import torch
9 | import os
10 | import logging
11 |
12 |
13 | # ANSI escape codes for colors
14 | class Formatter(logging.Formatter):
15 | COLOR_CODES = {
16 | 'DEBUG': '\033[94m', # Blue
17 | 'INFO': '\033[92m', # Green
18 | 'WARNING': '\033[91m', # Red
19 | 'ERROR': '\033[93m', # Yellow
20 | 'CRITICAL': '\033[95m', # Magenta
21 | }
22 | RESET_CODE = '\033[0m' # Reset color
23 |
24 | def format(self, record):
25 | log_color = self.COLOR_CODES.get(record.levelname, self.RESET_CODE)
26 | message = super().format(record)
27 | return f"{log_color}{message}{self.RESET_CODE}"
28 |
29 |
30 | def load_logger(name):
31 | # Custom logger setup
32 | logger = logging.getLogger(name)
33 | logger.setLevel(logging.DEBUG)
34 |
35 | # Console handler with color formatter
36 | console_handler = logging.StreamHandler()
37 | console_handler.setLevel(logging.DEBUG)
38 |
39 | # Define formatter with color
40 | formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41 | console_handler.setFormatter(formatter)
42 |
43 | # Add handler to logger
44 | logger.addHandler(console_handler)
45 | return logger
46 |
47 |
48 | class BaseOptions(object):
49 | def __init__(self):
50 | self.parser = None
51 | self.initialized = False
52 | self.opt = None
53 |
54 | def initialize(self):
55 | self.initialized = True
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--model_type", type=str, default='TimeChat',
58 | choices=['Video-ChatGPT', 'Video-LLaMA', 'Video-LLaMA2', 'Video-LLaVA', 'VTimeLLM', 'TimeChat', 'VTG-LLM', 'VideoChat2', 'GPT4', 'Gemini'],
59 | help="A list of Video-LLMs.")
60 | parser.add_argument("--dset_name", type=str, default="charades", choices=['activitynet', 'charades'],
61 | help="Dataset name.")
62 | parser.add_argument("--task", type=str, default="consistency", choices=['grounding', 'consistency'],
63 | help="Type of task.")
64 | parser.add_argument("--grounding_prompt", type=int, default=None)
65 | parser.add_argument('--description', action="store_true",
66 | help="Prompt the model to generate a video description before performing target tasks.")
67 | parser.add_argument('--CoT', action="store_true", help="Utilizes Chain-of-Thought Reasoning.")
68 | parser.add_argument('--fine_tuned', action="store_true")
69 | parser.add_argument('--iou_thd', type=float, default=0.5)
70 | parser.add_argument('--no_skip', action="store_true",
71 | help="Test the probes even if the initial prediction is in accurate.")
72 | parser.add_argument('--overwrite', action="store_true")
73 | parser.add_argument("--video_root", type=str, default="/data/video_datasets/", help="path to video files")
74 | parser.add_argument("--output_dir", type=str, default=None, help="path to output files")
75 | parser.add_argument("--exp_id", type=str, default=None, help="ID of this run.")
76 | parser.add_argument("--seed", type=int, default=1000)
77 | parser.add_argument('--debug', action="store_true", help="Debug mode.")
78 | self.parser = parser
79 |
80 | def parse(self):
81 | if not self.initialized:
82 | self.initialize()
83 |
84 | opt = self.parser.parse_args()
85 |
86 | if not opt.video_root:
87 | opt.video_root = f"/data/video_datasets/{opt.dset_name}"
88 | else:
89 | opt.video_root = os.path.join(opt.video_root, opt.dset_name)
90 |
91 | opt.test_path = f"data/{opt.dset_name}_consistency_test.json"
92 |
93 | self.opt = opt
94 | return opt
95 |
96 |
97 | def generate_question(task, prompt, query, duration, st=None, ed=None):
98 | choice = random.choice(["pos", "neg"])
99 | if st and ed:
100 | st, ed = min(st, duration), min(ed, duration)
101 |
102 | add_detail = prompt["add_detail"]
103 | if task in ["grounding"]:
104 | question = prompt[task].format(event=query)
105 | add_detail = None
106 |
107 | elif task in ["description"]:
108 | question = prompt[task]
109 | add_detail = None
110 |
111 | elif task in ["occurrence"]:
112 | question = random.choice(prompt[choice]).format(event=query, st=st, ed=ed)
113 |
114 | elif task in ["compositional"]:
115 | query = query.replace("?", "")
116 | question = prompt[task].format(question=query, st=st, ed=ed)
117 |
118 | else:
119 | raise NotImplementedError(f"Not implemented task: {task}")
120 |
121 | return question, add_detail, choice
122 |
123 |
124 | def load_jsonl(filename):
125 | with open(filename, "r") as f:
126 | return [json.loads(l.replace("'","").strip("\n")) for l in f.readlines()]
127 |
128 |
129 | def save_jsonl(data, filename):
130 | """data is a list"""
131 | with open(filename, "w") as f:
132 | f.write("\n".join([json.dumps(e) for e in data]))
133 |
134 |
135 | def save_json(data, filename, save_pretty=False, sort_keys=False):
136 | with open(filename, "w") as f:
137 | if save_pretty:
138 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
139 | else:
140 | json.dump(data, f)
141 |
142 |
143 | def load_json(filename):
144 | with open(filename, "r") as f:
145 | return json.load(f)
146 |
147 |
148 | def get_iou(A, B):
149 | try:
150 | max0 = max((A[0]), (B[0]))
151 | min0 = min((A[0]), (B[0]))
152 | max1 = max((A[1]), (B[1]))
153 | min1 = min((A[1]), (B[1]))
154 |
155 | return round(max(min1 - max0, 0) / (max1 - min0), 2)
156 |
157 | except:
158 | return 0
159 |
160 |
161 | def shifting_video_moment(video_features, org_timestamp, new_timestamp, duration):
162 | """
163 | Shifts frames of a video between the original timestamp and new timestamp.
164 |
165 | video_features: The input video features (either list or torch.Tensor)
166 | org_timestamp: The original start and end time (in seconds) for the part of the video to be moved.
167 | new_timestamp: The new start and end time (in seconds) where the original frames should be shifted.
168 | duration: Total duration of the video (in seconds).
169 |
170 | The format of video_features in Video-LLaVA and TimeChat is list, containing tensor features.
171 | """
172 | if isinstance(video_features, list):
173 | # Handle list of frames
174 | n_frames = len(video_features)
175 | if not isinstance(video_features[0], torch.Tensor):
176 | _img_embeds = copy.deepcopy(video_features)
177 | org_frame = second_to_frame(n_frames, org_timestamp, duration)
178 | new_frame = second_to_frame(n_frames, new_timestamp, duration)
179 |
180 | # Perform the shift
181 | _img_embeds[org_frame[0]: org_frame[1] + 1] = video_features[new_frame[0]: new_frame[1] + 1]
182 | _img_embeds[new_frame[0]: new_frame[1] + 1] = video_features[org_frame[0]: org_frame[1] + 1]
183 | # print("to", video_features[0].shape, len(video_features[0]))
184 | return _img_embeds
185 | else:
186 | img_embes = video_features[0]
187 | _img_embeds = img_embes.clone()
188 |
189 | org_frame = second_to_frame(n_frames, org_timestamp, duration)
190 | new_frame = second_to_frame(n_frames, new_timestamp, duration)
191 |
192 | # Perform the shift
193 | _img_embeds[org_frame[0]: org_frame[1]+1] = img_embes[new_frame[0]: new_frame[1]+1]
194 | _img_embeds[new_frame[0]: new_frame[1]+1] = img_embes[org_frame[0]: org_frame[1]+1]
195 | # print("to", video_features[0].shape, len(video_features[0]))
196 | return [_img_embeds]
197 |
198 | elif isinstance(video_features, torch.Tensor):
199 | n_frames = video_features.shape[0]
200 | org_frame = second_to_frame(n_frames, org_timestamp, duration)
201 | new_frame = second_to_frame(n_frames, new_timestamp, duration)
202 |
203 | # Calculate the number of frames in each range
204 | org_length = org_frame[1] - org_frame[0]
205 | new_length = new_frame[1] - new_frame[0]
206 |
207 | # Find the minimum length to avoid shape mismatch
208 | min_length = min(org_length, new_length)
209 |
210 | # Extract the original and new frame segments with adjusted lengths
211 | org_frame_feat = video_features[org_frame[0]: org_frame[0] + min_length, :]
212 | new_frame_feat = video_features[new_frame[0]: new_frame[0] + min_length, :]
213 |
214 | # Clone the tensor to avoid in-place overwriting
215 | shifted_video_features = video_features.clone()
216 |
217 | # Perform the swap, making sure both ranges are equal in length
218 | shifted_video_features[org_frame[0]: org_frame[0] + min_length, :] = new_frame_feat
219 | shifted_video_features[new_frame[0]: new_frame[0] + min_length, :] = org_frame_feat
220 |
221 | return shifted_video_features
222 |
223 |
224 | def dict_to_markdown(d, max_str_len=120):
225 | # convert list into its str representation
226 | d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()}
227 | # truncate string that is longer than max_str_len
228 | if max_str_len is not None:
229 | d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()}
230 | return pd.DataFrame(d, index=[0]).transpose().to_markdown()
231 |
232 |
233 | def display(args):
234 | opt = args if isinstance(args, dict) else vars(args)
235 | # opt = vars(args)
236 | print(dict_to_markdown(opt, max_str_len=120))
237 |
238 |
239 | def second_to_frame(n_frame, seconds, duration):
240 | return [int(seconds[0] / duration * n_frame), int(seconds[1] / duration * n_frame)]
--------------------------------------------------------------------------------
/timechat/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 | import wandb
11 | import torch
12 | import torch.distributed as dist
13 | from timechat.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
14 | from timechat.common.logger import MetricLogger, SmoothedValue
15 | from timechat.common.registry import registry
16 | from timechat.datasets.data_utils import prepare_sample
17 |
18 |
19 | class BaseTask:
20 | def __init__(self, **kwargs):
21 | super().__init__()
22 |
23 | self.inst_id_key = "instance_id"
24 |
25 | @classmethod
26 | def setup_task(cls, **kwargs):
27 | return cls()
28 |
29 | def build_model(self, cfg):
30 | model_config = cfg.model_cfg
31 |
32 | model_cls = registry.get_model_class(model_config.arch)
33 | return model_cls.from_config(model_config)
34 |
35 | def build_datasets(self, cfg):
36 | """
37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
38 | Download dataset and annotations automatically if not exist.
39 |
40 | Args:
41 | cfg (common.config.Config): _description_
42 |
43 | Returns:
44 | dict: Dictionary of torch.utils.data.Dataset objects by split.
45 | """
46 |
47 | datasets = dict()
48 |
49 | datasets_config = cfg.datasets_cfg
50 |
51 | assert len(datasets_config) > 0, "At least one dataset has to be specified."
52 |
53 | for name in datasets_config:
54 | dataset_config = datasets_config[name]
55 |
56 | builder = registry.get_builder_class(name)(dataset_config)
57 | dataset = builder.build_datasets()
58 |
59 | dataset['train'].name = name
60 | if 'sample_ratio' in dataset_config:
61 | dataset['train'].sample_ratio = dataset_config.sample_ratio
62 |
63 | datasets[name] = dataset
64 |
65 | return datasets
66 |
67 | def train_step(self, model, samples):
68 | loss = model(samples)["loss"]
69 | return loss
70 |
71 | def valid_step(self, model, samples):
72 | raise NotImplementedError
73 |
74 | def before_evaluation(self, model, dataset, **kwargs):
75 | model.before_evaluation(dataset=dataset, task_type=type(self))
76 |
77 | def after_evaluation(self, **kwargs):
78 | pass
79 |
80 | def inference_step(self):
81 | raise NotImplementedError
82 |
83 | def evaluation(self, model, data_loader, cuda_enabled=True):
84 | metric_logger = MetricLogger(delimiter=" ")
85 | header = "Evaluation"
86 | # TODO make it configurable
87 | print_freq = 10
88 |
89 | results = []
90 |
91 | for samples in metric_logger.log_every(data_loader, print_freq, header):
92 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
93 |
94 | eval_output = self.valid_step(model=model, samples=samples)
95 | results.extend(eval_output)
96 |
97 | if is_dist_avail_and_initialized():
98 | dist.barrier()
99 |
100 | return results
101 |
102 | def train_epoch(
103 | self,
104 | epoch,
105 | model,
106 | data_loader,
107 | optimizer,
108 | lr_scheduler,
109 | scaler=None,
110 | cuda_enabled=False,
111 | log_freq=50,
112 | accum_grad_iters=1,
113 | ):
114 | return self._train_inner_loop(
115 | epoch=epoch,
116 | iters_per_epoch=lr_scheduler.iters_per_epoch,
117 | model=model,
118 | data_loader=data_loader,
119 | optimizer=optimizer,
120 | scaler=scaler,
121 | lr_scheduler=lr_scheduler,
122 | log_freq=log_freq,
123 | cuda_enabled=cuda_enabled,
124 | accum_grad_iters=accum_grad_iters,
125 | )
126 |
127 | def train_iters(
128 | self,
129 | epoch,
130 | start_iters,
131 | iters_per_inner_epoch,
132 | model,
133 | data_loader,
134 | optimizer,
135 | lr_scheduler,
136 | scaler=None,
137 | cuda_enabled=False,
138 | log_freq=50,
139 | accum_grad_iters=1,
140 | ):
141 | return self._train_inner_loop(
142 | epoch=epoch,
143 | start_iters=start_iters,
144 | iters_per_epoch=iters_per_inner_epoch,
145 | model=model,
146 | data_loader=data_loader,
147 | optimizer=optimizer,
148 | scaler=scaler,
149 | lr_scheduler=lr_scheduler,
150 | log_freq=log_freq,
151 | cuda_enabled=cuda_enabled,
152 | accum_grad_iters=accum_grad_iters,
153 | )
154 |
155 | def _train_inner_loop(
156 | self,
157 | epoch,
158 | iters_per_epoch,
159 | model,
160 | data_loader,
161 | optimizer,
162 | lr_scheduler,
163 | scaler=None,
164 | start_iters=None,
165 | log_freq=50,
166 | cuda_enabled=False,
167 | accum_grad_iters=1,
168 | ):
169 | """
170 | An inner training loop compatible with both epoch-based and iter-based training.
171 |
172 | When using epoch-based, training stops after one epoch; when using iter-based,
173 | training stops after #iters_per_epoch iterations.
174 | """
175 | use_amp = scaler is not None
176 |
177 | if not hasattr(data_loader, "__next__"):
178 | # convert to iterator if not already
179 | data_loader = iter(data_loader)
180 |
181 | metric_logger = MetricLogger(delimiter=" ")
182 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
183 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
184 |
185 | # if iter-based runner, schedule lr based on inner epoch.
186 | logging.info(
187 | "Start training epoch {}, {} iters per inner epoch.".format(
188 | epoch, iters_per_epoch
189 | )
190 | )
191 | header = "Train: data epoch: [{}]".format(epoch)
192 | if start_iters is None:
193 | # epoch-based runner
194 | inner_epoch = epoch
195 | else:
196 | # In iter-based runner, we schedule the learning rate based on iterations.
197 | inner_epoch = start_iters // iters_per_epoch
198 | header = header + "; inner epoch [{}]".format(inner_epoch)
199 |
200 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
201 | # if using iter-based runner, we stop after iters_per_epoch iterations.
202 | if i >= iters_per_epoch:
203 | break
204 |
205 | samples = next(data_loader)
206 |
207 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
208 | samples.update(
209 | {
210 | "epoch": inner_epoch,
211 | "num_iters_per_epoch": iters_per_epoch,
212 | "iters": i,
213 | }
214 | )
215 |
216 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
217 |
218 | with torch.cuda.amp.autocast(enabled=use_amp):
219 | loss = self.train_step(model=model, samples=samples)
220 |
221 | # after_train_step()
222 | if use_amp:
223 | scaler.scale(loss).backward()
224 | else:
225 | loss.backward()
226 |
227 | # update gradients every accum_grad_iters iterations
228 | if (i + 1) % accum_grad_iters == 0:
229 | if use_amp:
230 | scaler.step(optimizer)
231 | scaler.update()
232 | else:
233 | optimizer.step()
234 | optimizer.zero_grad()
235 |
236 | metric_logger.update(loss=loss.item())
237 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
238 |
239 | if is_main_process() and wandb.run is not None:
240 | wandb.log({'train/loss': loss.item(),
241 | 'train/lr': optimizer.param_groups[0]["lr"]},
242 | step=epoch * iters_per_epoch + i)
243 |
244 | # after train_epoch()
245 | # gather the stats from all processes
246 | metric_logger.synchronize_between_processes()
247 | logging.info("Averaged stats: " + str(metric_logger.global_avg()))
248 | return {
249 | k: "{:.3f}".format(meter.global_avg)
250 | for k, meter in metric_logger.meters.items()
251 | }
252 |
253 | @staticmethod
254 | def save_result(result, result_dir, filename, remove_duplicate=""):
255 | import json
256 |
257 | result_file = os.path.join(
258 | result_dir, "%s_rank%d.json" % (filename, get_rank())
259 | )
260 | final_result_file = os.path.join(result_dir, "%s.json" % filename)
261 |
262 | json.dump(result, open(result_file, "w"))
263 |
264 | if is_dist_avail_and_initialized():
265 | dist.barrier()
266 |
267 | if is_main_process():
268 | logging.warning("rank %d starts merging results." % get_rank())
269 | # combine results from all processes
270 | result = []
271 |
272 | for rank in range(get_world_size()):
273 | result_file = os.path.join(
274 | result_dir, "%s_rank%d.json" % (filename, rank)
275 | )
276 | res = json.load(open(result_file, "r"))
277 | result += res
278 |
279 | if remove_duplicate:
280 | result_new = []
281 | id_list = []
282 | for res in result:
283 | if res[remove_duplicate] not in id_list:
284 | id_list.append(res[remove_duplicate])
285 | result_new.append(res)
286 | result = result_new
287 |
288 | json.dump(result, open(final_result_file, "w"))
289 | print("result file saved to %s" % final_result_file)
290 |
291 | return final_result_file
292 |
--------------------------------------------------------------------------------
/timechat/common/registry.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 |
9 | class Registry:
10 | mapping = {
11 | "builder_name_mapping": {},
12 | "task_name_mapping": {},
13 | "processor_name_mapping": {},
14 | "model_name_mapping": {},
15 | "lr_scheduler_name_mapping": {},
16 | "runner_name_mapping": {},
17 | "state": {},
18 | "paths": {},
19 | }
20 |
21 | @classmethod
22 | def register_builder(cls, name):
23 | r"""Register a dataset builder to registry with key 'name'
24 |
25 | Args:
26 | name: Key with which the builder will be registered.
27 |
28 | Usage:
29 |
30 | from timechat.common.registry import registry
31 | from timechat.datasets.base_dataset_builder import BaseDatasetBuilder
32 | """
33 |
34 | def wrap(builder_cls):
35 | from timechat.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36 |
37 | assert issubclass(
38 | builder_cls, BaseDatasetBuilder
39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40 | builder_cls
41 | )
42 | if name in cls.mapping["builder_name_mapping"]:
43 | raise KeyError(
44 | "Name '{}' already registered for {}.".format(
45 | name, cls.mapping["builder_name_mapping"][name]
46 | )
47 | )
48 | cls.mapping["builder_name_mapping"][name] = builder_cls
49 | return builder_cls
50 |
51 | return wrap
52 |
53 | @classmethod
54 | def register_task(cls, name):
55 | r"""Register a task to registry with key 'name'
56 |
57 | Args:
58 | name: Key with which the task will be registered.
59 |
60 | Usage:
61 |
62 | from timechat.common.registry import registry
63 | """
64 |
65 | def wrap(task_cls):
66 | from timechat.tasks.base_task import BaseTask
67 |
68 | assert issubclass(
69 | task_cls, BaseTask
70 | ), "All tasks must inherit BaseTask class"
71 | if name in cls.mapping["task_name_mapping"]:
72 | raise KeyError(
73 | "Name '{}' already registered for {}.".format(
74 | name, cls.mapping["task_name_mapping"][name]
75 | )
76 | )
77 | cls.mapping["task_name_mapping"][name] = task_cls
78 | return task_cls
79 |
80 | return wrap
81 |
82 | @classmethod
83 | def register_model(cls, name):
84 | r"""Register a task to registry with key 'name'
85 |
86 | Args:
87 | name: Key with which the task will be registered.
88 |
89 | Usage:
90 |
91 | from timechat.common.registry import registry
92 | """
93 |
94 | def wrap(model_cls):
95 | from timechat.models import BaseModel
96 |
97 | assert issubclass(
98 | model_cls, BaseModel
99 | ), "All models must inherit BaseModel class"
100 | if name in cls.mapping["model_name_mapping"]:
101 | raise KeyError(
102 | "Name '{}' already registered for {}.".format(
103 | name, cls.mapping["model_name_mapping"][name]
104 | )
105 | )
106 | cls.mapping["model_name_mapping"][name] = model_cls
107 | return model_cls
108 |
109 | return wrap
110 |
111 | @classmethod
112 | def register_processor(cls, name):
113 | r"""Register a processor to registry with key 'name'
114 |
115 | Args:
116 | name: Key with which the task will be registered.
117 |
118 | Usage:
119 |
120 | from timechat.common.registry import registry
121 | """
122 |
123 | def wrap(processor_cls):
124 | from timechat.processors import BaseProcessor
125 |
126 | assert issubclass(
127 | processor_cls, BaseProcessor
128 | ), "All processors must inherit BaseProcessor class"
129 | if name in cls.mapping["processor_name_mapping"]:
130 | raise KeyError(
131 | "Name '{}' already registered for {}.".format(
132 | name, cls.mapping["processor_name_mapping"][name]
133 | )
134 | )
135 | cls.mapping["processor_name_mapping"][name] = processor_cls
136 | return processor_cls
137 |
138 | return wrap
139 |
140 | @classmethod
141 | def register_lr_scheduler(cls, name):
142 | r"""Register a model to registry with key 'name'
143 |
144 | Args:
145 | name: Key with which the task will be registered.
146 |
147 | Usage:
148 |
149 | from timechat.common.registry import registry
150 | """
151 |
152 | def wrap(lr_sched_cls):
153 | if name in cls.mapping["lr_scheduler_name_mapping"]:
154 | raise KeyError(
155 | "Name '{}' already registered for {}.".format(
156 | name, cls.mapping["lr_scheduler_name_mapping"][name]
157 | )
158 | )
159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160 | return lr_sched_cls
161 |
162 | return wrap
163 |
164 | @classmethod
165 | def register_runner(cls, name):
166 | r"""Register a model to registry with key 'name'
167 |
168 | Args:
169 | name: Key with which the task will be registered.
170 |
171 | Usage:
172 |
173 | from timechat.common.registry import registry
174 | """
175 |
176 | def wrap(runner_cls):
177 | if name in cls.mapping["runner_name_mapping"]:
178 | raise KeyError(
179 | "Name '{}' already registered for {}.".format(
180 | name, cls.mapping["runner_name_mapping"][name]
181 | )
182 | )
183 | cls.mapping["runner_name_mapping"][name] = runner_cls
184 | return runner_cls
185 |
186 | return wrap
187 |
188 | @classmethod
189 | def register_path(cls, name, path):
190 | r"""Register a path to registry with key 'name'
191 |
192 | Args:
193 | name: Key with which the path will be registered.
194 |
195 | Usage:
196 |
197 | from timechat.common.registry import registry
198 | """
199 | assert isinstance(path, str), "All path must be str."
200 | if name in cls.mapping["paths"]:
201 | raise KeyError("Name '{}' already registered.".format(name))
202 | cls.mapping["paths"][name] = path
203 |
204 | @classmethod
205 | def register(cls, name, obj):
206 | r"""Register an item to registry with key 'name'
207 |
208 | Args:
209 | name: Key with which the item will be registered.
210 |
211 | Usage::
212 |
213 | from timechat.common.registry import registry
214 |
215 | registry.register("config", {})
216 | """
217 | path = name.split(".")
218 | current = cls.mapping["state"]
219 |
220 | for part in path[:-1]:
221 | if part not in current:
222 | current[part] = {}
223 | current = current[part]
224 |
225 | current[path[-1]] = obj
226 |
227 | # @classmethod
228 | # def get_trainer_class(cls, name):
229 | # return cls.mapping["trainer_name_mapping"].get(name, None)
230 |
231 | @classmethod
232 | def get_builder_class(cls, name):
233 | return cls.mapping["builder_name_mapping"].get(name, None)
234 |
235 | @classmethod
236 | def get_model_class(cls, name):
237 | return cls.mapping["model_name_mapping"].get(name, None)
238 |
239 | @classmethod
240 | def get_task_class(cls, name):
241 | return cls.mapping["task_name_mapping"].get(name, None)
242 |
243 | @classmethod
244 | def get_processor_class(cls, name):
245 | return cls.mapping["processor_name_mapping"].get(name, None)
246 |
247 | @classmethod
248 | def get_lr_scheduler_class(cls, name):
249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250 |
251 | @classmethod
252 | def get_runner_class(cls, name):
253 | return cls.mapping["runner_name_mapping"].get(name, None)
254 |
255 | @classmethod
256 | def list_runners(cls):
257 | return sorted(cls.mapping["runner_name_mapping"].keys())
258 |
259 | @classmethod
260 | def list_models(cls):
261 | return sorted(cls.mapping["model_name_mapping"].keys())
262 |
263 | @classmethod
264 | def list_tasks(cls):
265 | return sorted(cls.mapping["task_name_mapping"].keys())
266 |
267 | @classmethod
268 | def list_processors(cls):
269 | return sorted(cls.mapping["processor_name_mapping"].keys())
270 |
271 | @classmethod
272 | def list_lr_schedulers(cls):
273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274 |
275 | @classmethod
276 | def list_datasets(cls):
277 | return sorted(cls.mapping["builder_name_mapping"].keys())
278 |
279 | @classmethod
280 | def get_path(cls, name):
281 | return cls.mapping["paths"].get(name, None)
282 |
283 | @classmethod
284 | def get(cls, name, default=None, no_warning=False):
285 | r"""Get an item from registry with key 'name'
286 |
287 | Args:
288 | name (string): Key whose value needs to be retrieved.
289 | default: If passed and key is not in registry, default value will
290 | be returned with a warning. Default: None
291 | no_warning (bool): If passed as True, warning when key doesn't exist
292 | will not be generated. Useful for MMF's
293 | internal operations. Default: False
294 | """
295 | original_name = name
296 | name = name.split(".")
297 | value = cls.mapping["state"]
298 | for subname in name:
299 | value = value.get(subname, default)
300 | if value is default:
301 | break
302 |
303 | if (
304 | "writer" in cls.mapping["state"]
305 | and value == default
306 | and no_warning is False
307 | ):
308 | cls.mapping["state"]["writer"].warning(
309 | "Key {} is not present in registry, returning default value "
310 | "of {}".format(original_name, default)
311 | )
312 | return value
313 |
314 | @classmethod
315 | def unregister(cls, name):
316 | r"""Remove an item from registry with key 'name'
317 |
318 | Args:
319 | name: Key which needs to be removed.
320 | Usage::
321 |
322 | from mmf.common.registry import registry
323 |
324 | config = registry.unregister("config")
325 | """
326 | return cls.mapping["state"].pop(name, None)
327 |
328 |
329 | registry = Registry()
330 |
--------------------------------------------------------------------------------
/timechat/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import json
5 | import numpy as np
6 | import torch
7 | import decord
8 | import time
9 | import subprocess
10 | import re
11 | from utils.cons_utils import BaseOptions, generate_question, load_logger
12 | from utils.prompts import prompt, cot
13 |
14 | from timechat.common.config import Config
15 | from timechat.common.registry import registry
16 | from timechat.conversation.conversation_video import Chat, conv_llava_llama_2
17 | from timechat.processors.video_processor import load_video
18 | decord.bridge.set_bridge('torch')
19 | logger = load_logger("TimeChat")
20 |
21 |
22 | class TimeChat_Options(BaseOptions):
23 | def initialize(self):
24 | BaseOptions.initialize(self)
25 | self.parser.add_argument("--cfg-path", default='timechat/eval_configs/timechat.yaml', help="path to configuration file.")
26 | self.parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
27 | self.parser.add_argument("--options", nargs="+",
28 | help="override some settings in the used config, "
29 | "the key-value pair in xxx=yyy format will be merged into config file (deprecate), "
30 | "change to --cfg-options instead.",
31 | )
32 |
33 |
34 | class TimeChat:
35 | def __init__(self, args):
36 | """
37 | Follow the official grounding prompt at https://github.com/RenShuhuai-Andy/TimeChat
38 | """
39 |
40 | cfg = Config(args)
41 | if args.fine_tuned:
42 | prompt["grounding"] = "Localize the visual content described by the given textual query '{event}' in the video, and output the start and end timestamps in seconds."
43 | cfg.model_cfg.ckpt = cfg.model_cfg.charades_ckpt if args.dset_name == "charades" else cfg.model_cfg.activitynet_ckpt
44 | if not os.path.exists(cfg.model_cfg.ckpt):
45 | raise FileNotFoundError(f"Check the checkpoint path: {cfg.model_cfg.ckpt}")
46 | else:
47 | prompt["grounding"] = "Please find the visual event described by a sentence in the video, determining its starting and ending times. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seconds. Now I will give you the textual sentence: {event}. Please return its start time and end time."
48 | logger.info(f"Load checkpoints from '{cfg.model_cfg.ckpt}'")
49 |
50 | self.model_config = cfg.model_cfg
51 | self.gpu_id = args.gpu_id
52 | self.model, self.vis_processor = self.load_model(cfg)
53 | self.chat = Chat(self.model, self.vis_processor, device='cuda:{}'.format(args.gpu_id))
54 | self.debug = args.debug
55 | self.n_frames = 96
56 | self.CoT = args.CoT
57 |
58 | def load_model(self, cfg):
59 | model_config = cfg.model_cfg
60 | model_config.device_8bit = self.gpu_id
61 | model_cls = registry.get_model_class(model_config.arch)
62 | model = model_cls.from_config(model_config).to('cuda:{}'.format(self.gpu_id))
63 | model.eval()
64 |
65 | vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
66 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
67 |
68 | return model, vis_processor
69 |
70 | def load_video_features(self, video_path):
71 | video_features = []
72 | video, msg = load_video(
73 | video_path=video_path,
74 | n_frms=self.n_frames,
75 | height=224,
76 | width=224,
77 | sampling="uniform",
78 | return_msg=True
79 | )
80 | video = self.vis_processor.transform(video)
81 | video = video.unsqueeze(0).to(self.gpu_id)
82 |
83 | if self.model.qformer_text_input:
84 | # timestamp
85 | timestamps = msg.split('at')[1].replace('seconds.', '').strip().split(',') # extract timestamps from msg
86 | timestamps = [f'This frame is sampled at {t.strip()} second.' for t in timestamps]
87 | timestamps = self.model.tokenizer(
88 | timestamps,
89 | return_tensors="pt",
90 | padding="longest",
91 | max_length=32,
92 | truncation=True,
93 | )
94 |
95 | if self.model.qformer_text_input:
96 | image_emb, _ = self.model.encode_videoQformer_visual(video, timestamp=timestamps)
97 | else:
98 | image_emb, _ = self.model.encode_videoQformer_visual(video)
99 | video_features.append(image_emb)
100 |
101 | return video_features, msg
102 |
103 | def inference(self, chat, chat_state, video_features):
104 | llm_message = chat.answer(conv=chat_state,
105 | img_list=video_features,
106 | num_beams=1,
107 | do_sample=False,
108 | temperature=0.05,
109 | max_new_tokens=300,
110 | max_length=2000)[0]
111 | chat_state.messages[-1][-1] = llm_message
112 |
113 | return llm_message
114 |
115 | def extract_time(self, paragraph):
116 | prompt = 'A specific example is : 20.8 - 30.0 seconds'.lower()
117 | paragraph = paragraph.lower()
118 | paragraph.replace(prompt, '')
119 | # Split text into sentences based on common delimiters
120 | sentences = re.split(r'[!?\n]', paragraph)
121 |
122 | # Keywords that might indicate the presence of time information
123 | keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"]
124 | # filter sentences by keywords
125 | candidates = []
126 | for sentence in sentences:
127 | # If sentence contains one of the keywords
128 | if any(keyword in sentence for keyword in keywords):
129 | candidates.append(sentence)
130 |
131 | timestamps = []
132 | # Check for The given query happens in m - n (seconds)
133 | patterns = [
134 | r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)"
135 | ]
136 |
137 | for time_pattern in patterns:
138 | time_matches = re.findall(time_pattern, paragraph)
139 | if time_matches:
140 | timestamps = [[float(start), float(end)] for start, end in time_matches]
141 |
142 | if len(sentences) == 0:
143 | return []
144 | # check for other formats e.g.:
145 | # 1 .Starting time: 0.8 seconds
146 | # Ending time: 1.1 seconds
147 | # 2. The start time for this event is 0 seconds, and the end time is 12 seconds.
148 | if len(timestamps) == 0:
149 | times = []
150 | time_regex = re.compile(r'\b(\d+\.\d+\b|\b\d+)\b') # time formats (e.g., 18, 18.5)
151 | for sentence in candidates:
152 | time = re.findall(time_regex, sentence)
153 | if time:
154 | time_in_sec = float(time[0])
155 | times.append(time_in_sec)
156 | times = times[:len(times) // 2 * 2]
157 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)]
158 | # Check for examples like:
159 | # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23.
160 | if len(timestamps) == 0:
161 | times = []
162 | time_regex = re.compile(r'\b((\d{1,2}:\d{2}:\d{2}))\b') # time formats (e.g., 18:00, 00:18:05)
163 | for sentence in candidates:
164 | time = re.findall(time_regex, sentence)
165 | if time:
166 | t = time[0]
167 | else:
168 | continue
169 | # If time is in HH:MM:SS format, convert to seconds
170 | if t.count(':') == 2:
171 | h, m, s = map(int, t.split(':'))
172 | time_in_sec = h * 3600 + m * 60 + s
173 | elif t.count(':') == 1:
174 | m, s = map(int, t.split(':'))
175 | time_in_sec = m * 60 + s
176 | times.append(time_in_sec)
177 | times = times[:len(times) // 2 * 2]
178 | timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)]
179 |
180 | results = []
181 | for (start, end) in timestamps:
182 | if end > start:
183 | results.append([start, end])
184 | else:
185 | results.append([end, start])
186 |
187 | if len(results) == 0:
188 | return [0, 0]
189 | else:
190 | return results[0]
191 |
192 | def extract_time2(self, paragraph):
193 | pattern = r'(?:(?:from\s+|in\s+|between\s+|at\s+|—|to\s+|and\s+|–)\s*)?(\d+(?:\.\d+)?)\s*(?:to|and|–|—)\s*(?:(?:from\s+|in\s+|between\s+|at\s+|—|to\s+|and\s+|–)\s*)?(\d+(?:\.\d+)?)?\s*seconds?'
194 | match = re.search(pattern, paragraph)
195 |
196 | if match:
197 | start_timestamp = float(match.group(1))
198 | end_timestamp = float(match.group(2))
199 |
200 | # If only the start timestamp is found
201 | if end_timestamp is None:
202 | return [0, start_timestamp]
203 | else:
204 | return [float(start_timestamp), float(end_timestamp)]
205 |
206 | else:
207 | pattern = r'\d+\.\d+|\d+'
208 |
209 | # Find all matches in the input string
210 | matches = re.findall(pattern, paragraph)
211 |
212 | # Convert matched strings to float
213 | float_numbers = [float(match) for match in matches]
214 | if len(float_numbers) == 0:
215 | return [0, 0]
216 | elif len(float_numbers) == 1:
217 | return [0, float_numbers[0]]
218 | else:
219 | return float_numbers[:2]
220 |
221 | def initialize_chat(self, task, msg, add_detail=None):
222 | chat_state = conv_llava_llama_2.copy()
223 | chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
224 |
225 | if add_detail:
226 | chat_state.system += (" " + add_detail)
227 | if self.CoT:
228 | chat_state.system = cot[task] if task in cot else chat_state.system
229 |
230 | chat_state.append_message(chat_state.roles[0], " " + msg)
231 |
232 | return chat_state
233 |
234 | def run(self, task, video_features, query, duration, chat_state=None, st=None, ed=None, msg=None, return_chat_state=False):
235 | question, add_detail, choice = generate_question(task, prompt, query, duration, st, ed)
236 | if not chat_state:
237 | chat_state = self.initialize_chat(task, msg, add_detail=add_detail)
238 | else:
239 | chat_state = chat_state.copy()
240 |
241 | question = " ".join([question, add_detail]) if add_detail else question
242 | self.chat.ask(question, chat_state)
243 | answer = self.inference(self.chat, chat_state, video_features)
244 |
245 | if self.debug:
246 | print(f"[{task}] Question: {question}")
247 | print(f"Answer: {answer}\n")
248 |
249 | if task in ["grounding"]:
250 | timestamps = self.extract_time(answer)
251 | if timestamps == [0,0]:
252 | timestamps = self.extract_time2(answer)
253 | return {"q": question, "a": answer, "t": timestamps}
254 |
255 | elif task in ["occurrence", "compositional"]:
256 | if task == "compositional":
257 | choice = "pos"
258 | return {"c": choice, "q": question, "a": answer}
259 |
260 | elif task in ["description"]:
261 | if return_chat_state:
262 | return answer, chat_state
263 | return answer
264 | else:
265 | raise NotImplementedError(f"Task {task} is not yet implemented.")
266 |
--------------------------------------------------------------------------------
/timechat/datasets/datasets/llava_instruct_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from timechat.datasets.datasets.base_dataset import BaseDataset
3 | from timechat.datasets.datasets.caption_datasets import CaptionDataset
4 | import pandas as pd
5 | import decord
6 | from decord import VideoReader
7 | import random
8 | import torch
9 | from torch.utils.data.dataloader import default_collate
10 | from PIL import Image
11 | from typing import Dict, Optional, Sequence
12 | import transformers
13 | import pathlib
14 | import json
15 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
16 | from timechat.conversation.conversation_video import Conversation,SeparatorStyle
17 | DEFAULT_IMAGE_PATCH_TOKEN = ''
18 | DEFAULT_IMAGE_TOKEN = ""
19 | import copy
20 | from timechat.processors import transforms_video,AlproVideoTrainProcessor
21 | IGNORE_INDEX = -100
22 | image_conversation = Conversation(
23 | system="",
24 | roles=("Human", "Assistant"),
25 | messages=[],
26 | offset=0,
27 | sep_style=SeparatorStyle.SINGLE,
28 | sep="###",
29 | )
30 | llama_v2_image_conversation = Conversation(
31 | system=" ",
32 | roles=("USER", "ASSISTANT"),
33 | messages=(),
34 | offset=0,
35 | sep_style=SeparatorStyle.LLAMA_2,
36 | sep="",
37 | sep2="",
38 | )
39 | IGNORE_INDEX = -100
40 |
41 | class Instruct_Dataset(BaseDataset):
42 | def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'):
43 | """
44 | vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
45 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
46 | split (string): val or test
47 | """
48 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
49 |
50 | data_path = pathlib.Path(ann_root)
51 | with data_path.open(encoding='utf-8') as f:
52 | self.annotation = json.load(f)
53 |
54 | self.vis_root = vis_root
55 | self.resize_size = 224
56 | self.num_frm = 8
57 | self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
58 | self.tokenizer.pad_token = self.tokenizer.unk_token
59 | self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
60 | self.num_video_query_token = num_video_query_token
61 | self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
62 |
63 | self.transform = AlproVideoTrainProcessor(
64 | image_size=self.resize_size, n_frms = self.num_frm
65 | ).transform
66 | self.data_type = data_type
67 | self.model_type = model_type
68 |
69 | def _get_image_path(self, sample):
70 | rel_video_fp ='COCO_train2014_' + sample['image']
71 | full_video_fp = os.path.join(self.vis_root, rel_video_fp)
72 | return full_video_fp
73 |
74 | def __getitem__(self, index):
75 | num_retries = 10 # skip error videos
76 | for _ in range(num_retries):
77 | try:
78 | sample = self.annotation[index]
79 |
80 | image_path = self._get_image_path(sample)
81 | conversation_list = sample['conversations']
82 | image = Image.open(image_path).convert("RGB")
83 |
84 | image = self.vis_processor(image)
85 | # text = self.text_processor(text)
86 | sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token)
87 | if self.model_type =='vicuna':
88 | data_dict = preprocess(
89 | sources,
90 | self.tokenizer)
91 | elif self.model_type =='llama_v2':
92 | data_dict = preprocess_for_llama_v2(
93 | sources,
94 | self.tokenizer)
95 | else:
96 | print('not support')
97 | raise('not support')
98 | data_dict = dict(input_ids=data_dict["input_ids"][0],
99 | labels=data_dict["labels"][0])
100 |
101 | # image exist in the data
102 | data_dict['image'] = image
103 | except:
104 | print(f"Failed to load examples with image: {image_path}. "
105 | f"Will randomly sample an example as a replacement.")
106 | index = random.randint(0, len(self) - 1)
107 | continue
108 | break
109 | else:
110 | raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
111 | # "image_id" is kept to stay compatible with the COCO evaluation format
112 | return {
113 | "image": image,
114 | "text_input": data_dict["input_ids"],
115 | "labels": data_dict["labels"],
116 | "type":'image',
117 | }
118 |
119 | def __len__(self):
120 | return len(self.annotation)
121 |
122 | def collater(self, instances):
123 | input_ids, labels = tuple([instance[key] for instance in instances]
124 | for key in ("text_input", "labels"))
125 | input_ids = torch.nn.utils.rnn.pad_sequence(
126 | input_ids,
127 | batch_first=True,
128 | padding_value=self.tokenizer.pad_token_id)
129 | labels = torch.nn.utils.rnn.pad_sequence(labels,
130 | batch_first=True,
131 | padding_value=IGNORE_INDEX)
132 | batch = dict(
133 | input_ids=input_ids,
134 | labels=labels,
135 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
136 | )
137 |
138 | if 'image' in instances[0]:
139 | images = [instance['image'] for instance in instances]
140 | if all(x is not None and x.shape == images[0].shape for x in images):
141 | batch['images'] = torch.stack(images)
142 | else:
143 | batch['images'] = images
144 | batch['conv_type'] = 'multi'
145 | return batch
146 |
147 |
148 | def preprocess_multimodal(
149 | conversation_list: Sequence[str],
150 | multimodal_cfg: dict,
151 | cur_token_len: int,
152 | ) -> Dict:
153 | # 将conversational list中
154 | is_multimodal = True
155 | # image_token_len = multimodal_cfg['image_token_len']
156 | image_token_len = cur_token_len
157 |
158 | for sentence in conversation_list:
159 | replace_token = ''+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+''
160 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
161 |
162 | return [conversation_list]
163 |
164 | def _add_speaker_and_signal(header, source, get_conversation=True):
165 | """Add speaker and start/end signal on each round."""
166 | BEGIN_SIGNAL = "###"
167 | END_SIGNAL = "\n"
168 | conversation = header
169 | for sentence in source:
170 | from_str = sentence["from"]
171 | if from_str.lower() == "human":
172 | from_str = image_conversation.roles[0]
173 | elif from_str.lower() == "gpt":
174 | from_str = image_conversation.roles[1]
175 | else:
176 | from_str = 'unknown'
177 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
178 | sentence["value"] + END_SIGNAL)
179 | if get_conversation:
180 | conversation += sentence["value"]
181 | conversation += BEGIN_SIGNAL
182 | return conversation
183 |
184 | def _tokenize_fn(strings: Sequence[str],
185 | tokenizer: transformers.PreTrainedTokenizer) -> Dict:
186 | """Tokenize a list of strings."""
187 | tokenized_list = [
188 | tokenizer(
189 | text,
190 | return_tensors="pt",
191 | padding="longest",
192 | max_length=512,
193 | truncation=True,
194 | ) for text in strings
195 | ]
196 | input_ids = labels = [
197 | tokenized.input_ids[0] for tokenized in tokenized_list
198 | ]
199 | input_ids_lens = labels_lens = [
200 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
201 | for tokenized in tokenized_list
202 | ]
203 | return dict(
204 | input_ids=input_ids,
205 | labels=labels,
206 | input_ids_lens=input_ids_lens,
207 | labels_lens=labels_lens,
208 | )
209 |
210 | def preprocess(
211 | sources: Sequence[str],
212 | tokenizer: transformers.PreTrainedTokenizer,
213 | ) -> Dict:
214 | """
215 | Given a list of sources, each is a conversation list. This transform:
216 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
217 | 2. Concatenate conversations together;
218 | 3. Tokenize the concatenated conversation;
219 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
220 | """
221 | # add end signal and concatenate together
222 | conversations = []
223 | for source in sources:
224 | header = f"{image_conversation.system}\n\n"
225 | conversation = _add_speaker_and_signal(header, source)
226 | conversations.append(conversation)
227 | # tokenize conversations
228 | conversations_tokenized = _tokenize_fn(conversations, tokenizer)
229 | input_ids = conversations_tokenized["input_ids"]
230 | targets = copy.deepcopy(input_ids)
231 | for target, source in zip(targets, sources):
232 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
233 | tokenizer)["input_ids_lens"]
234 | speakers = [sentence["from"] for sentence in source]
235 | _mask_targets(target, tokenized_lens, speakers)
236 |
237 | return dict(input_ids=input_ids, labels=targets)
238 |
239 | def preprocess_for_llama_v2(
240 | sources: Sequence[str],
241 | tokenizer: transformers.PreTrainedTokenizer,
242 | ) -> Dict:
243 | """
244 | Given a list of sources, each is a conversation list. This transform:
245 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
246 | 2. Concatenate conversations together;
247 | 3. Tokenize the concatenated conversation;
248 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
249 | """
250 | # add end signal and concatenate together
251 | conversations = []
252 | conv = copy.deepcopy(llama_v2_image_conversation.copy())
253 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
254 | for source in sources:
255 | # [INST] <>\n{system_prompt}\n<>\n\n
256 | header = f"[INST] <>\n{conv.system}\n>\n\n"
257 |
258 | if roles[source[0]["from"]] != conv.roles[0]:
259 | # Skip the first one if it is not from human
260 | source = source[1:]
261 | conv.messages = []
262 | for j, sentence in enumerate(source):
263 | role = roles[sentence["from"]]
264 | assert role == conv.roles[j % 2]
265 | conv.append_message(role, sentence["value"])
266 | conversations.append(conv.get_prompt())
267 |
268 | input_ids = tokenizer(
269 | conversations,
270 | return_tensors="pt",
271 | padding="longest",
272 | max_length=512,
273 | truncation=True,
274 | ).input_ids
275 | targets = copy.deepcopy(input_ids)
276 |
277 |
278 | sep = "[/INST] "
279 | for conversation, target in zip(conversations, targets):
280 | # total_len = int(target.ne(tokenizer.pad_token_id).sum())
281 | rounds = conversation.split(conv.sep2)
282 | cur_len = 1
283 | target[:cur_len] = IGNORE_INDEX
284 | for i, rou in enumerate(rounds):
285 | if rou == "":
286 | break
287 |
288 | parts = rou.split(sep)
289 | if len(parts) != 2:
290 | break
291 | parts[0] += sep
292 |
293 |
294 | round_len = len(tokenizer(rou).input_ids)
295 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目
296 |
297 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
298 |
299 | cur_len += round_len
300 | target[cur_len:] = IGNORE_INDEX
301 |
302 | return dict(input_ids=input_ids, labels=targets)
303 |
304 | def _mask_targets(target, tokenized_lens, speakers):
305 | # cur_idx = 0
306 | cur_idx = tokenized_lens[0]
307 | tokenized_lens = tokenized_lens[1:]
308 | target[:cur_idx] = IGNORE_INDEX
309 | for tokenized_len, speaker in zip(tokenized_lens, speakers):
310 | if speaker == "human":
311 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
312 | cur_idx += tokenized_len
313 |
--------------------------------------------------------------------------------