├── video_llama ├── runners │ ├── test.py │ └── __init__.py ├── common │ ├── __init__.py │ ├── gradcam.py │ ├── optims.py │ ├── dist_utils.py │ ├── logger.py │ └── registry.py ├── datasets │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── laion_dataset.py │ │ ├── cc_sbu_dataset.py │ │ ├── base_dataset.py │ │ ├── caption_datasets.py │ │ ├── webvid_datasets.py │ │ ├── dataloader_utils.py │ │ └── llava_instruct_dataset.py │ ├── builders │ │ ├── video_caption_builder.py │ │ ├── __init__.py │ │ ├── instruct_builder.py │ │ ├── image_text_pair_builder.py │ │ └── base_dataset_builder.py │ └── data_utils.py ├── conversation │ └── __init__.py ├── models │ ├── ImageBind │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── helpers.py │ │ │ └── transformer.py │ │ ├── .assets │ │ │ ├── car_audio.wav │ │ │ ├── car_image.jpg │ │ │ ├── dog_audio.wav │ │ │ ├── dog_image.jpg │ │ │ ├── bird_audio.wav │ │ │ └── bird_image.jpg │ │ ├── bpe │ │ │ └── bpe_simple_vocab_16e6.txt.gz │ │ ├── requirements.txt │ │ ├── CONTRIBUTING.md │ │ ├── CODE_OF_CONDUCT.md │ │ ├── model_card.md │ │ ├── README.md │ │ └── data.py │ ├── blip2_outputs.py │ ├── __init__.py │ ├── blip2.py │ └── base_model.py ├── configs │ ├── datasets │ │ ├── cc_sbu │ │ │ ├── align.yaml │ │ │ └── defaults.yaml │ │ ├── laion │ │ │ └── defaults.yaml │ │ ├── instruct │ │ │ ├── llava_instruct.yaml │ │ │ └── webvid_instruct.yaml │ │ └── webvid │ │ │ └── defaults.yaml │ ├── default.yaml │ └── models │ │ ├── minigpt4.yaml │ │ └── video_llama.yaml ├── tasks │ ├── image_text_pretrain.py │ ├── video_text_pretrain.py │ ├── __init__.py │ └── base_task.py ├── processors │ ├── base_processor.py │ ├── __init__.py │ ├── functional_video.py │ ├── blip_processors.py │ ├── transforms_video.py │ ├── video_processor.py │ └── .ipynb_checkpoints │ │ └── video_processor-checkpoint.py └── __init__.py ├── 2stage1.png ├── 2stage2.png ├── 2stage_audio.wav ├── 2stage_video.mp4 ├── 2stage_videollama.jpeg ├── README.md └── infer_batch.py /video_llama/runners/test.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video_llama/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video_llama/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video_llama/conversation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /2stage1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/2stage1.png -------------------------------------------------------------------------------- /2stage2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/2stage2.png -------------------------------------------------------------------------------- /2stage_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/2stage_audio.wav -------------------------------------------------------------------------------- /2stage_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/2stage_video.mp4 -------------------------------------------------------------------------------- /2stage_videollama.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/2stage_videollama.jpeg -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/car_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/car_audio.wav -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/car_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/car_image.jpg -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/dog_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/dog_audio.wav -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/dog_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/dog_image.jpg -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/bird_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/bird_audio.wav -------------------------------------------------------------------------------- /video_llama/models/ImageBind/.assets/bird_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/.assets/bird_image.jpg -------------------------------------------------------------------------------- /video_llama/configs/datasets/cc_sbu/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_align_dataset 6 | -------------------------------------------------------------------------------- /video_llama/configs/datasets/cc_sbu/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_dataset/{00000..00001}.tar 6 | -------------------------------------------------------------------------------- /video_llama/configs/datasets/laion/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | storage: path/laion/laion_dataset/{00000..00001}.tar 6 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-anonymous-bs/av-SALMONN/HEAD/video_llama/models/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /video_llama/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with persistent storage 5 | cache_root: "/export/home/.cache/minigpt4" 6 | -------------------------------------------------------------------------------- /video_llama/configs/datasets/instruct/llava_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | llava_instruct: 3 | data_type: image 4 | build_info: 5 | anno_dir: /path/llava_instruct_150k.json 6 | videos_dir: /path/train2014/train2014/ 7 | -------------------------------------------------------------------------------- /video_llama/configs/datasets/webvid/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | webvid: 3 | data_type: video 4 | build_info: 5 | anno_dir: path/webvid/webvid_tain_data/annotations/ 6 | videos_dir: path//webvid/webvid_tain_data/videos/ 7 | -------------------------------------------------------------------------------- /video_llama/configs/datasets/instruct/webvid_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | webvid_instruct: 3 | data_type: image 4 | build_info: 5 | anno_dir: /path/webvid_align/videochat_instruct_11k.json 6 | videos_dir: /path/webvid_align/videos/ 7 | -------------------------------------------------------------------------------- /video_llama/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.runners.runner_base import RunnerBase 9 | 10 | __all__ = ["RunnerBase"] 11 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.13.0 3 | torchvision==0.14.0 4 | torchaudio==0.13.0 5 | pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d 6 | timm==0.6.7 7 | ftfy 8 | regex 9 | einops 10 | fvcore 11 | decord==0.6.0 12 | iopath 13 | numpy 14 | matplotlib 15 | types-regex 16 | mayavi 17 | cartopy 18 | -------------------------------------------------------------------------------- /video_llama/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.common.registry import registry 9 | from video_llama.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /video_llama/tasks/video_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.common.registry import registry 9 | from video_llama.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("video_text_pretrain") 13 | class VideoTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /video_llama/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /video_llama/configs/models/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # Vicuna 16 | llama_model: "ckpt/vicuna-13b/" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "blip2_image_train" 25 | image_size: 224 26 | eval: 27 | name: "blip2_image_eval" 28 | image_size: 224 29 | text_processor: 30 | train: 31 | name: "blip_caption" 32 | eval: 33 | name: "blip_caption" 34 | -------------------------------------------------------------------------------- /video_llama/configs/models/video_llama.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: video_llama 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # Vicuna 16 | llama_model: "ckpt/vicuna-7b/" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "alpro_video_train" 25 | image_size: 224 26 | n_frms: 8 27 | eval: 28 | name: "alpro_video_eval" 29 | image_size: 224 30 | n_frms: 8 31 | text_processor: 32 | train: 33 | name: "blip_caption" 34 | eval: 35 | name: "blip_caption" 36 | -------------------------------------------------------------------------------- /video_llama/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /video_llama/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.common.registry import registry 9 | from video_llama.tasks.base_task import BaseTask 10 | from video_llama.tasks.image_text_pretrain import ImageTextPretrainTask 11 | from video_llama.tasks.video_text_pretrain import VideoTextPretrainTask 12 | 13 | 14 | def setup_task(cfg): 15 | assert "task" in cfg.run_cfg, "Task name must be provided." 16 | 17 | task_name = cfg.run_cfg.task 18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 19 | assert task is not None, "Task {} not properly registered.".format(task_name) 20 | 21 | return task 22 | 23 | 24 | __all__ = [ 25 | "BaseTask", 26 | "ImageTextPretrainTask", 27 | "VideoTextPretrainTask" 28 | ] 29 | -------------------------------------------------------------------------------- /video_llama/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from video_llama.common.registry import registry 14 | 15 | from video_llama.datasets.builders import * 16 | from video_llama.models import * 17 | from video_llama.processors import * 18 | from video_llama.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /video_llama/datasets/builders/video_caption_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from video_llama.common.registry import registry 6 | from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from video_llama.datasets.datasets.webvid_datasets import WebvidDataset 8 | 9 | @registry.register_builder("webvid") 10 | class WebvidBuilder(BaseDatasetBuilder): 11 | train_dataset_cls = WebvidDataset 12 | DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"} 13 | 14 | def _download_ann(self): 15 | pass 16 | 17 | def _download_vis(self): 18 | pass 19 | 20 | def build(self): 21 | self.build_processors() 22 | datasets = dict() 23 | split = "train" 24 | 25 | build_info = self.config.build_info 26 | dataset_cls = self.train_dataset_cls 27 | datasets[split] = dataset_cls( 28 | vis_processor=self.vis_processors[split], 29 | text_processor=self.text_processors[split], 30 | vis_root=build_info.videos_dir, 31 | ann_root=build_info.anno_dir 32 | ) 33 | 34 | return datasets -------------------------------------------------------------------------------- /video_llama/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.processors.base_processor import BaseProcessor 9 | from video_llama.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | from video_llama.processors.video_processor import ( 15 | AlproVideoTrainProcessor, 16 | AlproVideoEvalProcessor 17 | ) 18 | from video_llama.common.registry import registry 19 | 20 | __all__ = [ 21 | "BaseProcessor", 22 | "Blip2ImageTrainProcessor", 23 | "Blip2ImageEvalProcessor", 24 | "BlipCaptionProcessor", 25 | "AlproVideoTrainProcessor", 26 | "AlproVideoEvalProcessor", 27 | ] 28 | 29 | 30 | def load_processor(name, cfg=None): 31 | """ 32 | Example 33 | 34 | >>> processor = load_processor("alpro_video_train", cfg=None) 35 | """ 36 | processor = registry.get_processor_class(name).from_config(cfg) 37 | 38 | return processor 39 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | from video_llama.datasets.datasets.base_dataset import BaseDataset 10 | 11 | 12 | class LaionDataset(BaseDataset): 13 | def __init__(self, vis_processor, text_processor, location): 14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 15 | 16 | self.inner_dataset = wds.DataPipeline( 17 | wds.ResampledShards(location), 18 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 19 | wds.shuffle(1000, handler=wds.warn_and_continue), 20 | wds.decode("pilrgb", handler=wds.warn_and_continue), 21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 23 | wds.map(self.to_dict, handler=wds.warn_and_continue), 24 | ) 25 | 26 | def to_dict(self, sample): 27 | return { 28 | "image": sample[0], 29 | "text_input": self.text_processor(sample[1]["caption"]), 30 | } 31 | 32 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ImageBind 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to Omnivore, you agree that your contributions will be licensed 31 | under the [LICENSE](LICENSE) file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/cc_sbu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import webdataset as wds 4 | from video_llama.datasets.datasets.base_dataset import BaseDataset 5 | from video_llama.datasets.datasets.caption_datasets import CaptionDataset 6 | 7 | 8 | class CCSBUDataset(BaseDataset): 9 | def __init__(self, vis_processor, text_processor, location): 10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 11 | 12 | self.inner_dataset = wds.DataPipeline( 13 | wds.ResampledShards(location), 14 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 15 | wds.shuffle(1000, handler=wds.warn_and_continue), 16 | wds.decode("pilrgb", handler=wds.warn_and_continue), 17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 19 | wds.map(self.to_dict, handler=wds.warn_and_continue), 20 | ) 21 | 22 | def to_dict(self, sample): 23 | return { 24 | "image": sample[0], 25 | "text_input": self.text_processor(sample[1]["caption"]), 26 | "type":'image', 27 | } 28 | 29 | 30 | class CCSBUAlignDataset(CaptionDataset): 31 | 32 | def __getitem__(self, index): 33 | 34 | # TODO this assumes image input, not general enough 35 | ann = self.annotation[index] 36 | 37 | img_file = '{}.jpg'.format(ann["image_id"]) 38 | image_path = os.path.join(self.vis_root, img_file) 39 | image = Image.open(image_path).convert("RGB") 40 | 41 | image = self.vis_processor(image) 42 | caption = ann["caption"] 43 | 44 | return { 45 | "image": image, 46 | "text_input": caption, 47 | "image_id": self.img_ids[ann["image_id"]], 48 | "type":'image', 49 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # av-SALMONN 2 | av-SALMONN: Speech-Enhanced Audio-Visual Large Language Models 3 | 4 | 5 | 6 | Button Specifications: 7 | 8 | `Clear All`: clear chat history as well as all modality inputs. **Please always use clear all before you want to upload or update any image, audio or video** 9 | 10 | `Clear history`: only clear chat history. The modality input will remain unchanged unless you click `Clear All`. 11 | 12 | `Submit`: submit the text in the text box to get a response 13 | 14 | `Resubmit`: clear the previous conversation turn and then submit the text in the text box 15 | 16 | `maximum length`, `top p` and `temperature` have their own individual meanings 17 | 18 | Examples mentioned in the paper are provided. Please feel free to start with those. 19 | 20 | 21 | We provide the script for evaluating speech (LibriSpeech) and audio (AudioCaps) as single-modal tasks using Video-LLaMA. Please find codes in `infer_batch.sh` and `video_llama/` 22 | We provide the generated results for LibriSpeech (`librispeech.json` and `librispeech_finetuned.json` for finetuning 50k steps on LibriSpeech) and AudioCaps (`audiocaps.json`) 23 | 24 | 25 | ## Demo comparison between av-SALMONN and 2-stage systems 26 | We perform a case study for the following video: 27 | Video file: `2stage_video.mp4` 28 | 29 | In the following three examples, the question required the model to associate the speech with the correct speaker in order to answer it. As a result, only av-SALMONN can answer it correctly, whereas the other two 2-stage systems (av-SALMONN without audio input but add ASR transcription, and Video-Llama with ASR transcription) can not answer it. 30 | - av-SALMONN 31 | ![avsalmonn](2stage1.png) 32 | 33 | - 2-stage av-SALMONN without audio input + ASR transcription 34 | ![avsalmonn](2stage2.png) 35 | 36 | - 2-stage Video-Llama + ASR transcription 37 | ![avsalmonn](2stage_videollama.jpeg) 38 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /video_llama/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from video_llama.datasets.builders.base_dataset_builder import load_dataset_config 9 | from video_llama.datasets.builders.image_text_pair_builder import ( 10 | CCSBUBuilder, 11 | LaionBuilder, 12 | CCSBUAlignBuilder 13 | ) 14 | from video_llama.datasets.builders.video_caption_builder import WebvidBuilder 15 | from video_llama.common.registry import registry 16 | from video_llama.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder 17 | __all__ = [ 18 | "CCSBUBuilder", 19 | "LaionBuilder", 20 | "CCSBUAlignBuilder", 21 | "WebvidBuilder", 22 | "LlavaInstruct_Builder", 23 | "WebvidInstruct_Builder" 24 | 25 | ] 26 | 27 | 28 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 29 | """ 30 | Example 31 | 32 | >>> dataset = load_dataset("coco_caption", cfg=None) 33 | >>> splits = dataset.keys() 34 | >>> print([len(dataset[split]) for split in splits]) 35 | 36 | """ 37 | if cfg_path is None: 38 | cfg = None 39 | else: 40 | cfg = load_dataset_config(cfg_path) 41 | 42 | try: 43 | builder = registry.get_builder_class(name)(cfg) 44 | except TypeError: 45 | print( 46 | f"Dataset {name} not found. Available datasets:\n" 47 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 48 | ) 49 | exit(1) 50 | 51 | if vis_path is not None: 52 | if data_type is None: 53 | # use default data type in the config 54 | data_type = builder.config.data_type 55 | 56 | assert ( 57 | data_type in builder.config.build_info 58 | ), f"Invalid data_type {data_type} for {name}." 59 | 60 | builder.config.build_info.get(data_type).storage = vis_path 61 | 62 | dataset = builder.build_datasets() 63 | return dataset 64 | 65 | 66 | class DatasetZoo: 67 | def __init__(self) -> None: 68 | self.dataset_zoo = { 69 | k: list(v.DATASET_CONFIG_DICT.keys()) 70 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 71 | } 72 | 73 | def get_names(self): 74 | return list(self.dataset_zoo.keys()) 75 | 76 | 77 | dataset_zoo = DatasetZoo() 78 | -------------------------------------------------------------------------------- /video_llama/datasets/builders/instruct_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from video_llama.common.registry import registry 6 | from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from video_llama.datasets.datasets.laion_dataset import LaionDataset 8 | from video_llama.datasets.datasets.llava_instruct_dataset import Instruct_Dataset 9 | from video_llama.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset 10 | 11 | @registry.register_builder("instruct") 12 | class Instruct_Builder(BaseDatasetBuilder): 13 | train_dataset_cls = Instruct_Dataset 14 | 15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"} 16 | 17 | def _download_ann(self): 18 | pass 19 | 20 | def _download_vis(self): 21 | pass 22 | 23 | def build(self): 24 | self.build_processors() 25 | datasets = dict() 26 | split = "train" 27 | 28 | build_info = self.config.build_info 29 | dataset_cls = self.train_dataset_cls 30 | if self.config.num_video_query_token: 31 | num_video_query_token = self.config.num_video_query_token 32 | else: 33 | num_video_query_token = 32 34 | 35 | if self.config.tokenizer_name: 36 | tokenizer_name = self.config.tokenizer_name 37 | else: 38 | tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/' 39 | 40 | 41 | datasets[split] = dataset_cls( 42 | vis_processor=self.vis_processors[split], 43 | text_processor=self.text_processors[split], 44 | vis_root=build_info.videos_dir, 45 | ann_root=build_info.anno_dir, 46 | num_video_query_token = num_video_query_token, 47 | tokenizer_name = tokenizer_name, 48 | data_type = self.config.data_type, 49 | model_type = self.config.model_type 50 | ) 51 | 52 | return datasets 53 | 54 | @registry.register_builder("webvid_instruct") 55 | class WebvidInstruct_Builder(Instruct_Builder): 56 | train_dataset_cls = Video_Instruct_Dataset 57 | 58 | DATASET_CONFIG_DICT = { 59 | "default": "configs/datasets/instruct/webvid_instruct.yaml", 60 | } 61 | 62 | @registry.register_builder("webvid_instruct_zh") 63 | class WebvidInstruct_zh_Builder(Instruct_Builder): 64 | train_dataset_cls = Video_Instruct_Dataset 65 | 66 | DATASET_CONFIG_DICT = { 67 | "default": "configs/datasets/instruct/webvid_instruct.yaml", 68 | } 69 | 70 | 71 | 72 | @registry.register_builder("llava_instruct") 73 | class LlavaInstruct_Builder(Instruct_Builder): 74 | train_dataset_cls = Instruct_Dataset 75 | 76 | DATASET_CONFIG_DICT = { 77 | "default": "configs/datasets/instruct/llava_instruct.yaml", 78 | } 79 | 80 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from video_llama.datasets.datasets.base_dataset import BaseDataset 12 | from PIL import Image 13 | 14 | 15 | class __DisplMixin: 16 | def displ_item(self, index): 17 | sample, ann = self.__getitem__(index), self.annotation[index] 18 | 19 | return OrderedDict( 20 | { 21 | "file": ann["image"], 22 | "caption": ann["caption"], 23 | "image": sample["image"], 24 | } 25 | ) 26 | 27 | 28 | class CaptionDataset(BaseDataset, __DisplMixin): 29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 30 | """ 31 | vis_root (string): Root directory of images (e.g. coco/images/) 32 | ann_root (string): directory to store the annotation file 33 | """ 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | self.img_ids = {} 37 | n = 0 38 | for ann in self.annotation: 39 | img_id = ann["image_id"] 40 | if img_id not in self.img_ids.keys(): 41 | self.img_ids[img_id] = n 42 | n += 1 43 | 44 | def __getitem__(self, index): 45 | 46 | # TODO this assumes image input, not general enough 47 | ann = self.annotation[index] 48 | 49 | img_file = '{:0>12}.jpg'.format(ann["image_id"]) 50 | image_path = os.path.join(self.vis_root, img_file) 51 | image = Image.open(image_path).convert("RGB") 52 | 53 | image = self.vis_processor(image) 54 | caption = self.text_processor(ann["caption"]) 55 | 56 | return { 57 | "image": image, 58 | "text_input": caption, 59 | "image_id": self.img_ids[ann["image_id"]], 60 | } 61 | 62 | 63 | class CaptionEvalDataset(BaseDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | split (string): val or test 69 | """ 70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.vis_root, ann["image"]) 77 | image = Image.open(image_path).convert("RGB") 78 | 79 | image = self.vis_processor(image) 80 | 81 | return { 82 | "image": image, 83 | "image_id": ann["image_id"], 84 | "instance_id": ann["instance_id"], 85 | } 86 | -------------------------------------------------------------------------------- /video_llama/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from video_llama.common.registry import registry 6 | from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from video_llama.datasets.datasets.laion_dataset import LaionDataset 8 | from video_llama.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset 9 | 10 | 11 | @registry.register_builder("cc_sbu") 12 | class CCSBUBuilder(BaseDatasetBuilder): 13 | train_dataset_cls = CCSBUDataset 14 | 15 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} 16 | 17 | def _download_ann(self): 18 | pass 19 | 20 | def _download_vis(self): 21 | pass 22 | 23 | def build(self): 24 | self.build_processors() 25 | 26 | build_info = self.config.build_info 27 | 28 | datasets = dict() 29 | split = "train" 30 | 31 | # create datasets 32 | # [NOTE] return inner_datasets (wds.DataPipeline) 33 | dataset_cls = self.train_dataset_cls 34 | datasets[split] = dataset_cls( 35 | vis_processor=self.vis_processors[split], 36 | text_processor=self.text_processors[split], 37 | location=build_info.storage, 38 | ).inner_dataset 39 | 40 | return datasets 41 | 42 | 43 | @registry.register_builder("laion") 44 | class LaionBuilder(BaseDatasetBuilder): 45 | train_dataset_cls = LaionDataset 46 | 47 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} 48 | 49 | def _download_ann(self): 50 | pass 51 | 52 | def _download_vis(self): 53 | pass 54 | 55 | def build(self): 56 | self.build_processors() 57 | 58 | build_info = self.config.build_info 59 | 60 | datasets = dict() 61 | split = "train" 62 | 63 | # create datasets 64 | # [NOTE] return inner_datasets (wds.DataPipeline) 65 | dataset_cls = self.train_dataset_cls 66 | datasets[split] = dataset_cls( 67 | vis_processor=self.vis_processors[split], 68 | text_processor=self.text_processors[split], 69 | location=build_info.storage, 70 | ).inner_dataset 71 | 72 | return datasets 73 | 74 | 75 | @registry.register_builder("cc_sbu_align") 76 | class CCSBUAlignBuilder(BaseDatasetBuilder): 77 | train_dataset_cls = CCSBUAlignDataset 78 | 79 | DATASET_CONFIG_DICT = { 80 | "default": "configs/datasets/cc_sbu/align.yaml", 81 | } 82 | 83 | def build_datasets(self): 84 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 85 | logging.info("Building datasets...") 86 | self.build_processors() 87 | 88 | build_info = self.config.build_info 89 | storage_path = build_info.storage 90 | 91 | datasets = dict() 92 | 93 | if not os.path.exists(storage_path): 94 | warnings.warn("storage path {} does not exist.".format(storage_path)) 95 | 96 | # create datasets 97 | dataset_cls = self.train_dataset_cls 98 | datasets['train'] = dataset_cls( 99 | vis_processor=self.vis_processors["train"], 100 | text_processor=self.text_processors["train"], 101 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')], 102 | vis_root=os.path.join(storage_path, 'image'), 103 | ) 104 | 105 | return datasets 106 | 107 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /video_llama/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from video_llama.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | iters_per_epoch, 63 | min_lr, 64 | init_lr, 65 | warmup_steps=0, 66 | warmup_start_lr=-1, 67 | **kwargs 68 | ): 69 | self.optimizer = optimizer 70 | 71 | self.max_epoch = max_epoch 72 | self.iters_per_epoch = iters_per_epoch 73 | self.min_lr = min_lr 74 | 75 | self.init_lr = init_lr 76 | self.warmup_steps = warmup_steps 77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 78 | 79 | def step(self, cur_epoch, cur_step): 80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 81 | if total_cur_step < self.warmup_steps: 82 | warmup_lr_schedule( 83 | step=cur_step, 84 | optimizer=self.optimizer, 85 | max_step=self.warmup_steps, 86 | init_lr=self.warmup_start_lr, 87 | max_lr=self.init_lr, 88 | ) 89 | else: 90 | cosine_lr_schedule( 91 | epoch=total_cur_step, 92 | optimizer=self.optimizer, 93 | max_epoch=self.max_epoch * self.iters_per_epoch, 94 | init_lr=self.init_lr, 95 | min_lr=self.min_lr, 96 | ) 97 | 98 | 99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 100 | """Decay the learning rate""" 101 | lr = (init_lr - min_lr) * 0.5 * ( 102 | 1.0 + math.cos(math.pi * epoch / max_epoch) 103 | ) + min_lr 104 | for param_group in optimizer.param_groups: 105 | param_group["lr"] = lr 106 | 107 | 108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 109 | """Warmup the learning rate""" 110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 111 | for param_group in optimizer.param_groups: 112 | param_group["lr"] = lr 113 | 114 | 115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 116 | """Decay the learning rate""" 117 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = lr 120 | -------------------------------------------------------------------------------- /video_llama/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /video_llama/processors/functional_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import warnings 9 | 10 | import torch 11 | 12 | 13 | def _is_tensor_video_clip(clip): 14 | if not torch.is_tensor(clip): 15 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 16 | 17 | if not clip.ndimension() == 4: 18 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 19 | 20 | return True 21 | 22 | 23 | def crop(clip, i, j, h, w): 24 | """ 25 | Args: 26 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 27 | """ 28 | if len(clip.size()) != 4: 29 | raise ValueError("clip should be a 4D tensor") 30 | return clip[..., i : i + h, j : j + w] 31 | 32 | 33 | def resize(clip, target_size, interpolation_mode): 34 | if len(target_size) != 2: 35 | raise ValueError( 36 | f"target size should be tuple (height, width), instead got {target_size}" 37 | ) 38 | return torch.nn.functional.interpolate( 39 | clip, size=target_size, mode=interpolation_mode, align_corners=False 40 | ) 41 | 42 | 43 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 44 | """ 45 | Do spatial cropping and resizing to the video clip 46 | Args: 47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 48 | i (int): i in (i,j) i.e coordinates of the upper left corner. 49 | j (int): j in (i,j) i.e coordinates of the upper left corner. 50 | h (int): Height of the cropped region. 51 | w (int): Width of the cropped region. 52 | size (tuple(int, int)): height and width of resized clip 53 | Returns: 54 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) 55 | """ 56 | if not _is_tensor_video_clip(clip): 57 | raise ValueError("clip should be a 4D torch.tensor") 58 | clip = crop(clip, i, j, h, w) 59 | clip = resize(clip, size, interpolation_mode) 60 | return clip 61 | 62 | 63 | def center_crop(clip, crop_size): 64 | if not _is_tensor_video_clip(clip): 65 | raise ValueError("clip should be a 4D torch.tensor") 66 | h, w = clip.size(-2), clip.size(-1) 67 | th, tw = crop_size 68 | if h < th or w < tw: 69 | raise ValueError("height and width must be no smaller than crop_size") 70 | 71 | i = int(round((h - th) / 2.0)) 72 | j = int(round((w - tw) / 2.0)) 73 | return crop(clip, i, j, th, tw) 74 | 75 | 76 | def to_tensor(clip): 77 | """ 78 | Convert tensor data type from uint8 to float, divide value by 255.0 and 79 | permute the dimensions of clip tensor 80 | Args: 81 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 82 | Return: 83 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 84 | """ 85 | _is_tensor_video_clip(clip) 86 | if not clip.dtype == torch.uint8: 87 | raise TypeError( 88 | "clip tensor should have data type uint8. Got %s" % str(clip.dtype) 89 | ) 90 | return clip.float().permute(3, 0, 1, 2) / 255.0 91 | 92 | 93 | def normalize(clip, mean, std, inplace=False): 94 | """ 95 | Args: 96 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 97 | mean (tuple): pixel RGB mean. Size is (3) 98 | std (tuple): pixel standard deviation. Size is (3) 99 | Returns: 100 | normalized clip (torch.tensor): Size is (C, T, H, W) 101 | """ 102 | if not _is_tensor_video_clip(clip): 103 | raise ValueError("clip should be a 4D torch.tensor") 104 | if not inplace: 105 | clip = clip.clone() 106 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 107 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 109 | return clip 110 | 111 | 112 | def hflip(clip): 113 | """ 114 | Args: 115 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 116 | Returns: 117 | flipped clip (torch.tensor): Size is (C, T, H, W) 118 | """ 119 | if not _is_tensor_video_clip(clip): 120 | raise ValueError("clip should be a 4D torch.tensor") 121 | return clip.flip(-1) 122 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/models/helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import einops 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class Normalize(nn.Module): 16 | def __init__(self, dim: int) -> None: 17 | super().__init__() 18 | self.dim = dim 19 | 20 | def forward(self, x): 21 | return torch.nn.functional.normalize(x, dim=self.dim, p=2) 22 | 23 | 24 | class LearnableLogitScaling(nn.Module): 25 | def __init__( 26 | self, 27 | logit_scale_init: float = 1 / 0.07, 28 | learnable: bool = True, 29 | max_logit_scale: float = 100, 30 | ) -> None: 31 | super().__init__() 32 | self.max_logit_scale = max_logit_scale 33 | self.logit_scale_init = logit_scale_init 34 | self.learnable = learnable 35 | log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) 36 | if learnable: 37 | self.log_logit_scale = nn.Parameter(log_logit_scale) 38 | else: 39 | self.register_buffer("log_logit_scale", log_logit_scale) 40 | 41 | def forward(self, x): 42 | return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x 43 | 44 | def extra_repr(self): 45 | st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," \ 46 | f" max_logit_scale={self.max_logit_scale}" 47 | return st 48 | 49 | 50 | class EinOpsRearrange(nn.Module): 51 | def __init__(self, rearrange_expr: str, **kwargs) -> None: 52 | super().__init__() 53 | self.rearrange_expr = rearrange_expr 54 | self.kwargs = kwargs 55 | 56 | def forward(self, x): 57 | assert isinstance(x, torch.Tensor) 58 | return einops.rearrange(x, self.rearrange_expr, **self.kwargs) 59 | 60 | 61 | class VerboseNNModule(nn.Module): 62 | """ 63 | Wrapper around nn.Module that prints registered buffers and parameter names. 64 | """ 65 | 66 | @staticmethod 67 | def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: 68 | st = ( 69 | "(" 70 | + name 71 | + "): " 72 | + "tensor(" 73 | + str(tuple(tensor[1].shape)) 74 | + ", requires_grad=" 75 | + str(tensor[1].requires_grad) 76 | + ")\n" 77 | ) 78 | return st 79 | 80 | def extra_repr(self) -> str: 81 | named_modules = set() 82 | for p in self.named_modules(): 83 | named_modules.update([p[0]]) 84 | named_modules = list(named_modules) 85 | 86 | string_repr = "" 87 | for p in self.named_parameters(): 88 | name = p[0].split(".")[0] 89 | if name not in named_modules: 90 | string_repr += self.get_readable_tensor_repr(name, p) 91 | 92 | for p in self.named_buffers(): 93 | name = p[0].split(".")[0] 94 | string_repr += self.get_readable_tensor_repr(name, p) 95 | 96 | return string_repr 97 | 98 | 99 | def cast_if_src_dtype( 100 | tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype 101 | ): 102 | updated = False 103 | if tensor.dtype == src_dtype: 104 | tensor = tensor.to(dtype=tgt_dtype) 105 | updated = True 106 | return tensor, updated 107 | 108 | 109 | class QuickGELU(nn.Module): 110 | # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 111 | def forward(self, x: torch.Tensor): 112 | return x * torch.sigmoid(1.702 * x) 113 | 114 | 115 | class SelectElement(nn.Module): 116 | def __init__(self, index) -> None: 117 | super().__init__() 118 | self.index = index 119 | 120 | def forward(self, x): 121 | assert x.ndim >= 3 122 | return x[:, self.index, ...] 123 | 124 | 125 | class SelectEOSAndProject(nn.Module): 126 | """ 127 | Text Pooling used in OpenCLIP 128 | """ 129 | 130 | def __init__(self, proj: nn.Module) -> None: 131 | super().__init__() 132 | self.proj = proj 133 | 134 | def forward(self, x, seq_len): 135 | assert x.ndim == 3 136 | # x is of shape B x L x D 137 | # take features from the eot embedding (eot_token is the highest number in each sequence) 138 | x = x[torch.arange(x.shape[0]), seq_len] 139 | x = self.proj(x) 140 | return x 141 | -------------------------------------------------------------------------------- /video_llama/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from video_llama.common.registry import registry 11 | from video_llama.processors.base_processor import BaseProcessor 12 | from video_llama.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("blip_caption") 29 | class BlipCaptionProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg=None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption 69 | 70 | 71 | @registry.register_processor("blip2_image_train") 72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 74 | super().__init__(mean=mean, std=std) 75 | 76 | self.transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop( 79 | image_size, 80 | scale=(min_scale, max_scale), 81 | interpolation=InterpolationMode.BICUBIC, 82 | ), 83 | transforms.ToTensor(), 84 | self.normalize, 85 | ] 86 | ) 87 | 88 | def __call__(self, item): 89 | return self.transform(item) 90 | 91 | @classmethod 92 | def from_config(cls, cfg=None): 93 | if cfg is None: 94 | cfg = OmegaConf.create() 95 | 96 | image_size = cfg.get("image_size", 224) 97 | 98 | mean = cfg.get("mean", None) 99 | std = cfg.get("std", None) 100 | 101 | min_scale = cfg.get("min_scale", 0.5) 102 | max_scale = cfg.get("max_scale", 1.0) 103 | 104 | return cls( 105 | image_size=image_size, 106 | mean=mean, 107 | std=std, 108 | min_scale=min_scale, 109 | max_scale=max_scale, 110 | ) 111 | 112 | 113 | @registry.register_processor("blip2_image_eval") 114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor): 115 | def __init__(self, image_size=224, mean=None, std=None): 116 | super().__init__(mean=mean, std=std) 117 | 118 | self.transform = transforms.Compose( 119 | [ 120 | transforms.Resize( 121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 122 | ), 123 | transforms.ToTensor(), 124 | self.normalize, 125 | ] 126 | ) 127 | 128 | def __call__(self, item): 129 | return self.transform(item) 130 | 131 | @classmethod 132 | def from_config(cls, cfg=None): 133 | if cfg is None: 134 | cfg = OmegaConf.create() 135 | 136 | image_size = cfg.get("image_size", 224) 137 | 138 | mean = cfg.get("mean", None) 139 | std = cfg.get("std", None) 140 | 141 | return cls(image_size=image_size, mean=mean, std=std) 142 | 143 | -------------------------------------------------------------------------------- /video_llama/models/blip2_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import torch 13 | from transformers.modeling_outputs import ( 14 | ModelOutput, 15 | BaseModelOutputWithPoolingAndCrossAttentions, 16 | CausalLMOutputWithCrossAttentions, 17 | ) 18 | 19 | 20 | @dataclass 21 | class BlipSimilarity(ModelOutput): 22 | sim_i2t: torch.FloatTensor = None 23 | sim_t2i: torch.FloatTensor = None 24 | 25 | sim_i2t_m: Optional[torch.FloatTensor] = None 26 | sim_t2i_m: Optional[torch.FloatTensor] = None 27 | 28 | sim_i2t_targets: Optional[torch.FloatTensor] = None 29 | sim_t2i_targets: Optional[torch.FloatTensor] = None 30 | 31 | 32 | @dataclass 33 | class BlipIntermediateOutput(ModelOutput): 34 | """ 35 | Data class for intermediate outputs of BLIP models. 36 | 37 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 38 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 39 | 40 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 41 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 42 | 43 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 44 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 45 | 46 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 47 | decoder_labels (torch.LongTensor): labels for the captioning loss. 48 | 49 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 50 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 51 | 52 | """ 53 | 54 | # uni-modal features 55 | image_embeds: torch.FloatTensor = None 56 | text_embeds: Optional[torch.FloatTensor] = None 57 | 58 | image_embeds_m: Optional[torch.FloatTensor] = None 59 | text_embeds_m: Optional[torch.FloatTensor] = None 60 | 61 | # intermediate outputs of multimodal encoder 62 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 64 | 65 | itm_logits: Optional[torch.FloatTensor] = None 66 | itm_labels: Optional[torch.LongTensor] = None 67 | 68 | # intermediate outputs of multimodal decoder 69 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 70 | decoder_labels: Optional[torch.LongTensor] = None 71 | 72 | 73 | @dataclass 74 | class BlipOutput(ModelOutput): 75 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 76 | sims: Optional[BlipSimilarity] = None 77 | 78 | intermediate_output: BlipIntermediateOutput = None 79 | 80 | loss: Optional[torch.FloatTensor] = None 81 | 82 | loss_itc: Optional[torch.FloatTensor] = None 83 | 84 | loss_itm: Optional[torch.FloatTensor] = None 85 | 86 | loss_lm: Optional[torch.FloatTensor] = None 87 | 88 | 89 | @dataclass 90 | class BlipOutputFeatures(ModelOutput): 91 | """ 92 | Data class of features from BlipFeatureExtractor. 93 | 94 | Args: 95 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 96 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 97 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 98 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 99 | 100 | The first embedding or feature is for the [CLS] token. 101 | 102 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 103 | """ 104 | 105 | image_embeds: Optional[torch.FloatTensor] = None 106 | image_embeds_proj: Optional[torch.FloatTensor] = None 107 | 108 | text_embeds: Optional[torch.FloatTensor] = None 109 | text_embeds_proj: Optional[torch.FloatTensor] = None 110 | 111 | multimodal_embeds: Optional[torch.FloatTensor] = None 112 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/model_card.md: -------------------------------------------------------------------------------- 1 | # Model Card for ImageBind 2 | 3 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images. 4 | Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks. 5 | 6 | # Model Details 7 | 8 | ## Model Description 9 | 10 | 11 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images 12 | 13 | - **Developed by:** Meta AI 14 | - **Model type:** Multimodal model 15 | - **Language(s) (NLP):** en 16 | - **License:** CC BY-NC-SA 4.0 17 | - **Resources for more information:** 18 | - [GitHub Repo](https://github.com/facebookresearch/ImageBind) 19 | 20 | 21 | # Uses 22 | 23 | 24 | This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images. 25 | We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities. 26 | 27 | ## Out-of-Scope Use 28 | 29 | 30 | 31 | 32 | This model is *NOT* intended to be used in any real world application -- commercial or otherwise. 33 | It may produce harmful associations with different inputs. 34 | The model needs to be investigated and likely re-trained on specific data for any such application. 35 | The model is expected to work better on web-based visual data since it was trained on such data. 36 | The text encoder is likely to work only on English language text because of the underlying training datasets. 37 | 38 | # Bias, Risks, and Limitations 39 | 40 | 41 | Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness). 42 | Since our model uses such models as initialization, it will exhibit such biases too. 43 | Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes. 44 | 45 | 46 | 47 | # Training Details 48 | 49 | ## Training Data 50 | 51 | 52 | 53 | ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data. 54 | In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder. 55 | We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset. 56 | We provide the exact training data details in the paper. 57 | 58 | 59 | ## Training Procedure 60 | 61 | 62 | Please refer to the research paper and github repo for exact details on this. 63 | 64 | # Evaluation 65 | 66 | ## Testing Data, Factors & Metrics 67 | 68 | We evaluate the model on a variety of different classification benchmarks for each modality. 69 | The evaluation details are presented in the paper. 70 | The models performance is measured using standard classification metrics such as accuracy and mAP. 71 | 72 | # Citation 73 | 74 | 75 | 76 | **BibTeX:** 77 | ``` 78 | @inproceedings{girdhar2023imagebind, 79 | title={ImageBind: One Embedding Space To Bind Them All}, 80 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang 81 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, 82 | booktitle={CVPR}, 83 | year={2023} 84 | } 85 | ``` 86 | 87 | 88 | # Model Card Contact 89 | 90 | Please reach out to the authors at: rgirdhar@meta.com imisra@meta.com alaaelnouby@gmail.com 91 | 92 | # How to Get Started with the Model 93 | 94 | Our github repo provides a simple example to extract embeddings from images, audio etc. 95 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/webvid_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from video_llama.datasets.datasets.base_dataset import BaseDataset 10 | from video_llama.datasets.datasets.caption_datasets import CaptionDataset 11 | import pandas as pd 12 | import decord 13 | from decord import VideoReader 14 | import random 15 | import torch 16 | from torch.utils.data.dataloader import default_collate 17 | class WebvidDataset(BaseDataset): 18 | def __init__(self, vis_processor, text_processor, vis_root, ann_root): 19 | """ 20 | vis_root (string): Root directory of video (e.g. webvid_eval/video/) 21 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) 22 | split (string): val or test 23 | """ 24 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 25 | 26 | 27 | # 读取一个路径下所有的 28 | 29 | ts_df = [] 30 | for file_name in os.listdir(ann_root): 31 | if file_name.endswith('.csv'): 32 | df = pd.read_csv(os.path.join(ann_root, file_name)) 33 | ts_df.append(df) 34 | 35 | merged_df = pd.concat(ts_df) 36 | self.annotation = merged_df 37 | self.vis_root = vis_root 38 | self.resize_size = 224 39 | self.num_frm = 8 40 | self.frm_sampling_strategy = 'headtail' 41 | 42 | def _get_video_path(self, sample): 43 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 44 | full_video_fp = os.path.join(self.vis_root, rel_video_fp) 45 | return full_video_fp 46 | 47 | def __getitem__(self, index): 48 | num_retries = 10 # skip error videos 49 | for _ in range(num_retries): 50 | sample = self.annotation.iloc[index] 51 | sample_dict = sample.to_dict() 52 | video_id = sample_dict['videoid'] 53 | 54 | if 'name' in sample_dict.keys(): 55 | text = sample_dict['name'].strip() 56 | else: 57 | raise NotImplementedError("Un-supported text annotation format.") 58 | 59 | # fetch video 60 | video_path = self._get_video_path(sample_dict) 61 | # if os.path.exists(video_path): 62 | try: 63 | video = self.vis_processor(video_path) 64 | except: 65 | print(f"Failed to load examples with video: {video_path}. " 66 | f"Will randomly sample an example as a replacement.") 67 | index = random.randint(0, len(self) - 1) 68 | continue 69 | caption = self.text_processor(text) 70 | 71 | # print(video.size()) 72 | if video is None or caption is None \ 73 | or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]): 74 | print(f"Failed to load examples with video: {video_path}. " 75 | f"Will randomly sample an example as a replacement.") 76 | index = random.randint(0, len(self) - 1) 77 | continue 78 | else: 79 | break 80 | else: 81 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") 82 | # "image_id" is kept to stay compatible with the COCO evaluation format 83 | return { 84 | "image": video, 85 | "text_input": caption, 86 | "type":'video', 87 | } 88 | 89 | def __len__(self): 90 | return len(self.annotation) 91 | 92 | # def collater(self, samples): 93 | # new_result = {} 94 | # new_result['image'] = default_collate( [sample["image"] for sample in samples]) 95 | # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples]) 96 | # return new_result 97 | 98 | class WebvidDatasetEvalDataset(BaseDataset): 99 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 100 | """ 101 | vis_root (string): Root directory of images (e.g. coco/images/) 102 | ann_root (string): directory to store the annotation file 103 | split (string): val or test 104 | """ 105 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 106 | 107 | def __getitem__(self, index): 108 | 109 | ann = self.annotation[index] 110 | 111 | vname = ann["video"] 112 | video_path = os.path.join(self.vis_root, vname) 113 | 114 | video = self.vis_processor(video_path) 115 | 116 | return { 117 | "video": video, 118 | "image_id": ann["image_id"], 119 | "instance_id": ann["instance_id"], 120 | } 121 | 122 | 123 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/README.md: -------------------------------------------------------------------------------- 1 | # ImageBind: One Embedding Space To Bind Them All 2 | 3 | **[FAIR, Meta AI](https://ai.facebook.com/research/)** 4 | 5 | Rohit Girdhar*, 6 | Alaaeldin El-Nouby*, 7 | Zhuang Liu, 8 | Mannat Singh, 9 | Kalyan Vasudev Alwala, 10 | Armand Joulin, 11 | Ishan Misra* 12 | 13 | To appear at CVPR 2023 (*Highlighted paper*) 14 | 15 | [[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)] 16 | 17 | PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**. 18 | 19 | ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. 20 | 21 | 22 | 23 | ![ImageBind](https://user-images.githubusercontent.com/8495451/236859695-ffa13364-3e39-4d99-a8da-fbfab17f9a6b.gif) 24 | 25 | ## ImageBind model 26 | 27 | Emergent zero-shot classification performance. 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 |
ModelIN1kK400NYU-DESCLLVIPEgo4Ddownload
imagebind_huge77.750.054.066.963.425.0checkpoint
52 | 53 | ## Usage 54 | 55 | Install pytorch 1.13+ and other 3rd party dependencies. 56 | 57 | ```shell 58 | conda create --name imagebind python=3.8 -y 59 | conda activate imagebind 60 | 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977) 65 | 66 | ``` 67 | pip install soundfile 68 | ``` 69 | 70 | 71 | Extract and compare features across modalities (e.g. Image, Text and Audio). 72 | 73 | ```python 74 | import data 75 | import torch 76 | from models import imagebind_model 77 | from models.imagebind_model import ModalityType 78 | 79 | text_list=["A dog.", "A car", "A bird"] 80 | image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"] 81 | audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"] 82 | 83 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 84 | 85 | # Instantiate model 86 | model = imagebind_model.imagebind_huge(pretrained=True) 87 | model.eval() 88 | model.to(device) 89 | 90 | # Load data 91 | inputs = { 92 | ModalityType.TEXT: data.load_and_transform_text(text_list, device), 93 | ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), 94 | ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), 95 | } 96 | 97 | with torch.no_grad(): 98 | embeddings = model(inputs) 99 | 100 | print( 101 | "Vision x Text: ", 102 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1), 103 | ) 104 | print( 105 | "Audio x Text: ", 106 | torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1), 107 | ) 108 | print( 109 | "Vision x Audio: ", 110 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), 111 | ) 112 | 113 | # Expected output: 114 | # 115 | # Vision x Text: 116 | # tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05], 117 | # [3.3836e-05, 9.9994e-01, 2.4118e-05], 118 | # [4.7997e-05, 1.3496e-02, 9.8646e-01]]) 119 | # 120 | # Audio x Text: 121 | # tensor([[1., 0., 0.], 122 | # [0., 1., 0.], 123 | # [0., 0., 1.]]) 124 | # 125 | # Vision x Audio: 126 | # tensor([[0.8070, 0.1088, 0.0842], 127 | # [0.1036, 0.7884, 0.1079], 128 | # [0.0018, 0.0022, 0.9960]]) 129 | 130 | ``` 131 | 132 | ## Model card 133 | Please see the [model card](model_card.md) for details. 134 | 135 | ## License 136 | 137 | ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details. 138 | 139 | ## Contributing 140 | 141 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 142 | 143 | ## Citing ImageBind 144 | 145 | If you find this repository useful, please consider giving a star :star: and citation 146 | 147 | ``` 148 | @inproceedings{girdhar2023imagebind, 149 | title={ImageBind: One Embedding Space To Bind Them All}, 150 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang 151 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, 152 | booktitle={CVPR}, 153 | year={2023} 154 | } 155 | ``` 156 | -------------------------------------------------------------------------------- /video_llama/processors/transforms_video.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | 10 | import numbers 11 | import random 12 | 13 | from torchvision.transforms import ( 14 | RandomCrop, 15 | RandomResizedCrop, 16 | ) 17 | 18 | import video_llama.processors.functional_video as F 19 | 20 | 21 | __all__ = [ 22 | "RandomCropVideo", 23 | "RandomResizedCropVideo", 24 | "CenterCropVideo", 25 | "NormalizeVideo", 26 | "ToTensorVideo", 27 | "RandomHorizontalFlipVideo", 28 | ] 29 | 30 | 31 | class RandomCropVideo(RandomCrop): 32 | def __init__(self, size): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | 38 | def __call__(self, clip): 39 | """ 40 | Args: 41 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 42 | Returns: 43 | torch.tensor: randomly cropped/resized video clip. 44 | size is (C, T, OH, OW) 45 | """ 46 | i, j, h, w = self.get_params(clip, self.size) 47 | return F.crop(clip, i, j, h, w) 48 | 49 | def __repr__(self) -> str: 50 | return f"{self.__class__.__name__}(size={self.size})" 51 | 52 | 53 | class RandomResizedCropVideo(RandomResizedCrop): 54 | def __init__( 55 | self, 56 | size, 57 | scale=(0.08, 1.0), 58 | ratio=(3.0 / 4.0, 4.0 / 3.0), 59 | interpolation_mode="bilinear", 60 | ): 61 | if isinstance(size, tuple): 62 | if len(size) != 2: 63 | raise ValueError( 64 | f"size should be tuple (height, width), instead got {size}" 65 | ) 66 | self.size = size 67 | else: 68 | self.size = (size, size) 69 | 70 | self.interpolation_mode = interpolation_mode 71 | self.scale = scale 72 | self.ratio = ratio 73 | 74 | def __call__(self, clip): 75 | """ 76 | Args: 77 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 78 | Returns: 79 | torch.tensor: randomly cropped/resized video clip. 80 | size is (C, T, H, W) 81 | """ 82 | i, j, h, w = self.get_params(clip, self.scale, self.ratio) 83 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) 84 | 85 | def __repr__(self) -> str: 86 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})" 87 | 88 | 89 | class CenterCropVideo: 90 | def __init__(self, crop_size): 91 | if isinstance(crop_size, numbers.Number): 92 | self.crop_size = (int(crop_size), int(crop_size)) 93 | else: 94 | self.crop_size = crop_size 95 | 96 | def __call__(self, clip): 97 | """ 98 | Args: 99 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 100 | Returns: 101 | torch.tensor: central cropping of video clip. Size is 102 | (C, T, crop_size, crop_size) 103 | """ 104 | return F.center_crop(clip, self.crop_size) 105 | 106 | def __repr__(self) -> str: 107 | return f"{self.__class__.__name__}(crop_size={self.crop_size})" 108 | 109 | 110 | class NormalizeVideo: 111 | """ 112 | Normalize the video clip by mean subtraction and division by standard deviation 113 | Args: 114 | mean (3-tuple): pixel RGB mean 115 | std (3-tuple): pixel RGB standard deviation 116 | inplace (boolean): whether do in-place normalization 117 | """ 118 | 119 | def __init__(self, mean, std, inplace=False): 120 | self.mean = mean 121 | self.std = std 122 | self.inplace = inplace 123 | 124 | def __call__(self, clip): 125 | """ 126 | Args: 127 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) 128 | """ 129 | return F.normalize(clip, self.mean, self.std, self.inplace) 130 | 131 | def __repr__(self) -> str: 132 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 133 | 134 | 135 | class ToTensorVideo: 136 | """ 137 | Convert tensor data type from uint8 to float, divide value by 255.0 and 138 | permute the dimensions of clip tensor 139 | """ 140 | 141 | def __init__(self): 142 | pass 143 | 144 | def __call__(self, clip): 145 | """ 146 | Args: 147 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 148 | Return: 149 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 150 | """ 151 | return F.to_tensor(clip) 152 | 153 | def __repr__(self) -> str: 154 | return self.__class__.__name__ 155 | 156 | 157 | class RandomHorizontalFlipVideo: 158 | """ 159 | Flip the video clip along the horizonal direction with a given probability 160 | Args: 161 | p (float): probability of the clip being flipped. Default value is 0.5 162 | """ 163 | 164 | def __init__(self, p=0.5): 165 | self.p = p 166 | 167 | def __call__(self, clip): 168 | """ 169 | Args: 170 | clip (torch.tensor): Size is (C, T, H, W) 171 | Return: 172 | clip (torch.tensor): Size is (C, T, H, W) 173 | """ 174 | if random.random() < self.p: 175 | clip = F.hflip(clip) 176 | return clip 177 | 178 | def __repr__(self) -> str: 179 | return f"{self.__class__.__name__}(p={self.p})" 180 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from video_llama.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class MultiIterLoader: 16 | """ 17 | A simple wrapper for iterating over multiple iterators. 18 | 19 | Args: 20 | loaders (List[Loader]): List of Iterator loaders. 21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | """ 23 | 24 | def __init__(self, loaders, ratios=None): 25 | # assert all loaders has __next__ method 26 | for loader in loaders: 27 | assert hasattr( 28 | loader, "__next__" 29 | ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | if ratios is None: 32 | ratios = [1.0] * len(loaders) 33 | else: 34 | assert len(ratios) == len(loaders) 35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | self.loaders = loaders 38 | self.ratios = ratios 39 | 40 | def __next__(self): 41 | # random sample from each loader by ratio 42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | return next(self.loaders[loader_idx]) 44 | 45 | 46 | class PrefetchLoader(object): 47 | """ 48 | Modified from https://github.com/ChenRocks/UNITER. 49 | 50 | overlap compute and cuda data transfer 51 | (copied and then modified from nvidia apex) 52 | """ 53 | 54 | def __init__(self, loader): 55 | self.loader = loader 56 | self.stream = torch.cuda.Stream() 57 | 58 | def __iter__(self): 59 | loader_it = iter(self.loader) 60 | self.preload(loader_it) 61 | batch = self.next(loader_it) 62 | while batch is not None: 63 | is_tuple = isinstance(batch, tuple) 64 | if is_tuple: 65 | task, batch = batch 66 | 67 | if is_tuple: 68 | yield task, batch 69 | else: 70 | yield batch 71 | batch = self.next(loader_it) 72 | 73 | def __len__(self): 74 | return len(self.loader) 75 | 76 | def preload(self, it): 77 | try: 78 | self.batch = next(it) 79 | except StopIteration: 80 | self.batch = None 81 | return 82 | # if record_stream() doesn't work, another option is to make sure 83 | # device inputs are created on the main stream. 84 | # self.next_input_gpu = torch.empty_like(self.next_input, 85 | # device='cuda') 86 | # self.next_target_gpu = torch.empty_like(self.next_target, 87 | # device='cuda') 88 | # Need to make sure the memory allocated for next_* is not still in use 89 | # by the main stream at the time we start copying to next_*: 90 | # self.stream.wait_stream(torch.cuda.current_stream()) 91 | with torch.cuda.stream(self.stream): 92 | self.batch = move_to_cuda(self.batch) 93 | # more code for the alternative if record_stream() doesn't work: 94 | # copy_ will record the use of the pinned source tensor in this 95 | # side stream. 96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 98 | # self.next_input = self.next_input_gpu 99 | # self.next_target = self.next_target_gpu 100 | 101 | def next(self, it): 102 | torch.cuda.current_stream().wait_stream(self.stream) 103 | batch = self.batch 104 | if batch is not None: 105 | record_cuda_stream(batch) 106 | self.preload(it) 107 | return batch 108 | 109 | def __getattr__(self, name): 110 | method = self.loader.__getattribute__(name) 111 | return method 112 | 113 | 114 | def record_cuda_stream(batch): 115 | if isinstance(batch, torch.Tensor): 116 | batch.record_stream(torch.cuda.current_stream()) 117 | elif isinstance(batch, list) or isinstance(batch, tuple): 118 | for t in batch: 119 | record_cuda_stream(t) 120 | elif isinstance(batch, dict): 121 | for t in batch.values(): 122 | record_cuda_stream(t) 123 | else: 124 | pass 125 | 126 | 127 | class IterLoader: 128 | """ 129 | A wrapper to convert DataLoader as an infinite iterator. 130 | 131 | Modified from: 132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 133 | """ 134 | 135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 136 | self._dataloader = dataloader 137 | self.iter_loader = iter(self._dataloader) 138 | self._use_distributed = use_distributed 139 | self._epoch = 0 140 | 141 | @property 142 | def epoch(self) -> int: 143 | return self._epoch 144 | 145 | def __next__(self): 146 | try: 147 | data = next(self.iter_loader) 148 | except StopIteration: 149 | self._epoch += 1 150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 151 | self._dataloader.sampler.set_epoch(self._epoch) 152 | time.sleep(2) # Prevent possible deadlock during epoch transition 153 | self.iter_loader = iter(self._dataloader) 154 | data = next(self.iter_loader) 155 | 156 | return data 157 | 158 | def __iter__(self): 159 | return self 160 | 161 | def __len__(self): 162 | return len(self._dataloader) 163 | -------------------------------------------------------------------------------- /infer_batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py 3 | """ 4 | import argparse 5 | import time 6 | import os 7 | import json 8 | 9 | from video_llama.common.config import Config 10 | from video_llama.common.registry import registry 11 | from video_llama.conversation.conversation_video import Chat, default_conversation, conv_llava_llama_2 12 | import decord 13 | decord.bridge.set_bridge('torch') 14 | from tqdm import tqdm 15 | 16 | from video_llama.datasets.builders import * 17 | from video_llama.models import * 18 | from video_llama.processors import * 19 | from video_llama.runners import * 20 | from video_llama.tasks import * 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Demo") 25 | parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval_withaudio.yaml', help="path to configuration file.") 26 | parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") 27 | parser.add_argument("--model_type", type=str, default='vicuna', help="The type of LLM") 28 | parser.add_argument( 29 | "--options", 30 | nargs="+", 31 | help="override some settings in the used config, the key-value pair " 32 | "in xxx=yyy format will be merged into config file (deprecate), " 33 | "change to --cfg-options instead.", 34 | ) 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | class ChatBot: 40 | 41 | def __init__(self, args): 42 | self.chat = self._init_model(args) 43 | if args.model_type == 'vicuna': 44 | self.chat_state = default_conversation.copy() 45 | else: 46 | self.chat_state = conv_llava_llama_2.copy() 47 | self.img_list = list() 48 | self.set_para() 49 | 50 | def _init_model(self, args): 51 | print('Initializing Chat') 52 | cfg = Config(args) 53 | model_config = cfg.model_cfg 54 | model_config.device_8bit = args.gpu_id 55 | model_cls = registry.get_model_class(model_config.arch) 56 | model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) 57 | model.eval() 58 | vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train 59 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 60 | chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) 61 | print('Initialization Finished') 62 | return chat 63 | 64 | def set_para(self, num_beams=1, temperature=1.0): 65 | self.num_beams = num_beams 66 | self.temperature = temperature 67 | print('set num_beams: {}'.format(num_beams)) 68 | print('set temperature: {}'.format(temperature)) 69 | 70 | def upload(self, up_img=False, up_video=False, audio_flag=False): 71 | if up_img and not up_video: 72 | self.chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail." 73 | self.chat.upload_img(up_img, self.chat_state, self.img_list) 74 | elif not up_img and up_video: 75 | self.chat_state.system = "" 76 | if audio_flag: 77 | self.chat.upload_video(up_video, self.chat_state, self.img_list) 78 | else: 79 | self.chat.upload_video_without_audio(up_video, self.chat_state, self.img_list) 80 | 81 | def ask_answer(self, user_message): 82 | self.chat.ask(user_message, self.chat_state) 83 | llm_message = self.chat.answer(conv=self.chat_state, 84 | img_list=self.img_list, 85 | num_beams=self.num_beams, 86 | temperature=self.temperature, 87 | max_new_tokens=256, 88 | max_length=2000)[0] 89 | 90 | return llm_message 91 | 92 | def reset(self): 93 | if self.chat_state is not None: 94 | self.chat_state.messages = list() 95 | if self.img_list is not None: 96 | self.img_list = list() 97 | self.set_para() 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | args = parse_args() 103 | chatbot = ChatBot(args) 104 | # file_path = "/home/gs534/rds/rds-t2-cs164-KQ4S3rlDzm8/gs534/opensource/favor/data/VALOR32k/valor32ktest.json" 105 | file_path = "/home/gs534/rds/rds-t2-cs164-KQ4S3rlDzm8/gs534/opensource/favor/data/audio/librispeech_test_clean.json" 106 | # file_path = "/home/gs534/rds/rds-t2-cs164-KQ4S3rlDzm8/gs534/opensource/favor/data/audio/audiocaps_test.json" 107 | 108 | # while True: 109 | # try: 110 | # file_path = input('Input file path: ') 111 | # except: 112 | # print('Input error, try again.') 113 | # continue 114 | # else: 115 | # if file_path == 'exit': 116 | # print('Goodbye!') 117 | # break 118 | # if not os.path.exists(file_path): 119 | # print('{} not exist, try again.'.format(file_path)) 120 | # continue 121 | 122 | out_path = "output/librispeech_finetuned.json" # input('Output file path: ') 123 | 124 | num_beams = 1 # int(input('Input new num_beams:(1-10) ')) 125 | temperature = 1.0 # float(input('Input new temperature:(0.1-2.0) ')) 126 | chatbot.set_para(num_beams=num_beams, temperature=temperature) 127 | 128 | with open(file_path, "r") as f: 129 | data = json.load(f) 130 | 131 | res = [] 132 | for item in tqdm(data): 133 | chatbot.set_para(num_beams=num_beams, temperature=temperature) 134 | 135 | path = item["image_name"] 136 | user_message = item["conversation"][0]["value"] 137 | # user_message += "\n###Assistant:" 138 | try: 139 | chatbot.upload(up_video=path, audio_flag=True) 140 | llm_message = chatbot.ask_answer(user_message=user_message) 141 | llm_message = llm_message.split("\n###")[0].replace("###Assistant:", "").replace("### Assistant:", "") 142 | print(llm_message) 143 | res.append( 144 | { 145 | "answer": item["conversation"][1]["value"], 146 | "gen_answer": llm_message 147 | } 148 | ) 149 | 150 | chatbot.reset() 151 | except: 152 | time.sleep(1) 153 | pass 154 | 155 | with open(out_path, "w") as f: 156 | json.dump(res, f, indent=4, ensure_ascii=False) 157 | -------------------------------------------------------------------------------- /video_llama/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import torch 11 | from omegaconf import OmegaConf 12 | 13 | from video_llama.common.registry import registry 14 | from video_llama.models.base_model import BaseModel 15 | from video_llama.models.blip2 import Blip2Base 16 | from video_llama.models.video_llama import VideoLLAMA 17 | from video_llama.processors.base_processor import BaseProcessor 18 | 19 | 20 | __all__ = [ 21 | "load_model", 22 | "BaseModel", 23 | "Blip2Base", 24 | "VideoLLAMA" 25 | ] 26 | 27 | 28 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 29 | """ 30 | Load supported models. 31 | 32 | To list all available models and types in registry: 33 | >>> from video_llama.models import model_zoo 34 | >>> print(model_zoo) 35 | 36 | Args: 37 | name (str): name of the model. 38 | model_type (str): type of the model. 39 | is_eval (bool): whether the model is in eval mode. Default: False. 40 | device (str): device to use. Default: "cpu". 41 | checkpoint (str): path or to checkpoint. Default: None. 42 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 43 | 44 | Returns: 45 | model (torch.nn.Module): model. 46 | """ 47 | 48 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 49 | 50 | if checkpoint is not None: 51 | model.load_checkpoint(checkpoint) 52 | 53 | if is_eval: 54 | model.eval() 55 | 56 | if device == "cpu": 57 | model = model.float() 58 | 59 | return model.to(device) 60 | 61 | 62 | def load_preprocess(config): 63 | """ 64 | Load preprocessor configs and construct preprocessors. 65 | 66 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 67 | 68 | Args: 69 | config (dict): preprocessor configs. 70 | 71 | Returns: 72 | vis_processors (dict): preprocessors for visual inputs. 73 | txt_processors (dict): preprocessors for text inputs. 74 | 75 | Key is "train" or "eval" for processors used in training and evaluation respectively. 76 | """ 77 | 78 | def _build_proc_from_cfg(cfg): 79 | return ( 80 | registry.get_processor_class(cfg.name).from_config(cfg) 81 | if cfg is not None 82 | else BaseProcessor() 83 | ) 84 | 85 | vis_processors = dict() 86 | txt_processors = dict() 87 | 88 | vis_proc_cfg = config.get("vis_processor") 89 | txt_proc_cfg = config.get("text_processor") 90 | 91 | if vis_proc_cfg is not None: 92 | vis_train_cfg = vis_proc_cfg.get("train") 93 | vis_eval_cfg = vis_proc_cfg.get("eval") 94 | else: 95 | vis_train_cfg = None 96 | vis_eval_cfg = None 97 | 98 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 99 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 100 | 101 | if txt_proc_cfg is not None: 102 | txt_train_cfg = txt_proc_cfg.get("train") 103 | txt_eval_cfg = txt_proc_cfg.get("eval") 104 | else: 105 | txt_train_cfg = None 106 | txt_eval_cfg = None 107 | 108 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 109 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 110 | 111 | return vis_processors, txt_processors 112 | 113 | 114 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 115 | """ 116 | Load model and its related preprocessors. 117 | 118 | List all available models and types in registry: 119 | >>> from video_llama.models import model_zoo 120 | >>> print(model_zoo) 121 | 122 | Args: 123 | name (str): name of the model. 124 | model_type (str): type of the model. 125 | is_eval (bool): whether the model is in eval mode. Default: False. 126 | device (str): device to use. Default: "cpu". 127 | 128 | Returns: 129 | model (torch.nn.Module): model. 130 | vis_processors (dict): preprocessors for visual inputs. 131 | txt_processors (dict): preprocessors for text inputs. 132 | """ 133 | model_cls = registry.get_model_class(name) 134 | 135 | # load model 136 | model = model_cls.from_pretrained(model_type=model_type) 137 | 138 | if is_eval: 139 | model.eval() 140 | 141 | # load preprocess 142 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 143 | if cfg is not None: 144 | preprocess_cfg = cfg.preprocess 145 | 146 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 147 | else: 148 | vis_processors, txt_processors = None, None 149 | logging.info( 150 | f"""No default preprocess for model {name} ({model_type}). 151 | This can happen if the model is not finetuned on downstream datasets, 152 | or it is not intended for direct use without finetuning. 153 | """ 154 | ) 155 | 156 | if device == "cpu" or device == torch.device("cpu"): 157 | model = model.float() 158 | 159 | return model.to(device), vis_processors, txt_processors 160 | 161 | 162 | class ModelZoo: 163 | """ 164 | A utility class to create string representation of available model architectures and types. 165 | 166 | >>> from video_llama.models import model_zoo 167 | >>> # list all available models 168 | >>> print(model_zoo) 169 | >>> # show total number of models 170 | >>> print(len(model_zoo)) 171 | """ 172 | 173 | def __init__(self) -> None: 174 | self.model_zoo = { 175 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 176 | for k, v in registry.mapping["model_name_mapping"].items() 177 | } 178 | 179 | def __str__(self) -> str: 180 | return ( 181 | "=" * 50 182 | + "\n" 183 | + f"{'Architectures':<30} {'Types'}\n" 184 | + "=" * 50 185 | + "\n" 186 | + "\n".join( 187 | [ 188 | f"{name:<30} {', '.join(types)}" 189 | for name, types in self.model_zoo.items() 190 | ] 191 | ) 192 | ) 193 | 194 | def __iter__(self): 195 | return iter(self.model_zoo.items()) 196 | 197 | def __len__(self): 198 | return sum([len(v) for v in self.model_zoo.values()]) 199 | 200 | 201 | model_zoo = ModelZoo() 202 | -------------------------------------------------------------------------------- /video_llama/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from video_llama.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def global_avg(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = "" 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt="{avg:.4f}") 129 | data_time = SmoothedValue(fmt="{avg:.4f}") 130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 131 | log_msg = [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append("max mem: {memory:.0f}") 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print( 152 | log_msg.format( 153 | i, 154 | len(iterable), 155 | eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), 158 | data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB, 160 | ) 161 | ) 162 | else: 163 | print( 164 | log_msg.format( 165 | i, 166 | len(iterable), 167 | eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), 170 | data=str(data_time), 171 | ) 172 | ) 173 | i += 1 174 | end = time.time() 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print( 178 | "{} Total time: {} ({:.4f} s / it)".format( 179 | header, total_time_str, total_time / len(iterable) 180 | ) 181 | ) 182 | 183 | 184 | class AttrDict(dict): 185 | def __init__(self, *args, **kwargs): 186 | super(AttrDict, self).__init__(*args, **kwargs) 187 | self.__dict__ = self 188 | 189 | 190 | def setup_logger(): 191 | logging.basicConfig( 192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 193 | format="%(asctime)s [%(levelname)s] %(message)s", 194 | handlers=[logging.StreamHandler()], 195 | ) 196 | -------------------------------------------------------------------------------- /video_llama/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from video_llama.common.registry import registry 26 | from video_llama.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | 114 | # reorganize by split 115 | for _, dataset in datasets.items(): 116 | for split_name, dataset_split in dataset.items(): 117 | if split_name not in reorg_datasets: 118 | reorg_datasets[split_name] = [dataset_split] 119 | else: 120 | reorg_datasets[split_name].append(dataset_split) 121 | 122 | return reorg_datasets 123 | 124 | 125 | def concat_datasets(datasets): 126 | """ 127 | Concatenates multiple datasets into a single dataset. 128 | 129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 130 | generic IterableDataset because it requires creating separate samplers. 131 | 132 | Now only supports conctenating training datasets and assuming validation and testing 133 | have only a single dataset. This is because metrics should not be computed on the concatenated 134 | datasets. 135 | 136 | Args: 137 | datasets: dict of torch.utils.data.Dataset objects by split. 138 | 139 | Returns: 140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 141 | "val" and "test" remain the same. 142 | 143 | If the input training datasets contain both map-style and DataPipeline datasets, returns 144 | a tuple, where the first element is a concatenated map-style dataset and the second 145 | element is a chained DataPipeline dataset. 146 | 147 | """ 148 | # concatenate datasets in the same split 149 | for split_name in datasets: 150 | if split_name != "train": 151 | assert ( 152 | len(datasets[split_name]) == 1 153 | ), "Do not support multiple {} datasets.".format(split_name) 154 | datasets[split_name] = datasets[split_name][0] 155 | else: 156 | iterable_datasets, map_datasets = [], [] 157 | for dataset in datasets[split_name]: 158 | if isinstance(dataset, wds.DataPipeline): 159 | logging.info( 160 | "Dataset {} is IterableDataset, can't be concatenated.".format( 161 | dataset 162 | ) 163 | ) 164 | iterable_datasets.append(dataset) 165 | elif isinstance(dataset, IterableDataset): 166 | raise NotImplementedError( 167 | "Do not support concatenation of generic IterableDataset." 168 | ) 169 | else: 170 | map_datasets.append(dataset) 171 | 172 | # if len(iterable_datasets) > 0: 173 | # concatenate map-style datasets and iterable-style datasets separately 174 | if len(iterable_datasets) > 1: 175 | chained_datasets = ( 176 | ChainDataset(iterable_datasets) 177 | ) 178 | elif len(iterable_datasets) == 1: 179 | chained_datasets = iterable_datasets[0] 180 | else: 181 | chained_datasets = None 182 | 183 | concat_datasets = ( 184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 185 | ) 186 | 187 | train_datasets = concat_datasets, chained_datasets 188 | train_datasets = tuple([x for x in train_datasets if x is not None]) 189 | train_datasets = ( 190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 191 | ) 192 | 193 | datasets[split_name] = train_datasets 194 | 195 | return datasets 196 | 197 | -------------------------------------------------------------------------------- /video_llama/processors/video_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | from video_llama.common.registry import registry 10 | from decord import VideoReader 11 | import decord 12 | import numpy as np 13 | from video_llama.processors import transforms_video 14 | from video_llama.processors.base_processor import BaseProcessor 15 | from video_llama.processors.randaugment import VideoRandomAugment 16 | from video_llama.processors import functional_video as F 17 | from omegaconf import OmegaConf 18 | from torchvision import transforms 19 | import random as rnd 20 | 21 | 22 | MAX_INT = registry.get("MAX_INT") 23 | decord.bridge.set_bridge("torch") 24 | 25 | def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False): 26 | decord.bridge.set_bridge("torch") 27 | vr = VideoReader(uri=video_path, height=height, width=width) 28 | 29 | vlen = len(vr) 30 | start, end = 0, vlen 31 | 32 | n_frms = min(n_frms, vlen) 33 | 34 | if sampling == "uniform": 35 | indices = np.arange(start, end, vlen / n_frms).astype(int).tolist() 36 | elif sampling == "headtail": 37 | indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) 38 | indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) 39 | indices = indices_h + indices_t 40 | else: 41 | raise NotImplementedError 42 | 43 | # get_batch -> T, H, W, C 44 | temp_frms = vr.get_batch(indices) 45 | # print(type(temp_frms)) 46 | tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms 47 | frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W) 48 | 49 | if not return_msg: 50 | return frms 51 | 52 | fps = float(vr.get_avg_fps()) 53 | sec = ", ".join([str(round(f / fps, 1)) for f in indices]) 54 | # " " should be added in the start and end 55 | msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. " 56 | return frms, msg 57 | 58 | 59 | class AlproVideoBaseProcessor(BaseProcessor): 60 | def __init__(self, mean=None, std=None, n_frms=MAX_INT): 61 | if mean is None: 62 | mean = (0.48145466, 0.4578275, 0.40821073) 63 | if std is None: 64 | std = (0.26862954, 0.26130258, 0.27577711) 65 | 66 | self.normalize = transforms_video.NormalizeVideo(mean, std) 67 | 68 | self.n_frms = n_frms 69 | 70 | 71 | class ToUint8(object): 72 | def __init__(self): 73 | pass 74 | 75 | def __call__(self, tensor): 76 | return tensor.to(torch.uint8) 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ 80 | 81 | 82 | class ToTHWC(object): 83 | """ 84 | Args: 85 | clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) 86 | Return: 87 | clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) 88 | """ 89 | 90 | def __init__(self): 91 | pass 92 | 93 | def __call__(self, tensor): 94 | return tensor.permute(1, 2, 3, 0) 95 | 96 | def __repr__(self): 97 | return self.__class__.__name__ 98 | 99 | 100 | class ResizeVideo(object): 101 | def __init__(self, target_size, interpolation_mode="bilinear"): 102 | self.target_size = target_size 103 | self.interpolation_mode = interpolation_mode 104 | 105 | def __call__(self, clip): 106 | """ 107 | Args: 108 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 109 | Returns: 110 | torch.tensor: central cropping of video clip. Size is 111 | (C, T, crop_size, crop_size) 112 | """ 113 | return F.resize(clip, self.target_size, self.interpolation_mode) 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) 117 | 118 | 119 | @registry.register_processor("alpro_video_train") 120 | class AlproVideoTrainProcessor(AlproVideoBaseProcessor): 121 | def __init__( 122 | self, 123 | image_size=384, 124 | mean=None, 125 | std=None, 126 | min_scale=0.5, 127 | max_scale=1.0, 128 | n_frms=MAX_INT, 129 | ): 130 | super().__init__(mean=mean, std=std, n_frms=n_frms) 131 | 132 | self.image_size = image_size 133 | 134 | self.transform = transforms.Compose( 135 | [ 136 | # Video size is (C, T, H, W) 137 | transforms_video.RandomResizedCropVideo( 138 | image_size, 139 | scale=(min_scale, max_scale), 140 | interpolation_mode="bicubic", 141 | ), 142 | ToTHWC(), # C, T, H, W -> T, H, W, C 143 | ToUint8(), 144 | transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W 145 | self.normalize, 146 | ] 147 | ) 148 | 149 | def __call__(self, vpath): 150 | """ 151 | Args: 152 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 153 | Returns: 154 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 155 | """ 156 | clip = load_video( 157 | video_path=vpath, 158 | n_frms=self.n_frms, 159 | height=self.image_size, 160 | width=self.image_size, 161 | sampling="headtail", 162 | ) 163 | 164 | return self.transform(clip) 165 | 166 | @classmethod 167 | def from_config(cls, cfg=None): 168 | if cfg is None: 169 | cfg = OmegaConf.create() 170 | 171 | image_size = cfg.get("image_size", 256) 172 | 173 | mean = cfg.get("mean", None) 174 | std = cfg.get("std", None) 175 | 176 | min_scale = cfg.get("min_scale", 0.5) 177 | max_scale = cfg.get("max_scale", 1.0) 178 | 179 | n_frms = cfg.get("n_frms", MAX_INT) 180 | 181 | return cls( 182 | image_size=image_size, 183 | mean=mean, 184 | std=std, 185 | min_scale=min_scale, 186 | max_scale=max_scale, 187 | n_frms=n_frms, 188 | ) 189 | 190 | 191 | @registry.register_processor("alpro_video_eval") 192 | class AlproVideoEvalProcessor(AlproVideoBaseProcessor): 193 | def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): 194 | super().__init__(mean=mean, std=std, n_frms=n_frms) 195 | 196 | self.image_size = image_size 197 | 198 | # Input video size is (C, T, H, W) 199 | self.transform = transforms.Compose( 200 | [ 201 | # frames will be resized during decord loading. 202 | ToUint8(), # C, T, H, W 203 | ToTHWC(), # T, H, W, C 204 | transforms_video.ToTensorVideo(), # C, T, H, W 205 | self.normalize, # C, T, H, W 206 | ] 207 | ) 208 | 209 | def __call__(self, vpath): 210 | """ 211 | Args: 212 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 213 | Returns: 214 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 215 | """ 216 | clip = load_video( 217 | video_path=vpath, 218 | n_frms=self.n_frms, 219 | height=self.image_size, 220 | width=self.image_size, 221 | ) 222 | 223 | return self.transform(clip) 224 | 225 | @classmethod 226 | def from_config(cls, cfg=None): 227 | if cfg is None: 228 | cfg = OmegaConf.create() 229 | 230 | image_size = cfg.get("image_size", 256) 231 | 232 | mean = cfg.get("mean", None) 233 | std = cfg.get("std", None) 234 | 235 | n_frms = cfg.get("n_frms", MAX_INT) 236 | 237 | return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) 238 | -------------------------------------------------------------------------------- /video_llama/processors/.ipynb_checkpoints/video_processor-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | from video_llama.common.registry import registry 10 | from decord import VideoReader 11 | import decord 12 | import numpy as np 13 | from video_llama.processors import transforms_video 14 | from video_llama.processors.base_processor import BaseProcessor 15 | from video_llama.processors.randaugment import VideoRandomAugment 16 | from video_llama.processors import functional_video as F 17 | from omegaconf import OmegaConf 18 | from torchvision import transforms 19 | import random as rnd 20 | MAX_INT = registry.get("MAX_INT") 21 | 22 | def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"): 23 | vr = VideoReader(uri=video_path, height=height, width=width) 24 | 25 | vlen = len(vr) 26 | start, end = 0, vlen 27 | 28 | n_frms = min(n_frms, vlen) 29 | 30 | if sampling == "uniform": 31 | indices = np.arange(start, end, vlen / n_frms).astype(int).tolist() 32 | elif sampling == "headtail": 33 | indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) 34 | indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) 35 | indices = indices_h + indices_t 36 | else: 37 | raise NotImplementedError 38 | 39 | # get_batch -> T, H, W, C 40 | print(video_path) 41 | print(indices) 42 | print(vr.get_batch(indices)) 43 | 44 | frms = vr.get_batch(indices).permute(3, 0, 1, 2).float() # (C, T, H, W) 45 | # print(111) 46 | return frms 47 | 48 | class AlproVideoBaseProcessor(BaseProcessor): 49 | def __init__(self, mean=None, std=None, n_frms=MAX_INT): 50 | if mean is None: 51 | mean = (0.48145466, 0.4578275, 0.40821073) 52 | if std is None: 53 | std = (0.26862954, 0.26130258, 0.27577711) 54 | 55 | self.normalize = transforms_video.NormalizeVideo(mean, std) 56 | 57 | self.n_frms = n_frms 58 | 59 | 60 | class ToUint8(object): 61 | def __init__(self): 62 | pass 63 | 64 | def __call__(self, tensor): 65 | return tensor.to(torch.uint8) 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ 69 | 70 | 71 | class ToTHWC(object): 72 | """ 73 | Args: 74 | clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) 75 | Return: 76 | clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) 77 | """ 78 | 79 | def __init__(self): 80 | pass 81 | 82 | def __call__(self, tensor): 83 | return tensor.permute(1, 2, 3, 0) 84 | 85 | def __repr__(self): 86 | return self.__class__.__name__ 87 | 88 | 89 | class ResizeVideo(object): 90 | def __init__(self, target_size, interpolation_mode="bilinear"): 91 | self.target_size = target_size 92 | self.interpolation_mode = interpolation_mode 93 | 94 | def __call__(self, clip): 95 | """ 96 | Args: 97 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 98 | Returns: 99 | torch.tensor: central cropping of video clip. Size is 100 | (C, T, crop_size, crop_size) 101 | """ 102 | return F.resize(clip, self.target_size, self.interpolation_mode) 103 | 104 | def __repr__(self): 105 | return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) 106 | 107 | 108 | @registry.register_processor("alpro_video_train") 109 | class AlproVideoTrainProcessor(AlproVideoBaseProcessor): 110 | def __init__( 111 | self, 112 | image_size=384, 113 | mean=None, 114 | std=None, 115 | min_scale=0.5, 116 | max_scale=1.0, 117 | n_frms=MAX_INT, 118 | ): 119 | super().__init__(mean=mean, std=std, n_frms=n_frms) 120 | 121 | self.image_size = image_size 122 | 123 | self.transform = transforms.Compose( 124 | [ 125 | # Video size is (C, T, H, W) 126 | transforms_video.RandomResizedCropVideo( 127 | image_size, 128 | scale=(min_scale, max_scale), 129 | interpolation_mode="bicubic", 130 | ), 131 | transforms_video.RandomHorizontalFlipVideo(), 132 | ToTHWC(), # C, T, H, W -> T, H, W, C 133 | VideoRandomAugment( 134 | 2, 135 | 5, 136 | augs=[ 137 | "Identity", 138 | "AutoContrast", 139 | "Brightness", 140 | "Sharpness", 141 | "Equalize", 142 | "ShearX", 143 | "ShearY", 144 | "TranslateX", 145 | "TranslateY", 146 | "Rotate", 147 | ], 148 | ), 149 | ToUint8(), 150 | transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W 151 | self.normalize, 152 | ] 153 | ) 154 | 155 | def __call__(self, vpath): 156 | """ 157 | Args: 158 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 159 | Returns: 160 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 161 | """ 162 | clip = load_video( 163 | video_path=vpath, 164 | n_frms=self.n_frms, 165 | height=self.image_size, 166 | width=self.image_size, 167 | sampling="headtail", 168 | ) 169 | 170 | return self.transform(clip) 171 | 172 | @classmethod 173 | def from_config(cls, cfg=None): 174 | if cfg is None: 175 | cfg = OmegaConf.create() 176 | 177 | image_size = cfg.get("image_size", 256) 178 | 179 | mean = cfg.get("mean", None) 180 | std = cfg.get("std", None) 181 | 182 | min_scale = cfg.get("min_scale", 0.5) 183 | max_scale = cfg.get("max_scale", 1.0) 184 | 185 | n_frms = cfg.get("n_frms", MAX_INT) 186 | 187 | return cls( 188 | image_size=image_size, 189 | mean=mean, 190 | std=std, 191 | min_scale=min_scale, 192 | max_scale=max_scale, 193 | n_frms=n_frms, 194 | ) 195 | 196 | 197 | @registry.register_processor("alpro_video_eval") 198 | class AlproVideoEvalProcessor(AlproVideoBaseProcessor): 199 | def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): 200 | super().__init__(mean=mean, std=std, n_frms=n_frms) 201 | 202 | self.image_size = image_size 203 | 204 | # Input video size is (C, T, H, W) 205 | self.transform = transforms.Compose( 206 | [ 207 | # frames will be resized during decord loading. 208 | ToUint8(), # C, T, H, W 209 | ToTHWC(), # T, H, W, C 210 | transforms_video.ToTensorVideo(), # C, T, H, W 211 | self.normalize, # C, T, H, W 212 | ] 213 | ) 214 | 215 | def __call__(self, vpath): 216 | """ 217 | Args: 218 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 219 | Returns: 220 | torch.tensor: video clip after transforms. Size is (C, T, size, size). 221 | """ 222 | clip = load_video( 223 | video_path=vpath, 224 | n_frms=self.n_frms, 225 | height=self.image_size, 226 | width=self.image_size, 227 | ) 228 | 229 | return self.transform(clip) 230 | 231 | @classmethod 232 | def from_config(cls, cfg=None): 233 | if cfg is None: 234 | cfg = OmegaConf.create() 235 | 236 | image_size = cfg.get("image_size", 256) 237 | 238 | mean = cfg.get("mean", None) 239 | std = cfg.get("std", None) 240 | 241 | n_frms = cfg.get("n_frms", MAX_INT) 242 | 243 | return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) 244 | -------------------------------------------------------------------------------- /video_llama/models/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2023, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | import contextlib 9 | import logging 10 | import os 11 | import time 12 | import datetime 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.distributed as dist 17 | import torch.nn.functional as F 18 | 19 | import video_llama.common.dist_utils as dist_utils 20 | from video_llama.common.dist_utils import download_cached_file 21 | from video_llama.common.utils import is_url 22 | from video_llama.common.logger import MetricLogger 23 | from video_llama.models.base_model import BaseModel 24 | from video_llama.models.Qformer import BertConfig, BertLMHeadModel 25 | from video_llama.models.eva_vit import create_eva_vit_g 26 | from transformers import BertTokenizer 27 | 28 | 29 | class Blip2Base(BaseModel): 30 | @classmethod 31 | def init_tokenizer(cls): 32 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", cache_dir="/home/gs534/rds/rds-t2-cs164-KQ4S3rlDzm8/gs534/opensource/favor/pretrained_ckpt/cache") 33 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 34 | return tokenizer 35 | 36 | def maybe_autocast(self, dtype=torch.float16): 37 | # if on cpu, don't use autocast 38 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 39 | enable_autocast = self.device != torch.device("cpu") 40 | 41 | if enable_autocast: 42 | return torch.cuda.amp.autocast(dtype=dtype) 43 | else: 44 | return contextlib.nullcontext() 45 | 46 | @classmethod 47 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 48 | encoder_config = BertConfig.from_pretrained("bert-base-uncased", cache_dir="/home/gs534/rds/rds-t2-cs164-KQ4S3rlDzm8/gs534/opensource/favor/pretrained_ckpt/cache") 49 | encoder_config.encoder_width = vision_width 50 | # insert cross-attention layer every other block 51 | encoder_config.add_cross_attention = True 52 | encoder_config.cross_attention_freq = cross_attention_freq 53 | encoder_config.query_length = num_query_token 54 | Qformer = BertLMHeadModel(config=encoder_config) 55 | query_tokens = nn.Parameter( 56 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 57 | ) 58 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 59 | return Qformer, query_tokens 60 | 61 | @classmethod 62 | def init_vision_encoder( 63 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision 64 | ): 65 | assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" 66 | visual_encoder = create_eva_vit_g( 67 | img_size, drop_path_rate, use_grad_checkpoint, precision 68 | ) 69 | 70 | ln_vision = LayerNorm(visual_encoder.num_features) 71 | return visual_encoder, ln_vision 72 | 73 | def load_from_pretrained(self, url_or_filename): 74 | if is_url(url_or_filename): 75 | cached_file = download_cached_file( 76 | url_or_filename, check_hash=False, progress=True 77 | ) 78 | checkpoint = torch.load(cached_file, map_location="cpu") 79 | elif os.path.isfile(url_or_filename): 80 | checkpoint = torch.load(url_or_filename, map_location="cpu") 81 | else: 82 | raise RuntimeError("checkpoint url or path is invalid") 83 | 84 | state_dict = checkpoint["model"] 85 | 86 | msg = self.load_state_dict(state_dict, strict=False) 87 | 88 | # logging.info("Missing keys {}".format(msg.missing_keys)) 89 | logging.info("load checkpoint from %s" % url_or_filename) 90 | 91 | return msg 92 | 93 | 94 | def disabled_train(self, mode=True): 95 | """Overwrite model.train with this function to make sure train/eval mode 96 | does not change anymore.""" 97 | return self 98 | 99 | 100 | class LayerNorm(nn.LayerNorm): 101 | """Subclass torch's LayerNorm to handle fp16.""" 102 | 103 | def forward(self, x: torch.Tensor): 104 | orig_type = x.dtype 105 | ret = super().forward(x.type(torch.float32)) 106 | return ret.type(orig_type) 107 | 108 | 109 | def compute_sim_matrix(model, data_loader, **kwargs): 110 | k_test = kwargs.pop("k_test") 111 | 112 | metric_logger = MetricLogger(delimiter=" ") 113 | header = "Evaluation:" 114 | 115 | logging.info("Computing features for evaluation...") 116 | start_time = time.time() 117 | 118 | texts = data_loader.dataset.text 119 | num_text = len(texts) 120 | text_bs = 256 121 | text_ids = [] 122 | text_embeds = [] 123 | text_atts = [] 124 | for i in range(0, num_text, text_bs): 125 | text = texts[i : min(num_text, i + text_bs)] 126 | text_input = model.tokenizer( 127 | text, 128 | padding="max_length", 129 | truncation=True, 130 | max_length=35, 131 | return_tensors="pt", 132 | ).to(model.device) 133 | text_feat = model.forward_text(text_input) 134 | text_embed = F.normalize(model.text_proj(text_feat)) 135 | text_embeds.append(text_embed) 136 | text_ids.append(text_input.input_ids) 137 | text_atts.append(text_input.attention_mask) 138 | 139 | text_embeds = torch.cat(text_embeds, dim=0) 140 | text_ids = torch.cat(text_ids, dim=0) 141 | text_atts = torch.cat(text_atts, dim=0) 142 | 143 | vit_feats = [] 144 | image_embeds = [] 145 | for samples in data_loader: 146 | image = samples["image"] 147 | 148 | image = image.to(model.device) 149 | image_feat, vit_feat = model.forward_image(image) 150 | image_embed = model.vision_proj(image_feat) 151 | image_embed = F.normalize(image_embed, dim=-1) 152 | 153 | vit_feats.append(vit_feat.cpu()) 154 | image_embeds.append(image_embed) 155 | 156 | vit_feats = torch.cat(vit_feats, dim=0) 157 | image_embeds = torch.cat(image_embeds, dim=0) 158 | 159 | sims_matrix = [] 160 | for image_embed in image_embeds: 161 | sim_q2t = image_embed @ text_embeds.t() 162 | sim_i2t, _ = sim_q2t.max(0) 163 | sims_matrix.append(sim_i2t) 164 | sims_matrix = torch.stack(sims_matrix, dim=0) 165 | 166 | score_matrix_i2t = torch.full( 167 | (len(data_loader.dataset.image), len(texts)), -100.0 168 | ).to(model.device) 169 | 170 | num_tasks = dist_utils.get_world_size() 171 | rank = dist_utils.get_rank() 172 | step = sims_matrix.size(0) // num_tasks + 1 173 | start = rank * step 174 | end = min(sims_matrix.size(0), start + step) 175 | 176 | for i, sims in enumerate( 177 | metric_logger.log_every(sims_matrix[start:end], 50, header) 178 | ): 179 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 180 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 181 | score = model.compute_itm( 182 | image_inputs=image_inputs, 183 | text_ids=text_ids[topk_idx], 184 | text_atts=text_atts[topk_idx], 185 | ).float() 186 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 187 | 188 | sims_matrix = sims_matrix.t() 189 | score_matrix_t2i = torch.full( 190 | (len(texts), len(data_loader.dataset.image)), -100.0 191 | ).to(model.device) 192 | 193 | step = sims_matrix.size(0) // num_tasks + 1 194 | start = rank * step 195 | end = min(sims_matrix.size(0), start + step) 196 | 197 | for i, sims in enumerate( 198 | metric_logger.log_every(sims_matrix[start:end], 50, header) 199 | ): 200 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 201 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 202 | score = model.compute_itm( 203 | image_inputs=image_inputs, 204 | text_ids=text_ids[start + i].repeat(k_test, 1), 205 | text_atts=text_atts[start + i].repeat(k_test, 1), 206 | ).float() 207 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 208 | 209 | if dist_utils.is_dist_avail_and_initialized(): 210 | dist.barrier() 211 | torch.distributed.all_reduce( 212 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 213 | ) 214 | torch.distributed.all_reduce( 215 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 216 | ) 217 | 218 | total_time = time.time() - start_time 219 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 220 | logging.info("Evaluation time {}".format(total_time_str)) 221 | 222 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 223 | -------------------------------------------------------------------------------- /video_llama/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from salesforce@LAVIS. Below is the original copyright: 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from video_llama.common.dist_utils import download_cached_file, is_dist_avail_and_initialized 16 | from video_llama.common.utils import get_abs_path, is_url 17 | from omegaconf import OmegaConf 18 | 19 | 20 | class BaseModel(nn.Module): 21 | """Base class for models.""" 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | @property 27 | def device(self): 28 | return list(self.parameters())[0].device 29 | 30 | def load_checkpoint(self, url_or_filename): 31 | """ 32 | Load from a finetuned checkpoint. 33 | 34 | This should expect no mismatch in the model keys and the checkpoint keys. 35 | """ 36 | 37 | if is_url(url_or_filename): 38 | cached_file = download_cached_file( 39 | url_or_filename, check_hash=False, progress=True 40 | ) 41 | checkpoint = torch.load(cached_file, map_location="cpu") 42 | elif os.path.isfile(url_or_filename): 43 | checkpoint = torch.load(url_or_filename, map_location="cpu") 44 | else: 45 | raise RuntimeError("checkpoint url or path is invalid") 46 | 47 | if "model" in checkpoint.keys(): 48 | state_dict = checkpoint["model"] 49 | else: 50 | state_dict = checkpoint 51 | 52 | msg = self.load_state_dict(state_dict, strict=False) 53 | 54 | logging.info("Missing keys {}".format(msg.missing_keys)) 55 | logging.info("load checkpoint from %s" % url_or_filename) 56 | 57 | return msg 58 | 59 | @classmethod 60 | def from_pretrained(cls, model_type): 61 | """ 62 | Build a pretrained model from default configuration file, specified by model_type. 63 | 64 | Args: 65 | - model_type (str): model type, specifying architecture and checkpoints. 66 | 67 | Returns: 68 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 69 | """ 70 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 71 | model = cls.from_config(model_cfg) 72 | 73 | return model 74 | 75 | @classmethod 76 | def default_config_path(cls, model_type): 77 | assert ( 78 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 79 | ), "Unknown model type {}".format(model_type) 80 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 81 | 82 | def load_checkpoint_from_config(self, cfg, **kwargs): 83 | """ 84 | Load checkpoint as specified in the config file. 85 | 86 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 87 | When loading the pretrained model, each task-specific architecture may define their 88 | own load_from_pretrained() method. 89 | """ 90 | load_finetuned = cfg.get("load_finetuned", True) 91 | if load_finetuned: 92 | finetune_path = cfg.get("finetuned", None) 93 | assert ( 94 | finetune_path is not None 95 | ), "Found load_finetuned is True, but finetune_path is None." 96 | self.load_checkpoint(url_or_filename=finetune_path) 97 | else: 98 | # load pre-trained weights 99 | pretrain_path = cfg.get("pretrained", None) 100 | assert "Found load_finetuned is False, but pretrain_path is None." 101 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 102 | 103 | def before_evaluation(self, **kwargs): 104 | pass 105 | 106 | def show_n_params(self, return_str=True): 107 | tot = 0 108 | for p in self.parameters(): 109 | w = 1 110 | for x in p.shape: 111 | w *= x 112 | tot += w 113 | if return_str: 114 | if tot >= 1e6: 115 | return "{:.1f}M".format(tot / 1e6) 116 | else: 117 | return "{:.1f}K".format(tot / 1e3) 118 | else: 119 | return tot 120 | 121 | 122 | class BaseEncoder(nn.Module): 123 | """ 124 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 125 | """ 126 | 127 | def __init__(self): 128 | super().__init__() 129 | 130 | def forward_features(self, samples, **kwargs): 131 | raise NotImplementedError 132 | 133 | @property 134 | def device(self): 135 | return list(self.parameters())[0].device 136 | 137 | 138 | class SharedQueueMixin: 139 | @torch.no_grad() 140 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 141 | # gather keys before updating queue 142 | image_feats = concat_all_gather(image_feat) 143 | text_feats = concat_all_gather(text_feat) 144 | 145 | batch_size = image_feats.shape[0] 146 | 147 | ptr = int(self.queue_ptr) 148 | assert self.queue_size % batch_size == 0 # for simplicity 149 | 150 | # replace the keys at ptr (dequeue and enqueue) 151 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 152 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 153 | 154 | if idxs is not None: 155 | idxs = concat_all_gather(idxs) 156 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 157 | 158 | ptr = (ptr + batch_size) % self.queue_size # move pointer 159 | self.queue_ptr[0] = ptr 160 | 161 | 162 | class MomentumDistilationMixin: 163 | @torch.no_grad() 164 | def copy_params(self): 165 | for model_pair in self.model_pairs: 166 | for param, param_m in zip( 167 | model_pair[0].parameters(), model_pair[1].parameters() 168 | ): 169 | param_m.data.copy_(param.data) # initialize 170 | param_m.requires_grad = False # not update by gradient 171 | 172 | @torch.no_grad() 173 | def _momentum_update(self): 174 | for model_pair in self.model_pairs: 175 | for param, param_m in zip( 176 | model_pair[0].parameters(), model_pair[1].parameters() 177 | ): 178 | param_m.data = param_m.data * self.momentum + param.data * ( 179 | 1.0 - self.momentum 180 | ) 181 | 182 | 183 | class GatherLayer(torch.autograd.Function): 184 | """ 185 | Gather tensors from all workers with support for backward propagation: 186 | This implementation does not cut the gradients as torch.distributed.all_gather does. 187 | """ 188 | 189 | @staticmethod 190 | def forward(ctx, x): 191 | output = [ 192 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 193 | ] 194 | torch.distributed.all_gather(output, x) 195 | return tuple(output) 196 | 197 | @staticmethod 198 | def backward(ctx, *grads): 199 | all_gradients = torch.stack(grads) 200 | torch.distributed.all_reduce(all_gradients) 201 | return all_gradients[torch.distributed.get_rank()] 202 | 203 | 204 | def all_gather_with_grad(tensors): 205 | """ 206 | Performs all_gather operation on the provided tensors. 207 | Graph remains connected for backward grad computation. 208 | """ 209 | # Queue the gathered tensors 210 | world_size = torch.distributed.get_world_size() 211 | # There is no need for reduction in the single-proc case 212 | if world_size == 1: 213 | return tensors 214 | 215 | # tensor_all = GatherLayer.apply(tensors) 216 | tensor_all = GatherLayer.apply(tensors) 217 | 218 | return torch.cat(tensor_all, dim=0) 219 | 220 | 221 | @torch.no_grad() 222 | def concat_all_gather(tensor): 223 | """ 224 | Performs all_gather operation on the provided tensors. 225 | *** Warning ***: torch.distributed.all_gather has no gradient. 226 | """ 227 | # if use distributed training 228 | if not is_dist_avail_and_initialized(): 229 | return tensor 230 | 231 | tensors_gather = [ 232 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 233 | ] 234 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 235 | 236 | output = torch.cat(tensors_gather, dim=0) 237 | return output 238 | 239 | 240 | def tile(x, dim, n_tile): 241 | init_dim = x.size(dim) 242 | repeat_idx = [1] * x.dim() 243 | repeat_idx[dim] = n_tile 244 | x = x.repeat(*(repeat_idx)) 245 | order_index = torch.LongTensor( 246 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 247 | ) 248 | return torch.index_select(x, dim, order_index.to(x.device)) 249 | -------------------------------------------------------------------------------- /video_llama/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import video_llama.common.utils as utils 19 | from video_llama.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from video_llama.common.registry import registry 21 | from video_llama.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class BaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | 31 | if cfg is None: 32 | # help to create datasets from default config. 33 | self.config = load_dataset_config(self.default_config_path()) 34 | elif isinstance(cfg, str): 35 | self.config = load_dataset_config(cfg) 36 | else: 37 | # when called from task.build_dataset() 38 | self.config = cfg 39 | 40 | self.data_type = self.config.data_type 41 | 42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 44 | 45 | def build_datasets(self): 46 | # download, split, etc... 47 | # only called on 1 GPU/TPU in distributed 48 | 49 | if is_main_process(): 50 | self._download_data() 51 | 52 | if is_dist_avail_and_initialized(): 53 | dist.barrier() 54 | 55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 56 | logging.info("Building datasets...") 57 | datasets = self.build() # dataset['train'/'val'/'test'] 58 | 59 | return datasets 60 | 61 | def build_processors(self): 62 | vis_proc_cfg = self.config.get("vis_processor") 63 | txt_proc_cfg = self.config.get("text_processor") 64 | 65 | if vis_proc_cfg is not None: 66 | vis_train_cfg = vis_proc_cfg.get("train") 67 | vis_eval_cfg = vis_proc_cfg.get("eval") 68 | 69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 71 | 72 | if txt_proc_cfg is not None: 73 | txt_train_cfg = txt_proc_cfg.get("train") 74 | txt_eval_cfg = txt_proc_cfg.get("eval") 75 | 76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 78 | 79 | @staticmethod 80 | def _build_proc_from_cfg(cfg): 81 | return ( 82 | registry.get_processor_class(cfg.name).from_config(cfg) 83 | if cfg is not None 84 | else None 85 | ) 86 | 87 | @classmethod 88 | def default_config_path(cls, type="default"): 89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 90 | 91 | def _download_data(self): 92 | self._download_ann() 93 | self._download_vis() 94 | 95 | def _download_ann(self): 96 | """ 97 | Download annotation files if necessary. 98 | All the vision-language datasets should have annotations of unified format. 99 | 100 | storage_path can be: 101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 103 | 104 | Local annotation paths should be relative. 105 | """ 106 | anns = self.config.build_info.annotations 107 | 108 | splits = anns.keys() 109 | 110 | cache_root = registry.get_path("cache_root") 111 | 112 | for split in splits: 113 | info = anns[split] 114 | 115 | urls, storage_paths = info.get("url", None), info.storage 116 | 117 | if isinstance(urls, str): 118 | urls = [urls] 119 | if isinstance(storage_paths, str): 120 | storage_paths = [storage_paths] 121 | 122 | assert len(urls) == len(storage_paths) 123 | 124 | for url_or_filename, storage_path in zip(urls, storage_paths): 125 | # if storage_path is relative, make it full by prefixing with cache_root. 126 | if not os.path.isabs(storage_path): 127 | storage_path = os.path.join(cache_root, storage_path) 128 | 129 | dirname = os.path.dirname(storage_path) 130 | if not os.path.exists(dirname): 131 | os.makedirs(dirname) 132 | 133 | if os.path.isfile(url_or_filename): 134 | src, dst = url_or_filename, storage_path 135 | if not os.path.exists(dst): 136 | shutil.copyfile(src=src, dst=dst) 137 | else: 138 | logging.info("Using existing file {}.".format(dst)) 139 | else: 140 | if os.path.isdir(storage_path): 141 | # if only dirname is provided, suffix with basename of URL. 142 | raise ValueError( 143 | "Expecting storage_path to be a file path, got directory {}".format( 144 | storage_path 145 | ) 146 | ) 147 | else: 148 | filename = os.path.basename(storage_path) 149 | 150 | download_url(url=url_or_filename, root=dirname, filename=filename) 151 | 152 | def _download_vis(self): 153 | 154 | storage_path = self.config.build_info.get(self.data_type).storage 155 | storage_path = utils.get_cache_path(storage_path) 156 | 157 | if not os.path.exists(storage_path): 158 | warnings.warn( 159 | f""" 160 | The specified path {storage_path} for visual inputs does not exist. 161 | Please provide a correct path to the visual inputs or 162 | refer to datasets/download_scripts/README.md for downloading instructions. 163 | """ 164 | ) 165 | 166 | def build(self): 167 | """ 168 | Create by split datasets inheriting torch.utils.data.Datasets. 169 | 170 | # build() can be dataset-specific. Overwrite to customize. 171 | """ 172 | self.build_processors() 173 | 174 | build_info = self.config.build_info 175 | 176 | ann_info = build_info.annotations 177 | vis_info = build_info.get(self.data_type) 178 | 179 | datasets = dict() 180 | for split in ann_info.keys(): 181 | if split not in ["train", "val", "test"]: 182 | continue 183 | 184 | is_train = split == "train" 185 | 186 | # processors 187 | vis_processor = ( 188 | self.vis_processors["train"] 189 | if is_train 190 | else self.vis_processors["eval"] 191 | ) 192 | text_processor = ( 193 | self.text_processors["train"] 194 | if is_train 195 | else self.text_processors["eval"] 196 | ) 197 | 198 | # annotation path 199 | ann_paths = ann_info.get(split).storage 200 | if isinstance(ann_paths, str): 201 | ann_paths = [ann_paths] 202 | 203 | abs_ann_paths = [] 204 | for ann_path in ann_paths: 205 | if not os.path.isabs(ann_path): 206 | ann_path = utils.get_cache_path(ann_path) 207 | abs_ann_paths.append(ann_path) 208 | ann_paths = abs_ann_paths 209 | 210 | # visual data storage path 211 | vis_path = os.path.join(vis_info.storage, split) 212 | 213 | if not os.path.isabs(vis_path): 214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 215 | vis_path = utils.get_cache_path(vis_path) 216 | 217 | if not os.path.exists(vis_path): 218 | warnings.warn("storage path {} does not exist.".format(vis_path)) 219 | 220 | # create datasets 221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 222 | datasets[split] = dataset_cls( 223 | vis_processor=vis_processor, 224 | text_processor=text_processor, 225 | ann_paths=ann_paths, 226 | vis_root=vis_path, 227 | ) 228 | 229 | return datasets 230 | 231 | 232 | def load_dataset_config(cfg_path): 233 | cfg = OmegaConf.load(cfg_path).datasets 234 | cfg = cfg[list(cfg.keys())[0]] 235 | 236 | return cfg 237 | -------------------------------------------------------------------------------- /video_llama/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from video_llama.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 14 | from video_llama.common.logger import MetricLogger, SmoothedValue 15 | from video_llama.common.registry import registry 16 | from video_llama.datasets.data_utils import prepare_sample 17 | 18 | 19 | class BaseTask: 20 | def __init__(self, **kwargs): 21 | super().__init__() 22 | 23 | self.inst_id_key = "instance_id" 24 | 25 | @classmethod 26 | def setup_task(cls, **kwargs): 27 | return cls() 28 | 29 | def build_model(self, cfg): 30 | model_config = cfg.model_cfg 31 | 32 | model_cls = registry.get_model_class(model_config.arch) 33 | return model_cls.from_config(model_config) 34 | 35 | def build_datasets(self, cfg): 36 | """ 37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. 38 | Download dataset and annotations automatically if not exist. 39 | 40 | Args: 41 | cfg (common.config.Config): _description_ 42 | 43 | Returns: 44 | dict: Dictionary of torch.utils.data.Dataset objects by split. 45 | """ 46 | 47 | datasets = dict() 48 | 49 | datasets_config = cfg.datasets_cfg 50 | 51 | assert len(datasets_config) > 0, "At least one dataset has to be specified." 52 | 53 | for name in datasets_config: 54 | dataset_config = datasets_config[name] 55 | 56 | builder = registry.get_builder_class(name)(dataset_config) 57 | dataset = builder.build_datasets() 58 | 59 | dataset['train'].name = name 60 | if 'sample_ratio' in dataset_config: 61 | dataset['train'].sample_ratio = dataset_config.sample_ratio 62 | 63 | datasets[name] = dataset 64 | 65 | return datasets 66 | 67 | def train_step(self, model, samples): 68 | loss = model(samples)["loss"] 69 | return loss 70 | 71 | def valid_step(self, model, samples): 72 | raise NotImplementedError 73 | 74 | def before_evaluation(self, model, dataset, **kwargs): 75 | model.before_evaluation(dataset=dataset, task_type=type(self)) 76 | 77 | def after_evaluation(self, **kwargs): 78 | pass 79 | 80 | def inference_step(self): 81 | raise NotImplementedError 82 | 83 | def evaluation(self, model, data_loader, cuda_enabled=True): 84 | metric_logger = MetricLogger(delimiter=" ") 85 | header = "Evaluation" 86 | # TODO make it configurable 87 | print_freq = 10 88 | 89 | results = [] 90 | 91 | for samples in metric_logger.log_every(data_loader, print_freq, header): 92 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 93 | 94 | eval_output = self.valid_step(model=model, samples=samples) 95 | results.extend(eval_output) 96 | 97 | if is_dist_avail_and_initialized(): 98 | dist.barrier() 99 | 100 | return results 101 | 102 | def train_epoch( 103 | self, 104 | epoch, 105 | model, 106 | data_loader, 107 | optimizer, 108 | lr_scheduler, 109 | scaler=None, 110 | cuda_enabled=False, 111 | log_freq=50, 112 | accum_grad_iters=1, 113 | ): 114 | return self._train_inner_loop( 115 | epoch=epoch, 116 | iters_per_epoch=lr_scheduler.iters_per_epoch, 117 | model=model, 118 | data_loader=data_loader, 119 | optimizer=optimizer, 120 | scaler=scaler, 121 | lr_scheduler=lr_scheduler, 122 | log_freq=log_freq, 123 | cuda_enabled=cuda_enabled, 124 | accum_grad_iters=accum_grad_iters, 125 | ) 126 | 127 | def train_iters( 128 | self, 129 | epoch, 130 | start_iters, 131 | iters_per_inner_epoch, 132 | model, 133 | data_loader, 134 | optimizer, 135 | lr_scheduler, 136 | scaler=None, 137 | cuda_enabled=False, 138 | log_freq=50, 139 | accum_grad_iters=1, 140 | ): 141 | return self._train_inner_loop( 142 | epoch=epoch, 143 | start_iters=start_iters, 144 | iters_per_epoch=iters_per_inner_epoch, 145 | model=model, 146 | data_loader=data_loader, 147 | optimizer=optimizer, 148 | scaler=scaler, 149 | lr_scheduler=lr_scheduler, 150 | log_freq=log_freq, 151 | cuda_enabled=cuda_enabled, 152 | accum_grad_iters=accum_grad_iters, 153 | ) 154 | 155 | def _train_inner_loop( 156 | self, 157 | epoch, 158 | iters_per_epoch, 159 | model, 160 | data_loader, 161 | optimizer, 162 | lr_scheduler, 163 | scaler=None, 164 | start_iters=None, 165 | log_freq=50, 166 | cuda_enabled=False, 167 | accum_grad_iters=1, 168 | ): 169 | """ 170 | An inner training loop compatible with both epoch-based and iter-based training. 171 | 172 | When using epoch-based, training stops after one epoch; when using iter-based, 173 | training stops after #iters_per_epoch iterations. 174 | """ 175 | use_amp = scaler is not None 176 | 177 | if not hasattr(data_loader, "__next__"): 178 | # convert to iterator if not already 179 | data_loader = iter(data_loader) 180 | 181 | metric_logger = MetricLogger(delimiter=" ") 182 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 183 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 184 | 185 | # if iter-based runner, schedule lr based on inner epoch. 186 | logging.info( 187 | "Start training epoch {}, {} iters per inner epoch.".format( 188 | epoch, iters_per_epoch 189 | ) 190 | ) 191 | header = "Train: data epoch: [{}]".format(epoch) 192 | if start_iters is None: 193 | # epoch-based runner 194 | inner_epoch = epoch 195 | else: 196 | # In iter-based runner, we schedule the learning rate based on iterations. 197 | inner_epoch = start_iters // iters_per_epoch 198 | header = header + "; inner epoch [{}]".format(inner_epoch) 199 | 200 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 201 | # if using iter-based runner, we stop after iters_per_epoch iterations. 202 | if i >= iters_per_epoch: 203 | break 204 | 205 | samples = next(data_loader) 206 | 207 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 208 | samples.update( 209 | { 210 | "epoch": inner_epoch, 211 | "num_iters_per_epoch": iters_per_epoch, 212 | "iters": i, 213 | } 214 | ) 215 | 216 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 217 | 218 | with torch.cuda.amp.autocast(enabled=use_amp): 219 | loss = self.train_step(model=model, samples=samples) 220 | 221 | # after_train_step() 222 | if use_amp: 223 | scaler.scale(loss).backward() 224 | else: 225 | loss.backward() 226 | 227 | # update gradients every accum_grad_iters iterations 228 | if (i + 1) % accum_grad_iters == 0: 229 | if use_amp: 230 | scaler.step(optimizer) 231 | scaler.update() 232 | else: 233 | optimizer.step() 234 | optimizer.zero_grad() 235 | 236 | metric_logger.update(loss=loss.item()) 237 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 238 | 239 | # after train_epoch() 240 | # gather the stats from all processes 241 | metric_logger.synchronize_between_processes() 242 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 243 | return { 244 | k: "{:.3f}".format(meter.global_avg) 245 | for k, meter in metric_logger.meters.items() 246 | } 247 | 248 | @staticmethod 249 | def save_result(result, result_dir, filename, remove_duplicate=""): 250 | import json 251 | 252 | result_file = os.path.join( 253 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 254 | ) 255 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 256 | 257 | json.dump(result, open(result_file, "w")) 258 | 259 | if is_dist_avail_and_initialized(): 260 | dist.barrier() 261 | 262 | if is_main_process(): 263 | logging.warning("rank %d starts merging results." % get_rank()) 264 | # combine results from all processes 265 | result = [] 266 | 267 | for rank in range(get_world_size()): 268 | result_file = os.path.join( 269 | result_dir, "%s_rank%d.json" % (filename, rank) 270 | ) 271 | res = json.load(open(result_file, "r")) 272 | result += res 273 | 274 | if remove_duplicate: 275 | result_new = [] 276 | id_list = [] 277 | for res in result: 278 | if res[remove_duplicate] not in id_list: 279 | id_list.append(res[remove_duplicate]) 280 | result_new.append(res) 281 | result = result_new 282 | 283 | json.dump(result, open(final_result_file, "w")) 284 | print("result file saved to %s" % final_result_file) 285 | 286 | return final_result_file 287 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/models/transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Code modified from 9 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; 10 | # https://github.com/facebookresearch/deit/blob/main/models.py 11 | # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py 12 | 13 | 14 | from functools import partial 15 | from typing import Callable, List, Optional 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.checkpoint as checkpoint 20 | from timm.models.layers import DropPath, trunc_normal_ 21 | 22 | 23 | class Attention(nn.Module): 24 | def __init__( 25 | self, 26 | dim, 27 | num_heads=8, 28 | qkv_bias=False, 29 | qk_scale=None, 30 | attn_drop=0.0, 31 | proj_drop=0.0, 32 | ): 33 | super().__init__() 34 | self.num_heads = num_heads 35 | head_dim = dim // num_heads 36 | # NOTE scale factor was wrong in my original version, 37 | # can set manually to be compat with prev weights 38 | self.scale = qk_scale or head_dim**-0.5 39 | 40 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 41 | self.attn_drop = nn.Dropout(attn_drop) 42 | self.proj = nn.Linear(dim, dim) 43 | self.proj_drop = nn.Dropout(proj_drop) 44 | 45 | def forward(self, x): 46 | B, N, C = x.shape 47 | qkv = ( 48 | self.qkv(x) 49 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 50 | .permute(2, 0, 3, 1, 4) 51 | ) 52 | q, k, v = ( 53 | qkv[0], 54 | qkv[1], 55 | qkv[2], 56 | ) # make torchscript happy (cannot use tensor as tuple) 57 | 58 | attn = (q @ k.transpose(-2, -1)) * self.scale 59 | attn = attn.softmax(dim=-1) 60 | attn = self.attn_drop(attn) 61 | 62 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 63 | x = self.proj(x) 64 | x = self.proj_drop(x) 65 | return x 66 | 67 | 68 | class Mlp(nn.Module): 69 | def __init__( 70 | self, 71 | in_features, 72 | hidden_features=None, 73 | out_features=None, 74 | act_layer=nn.GELU, 75 | drop=0.0, 76 | ): 77 | super().__init__() 78 | out_features = out_features or in_features 79 | hidden_features = hidden_features or in_features 80 | self.fc1 = nn.Linear(in_features, hidden_features) 81 | self.act = act_layer() 82 | self.fc2 = nn.Linear(hidden_features, out_features) 83 | self.drop = nn.Dropout(drop) 84 | 85 | def forward(self, x): 86 | x = self.fc1(x) 87 | x = self.act(x) 88 | x = self.drop(x) 89 | x = self.fc2(x) 90 | x = self.drop(x) 91 | return x 92 | 93 | 94 | class MultiheadAttention(nn.MultiheadAttention): 95 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 96 | return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 97 | 98 | 99 | class ViTAttention(Attention): 100 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 101 | assert attn_mask is None 102 | return super().forward(x) 103 | 104 | 105 | class BlockWithMasking(nn.Module): 106 | def __init__( 107 | self, 108 | dim: int, 109 | attn_target: Callable, 110 | mlp_ratio: int = 4, 111 | act_layer: Callable = nn.GELU, 112 | norm_layer: Callable = nn.LayerNorm, 113 | ffn_dropout_rate: float = 0.0, 114 | drop_path: float = 0.0, 115 | layer_scale_type: Optional[str] = None, 116 | layer_scale_init_value: float = 1e-4, 117 | ): 118 | super().__init__() 119 | 120 | assert not isinstance( 121 | attn_target, nn.Module 122 | ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" 123 | self.attn = attn_target() 124 | if drop_path > 0.0: 125 | self.drop_path = DropPath(drop_path) 126 | else: 127 | self.drop_path = nn.Identity() 128 | self.norm_1 = norm_layer(dim) 129 | mlp_hidden_dim = int(mlp_ratio * dim) 130 | self.mlp = Mlp( 131 | in_features=dim, 132 | hidden_features=mlp_hidden_dim, 133 | act_layer=act_layer, 134 | drop=ffn_dropout_rate, 135 | ) 136 | self.norm_2 = norm_layer(dim) 137 | self.layer_scale_type = layer_scale_type 138 | if self.layer_scale_type is not None: 139 | assert self.layer_scale_type in [ 140 | "per_channel", 141 | "scalar", 142 | ], f"Found Layer scale type {self.layer_scale_type}" 143 | if self.layer_scale_type == "per_channel": 144 | # one gamma value per channel 145 | gamma_shape = [1, 1, dim] 146 | elif self.layer_scale_type == "scalar": 147 | # single gamma value for all channels 148 | gamma_shape = [1, 1, 1] 149 | # two gammas: for each part of the fwd in the encoder 150 | self.layer_scale_gamma1 = nn.Parameter( 151 | torch.ones(size=gamma_shape) * layer_scale_init_value, 152 | requires_grad=True, 153 | ) 154 | self.layer_scale_gamma2 = nn.Parameter( 155 | torch.ones(size=gamma_shape) * layer_scale_init_value, 156 | requires_grad=True, 157 | ) 158 | 159 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 160 | if self.layer_scale_type is None: 161 | x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) 162 | x = x + self.drop_path(self.mlp(self.norm_2(x))) 163 | else: 164 | x = ( 165 | x 166 | + self.drop_path(self.attn(self.norm_1(x), attn_mask)) 167 | * self.layer_scale_gamma1 168 | ) 169 | x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 170 | return x 171 | 172 | 173 | _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) 174 | 175 | 176 | class SimpleTransformer(nn.Module): 177 | def __init__( 178 | self, 179 | attn_target: Callable, 180 | embed_dim: int, 181 | num_blocks: int, 182 | block: Callable = BlockWithMasking, 183 | pre_transformer_layer: Optional[Callable] = None, 184 | post_transformer_layer: Optional[Callable] = None, 185 | drop_path_rate: float = 0.0, 186 | drop_path_type: str = "progressive", 187 | norm_layer: Callable = _LAYER_NORM, 188 | mlp_ratio: int = 4, 189 | ffn_dropout_rate: float = 0.0, 190 | layer_scale_type: Optional[str] = None, # from cait; possible values are None, "per_channel", "scalar" 191 | layer_scale_init_value: float = 1e-4, # from cait; float 192 | weight_init_style: str = "jax", # possible values jax or pytorch 193 | ): 194 | """ 195 | Simple Transformer with the following features 196 | 1. Supports masked attention 197 | 2. Supports DropPath 198 | 3. Supports LayerScale 199 | 4. Supports Dropout in Attention and FFN 200 | 5. Makes few assumptions about the input except that it is a Tensor 201 | """ 202 | super().__init__() 203 | self.pre_transformer_layer = pre_transformer_layer 204 | if drop_path_type == "progressive": 205 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] 206 | elif drop_path_type == "uniform": 207 | dpr = [drop_path_rate for i in range(num_blocks)] 208 | else: 209 | raise ValueError(f"Unknown drop_path_type: {drop_path_type}") 210 | 211 | self.blocks = nn.Sequential( 212 | *[ 213 | block( 214 | dim=embed_dim, 215 | attn_target=attn_target, 216 | mlp_ratio=mlp_ratio, 217 | ffn_dropout_rate=ffn_dropout_rate, 218 | drop_path=dpr[i], 219 | norm_layer=norm_layer, 220 | layer_scale_type=layer_scale_type, 221 | layer_scale_init_value=layer_scale_init_value, 222 | ) 223 | for i in range(num_blocks) 224 | ] 225 | ) 226 | self.post_transformer_layer = post_transformer_layer 227 | self.weight_init_style = weight_init_style 228 | self.apply(self._init_weights) 229 | 230 | def _init_weights(self, m): 231 | if isinstance(m, nn.Linear): 232 | if self.weight_init_style == "jax": 233 | # Based on MAE and official Jax ViT implementation 234 | torch.nn.init.xavier_uniform_(m.weight) 235 | elif self.weight_init_style == "pytorch": 236 | # PyTorch ViT uses trunc_normal_ 237 | trunc_normal_(m.weight, std=0.02) 238 | 239 | if m.bias is not None: 240 | nn.init.constant_(m.bias, 0) 241 | elif isinstance(m, (nn.LayerNorm)): 242 | nn.init.constant_(m.bias, 0) 243 | nn.init.constant_(m.weight, 1.0) 244 | 245 | def forward( 246 | self, 247 | tokens: torch.Tensor, 248 | attn_mask: torch.Tensor = None, 249 | use_checkpoint: bool = False, 250 | checkpoint_every_n: int = 1, 251 | checkpoint_blk_ids: Optional[List[int]] = None, 252 | ): 253 | """ 254 | Inputs 255 | - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) 256 | - attn: mask of shape L x L 257 | 258 | Output 259 | - x: data of shape N x L x D (or L x N x D depending on the attention implementation) 260 | """ 261 | if self.pre_transformer_layer: 262 | tokens = self.pre_transformer_layer(tokens) 263 | if use_checkpoint and checkpoint_blk_ids is None: 264 | checkpoint_blk_ids = [ 265 | blk_id 266 | for blk_id in range(len(self.blocks)) 267 | if blk_id % checkpoint_every_n == 0 268 | ] 269 | if checkpoint_blk_ids: 270 | checkpoint_blk_ids = set(checkpoint_blk_ids) 271 | for blk_id, blk in enumerate(self.blocks): 272 | if use_checkpoint and blk_id in checkpoint_blk_ids: 273 | tokens = checkpoint.checkpoint( 274 | blk, tokens, attn_mask, use_reentrant=False 275 | ) 276 | else: 277 | tokens = blk(tokens, attn_mask=attn_mask) 278 | if self.post_transformer_layer: 279 | tokens = self.post_transformer_layer(tokens) 280 | return tokens 281 | -------------------------------------------------------------------------------- /video_llama/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from video_llama.common.registry import registry 31 | from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | assert issubclass( 38 | builder_cls, BaseDatasetBuilder 39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | builder_cls 41 | ) 42 | if name in cls.mapping["builder_name_mapping"]: 43 | raise KeyError( 44 | "Name '{}' already registered for {}.".format( 45 | name, cls.mapping["builder_name_mapping"][name] 46 | ) 47 | ) 48 | cls.mapping["builder_name_mapping"][name] = builder_cls 49 | return builder_cls 50 | 51 | return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from video_llama.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from video_llama.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from video_llama.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from video_llama.models import BaseModel 96 | 97 | assert issubclass( 98 | model_cls, BaseModel 99 | ), "All models must inherit BaseModel class" 100 | if name in cls.mapping["model_name_mapping"]: 101 | raise KeyError( 102 | "Name '{}' already registered for {}.".format( 103 | name, cls.mapping["model_name_mapping"][name] 104 | ) 105 | ) 106 | cls.mapping["model_name_mapping"][name] = model_cls 107 | return model_cls 108 | 109 | return wrap 110 | 111 | @classmethod 112 | def register_processor(cls, name): 113 | r"""Register a processor to registry with key 'name' 114 | 115 | Args: 116 | name: Key with which the task will be registered. 117 | 118 | Usage: 119 | 120 | from video_llama.common.registry import registry 121 | """ 122 | 123 | def wrap(processor_cls): 124 | from video_llama.processors import BaseProcessor 125 | 126 | assert issubclass( 127 | processor_cls, BaseProcessor 128 | ), "All processors must inherit BaseProcessor class" 129 | if name in cls.mapping["processor_name_mapping"]: 130 | raise KeyError( 131 | "Name '{}' already registered for {}.".format( 132 | name, cls.mapping["processor_name_mapping"][name] 133 | ) 134 | ) 135 | cls.mapping["processor_name_mapping"][name] = processor_cls 136 | return processor_cls 137 | 138 | return wrap 139 | 140 | @classmethod 141 | def register_lr_scheduler(cls, name): 142 | r"""Register a model to registry with key 'name' 143 | 144 | Args: 145 | name: Key with which the task will be registered. 146 | 147 | Usage: 148 | 149 | from video_llama.common.registry import registry 150 | """ 151 | 152 | def wrap(lr_sched_cls): 153 | if name in cls.mapping["lr_scheduler_name_mapping"]: 154 | raise KeyError( 155 | "Name '{}' already registered for {}.".format( 156 | name, cls.mapping["lr_scheduler_name_mapping"][name] 157 | ) 158 | ) 159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 160 | return lr_sched_cls 161 | 162 | return wrap 163 | 164 | @classmethod 165 | def register_runner(cls, name): 166 | r"""Register a model to registry with key 'name' 167 | 168 | Args: 169 | name: Key with which the task will be registered. 170 | 171 | Usage: 172 | 173 | from video_llama.common.registry import registry 174 | """ 175 | 176 | def wrap(runner_cls): 177 | if name in cls.mapping["runner_name_mapping"]: 178 | raise KeyError( 179 | "Name '{}' already registered for {}.".format( 180 | name, cls.mapping["runner_name_mapping"][name] 181 | ) 182 | ) 183 | cls.mapping["runner_name_mapping"][name] = runner_cls 184 | return runner_cls 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def register_path(cls, name, path): 190 | r"""Register a path to registry with key 'name' 191 | 192 | Args: 193 | name: Key with which the path will be registered. 194 | 195 | Usage: 196 | 197 | from video_llama.common.registry import registry 198 | """ 199 | assert isinstance(path, str), "All path must be str." 200 | if name in cls.mapping["paths"]: 201 | raise KeyError("Name '{}' already registered.".format(name)) 202 | cls.mapping["paths"][name] = path 203 | 204 | @classmethod 205 | def register(cls, name, obj): 206 | r"""Register an item to registry with key 'name' 207 | 208 | Args: 209 | name: Key with which the item will be registered. 210 | 211 | Usage:: 212 | 213 | from video_llama.common.registry import registry 214 | 215 | registry.register("config", {}) 216 | """ 217 | path = name.split(".") 218 | current = cls.mapping["state"] 219 | 220 | for part in path[:-1]: 221 | if part not in current: 222 | current[part] = {} 223 | current = current[part] 224 | 225 | current[path[-1]] = obj 226 | 227 | # @classmethod 228 | # def get_trainer_class(cls, name): 229 | # return cls.mapping["trainer_name_mapping"].get(name, None) 230 | 231 | @classmethod 232 | def get_builder_class(cls, name): 233 | return cls.mapping["builder_name_mapping"].get(name, None) 234 | 235 | @classmethod 236 | def get_model_class(cls, name): 237 | return cls.mapping["model_name_mapping"].get(name, None) 238 | 239 | @classmethod 240 | def get_task_class(cls, name): 241 | return cls.mapping["task_name_mapping"].get(name, None) 242 | 243 | @classmethod 244 | def get_processor_class(cls, name): 245 | return cls.mapping["processor_name_mapping"].get(name, None) 246 | 247 | @classmethod 248 | def get_lr_scheduler_class(cls, name): 249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 250 | 251 | @classmethod 252 | def get_runner_class(cls, name): 253 | return cls.mapping["runner_name_mapping"].get(name, None) 254 | 255 | @classmethod 256 | def list_runners(cls): 257 | return sorted(cls.mapping["runner_name_mapping"].keys()) 258 | 259 | @classmethod 260 | def list_models(cls): 261 | return sorted(cls.mapping["model_name_mapping"].keys()) 262 | 263 | @classmethod 264 | def list_tasks(cls): 265 | return sorted(cls.mapping["task_name_mapping"].keys()) 266 | 267 | @classmethod 268 | def list_processors(cls): 269 | return sorted(cls.mapping["processor_name_mapping"].keys()) 270 | 271 | @classmethod 272 | def list_lr_schedulers(cls): 273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 274 | 275 | @classmethod 276 | def list_datasets(cls): 277 | return sorted(cls.mapping["builder_name_mapping"].keys()) 278 | 279 | @classmethod 280 | def get_path(cls, name): 281 | return cls.mapping["paths"].get(name, None) 282 | 283 | @classmethod 284 | def get(cls, name, default=None, no_warning=False): 285 | r"""Get an item from registry with key 'name' 286 | 287 | Args: 288 | name (string): Key whose value needs to be retrieved. 289 | default: If passed and key is not in registry, default value will 290 | be returned with a warning. Default: None 291 | no_warning (bool): If passed as True, warning when key doesn't exist 292 | will not be generated. Useful for MMF's 293 | internal operations. Default: False 294 | """ 295 | original_name = name 296 | name = name.split(".") 297 | value = cls.mapping["state"] 298 | for subname in name: 299 | value = value.get(subname, default) 300 | if value is default: 301 | break 302 | 303 | if ( 304 | "writer" in cls.mapping["state"] 305 | and value == default 306 | and no_warning is False 307 | ): 308 | cls.mapping["state"]["writer"].warning( 309 | "Key {} is not present in registry, returning default value " 310 | "of {}".format(original_name, default) 311 | ) 312 | return value 313 | 314 | @classmethod 315 | def unregister(cls, name): 316 | r"""Remove an item from registry with key 'name' 317 | 318 | Args: 319 | name: Key which needs to be removed. 320 | Usage:: 321 | 322 | from mmf.common.registry import registry 323 | 324 | config = registry.unregister("config") 325 | """ 326 | return cls.mapping["state"].pop(name, None) 327 | 328 | 329 | registry = Registry() 330 | -------------------------------------------------------------------------------- /video_llama/models/ImageBind/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchaudio 14 | from PIL import Image 15 | from pytorchvideo import transforms as pv_transforms 16 | from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler 17 | from pytorchvideo.data.encoded_video import EncodedVideo 18 | from torchvision import transforms 19 | from torchvision.transforms._transforms_video import NormalizeVideo 20 | 21 | from .models.multimodal_preprocessors import SimpleTokenizer 22 | 23 | DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds 24 | 25 | BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" 26 | 27 | 28 | def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): 29 | # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 30 | waveform -= waveform.mean() 31 | fbank = torchaudio.compliance.kaldi.fbank( 32 | waveform, 33 | htk_compat=True, 34 | sample_frequency=sample_rate, 35 | use_energy=False, 36 | window_type="hanning", 37 | num_mel_bins=num_mel_bins, 38 | dither=0.0, 39 | frame_length=25, 40 | frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, 41 | ) 42 | # Convert to [mel_bins, num_frames] shape 43 | fbank = fbank.transpose(0, 1) 44 | # Pad to target_length 45 | n_frames = fbank.size(1) 46 | p = target_length - n_frames 47 | # if p is too large (say >20%), flash a warning 48 | if abs(p) / n_frames > 0.2: 49 | logging.warning( 50 | "Large gap between audio n_frames(%d) and " 51 | "target_length (%d). Is the audio_target_length " 52 | "setting correct?", 53 | n_frames, 54 | target_length, 55 | ) 56 | # cut and pad 57 | if p > 0: 58 | fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) 59 | elif p < 0: 60 | fbank = fbank[:, 0:target_length] 61 | # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 62 | # channel image 63 | fbank = fbank.unsqueeze(0) 64 | return fbank 65 | 66 | 67 | def get_clip_timepoints(clip_sampler, duration): 68 | # Read out all clips in this video 69 | all_clips_timepoints = [] 70 | is_last_clip = False 71 | end = 0.0 72 | while not is_last_clip: 73 | start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) 74 | all_clips_timepoints.append((start, end)) 75 | return all_clips_timepoints 76 | 77 | 78 | def load_and_transform_vision_data(image_paths, device): 79 | if image_paths is None: 80 | return None 81 | 82 | image_ouputs = [] 83 | for image_path in image_paths: 84 | data_transform = transforms.Compose( 85 | [ 86 | transforms.Resize( 87 | 224, interpolation=transforms.InterpolationMode.BICUBIC 88 | ), 89 | transforms.CenterCrop(224), 90 | transforms.ToTensor(), 91 | transforms.Normalize( 92 | mean=(0.48145466, 0.4578275, 0.40821073), 93 | std=(0.26862954, 0.26130258, 0.27577711), 94 | ), 95 | ] 96 | ) 97 | with open(image_path, "rb") as fopen: 98 | image = Image.open(fopen).convert("RGB") 99 | 100 | image = data_transform(image).to(device) 101 | image_ouputs.append(image) 102 | return torch.stack(image_ouputs, dim=0) 103 | 104 | 105 | def load_and_transform_text(text, device): 106 | if text is None: 107 | return None 108 | tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) 109 | tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] 110 | tokens = torch.cat(tokens, dim=0) 111 | return tokens 112 | 113 | 114 | def load_and_transform_audio_data( 115 | audio_paths, 116 | device, 117 | num_mel_bins=128, 118 | target_length=204, 119 | sample_rate=16000, 120 | clip_duration=2, 121 | clips_per_video=3, 122 | mean=-4.268, 123 | std=9.138, 124 | ): 125 | if audio_paths is None: 126 | return None 127 | 128 | audio_outputs = [] 129 | clip_sampler = ConstantClipsPerVideoSampler( 130 | clip_duration=clip_duration, clips_per_video=clips_per_video 131 | ) 132 | 133 | for audio_path in audio_paths: 134 | waveform, sr = torchaudio.load(audio_path) 135 | if sample_rate != sr: 136 | waveform = torchaudio.functional.resample( 137 | waveform, orig_freq=sr, new_freq=sample_rate 138 | ) 139 | all_clips_timepoints = get_clip_timepoints( 140 | clip_sampler, waveform.size(1) / sample_rate 141 | ) 142 | all_clips = [] 143 | for clip_timepoints in all_clips_timepoints: 144 | waveform_clip = waveform[ 145 | :, 146 | int(clip_timepoints[0] * sample_rate) : int( 147 | clip_timepoints[1] * sample_rate 148 | ), 149 | ] 150 | waveform_melspec = waveform2melspec( 151 | waveform_clip, sample_rate, num_mel_bins, target_length 152 | ) 153 | all_clips.append(waveform_melspec) 154 | 155 | normalize = transforms.Normalize(mean=mean, std=std) 156 | all_clips = [normalize(ac).to(device) for ac in all_clips] 157 | 158 | all_clips = torch.stack(all_clips, dim=0) 159 | audio_outputs.append(all_clips) 160 | 161 | return torch.stack(audio_outputs, dim=0) 162 | 163 | 164 | def crop_boxes(boxes, x_offset, y_offset): 165 | """ 166 | Perform crop on the bounding boxes given the offsets. 167 | Args: 168 | boxes (ndarray or None): bounding boxes to perform crop. The dimension 169 | is `num boxes` x 4. 170 | x_offset (int): cropping offset in the x axis. 171 | y_offset (int): cropping offset in the y axis. 172 | Returns: 173 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 174 | `num boxes` x 4. 175 | """ 176 | cropped_boxes = boxes.copy() 177 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset 178 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset 179 | 180 | return cropped_boxes 181 | 182 | 183 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): 184 | """ 185 | Perform uniform spatial sampling on the images and corresponding boxes. 186 | Args: 187 | images (tensor): images to perform uniform crop. The dimension is 188 | `num frames` x `channel` x `height` x `width`. 189 | size (int): size of height and weight to crop the images. 190 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width 191 | is larger than height. Or 0, 1, or 2 for top, center, and bottom 192 | crop if height is larger than width. 193 | boxes (ndarray or None): optional. Corresponding boxes to images. 194 | Dimension is `num boxes` x 4. 195 | scale_size (int): optinal. If not None, resize the images to scale_size before 196 | performing any crop. 197 | Returns: 198 | cropped (tensor): images with dimension of 199 | `num frames` x `channel` x `size` x `size`. 200 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 201 | `num boxes` x 4. 202 | """ 203 | assert spatial_idx in [0, 1, 2] 204 | ndim = len(images.shape) 205 | if ndim == 3: 206 | images = images.unsqueeze(0) 207 | height = images.shape[2] 208 | width = images.shape[3] 209 | 210 | if scale_size is not None: 211 | if width <= height: 212 | width, height = scale_size, int(height / width * scale_size) 213 | else: 214 | width, height = int(width / height * scale_size), scale_size 215 | images = torch.nn.functional.interpolate( 216 | images, 217 | size=(height, width), 218 | mode="bilinear", 219 | align_corners=False, 220 | ) 221 | 222 | y_offset = int(math.ceil((height - size) / 2)) 223 | x_offset = int(math.ceil((width - size) / 2)) 224 | 225 | if height > width: 226 | if spatial_idx == 0: 227 | y_offset = 0 228 | elif spatial_idx == 2: 229 | y_offset = height - size 230 | else: 231 | if spatial_idx == 0: 232 | x_offset = 0 233 | elif spatial_idx == 2: 234 | x_offset = width - size 235 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] 236 | cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 237 | if ndim == 3: 238 | cropped = cropped.squeeze(0) 239 | return cropped, cropped_boxes 240 | 241 | 242 | class SpatialCrop(nn.Module): 243 | """ 244 | Convert the video into 3 smaller clips spatially. Must be used after the 245 | temporal crops to get spatial crops, and should be used with 246 | -2 in the spatial crop at the slowfast augmentation stage (so full 247 | frames are passed in here). Will return a larger list with the 248 | 3x spatial crops as well. 249 | """ 250 | 251 | def __init__(self, crop_size: int = 224, num_crops: int = 3): 252 | super().__init__() 253 | self.crop_size = crop_size 254 | if num_crops == 3: 255 | self.crops_to_ext = [0, 1, 2] 256 | self.flipped_crops_to_ext = [] 257 | elif num_crops == 1: 258 | self.crops_to_ext = [1] 259 | self.flipped_crops_to_ext = [] 260 | else: 261 | raise NotImplementedError("Nothing else supported yet") 262 | 263 | def forward(self, videos): 264 | """ 265 | Args: 266 | videos: A list of C, T, H, W videos. 267 | Returns: 268 | videos: A list with 3x the number of elements. Each video converted 269 | to C, T, H', W' by spatial cropping. 270 | """ 271 | assert isinstance(videos, list), "Must be a list of videos after temporal crops" 272 | assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" 273 | res = [] 274 | for video in videos: 275 | for spatial_idx in self.crops_to_ext: 276 | res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) 277 | if not self.flipped_crops_to_ext: 278 | continue 279 | flipped_video = transforms.functional.hflip(video) 280 | for spatial_idx in self.flipped_crops_to_ext: 281 | res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) 282 | return res 283 | 284 | 285 | def load_and_transform_video_data( 286 | video_paths, 287 | device, 288 | clip_duration=2, 289 | clips_per_video=5, 290 | sample_rate=16000, 291 | ): 292 | if video_paths is None: 293 | return None 294 | 295 | video_outputs = [] 296 | video_transform = transforms.Compose( 297 | [ 298 | pv_transforms.ShortSideScale(224), 299 | NormalizeVideo( 300 | mean=(0.48145466, 0.4578275, 0.40821073), 301 | std=(0.26862954, 0.26130258, 0.27577711), 302 | ), 303 | ] 304 | ) 305 | 306 | clip_sampler = ConstantClipsPerVideoSampler( 307 | clip_duration=clip_duration, clips_per_video=clips_per_video 308 | ) 309 | frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) 310 | 311 | for video_path in video_paths: 312 | video = EncodedVideo.from_path( 313 | video_path, 314 | decoder="decord", 315 | decode_audio=False, 316 | **{"sample_rate": sample_rate}, 317 | ) 318 | 319 | all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) 320 | 321 | all_video = [] 322 | for clip_timepoints in all_clips_timepoints: 323 | # Read the clip, get frames 324 | clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) 325 | if clip is None: 326 | raise ValueError("No clip found") 327 | video_clip = frame_sampler(clip["video"]) 328 | video_clip = video_clip / 255.0 # since this is float, need 0-1 329 | 330 | all_video.append(video_clip) 331 | 332 | all_video = [video_transform(clip) for clip in all_video] 333 | all_video = SpatialCrop(224, num_crops=3)(all_video) 334 | 335 | all_video = torch.stack(all_video, dim=0) 336 | video_outputs.append(all_video) 337 | 338 | return torch.stack(video_outputs, dim=0).to(device) 339 | -------------------------------------------------------------------------------- /video_llama/datasets/datasets/llava_instruct_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from video_llama.datasets.datasets.base_dataset import BaseDataset 3 | from video_llama.datasets.datasets.caption_datasets import CaptionDataset 4 | import pandas as pd 5 | import decord 6 | from decord import VideoReader 7 | import random 8 | import torch 9 | from torch.utils.data.dataloader import default_collate 10 | from PIL import Image 11 | from typing import Dict, Optional, Sequence 12 | import transformers 13 | import pathlib 14 | import json 15 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 16 | from video_llama.conversation.conversation_video import Conversation,SeparatorStyle 17 | DEFAULT_IMAGE_PATCH_TOKEN = '' 18 | DEFAULT_IMAGE_TOKEN = "" 19 | import copy 20 | from video_llama.processors import transforms_video,AlproVideoTrainProcessor 21 | IGNORE_INDEX = -100 22 | image_conversation = Conversation( 23 | system="", 24 | roles=("Human", "Assistant"), 25 | messages=[], 26 | offset=0, 27 | sep_style=SeparatorStyle.SINGLE, 28 | sep="###", 29 | ) 30 | llama_v2_image_conversation = Conversation( 31 | system=" ", 32 | roles=("USER", "ASSISTANT"), 33 | messages=(), 34 | offset=0, 35 | sep_style=SeparatorStyle.LLAMA_2, 36 | sep="", 37 | sep2="", 38 | ) 39 | IGNORE_INDEX = -100 40 | 41 | class Instruct_Dataset(BaseDataset): 42 | def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'): 43 | """ 44 | vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/) 45 | ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) 46 | split (string): val or test 47 | """ 48 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 49 | 50 | data_path = pathlib.Path(ann_root) 51 | with data_path.open(encoding='utf-8') as f: 52 | self.annotation = json.load(f) 53 | 54 | self.vis_root = vis_root 55 | self.resize_size = 224 56 | self.num_frm = 8 57 | self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False) 58 | self.tokenizer.pad_token = self.tokenizer.unk_token 59 | self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 60 | self.num_video_query_token = num_video_query_token 61 | self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN] 62 | 63 | self.transform = AlproVideoTrainProcessor( 64 | image_size=self.resize_size, n_frms = self.num_frm 65 | ).transform 66 | self.data_type = data_type 67 | self.model_type = model_type 68 | 69 | def _get_image_path(self, sample): 70 | # rel_video_fp ='COCO_train2014_' + sample['image'] 71 | # full_video_fp = os.path.join(self.vis_root, rel_video_fp) 72 | full_video_fp = sample['image_name'] 73 | return full_video_fp 74 | 75 | def __getitem__(self, index): 76 | num_retries = 10 # skip error videos 77 | for _ in range(num_retries): 78 | try: 79 | sample = self.annotation[index] 80 | 81 | image_path = self._get_image_path(sample) 82 | conversation_list = sample['conversations'] 83 | image = Image.open(image_path).convert("RGB") 84 | 85 | image = self.vis_processor(image) 86 | # text = self.text_processor(text) 87 | sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token) 88 | if self.model_type =='vicuna': 89 | data_dict = preprocess( 90 | sources, 91 | self.tokenizer) 92 | elif self.model_type =='llama_v2': 93 | data_dict = preprocess_for_llama_v2( 94 | sources, 95 | self.tokenizer) 96 | else: 97 | print('not support') 98 | raise('not support') 99 | data_dict = dict(input_ids=data_dict["input_ids"][0], 100 | labels=data_dict["labels"][0]) 101 | 102 | # image exist in the data 103 | data_dict['image'] = image 104 | except: 105 | print(f"Failed to load examples with image: {image_path}. " 106 | f"Will randomly sample an example as a replacement.") 107 | index = random.randint(0, len(self) - 1) 108 | continue 109 | break 110 | else: 111 | raise RuntimeError(f"Failed to fetch image after {num_retries} retries.") 112 | # "image_id" is kept to stay compatible with the COCO evaluation format 113 | return { 114 | "image": image, 115 | "text_input": data_dict["input_ids"], 116 | "labels": data_dict["labels"], 117 | "type":'image', 118 | } 119 | 120 | def __len__(self): 121 | return len(self.annotation) 122 | 123 | def collater(self, instances): 124 | input_ids, labels = tuple([instance[key] for instance in instances] 125 | for key in ("text_input", "labels")) 126 | input_ids = torch.nn.utils.rnn.pad_sequence( 127 | input_ids, 128 | batch_first=True, 129 | padding_value=self.tokenizer.pad_token_id) 130 | labels = torch.nn.utils.rnn.pad_sequence(labels, 131 | batch_first=True, 132 | padding_value=IGNORE_INDEX) 133 | batch = dict( 134 | input_ids=input_ids, 135 | labels=labels, 136 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 137 | ) 138 | 139 | if 'image' in instances[0]: 140 | images = [instance['image'] for instance in instances] 141 | if all(x is not None and x.shape == images[0].shape for x in images): 142 | batch['images'] = torch.stack(images) 143 | else: 144 | batch['images'] = images 145 | batch['conv_type'] = 'multi' 146 | return batch 147 | 148 | 149 | def preprocess_multimodal( 150 | conversation_list: Sequence[str], 151 | multimodal_cfg: dict, 152 | cur_token_len: int, 153 | ) -> Dict: 154 | # 将conversational list中 155 | is_multimodal = True 156 | # image_token_len = multimodal_cfg['image_token_len'] 157 | image_token_len = cur_token_len 158 | 159 | for sentence in conversation_list: 160 | replace_token = ''+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'' 161 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 162 | 163 | return [conversation_list] 164 | 165 | def _add_speaker_and_signal(header, source, get_conversation=True): 166 | """Add speaker and start/end signal on each round.""" 167 | BEGIN_SIGNAL = "###" 168 | END_SIGNAL = "\n" 169 | conversation = header 170 | for sentence in source: 171 | from_str = sentence["from"] 172 | if from_str.lower() == "human": 173 | from_str = image_conversation.roles[0] 174 | elif from_str.lower() == "gpt": 175 | from_str = image_conversation.roles[1] 176 | else: 177 | from_str = 'unknown' 178 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + 179 | sentence["value"] + END_SIGNAL) 180 | if get_conversation: 181 | conversation += sentence["value"] 182 | conversation += BEGIN_SIGNAL 183 | return conversation 184 | 185 | def _tokenize_fn(strings: Sequence[str], 186 | tokenizer: transformers.PreTrainedTokenizer) -> Dict: 187 | """Tokenize a list of strings.""" 188 | tokenized_list = [ 189 | tokenizer( 190 | text, 191 | return_tensors="pt", 192 | padding="longest", 193 | max_length=512, 194 | truncation=True, 195 | ) for text in strings 196 | ] 197 | input_ids = labels = [ 198 | tokenized.input_ids[0] for tokenized in tokenized_list 199 | ] 200 | input_ids_lens = labels_lens = [ 201 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 202 | for tokenized in tokenized_list 203 | ] 204 | return dict( 205 | input_ids=input_ids, 206 | labels=labels, 207 | input_ids_lens=input_ids_lens, 208 | labels_lens=labels_lens, 209 | ) 210 | 211 | def preprocess( 212 | sources: Sequence[str], 213 | tokenizer: transformers.PreTrainedTokenizer, 214 | ) -> Dict: 215 | """ 216 | Given a list of sources, each is a conversation list. This transform: 217 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 218 | 2. Concatenate conversations together; 219 | 3. Tokenize the concatenated conversation; 220 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 221 | """ 222 | # add end signal and concatenate together 223 | conversations = [] 224 | for source in sources: 225 | header = f"{image_conversation.system}\n\n" 226 | conversation = _add_speaker_and_signal(header, source) 227 | conversations.append(conversation) 228 | # tokenize conversations 229 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) 230 | input_ids = conversations_tokenized["input_ids"] 231 | targets = copy.deepcopy(input_ids) 232 | for target, source in zip(targets, sources): 233 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], 234 | tokenizer)["input_ids_lens"] 235 | speakers = [sentence["from"] for sentence in source] 236 | _mask_targets(target, tokenized_lens, speakers) 237 | 238 | return dict(input_ids=input_ids, labels=targets) 239 | 240 | def preprocess_for_llama_v2( 241 | sources: Sequence[str], 242 | tokenizer: transformers.PreTrainedTokenizer, 243 | ) -> Dict: 244 | """ 245 | Given a list of sources, each is a conversation list. This transform: 246 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 247 | 2. Concatenate conversations together; 248 | 3. Tokenize the concatenated conversation; 249 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 250 | """ 251 | # add end signal and concatenate together 252 | conversations = [] 253 | conv = copy.deepcopy(llama_v2_image_conversation.copy()) 254 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 255 | for source in sources: 256 | # [INST] <>\n{system_prompt}\n<>\n\n 257 | header = f"[INST] <>\n{conv.system}\n>\n\n" 258 | 259 | if roles[source[0]["from"]] != conv.roles[0]: 260 | # Skip the first one if it is not from human 261 | source = source[1:] 262 | conv.messages = [] 263 | for j, sentence in enumerate(source): 264 | role = roles[sentence["from"]] 265 | assert role == conv.roles[j % 2] 266 | conv.append_message(role, sentence["value"]) 267 | conversations.append(conv.get_prompt()) 268 | 269 | input_ids = tokenizer( 270 | conversations, 271 | return_tensors="pt", 272 | padding="longest", 273 | max_length=512, 274 | truncation=True, 275 | ).input_ids 276 | targets = copy.deepcopy(input_ids) 277 | 278 | 279 | sep = "[/INST] " 280 | for conversation, target in zip(conversations, targets): 281 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) 282 | rounds = conversation.split(conv.sep2) 283 | cur_len = 1 284 | target[:cur_len] = IGNORE_INDEX 285 | for i, rou in enumerate(rounds): 286 | if rou == "": 287 | break 288 | 289 | parts = rou.split(sep) 290 | if len(parts) != 2: 291 | break 292 | parts[0] += sep 293 | 294 | 295 | round_len = len(tokenizer(rou).input_ids) 296 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目 297 | 298 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 299 | 300 | cur_len += round_len 301 | target[cur_len:] = IGNORE_INDEX 302 | 303 | return dict(input_ids=input_ids, labels=targets) 304 | 305 | def _mask_targets(target, tokenized_lens, speakers): 306 | # cur_idx = 0 307 | cur_idx = tokenized_lens[0] 308 | tokenized_lens = tokenized_lens[1:] 309 | target[:cur_idx] = IGNORE_INDEX 310 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 311 | if speaker == "human": 312 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 313 | cur_idx += tokenized_len 314 | --------------------------------------------------------------------------------