├── lavis ├── models │ ├── blip2_models │ │ ├── __init__.py │ │ └── blip2_image_text_matching.py │ ├── clip_models │ │ ├── pics │ │ │ └── CLIP.png │ │ ├── __init__.py │ │ ├── clip_outputs.py │ │ ├── utils.py │ │ ├── transform.py │ │ └── loss.py │ ├── blip_models │ │ ├── blip.py │ │ ├── __init__.py │ │ └── blip_outputs.py │ └── albef_models │ │ └── albef_outputs.py ├── common │ ├── vqa_tools │ │ └── __init__.py │ ├── gradcam.py │ ├── optims.py │ └── dist_utils.py ├── configs │ ├── default.yaml │ ├── datasets │ │ ├── moviechat │ │ │ └── defaults_qa.yaml │ │ ├── moviecore │ │ │ └── defaults_qa.yaml │ │ ├── lvu │ │ │ └── defaults_cls.yaml │ │ ├── coin │ │ │ └── defaults_cls.yaml │ │ └── breakfast │ │ │ └── defaults_cls.yaml │ └── models │ │ └── blip2 │ │ ├── blip2_pretrain.yaml │ │ ├── blip2_pretrain_vitL.yaml │ │ ├── blip2_coco.yaml │ │ ├── blip2_pretrain_flant5xl.yaml │ │ ├── blip2_pretrain_opt2.7b.yaml │ │ ├── blip2_pretrain_opt6.7b.yaml │ │ ├── blip2_pretrain_flant5xxl.yaml │ │ ├── blip2_pretrain_llama7b.yaml │ │ ├── blip2_instruct_flant5xxl.yaml │ │ ├── blip2_pretrain_flant5xl_vitL.yaml │ │ ├── blip2_instruct_vicuna13b.yaml │ │ ├── blip2_instruct_vicuna7b.yaml │ │ ├── blip2_instruct_flant5xl.yaml │ │ ├── blip2_caption_opt2.7b.yaml │ │ ├── blip2_caption_opt6.7b.yaml │ │ └── blip2_caption_flant5xl.yaml ├── runners │ └── __init__.py ├── tasks │ ├── image_text_pretrain.py │ ├── __init__.py │ ├── multimodal_classification.py │ ├── retrieval.py │ ├── dialogue.py │ ├── moviechat_gpt_eval.py │ └── classification.py ├── processors │ ├── base_processor.py │ ├── __init__.py │ ├── clip_processors.py │ └── functional_video.py ├── datasets │ ├── datasets │ │ ├── multimodal_classification_datasets.py │ │ ├── vg_vqa_datasets.py │ │ ├── vqa_datasets.py │ │ ├── image_text_pair_datasets.py │ │ ├── imagefolder_dataset.py │ │ ├── snli_ve_datasets.py │ │ ├── video_vqa_datasets.py │ │ ├── video_caption_datasets.py │ │ ├── laion_dataset.py │ │ ├── base_dataset.py │ │ ├── coco_caption_datasets.py │ │ ├── caption_datasets.py │ │ ├── nlvr_datasets.py │ │ ├── gqa_datasets.py │ │ ├── coco_vqa_datasets.py │ │ ├── dialogue_datasets.py │ │ ├── msvd_caption_datasets.py │ │ ├── msrvtt_caption_datasets.py │ │ ├── msvd_vqa_datasets.py │ │ ├── msrvtt_vqa_datasets.py │ │ ├── activitynet_vqa_datasets.py │ │ └── aok_vqa_datasets.py │ └── builders │ │ ├── dialogue_builder.py │ │ ├── retrieval_builder.py │ │ ├── vqa_builder.py │ │ ├── image_text_pair_builder.py │ │ └── __init__.py ├── __init__.py ├── experimental │ └── bipartite_downsampling.py └── projects │ └── hermes │ ├── cls_coin.yaml │ ├── cls_breakfast.yaml │ ├── cls_lvu.yaml │ ├── qa_moviechat.yaml │ └── qa_moviecore.yaml ├── .gitignore ├── figs ├── hermes.png ├── hermes_banner.png ├── hermes_method.png ├── plug_and_play.png └── results_main.png ├── requirements.txt ├── run_scripts ├── coin │ ├── train.sh │ └── test.sh ├── breakfast │ ├── train.sh │ └── test.sh ├── moviechat │ ├── train.sh │ └── test.sh ├── lvu │ ├── test.sh │ └── train.sh └── moviecore │ ├── train.sh │ └── test.sh ├── LICENCE ├── setup.py ├── LICENCE_lavis.txt └── README.md /lavis/models/blip2_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | output 3 | __pycache__/ 4 | frames 5 | -------------------------------------------------------------------------------- /figs/hermes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/figs/hermes.png -------------------------------------------------------------------------------- /figs/hermes_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/figs/hermes_banner.png -------------------------------------------------------------------------------- /figs/hermes_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/figs/hermes_method.png -------------------------------------------------------------------------------- /figs/plug_and_play.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/figs/plug_and_play.png -------------------------------------------------------------------------------- /figs/results_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/figs/results_main.png -------------------------------------------------------------------------------- /lavis/models/clip_models/pics/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joslefaure/HERMES/HEAD/lavis/models/clip_models/pics/CLIP.png -------------------------------------------------------------------------------- /lavis/common/vqa_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = "aagrawal" 9 | -------------------------------------------------------------------------------- /lavis/configs/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | env: 7 | # For default users 8 | # cache_root: "cache" 9 | # For internal use with persistent storage 10 | cache_root: "data" 11 | -------------------------------------------------------------------------------- /lavis/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.runners.runner_base import RunnerBase 9 | from lavis.runners.runner_iter import RunnerIter 10 | 11 | __all__ = ["RunnerBase", "RunnerIter"] 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | contexttimer 2 | decord 3 | einops>=0.4.1 4 | fairscale==0.4.4 5 | ftfy 6 | iopath 7 | ipython 8 | omegaconf 9 | opencv-python-headless==4.5.5.64 10 | opendatasets 11 | packaging 12 | pandas 13 | plotly 14 | pre-commit 15 | pycocoevalcap 16 | pycocotools 17 | python-magic 18 | scikit-image 19 | sentencepiece 20 | spacy 21 | streamlit 22 | timm==0.4.12 23 | torch>=1.10.0 24 | torchvision 25 | tqdm 26 | transformers>=4.28.0 27 | webdataset 28 | wheel 29 | openai==0.28.0 30 | einops==0.8.0 31 | -------------------------------------------------------------------------------- /lavis/models/clip_models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Based on https://github.com/mlfoundations/open_clip 8 | """ 9 | 10 | """ OpenAI pretrained model functions 11 | Adapted from https://github.com/mlfoundations/open_clip and https://github.com/openai/CLIP. 12 | 13 | Originally MIT License, Copyright (c) 2021 OpenAI. 14 | """ 15 | -------------------------------------------------------------------------------- /lavis/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /lavis/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/multimodal_classification_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from abc import abstractmethod 9 | 10 | from lavis.datasets.datasets.base_dataset import BaseDataset 11 | 12 | 13 | class MultimodalClassificationDataset(BaseDataset): 14 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 15 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 16 | 17 | self.class_labels = None 18 | 19 | @abstractmethod 20 | def _build_class_labels(self): 21 | pass 22 | -------------------------------------------------------------------------------- /lavis/configs/datasets/moviechat/defaults_qa.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | moviechat_qa: # name of the dataset builder 3 | # data_dir: ${env.data_dir}/datasets 4 | data_type: videos # [images|videos|features] 5 | 6 | build_info: 7 | # Be careful not to append minus sign (-) before split to avoid itemizing 8 | annotations: 9 | train: 10 | url: moviechat/annotation/train.json 11 | storage: moviechat/annotation/train.json 12 | val: 13 | url: moviechat/annotation/test.json 14 | storage: moviechat/annotation/test.json 15 | test: 16 | url: moviechat/annotation/test.json 17 | storage: moviechat/annotation/test.json 18 | videos: 19 | storage: moviechat/frames 20 | 21 | instance_id_key: question_id 22 | -------------------------------------------------------------------------------- /lavis/datasets/builders/dialogue_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder 10 | from lavis.datasets.datasets.avsd_dialogue_datasets import ( 11 | AVSDDialDataset, 12 | AVSDDialEvalDataset, 13 | ) 14 | 15 | 16 | @registry.register_builder("avsd_dialogue") 17 | class AVSDDialBuilder(BaseDatasetBuilder): 18 | train_dataset_cls = AVSDDialDataset 19 | eval_dataset_cls = AVSDDialEvalDataset 20 | 21 | DATASET_CONFIG_DICT = {"default": "configs/datasets/avsd/defaults_dial.yaml"} 22 | -------------------------------------------------------------------------------- /lavis/configs/datasets/moviecore/defaults_qa.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | moviecore_qa: # name of the dataset builder 3 | # data_dir: ${env.data_dir}/datasets 4 | data_type: videos # [images|videos|features] 5 | 6 | build_info: 7 | # Be careful not to append minus sign (-) before split to avoid itemizing 8 | annotations: 9 | train: 10 | url: moviecore/annotation/moviecore_train.json 11 | storage: moviecore/annotation/moviecore_train.json 12 | val: 13 | url: moviecore/annotation/moviecore_test.json 14 | storage: moviecore/annotation/moviecore_test.json 15 | test: 16 | url: moviecore/annotation/moviecore_test.json 17 | storage: moviecore/annotation/moviecore_test.json 18 | videos: 19 | storage: moviecore/frames 20 | 21 | instance_id_key: question_id 22 | -------------------------------------------------------------------------------- /lavis/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /run_scripts/coin/train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 \ 2 | --master_port=34651 \ 3 | train.py \ 4 | --cfg-path lavis/projects/hermes/cls_coin.yaml \ 5 | --options \ 6 | model.arch blip2_vicuna_instruct \ 7 | model.model_type vicuna7b \ 8 | model.load_finetuned False \ 9 | model.load_pretrained True \ 10 | model.num_query_token 32 \ 11 | model.vit_precision fp16 \ 12 | model.freeze_vit True \ 13 | model.memory_bank_length 20 \ 14 | model.num_frames 100 \ 15 | model.num_frames_global 20 \ 16 | model.window_size 10 \ 17 | run.init_lr 1e-4 \ 18 | run.max_epoch 20 \ 19 | run.num_beams 5 \ 20 | run.batch_size_train 16 \ 21 | run.batch_size_eval 16 \ 22 | run.accum_grad_iters 1 \ 23 | run.num_workers 12 \ 24 | run.seed 42 \ 25 | run.evaluate False \ 26 | run.report_metric True \ 27 | run.prefix train 28 | # run.resume_ckpt_path 29 | 30 | -------------------------------------------------------------------------------- /run_scripts/breakfast/train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 \ 2 | --master_port=34650 \ 3 | train.py \ 4 | --cfg-path lavis/projects/hermes/cls_breakfast.yaml \ 5 | --options \ 6 | model.arch blip2_vicuna_instruct \ 7 | model.model_type vicuna7b \ 8 | model.load_finetuned False \ 9 | model.load_pretrained True \ 10 | model.num_query_token 32 \ 11 | model.vit_precision fp16 \ 12 | model.freeze_vit True \ 13 | model.memory_bank_length 20 \ 14 | model.num_frames 100 \ 15 | model.window_size 10 \ 16 | model.num_frames_global 20 \ 17 | run.init_lr 1e-4 \ 18 | run.max_epoch 20 \ 19 | run.num_beams 5 \ 20 | run.batch_size_train 4 \ 21 | run.batch_size_eval 4 \ 22 | run.accum_grad_iters 1 \ 23 | run.num_workers 12 \ 24 | run.seed 42 \ 25 | run.evaluate False \ 26 | run.report_metric True \ 27 | run.prefix train_refactor 28 | # run.resume_ckpt_path 29 | -------------------------------------------------------------------------------- /lavis/configs/datasets/lvu/defaults_cls.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | lvu_cls: # name of the dataset builder 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: videos # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: lvu/annotation/train.json 16 | storage: lvu/annotation/train.json 17 | val: 18 | url: lvu/annotation/test.json 19 | storage: lvu/annotation/test.json 20 | test: 21 | url: lvu/annotation/test.json 22 | storage: lvu/annotation/test.json 23 | videos: 24 | storage: lvu/frames 25 | -------------------------------------------------------------------------------- /lavis/configs/datasets/coin/defaults_cls.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | coin_cls: # name of the dataset builder 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: videos # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: coin/annotation/COIN.json 16 | storage: coin/annotation/COIN.json 17 | val: 18 | url: coin/annotation/COIN.json 19 | storage: coin/annotation/COIN.json 20 | test: 21 | url: coin/annotation/COIN.json 22 | storage: coin/annotation/COIN.json 23 | videos: 24 | storage: coin/frames 25 | -------------------------------------------------------------------------------- /run_scripts/coin/test.sh: -------------------------------------------------------------------------------- 1 | 2 | checkpoint_path=$1 3 | torchrun --nproc_per_node=4 \ 4 | --master_port=34651 \ 5 | train.py \ 6 | --cfg-path lavis/projects/hermes/cls_coin.yaml \ 7 | --options \ 8 | model.arch blip2_vicuna_instruct \ 9 | model.model_type vicuna7b \ 10 | model.load_finetuned False \ 11 | model.load_pretrained True \ 12 | model.num_query_token 32 \ 13 | model.vit_precision fp16 \ 14 | model.freeze_vit True \ 15 | model.memory_bank_length 20 \ 16 | model.num_frames 100 \ 17 | model.num_frames_global 20 \ 18 | model.window_size 10 \ 19 | run.init_lr 1e-4 \ 20 | run.max_epoch 20 \ 21 | run.num_beams 5 \ 22 | run.batch_size_train 16 \ 23 | run.batch_size_eval 16 \ 24 | run.accum_grad_iters 1 \ 25 | run.num_workers 12 \ 26 | run.seed 42 \ 27 | run.evaluate True \ 28 | run.report_metric True \ 29 | run.prefix test \ 30 | run.resume_ckpt_path ${checkpoint_path} 31 | 32 | -------------------------------------------------------------------------------- /run_scripts/breakfast/test.sh: -------------------------------------------------------------------------------- 1 | checkpoint_path=$1 2 | torchrun --nproc_per_node=4 \ 3 | --master_port=34650 \ 4 | train.py \ 5 | --cfg-path lavis/projects/hermes/cls_breakfast.yaml \ 6 | --options \ 7 | model.arch blip2_vicuna_instruct \ 8 | model.model_type vicuna7b \ 9 | model.load_finetuned False \ 10 | model.load_pretrained True \ 11 | model.num_query_token 32 \ 12 | model.vit_precision fp16 \ 13 | model.freeze_vit True \ 14 | model.memory_bank_length 20 \ 15 | model.num_frames 100 \ 16 | model.window_size 10 \ 17 | model.num_frames_global 20 \ 18 | run.init_lr 1e-4 \ 19 | run.max_epoch 20 \ 20 | run.num_beams 5 \ 21 | run.batch_size_train 8 \ 22 | run.batch_size_eval 8 \ 23 | run.accum_grad_iters 1 \ 24 | run.num_workers 12 \ 25 | run.seed 42 \ 26 | run.evaluate True \ 27 | run.report_metric True \ 28 | run.prefix test \ 29 | run.resume_ckpt_path ${checkpoint_path} 30 | 31 | -------------------------------------------------------------------------------- /run_scripts/moviechat/train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 \ 2 | --master_port=34651 \ 3 | train.py \ 4 | --cfg-path lavis/projects/hermes/qa_moviechat.yaml \ 5 | --options \ 6 | model.arch blip2_vicuna_instruct \ 7 | model.model_type vicuna7b \ 8 | model.load_finetuned False \ 9 | model.load_pretrained True \ 10 | model.num_query_token 32 \ 11 | model.vit_precision fp16 \ 12 | model.freeze_vit True \ 13 | model.memory_bank_length 20 \ 14 | model.num_frames 100 \ 15 | model.num_frames_global 20 \ 16 | model.window_size 10 \ 17 | model.trail_percentage 0.02 \ 18 | run.init_lr 1e-4 \ 19 | run.max_epoch 5 \ 20 | run.num_beams 5 \ 21 | run.batch_size_train 4 \ 22 | run.batch_size_eval 4 \ 23 | run.accum_grad_iters 1 \ 24 | run.num_workers 4 \ 25 | run.seed 42 \ 26 | run.evaluate False \ 27 | run.report_metric True \ 28 | run.prefix train 29 | # run.resume_ckpt_path null 30 | 31 | -------------------------------------------------------------------------------- /lavis/configs/datasets/breakfast/defaults_cls.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | breakfast_cls: # name of the dataset builder 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: videos # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: breakfast/annotation/train.json 16 | storage: breakfast/annotation/train.json 17 | val: 18 | url: breakfast/annotation/val.json 19 | storage: breakfast/annotation/val.json 20 | test: 21 | url: breakfast/annotation/val.json 22 | storage: breakfast/annotation/val.json 23 | videos: 24 | storage: breakfast/frames 25 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | 24 | preprocess: 25 | vis_processor: 26 | train: 27 | name: "blip_image_train" 28 | image_size: 224 29 | eval: 30 | name: "blip_image_eval" 31 | image_size: 224 32 | text_processor: 33 | train: 34 | name: "blip_caption" 35 | eval: 36 | name: "blip_caption" 37 | -------------------------------------------------------------------------------- /lavis/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from lavis.common.registry import registry 14 | 15 | from lavis.datasets.builders import * 16 | from lavis.models import * 17 | from lavis.processors import * 18 | from lavis.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /run_scripts/moviechat/test.sh: -------------------------------------------------------------------------------- 1 | checkpoint_path=$1 2 | torchrun --nproc_per_node=4 \ 3 | --master_port=34652 \ 4 | train.py \ 5 | --cfg-path lavis/projects/hermes/qa_moviechat.yaml \ 6 | --options \ 7 | model.arch blip2_vicuna_instruct \ 8 | model.model_type vicuna7b \ 9 | model.load_finetuned False \ 10 | model.load_pretrained True \ 11 | model.num_query_token 32 \ 12 | model.vit_precision fp16 \ 13 | model.freeze_vit True \ 14 | model.memory_bank_length 20 \ 15 | model.num_frames 100 \ 16 | model.window_size 10 \ 17 | model.num_frames_global 20 \ 18 | model.trail_percentage 0.02 \ 19 | model.is_zero_shot True \ 20 | run.init_lr 1e-4 \ 21 | run.max_epoch 5 \ 22 | run.num_beams 5 \ 23 | run.batch_size_train 8 \ 24 | run.batch_size_eval 8 \ 25 | run.accum_grad_iters 1 \ 26 | run.num_workers 12 \ 27 | run.seed 42 \ 28 | run.evaluate True \ 29 | run.valid_splits "['test']" \ 30 | run.report_metric True \ 31 | run.prefix test \ 32 | run.resume_ckpt_path ${checkpoint_path} 33 | 34 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_vitL.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_vitL.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | vit_model: "clip_L" 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | 25 | preprocess: 26 | vis_processor: 27 | train: 28 | name: "blip_image_train" 29 | image_size: 224 30 | eval: 31 | name: "blip_image_eval" 32 | image_size: 224 33 | text_processor: 34 | train: 35 | name: "blip_caption" 36 | eval: 37 | name: "blip_caption" 38 | -------------------------------------------------------------------------------- /run_scripts/lvu/test.sh: -------------------------------------------------------------------------------- 1 | 2 | checkpoint_path=$1 3 | torchrun --nproc_per_node=4 \ 4 | --master_port=34653 \ 5 | train.py \ 6 | --cfg-path lavis/projects/hermes/cls_lvu.yaml \ 7 | --options \ 8 | model.arch blip2_vicuna_instruct \ 9 | model.model_type vicuna7b \ 10 | model.load_finetuned False \ 11 | model.load_pretrained True \ 12 | model.num_query_token 32 \ 13 | model.vit_precision fp16 \ 14 | model.freeze_vit True \ 15 | model.memory_bank_length 20 \ 16 | model.num_frames 100 \ 17 | model.window_size 10 \ 18 | model.num_frames_global 20 \ 19 | datasets.lvu_cls.history 300 \ 20 | datasets.lvu_cls.task $task \ 21 | datasets.lvu_cls.stride 20 \ 22 | run.init_lr 1e-4 \ 23 | run.max_epoch 20 \ 24 | run.num_beams 5 \ 25 | run.batch_size_train 8 \ 26 | run.batch_size_eval 8 \ 27 | run.accum_grad_iters 1 \ 28 | run.num_workers 12 \ 29 | run.seed 42 \ 30 | run.evaluate True \ 31 | run.valid_splits "['test']" \ 32 | run.report_metric True \ 33 | run.prefix test \ 34 | run.resume_ckpt_path ${checkpoint_path} 35 | 36 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gueter Josmy Faure 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_coco.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: coco 8 | load_finetuned: True 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" 11 | finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_finetune_coco.pth" 12 | 13 | # vit encoder 14 | image_size: 364 15 | drop_path_rate: 0 16 | use_grad_checkpoint: True 17 | vit_precision: "fp32" 18 | freeze_vit: False 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | 24 | preprocess: 25 | vis_processor: 26 | train: 27 | name: "blip_image_train" 28 | image_size: 364 29 | eval: 30 | name: "blip_image_eval" 31 | image_size: 364 32 | text_processor: 33 | train: 34 | name: "blip_caption" 35 | eval: 36 | name: "blip_caption" 37 | -------------------------------------------------------------------------------- /run_scripts/moviecore/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # MovieCORE Training Script for HERMES 4 | # Usage: bash run_scripts/moviecore/train.sh 5 | export CUDA_VISIBLE_DEVICES=5 6 | 7 | torchrun --nproc_per_node=1 \ 8 | --master_port=34651 \ 9 | train.py \ 10 | --cfg-path lavis/projects/hermes/qa_moviecore.yaml \ 11 | --options \ 12 | model.arch blip2_vicuna_instruct \ 13 | model.model_type vicuna7b \ 14 | model.load_finetuned False \ 15 | model.load_pretrained True \ 16 | model.num_query_token 32 \ 17 | model.vit_precision fp16 \ 18 | model.freeze_vit True \ 19 | model.memory_bank_length 20 \ 20 | model.num_frames 100 \ 21 | model.num_frames_global 20 \ 22 | model.window_size 10 \ 23 | model.trail_percentage 0.02 \ 24 | model.max_txt_len 512 \ 25 | model.max_output_txt_len 512 \ 26 | run.init_lr 1e-4 \ 27 | run.max_epoch 1 \ 28 | run.num_beams 1 \ 29 | run.batch_size_train 2 \ 30 | run.batch_size_eval 4 \ 31 | run.accum_grad_iters 1 \ 32 | run.num_workers 12 \ 33 | run.seed 42 \ 34 | run.evaluate False \ 35 | run.report_metric True \ 36 | run.prefix moviecore_train 37 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain_flant5xl 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # T5 24 | t5_model: "google/flan-t5-xl" 25 | 26 | # generation configs 27 | prompt: "" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 224 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain_opt2.7b 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # OPT 24 | opt_model: "facebook/opt-2.7b" 25 | 26 | # generation configs 27 | prompt: "" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 224 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain_opt6.7b 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # OPT 24 | opt_model: "facebook/opt-6.7b" 25 | 26 | # generation configs 27 | prompt: "" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 224 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain_flant5xxl 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # T5 24 | t5_model: "google/flan-t5-xxl" 25 | 26 | # generation configs 27 | prompt: "" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 224 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_llama7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_llama 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | image_size: 224 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp16" 18 | freeze_vit: True 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # LLM 24 | llm_model: "/export/home/project/stanford_alpaca/llama_7B" 25 | 26 | # generation configs 27 | prompt: "" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip2_image_train" 34 | image_size: 224 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: flant5xxl 8 | load_finetuned: False 9 | load_pretrained: True 10 | 11 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxxl_trimmed.pth" 12 | finetuned: "" 13 | 14 | # vit encoder 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | # T5 25 | t5_model: "google/flan-t5-xxl" 26 | 27 | # generation configs 28 | prompt: "" 29 | 30 | 31 | preprocess: 32 | vis_processor: 33 | train: 34 | name: "blip_image_train" 35 | image_size: 224 36 | eval: 37 | name: "blip_image_eval" 38 | image_size: 224 39 | text_processor: 40 | train: 41 | name: "blip_caption" 42 | eval: 43 | name: "blip_caption" 44 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: pretrain_flant5xl 8 | load_finetuned: False 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl_vitL.pth" 11 | finetuned: "" 12 | 13 | # vit encoder 14 | vit_model: "clip_L" 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | # T5 25 | t5_model: "google/flan-t5-xl" 26 | 27 | # generation configs 28 | prompt: "" 29 | 30 | 31 | preprocess: 32 | vis_processor: 33 | train: 34 | name: "blip_image_train" 35 | image_size: 224 36 | eval: 37 | name: "blip_image_eval" 38 | image_size: 224 39 | text_processor: 40 | train: 41 | name: "blip_caption" 42 | eval: 43 | name: "blip_caption" 44 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/vg_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | 10 | from PIL import Image 11 | 12 | from lavis.datasets.datasets.vqa_datasets import VQADataset 13 | 14 | 15 | class VGVQADataset(VQADataset): 16 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 17 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 18 | 19 | def __getitem__(self, index): 20 | ann = self.annotation[index] 21 | 22 | image_path = os.path.join(self.vis_root, ann["image"]) 23 | image = Image.open(image_path).convert("RGB") 24 | 25 | image = self.vis_processor(image) 26 | question = self.text_processor(ann["question"]) 27 | 28 | answers = [ann["answer"]] 29 | # TODO this should be configured better 30 | weights = [0.2] 31 | 32 | return { 33 | "image": image, 34 | "text_input": question, 35 | "answers": answers, 36 | "weights": weights, 37 | } 38 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: instruct_vicuna13b 8 | load_finetuned: False 9 | load_pretrained: True 10 | 11 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth" 12 | finetuned: "" 13 | 14 | # vit encoder 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | # path to Vicuna checkpoint 25 | llm_model: "lmsys/vicuna-7b-v1.1" 26 | 27 | # generation configs 28 | prompt: "" 29 | 30 | 31 | preprocess: 32 | vis_processor: 33 | train: 34 | name: "blip2_image_train" 35 | image_size: 224 36 | eval: 37 | name: "blip_image_eval" 38 | image_size: 224 39 | text_processor: 40 | train: 41 | name: "blip_caption" 42 | eval: 43 | name: "blip_caption" 44 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: instruct_vicuna7b 8 | load_finetuned: False 9 | load_pretrained: True 10 | 11 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth" 12 | finetuned: "" 13 | 14 | # vit encoder 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | # path to Vicuna checkpoint 25 | llm_model: "lmsys/vicuna-7b-v1.1" 26 | 27 | # generation configs 28 | prompt: "" 29 | 30 | 31 | preprocess: 32 | vis_processor: 33 | train: 34 | name: "blip2_image_train" 35 | image_size: 224 36 | eval: 37 | name: "blip_image_eval" 38 | image_size: 224 39 | text_processor: 40 | train: 41 | name: "blip_caption" 42 | eval: 43 | name: "blip_caption" 44 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: flant5xl 8 | load_finetuned: False 9 | load_pretrained: True 10 | 11 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxl_trimmed.pth" 12 | finetuned: "" 13 | 14 | # vit encoder 15 | image_size: 224 16 | drop_path_rate: 0 17 | use_grad_checkpoint: False 18 | vit_precision: "fp16" 19 | freeze_vit: True 20 | 21 | # Q-Former 22 | num_query_token: 32 23 | 24 | # T5 25 | t5_model: "google/flan-t5-xl" 26 | # t5_model: "/fsx/boheumd/LAVIS/llm/flant5xl" 27 | 28 | # generation configs 29 | prompt: "" 30 | 31 | 32 | preprocess: 33 | vis_processor: 34 | train: 35 | name: "blip_image_train" 36 | image_size: 224 37 | eval: 38 | name: "blip_image_eval" 39 | image_size: 224 40 | text_processor: 41 | train: 42 | name: "blip_caption" 43 | eval: 44 | name: "blip_caption" 45 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: caption_coco_opt2.7b 8 | load_finetuned: True 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth" 11 | finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt2.7b.pth" 12 | 13 | # vit encoder 14 | image_size: 364 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp32" 18 | freeze_vit: False 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # OPT 24 | opt_model: "facebook/opt-2.7b" 25 | 26 | # generation configs 27 | prompt: "a photo of" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 364 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 364 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: caption_coco_opt6.7b 8 | load_finetuned: True 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth" 11 | finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt6.7b.pth" 12 | 13 | # vit encoder 14 | image_size: 364 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp32" 18 | freeze_vit: False 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # OPT 24 | opt_model: "facebook/opt-6.7b" 25 | 26 | # generation configs 27 | prompt: "a photo of" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 364 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 364 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /run_scripts/lvu/train.sh: -------------------------------------------------------------------------------- 1 | task_list=('relationship' 'director' 'genre' 'way_speaking' 'writer' 'year' 'scene') 2 | for task in "${task_list[@]}"; do 3 | torchrun --nproc_per_node=8 \ 4 | --master_port=34653 \ 5 | train.py \ 6 | --cfg-path lavis/projects/hermes/cls_lvu.yaml \ 7 | --options \ 8 | model.arch blip2_vicuna_instruct \ 9 | model.model_type vicuna7b \ 10 | model.load_finetuned False \ 11 | model.load_pretrained True \ 12 | model.num_query_token 32 \ 13 | model.vit_precision fp16 \ 14 | model.freeze_vit True \ 15 | model.memory_bank_length 20 \ 16 | model.num_frames 100 \ 17 | model.window_size 10 \ 18 | model.num_frames_global 20 \ 19 | datasets.lvu_cls.history 300 \ 20 | datasets.lvu_cls.task $task \ 21 | datasets.lvu_cls.stride 20 \ 22 | run.init_lr 1e-4 \ 23 | run.max_epoch 20 \ 24 | run.num_beams 5 \ 25 | run.batch_size_train 4 \ 26 | run.batch_size_eval 4 \ 27 | run.accum_grad_iters 1 \ 28 | run.num_workers 12 \ 29 | run.seed 42 \ 30 | run.evaluate False \ 31 | run.report_metric True \ 32 | run.prefix train 33 | # run.resume_ckpt_path null 34 | done 35 | -------------------------------------------------------------------------------- /run_scripts/moviecore/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # MovieCORE Test Script for HERMES 4 | # Usage: bash run_scripts/moviecore/test.sh [checkpoint_path] 5 | 6 | checkpoint_path=$1 7 | 8 | torchrun --nproc_per_node=1 \ 9 | --master_port=34659 \ 10 | train.py \ 11 | --cfg-path lavis/projects/hermes/qa_moviecore.yaml \ 12 | --options \ 13 | model.arch blip2_vicuna_instruct \ 14 | model.model_type vicuna7b \ 15 | model.load_finetuned False \ 16 | model.load_pretrained True \ 17 | model.num_query_token 32 \ 18 | model.vit_precision fp16 \ 19 | model.freeze_vit True \ 20 | model.memory_bank_length 20 \ 21 | model.num_frames 100 \ 22 | model.window_size 10 \ 23 | model.num_frames_global 20 \ 24 | model.trail_percentage 0.02 \ 25 | model.max_txt_len 512 \ 26 | model.max_output_txt_len 512 \ 27 | model.is_zero_shot False \ 28 | run.init_lr 1e-4 \ 29 | run.max_epoch 5 \ 30 | run.num_beams 1 \ 31 | run.batch_size_train 3 \ 32 | run.batch_size_eval 1 \ 33 | run.accum_grad_iters 1 \ 34 | run.num_workers 12 \ 35 | run.seed 42 \ 36 | run.evaluate True \ 37 | run.valid_splits "['test']" \ 38 | run.report_metric True \ 39 | run.prefix moviecore \ 40 | run.resume_ckpt_path ${checkpoint_path} 41 | -------------------------------------------------------------------------------- /lavis/configs/models/blip2/blip2_caption_flant5xl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: caption_coco_flant5xl 8 | load_finetuned: True 9 | 10 | pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth" 11 | finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_flant5xl.pth" 12 | 13 | # vit encoder 14 | image_size: 364 15 | drop_path_rate: 0 16 | use_grad_checkpoint: False 17 | vit_precision: "fp32" 18 | freeze_vit: False 19 | 20 | # Q-Former 21 | num_query_token: 32 22 | 23 | # T5 24 | t5_model: "google/flan-t5-xl" 25 | 26 | # generation configs 27 | prompt: "a photo of" 28 | 29 | 30 | preprocess: 31 | vis_processor: 32 | train: 33 | name: "blip_image_train" 34 | image_size: 364 35 | eval: 36 | name: "blip_image_eval" 37 | image_size: 364 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | eval: 42 | name: "blip_caption" 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from setuptools import setup, find_namespace_packages 9 | import platform 10 | 11 | DEPENDENCY_LINKS = [] 12 | if platform.system() == "Windows": 13 | DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html") 14 | 15 | 16 | def fetch_requirements(filename): 17 | with open(filename) as f: 18 | return [ln.strip() for ln in f.read().split("\n")] 19 | 20 | 21 | setup( 22 | name="salesforce-lavis", 23 | version="1.0.1", 24 | author="Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi", 25 | description="LAVIS - A One-stop Library for Language-Vision Intelligence", 26 | long_description=open("README.md", "r", encoding="utf-8").read(), 27 | long_description_content_type="text/markdown", 28 | keywords="Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch", 29 | license="3-Clause BSD", 30 | packages=find_namespace_packages(include="lavis.*"), 31 | install_requires=fetch_requirements("requirements.txt"), 32 | python_requires=">=3.7.0", 33 | include_package_data=True, 34 | dependency_links=DEPENDENCY_LINKS, 35 | zip_safe=False, 36 | ) 37 | -------------------------------------------------------------------------------- /LICENCE_lavis.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022 Salesforce, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /lavis/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.tasks.base_task import BaseTask 10 | from lavis.tasks.captioning import CaptionTask 11 | from lavis.tasks.classification import ClassificationTask 12 | from lavis.tasks.image_text_pretrain import ImageTextPretrainTask 13 | from lavis.tasks.multimodal_classification import ( 14 | MultimodalClassificationTask, 15 | ) 16 | from lavis.tasks.retrieval import RetrievalTask 17 | from lavis.tasks.vqa import VQATask, GQATask, AOKVQATask 18 | from lavis.tasks.vqa_reading_comprehension import VQARCTask, GQARCTask 19 | from lavis.tasks.dialogue import DialogueTask 20 | 21 | 22 | def setup_task(cfg): 23 | assert "task" in cfg.run_cfg, "Task name must be provided." 24 | 25 | task_name = cfg.run_cfg.task 26 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 27 | assert task is not None, "Task {} not properly registered.".format(task_name) 28 | 29 | return task 30 | 31 | 32 | __all__ = [ 33 | "BaseTask", 34 | "AOKVQATask", 35 | "RetrievalTask", 36 | "CaptionTask", 37 | "ClassificationTask", 38 | "VQATask", 39 | "GQATask", 40 | "VQARCTask", 41 | "GQARCTask", 42 | "MultimodalClassificationTask", 43 | "ImageTextPretrainTask", 44 | "DialogueTask", 45 | ] 46 | -------------------------------------------------------------------------------- /lavis/models/clip_models/clip_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Based on https://github.com/mlfoundations/open_clip 8 | """ 9 | 10 | from dataclasses import dataclass 11 | 12 | from typing import Optional 13 | 14 | import torch 15 | from transformers.modeling_outputs import ModelOutput 16 | 17 | 18 | @dataclass 19 | class ClipOutputFeatures(ModelOutput): 20 | """ 21 | Data class of features from AlbefFeatureExtractor. 22 | 23 | Args: 24 | image_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional` 25 | image_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional` 26 | text_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional` 27 | text_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional` 28 | """ 29 | 30 | image_embeds: Optional[torch.FloatTensor] = None 31 | image_embeds_proj: Optional[torch.FloatTensor] = None 32 | 33 | text_embeds: Optional[torch.FloatTensor] = None 34 | text_embeds_proj: Optional[torch.FloatTensor] = None 35 | 36 | 37 | @dataclass 38 | class ClipOutput(ModelOutput): 39 | intermediate_output: Optional[ClipOutputFeatures] = None 40 | 41 | logit_scale_exp: Optional[torch.FloatTensor] = None 42 | 43 | loss: Optional[torch.FloatTensor] = None 44 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | 10 | from lavis.datasets.datasets.base_dataset import BaseDataset 11 | 12 | 13 | class VQADataset(BaseDataset): 14 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 15 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 16 | 17 | def collater(self, samples): 18 | image_list, question_list, answer_list, weight_list = [], [], [], [] 19 | 20 | num_answers = [] 21 | 22 | for sample in samples: 23 | image_list.append(sample["image"]) 24 | question_list.append(sample["text_input"]) 25 | 26 | weight_list.extend(sample["weights"]) 27 | 28 | answers = sample["answers"] 29 | 30 | answer_list.extend(answers) 31 | num_answers.append(len(answers)) 32 | 33 | return { 34 | "image": torch.stack(image_list, dim=0), 35 | "text_input": question_list, 36 | "answer": answer_list, 37 | "weight": torch.Tensor(weight_list), 38 | "n_answers": torch.LongTensor(num_answers), 39 | } 40 | 41 | 42 | class VQAEvalDataset(BaseDataset): 43 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 44 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 45 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/image_text_pair_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from PIL import Image 12 | 13 | from lavis.datasets.datasets.base_dataset import BaseDataset 14 | 15 | 16 | class __DisplMixin: 17 | def displ_item(self, index): 18 | sample, ann = self.__getitem__(index), self.annotation[index] 19 | 20 | return OrderedDict( 21 | { 22 | "file": os.path.basename(ann["image"]), 23 | "caption": ann["caption"], 24 | "image": sample["image"], 25 | } 26 | ) 27 | 28 | 29 | class ImageTextPairDataset(BaseDataset, __DisplMixin): 30 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 31 | """ 32 | vis_root (string): Root directory of images (e.g. coco/images/) 33 | ann_root (string): directory to store the annotation file 34 | """ 35 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 36 | 37 | def __getitem__(self, index): 38 | 39 | # TODO this assumes image input, not general enough 40 | ann = self.annotation[index] 41 | 42 | image_path = os.path.join(self.vis_root, ann["image"]) 43 | image = Image.open(image_path).convert("RGB") 44 | 45 | image = self.vis_processor(image) 46 | caption = self.text_processor(ann["caption"]) 47 | 48 | return {"image": image, "text_input": caption} 49 | -------------------------------------------------------------------------------- /lavis/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.processors.base_processor import BaseProcessor 9 | 10 | from lavis.processors.alpro_processors import ( 11 | AlproVideoTrainProcessor, 12 | AlproVideoEvalProcessor, 13 | ) 14 | from lavis.processors.blip_processors import ( 15 | BlipImageTrainProcessor, 16 | Blip2ImageTrainProcessor, 17 | BlipImageEvalProcessor, 18 | BlipCaptionProcessor, 19 | Blip2VideoTrainProcessor, 20 | Blip2VideoEvalProcessor, 21 | ) 22 | from lavis.processors.gpt_processors import ( 23 | GPTVideoFeatureProcessor, 24 | GPTDialogueProcessor, 25 | ) 26 | from lavis.processors.clip_processors import ClipImageTrainProcessor 27 | 28 | from lavis.common.registry import registry 29 | 30 | __all__ = [ 31 | "BaseProcessor", 32 | # ALPRO 33 | "AlproVideoTrainProcessor", 34 | "AlproVideoEvalProcessor", 35 | # BLIP 36 | "BlipImageTrainProcessor", 37 | "Blip2ImageTrainProcessor", 38 | "BlipImageEvalProcessor", 39 | "BlipCaptionProcessor", 40 | "ClipImageTrainProcessor", 41 | "Blip2VideoTrainProcessor", 42 | "Blip2VideoEvalProcessor", 43 | # GPT 44 | "GPTVideoFeatureProcessor", 45 | "GPTDialogueProcessor", 46 | ] 47 | 48 | 49 | def load_processor(name, cfg=None): 50 | """ 51 | Example 52 | 53 | >>> processor = load_processor("alpro_video_train", cfg=None) 54 | """ 55 | processor = registry.get_processor_class(name).from_config(cfg) 56 | 57 | return processor 58 | -------------------------------------------------------------------------------- /lavis/datasets/builders/retrieval_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder 10 | from lavis.datasets.datasets.retrieval_datasets import ( 11 | RetrievalDataset, 12 | RetrievalEvalDataset, 13 | VideoRetrievalDataset, 14 | VideoRetrievalEvalDataset, 15 | ) 16 | 17 | 18 | @registry.register_builder("msrvtt_retrieval") 19 | class MSRVTTRetrievalBuilder(BaseDatasetBuilder): 20 | train_dataset_cls = VideoRetrievalDataset 21 | eval_dataset_cls = VideoRetrievalEvalDataset 22 | 23 | DATASET_CONFIG_DICT = {"default": "configs/datasets/msrvtt/defaults_ret.yaml"} 24 | 25 | 26 | @registry.register_builder("didemo_retrieval") 27 | class DiDeMoRetrievalBuilder(BaseDatasetBuilder): 28 | train_dataset_cls = VideoRetrievalDataset 29 | eval_dataset_cls = VideoRetrievalEvalDataset 30 | 31 | DATASET_CONFIG_DICT = {"default": "configs/datasets/didemo/defaults_ret.yaml"} 32 | 33 | 34 | @registry.register_builder("coco_retrieval") 35 | class COCORetrievalBuilder(BaseDatasetBuilder): 36 | train_dataset_cls = RetrievalDataset 37 | eval_dataset_cls = RetrievalEvalDataset 38 | 39 | DATASET_CONFIG_DICT = {"default": "configs/datasets/coco/defaults_ret.yaml"} 40 | 41 | 42 | @registry.register_builder("flickr30k") 43 | class Flickr30kBuilder(BaseDatasetBuilder): 44 | train_dataset_cls = RetrievalDataset 45 | eval_dataset_cls = RetrievalEvalDataset 46 | 47 | DATASET_CONFIG_DICT = {"default": "configs/datasets/flickr30k/defaults.yaml"} 48 | -------------------------------------------------------------------------------- /lavis/experimental/bipartite_downsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from the paper "ToMe: Token Mergin, your ViT but faster" by the authors. 3 | """ 4 | 5 | import math 6 | from typing import Callable, Tuple 7 | 8 | import torch 9 | 10 | 11 | def do_nothing(x: torch.Tensor) -> torch.Tensor: 12 | return x 13 | 14 | 15 | def semantics_retriever(metric: torch.Tensor, k: int) -> Tuple[Callable, Callable]: 16 | """ 17 | Merge frames with the two sets as (every kth element, the rest). 18 | """ 19 | if k <= 1: 20 | return do_nothing, do_nothing 21 | 22 | def split(x): 23 | t_rnd = (x.shape[1] // k) * k 24 | x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2]) 25 | a, b = ( 26 | x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]), 27 | x[:, :, (k - 1), :], 28 | ) 29 | return a, b 30 | 31 | with torch.no_grad(): 32 | metric = metric / metric.norm(dim=-1, keepdim=True) 33 | a, b = split(metric) 34 | r = a.shape[1] 35 | scores = a @ b.transpose(-1, -2) 36 | 37 | _, dst_idx = scores.max(dim=-1) 38 | dst_idx = dst_idx[..., None] 39 | 40 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: 41 | with torch.no_grad(): 42 | src, dst = split(x) 43 | n, _, c = src.shape 44 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) 45 | 46 | return dst 47 | 48 | def unmerge(x: torch.Tensor) -> torch.Tensor: 49 | n, _, c = x.shape 50 | dst = x 51 | 52 | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype) 53 | 54 | src = src.view(n, -1, (k - 1), c) 55 | dst = dst.view(n, -1, 1, c) 56 | 57 | out = torch.cat([src, dst], dim=-2) 58 | out = out.contiguous().view(n, -1, c) 59 | 60 | return out 61 | 62 | return merge, unmerge 63 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/imagefolder_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from PIL import Image 12 | from torchvision import datasets 13 | 14 | from lavis.datasets.datasets.base_dataset import BaseDataset 15 | 16 | 17 | class ImageFolderDataset(BaseDataset): 18 | def __init__(self, vis_processor, vis_root, classnames=[], **kwargs): 19 | super().__init__(vis_processor=vis_processor, vis_root=vis_root) 20 | 21 | self.inner_dataset = datasets.ImageFolder(vis_root) 22 | 23 | self.annotation = [ 24 | {"image": elem[0], "label": elem[1], "image_id": elem[0]} 25 | for elem in self.inner_dataset.imgs 26 | ] 27 | 28 | self.classnames = classnames 29 | 30 | self._add_instance_ids() 31 | 32 | def __len__(self): 33 | return len(self.inner_dataset) 34 | 35 | def __getitem__(self, index): 36 | ann = self.annotation[index] 37 | 38 | img_fn = ann["image"] 39 | image_path = os.path.join(self.vis_root, img_fn) 40 | image = Image.open(image_path).convert("RGB") 41 | 42 | image = self.vis_processor(image) 43 | 44 | return { 45 | "image": image, 46 | "label": ann["label"], 47 | "image_id": ann["image_id"], 48 | "instance_id": ann["instance_id"], 49 | } 50 | 51 | def displ_item(self, index): 52 | sample, ann = self.__getitem__(index), self.annotation[index] 53 | 54 | return OrderedDict( 55 | { 56 | "file": ann["image"], 57 | "label": self.classnames[ann["label"]], 58 | "image": sample["image"], 59 | } 60 | ) 61 | -------------------------------------------------------------------------------- /lavis/projects/hermes/cls_coin.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_vicuna_instruct 8 | model_type: vicuna7b 9 | load_finetuned: False 10 | load_pretrained: True 11 | 12 | # vit encoder 13 | image_size: 224 14 | drop_path_rate: 0 15 | use_grad_checkpoint: False 16 | vit_precision: "fp32" 17 | freeze_vit: False 18 | 19 | # Q-Former 20 | num_query_token: 32 21 | 22 | # path to Vicuna checkpoint 23 | llm_model: "lmsys/vicuna-7b-v1.1" 24 | 25 | # generation configs 26 | prompt: "" 27 | max_txt_len: 30 28 | 29 | datasets: 30 | coin_cls: # name of the dataset builder 31 | vis_processor: 32 | train: 33 | name: "blip2_video_train" 34 | image_size: 224 35 | eval: 36 | name: "blip2_video_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | prompt: "" 42 | eval: 43 | name: "blip_caption" 44 | prompt: "" 45 | num_frames: 10 46 | 47 | run: 48 | task: classification 49 | # optimizer 50 | lr_sched: "linear_warmup_cosine_lr" 51 | init_lr: 1e-5 52 | min_lr: 0 53 | warmup_lr: 1e-8 54 | warmup_steps: 1000 55 | weight_decay: 0.05 56 | max_epoch: 10 57 | batch_size_train: 16 58 | batch_size_eval: 16 59 | num_workers: 12 60 | accum_grad_iters: 1 61 | 62 | max_len: 20 63 | min_len: 1 64 | num_beams: 5 65 | 66 | seed: 42 67 | output_dir: "output" 68 | 69 | amp: True 70 | resume_ckpt_path: null 71 | 72 | evaluate: False 73 | train_splits: ["train"] 74 | valid_splits: ["val"] 75 | test_splits: ["test"] 76 | 77 | device: "cuda" 78 | world_size: 1 79 | dist_url: "env://" 80 | distributed: True 81 | report_metric: True 82 | suffix : null 83 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/snli_ve_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from PIL import Image 12 | 13 | from lavis.datasets.datasets.multimodal_classification_datasets import ( 14 | MultimodalClassificationDataset, 15 | ) 16 | 17 | 18 | class __DisplMixin: 19 | def displ_item(self, index): 20 | sample, ann = self.__getitem__(index), self.annotation[index] 21 | 22 | return OrderedDict( 23 | { 24 | "file": os.path.basename(ann["image"]), 25 | "sentence": ann["sentence"], 26 | "label": ann["label"], 27 | "image": sample["image"], 28 | } 29 | ) 30 | 31 | 32 | class SNLIVisualEntialmentDataset(MultimodalClassificationDataset, __DisplMixin): 33 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | self.class_labels = self._build_class_labels() 37 | 38 | def _build_class_labels(self): 39 | return {"contradiction": 0, "neutral": 1, "entailment": 2} 40 | 41 | def __getitem__(self, index): 42 | ann = self.annotation[index] 43 | 44 | image_id = ann["image"] 45 | image_path = os.path.join(self.vis_root, "%s.jpg" % image_id) 46 | image = Image.open(image_path).convert("RGB") 47 | 48 | image = self.vis_processor(image) 49 | sentence = self.text_processor(ann["sentence"]) 50 | 51 | return { 52 | "image": image, 53 | "text_input": sentence, 54 | "label": self.class_labels[ann["label"]], 55 | "image_id": image_id, 56 | "instance_id": ann["instance_id"], 57 | } 58 | -------------------------------------------------------------------------------- /lavis/projects/hermes/cls_breakfast.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_vicuna_instruct 8 | model_type: vicuna7b 9 | load_finetuned: False 10 | load_pretrained: True 11 | 12 | # vit encoder 13 | image_size: 224 14 | drop_path_rate: 0 15 | use_grad_checkpoint: False 16 | vit_precision: "fp32" 17 | freeze_vit: False 18 | 19 | # Q-Former 20 | num_query_token: 32 21 | 22 | # path to Vicuna checkpoint 23 | llm_model: "lmsys/vicuna-7b-v1.1" 24 | 25 | # generation configs 26 | prompt: "" 27 | max_txt_len: 30 28 | 29 | datasets: 30 | breakfast_cls: # name of the dataset builder 31 | vis_processor: 32 | train: 33 | name: "blip2_video_train" 34 | image_size: 224 35 | eval: 36 | name: "blip2_video_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | prompt: "" 42 | eval: 43 | name: "blip_caption" 44 | prompt: "" 45 | num_frames: 10 46 | 47 | run: 48 | task: classification 49 | # optimizer 50 | lr_sched: "linear_warmup_cosine_lr" 51 | init_lr: 1e-5 52 | min_lr: 0 53 | warmup_lr: 1e-8 54 | warmup_steps: 1000 55 | weight_decay: 0.05 56 | max_epoch: 10 57 | batch_size_train: 16 58 | batch_size_eval: 16 59 | num_workers: 12 60 | accum_grad_iters: 1 61 | 62 | max_len: 20 63 | min_len: 1 64 | num_beams: 5 65 | 66 | seed: 42 67 | output_dir: "output" 68 | 69 | amp: True 70 | resume_ckpt_path: null 71 | 72 | evaluate: False 73 | train_splits: ["train"] 74 | valid_splits: ["val"] 75 | test_splits: ["test"] 76 | 77 | device: "cuda" 78 | world_size: 1 79 | dist_url: "env://" 80 | distributed: True 81 | report_metric: True 82 | suffix : null 83 | -------------------------------------------------------------------------------- /lavis/projects/hermes/cls_lvu.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_vicuna_instruct 8 | model_type: vicuna7b 9 | load_finetuned: False 10 | load_pretrained: True 11 | 12 | # vit encoder 13 | image_size: 224 14 | drop_path_rate: 0 15 | use_grad_checkpoint: False 16 | vit_precision: "fp32" 17 | freeze_vit: False 18 | 19 | # Q-Former 20 | num_query_token: 32 21 | 22 | # path to Vicuna checkpoint 23 | llm_model: "lmsys/vicuna-7b-v1.1" 24 | 25 | # generation configs 26 | prompt: "" 27 | max_txt_len: 30 28 | 29 | datasets: 30 | lvu_cls: # name of the dataset builder 31 | vis_processor: 32 | train: 33 | name: "blip2_video_train" 34 | image_size: 224 35 | eval: 36 | name: "blip2_video_eval" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | prompt: "" 42 | eval: 43 | name: "blip_caption" 44 | prompt: "" 45 | history: 10 46 | num_frames: 10 47 | task: director 48 | stride: 10 49 | 50 | run: 51 | task: classification 52 | # optimizer 53 | lr_sched: "linear_warmup_cosine_lr" 54 | init_lr: 1e-5 55 | min_lr: 0 56 | warmup_lr: 1e-8 57 | warmup_steps: 1000 58 | weight_decay: 0.05 59 | max_epoch: 10 60 | batch_size_train: 16 61 | batch_size_eval: 16 62 | num_workers: 12 63 | accum_grad_iters: 1 64 | 65 | max_len: 5 66 | min_len: 1 67 | num_beams: 5 68 | 69 | seed: 42 70 | output_dir: "output" 71 | 72 | amp: True 73 | resume_ckpt_path: null 74 | 75 | evaluate: False 76 | train_splits: ["train"] 77 | valid_splits: ["val"] 78 | test_splits: ["test"] 79 | 80 | device: "cuda" 81 | world_size: 1 82 | dist_url: "env://" 83 | distributed: True 84 | report_metric: True 85 | suffix : null 86 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/video_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | from collections import OrderedDict 11 | 12 | from lavis.datasets.datasets.multimodal_classification_datasets import ( 13 | MultimodalClassificationDataset, 14 | ) 15 | 16 | 17 | class __DisplMixin: 18 | def displ_item(self, index): 19 | ann = self.annotation[index] 20 | 21 | vname = ann["video"] 22 | vpath = os.path.join(self.vis_root, vname) 23 | 24 | return OrderedDict( 25 | {"file": vpath, "question": ann["question"], "answer": ann["answer"]} 26 | ) 27 | 28 | 29 | class VideoQADataset(MultimodalClassificationDataset, __DisplMixin): 30 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 31 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 32 | 33 | def _build_class_labels(self, ans_path): 34 | ans2label = json.load(open(ans_path)) 35 | 36 | self.class_labels = ans2label 37 | 38 | def _get_answer_label(self, answer): 39 | if answer in self.class_labels: 40 | return self.class_labels[answer] 41 | else: 42 | return len(self.class_labels) 43 | 44 | def __getitem__(self, index): 45 | assert ( 46 | self.class_labels 47 | ), f"class_labels of {__class__.__name__} is not built yet." 48 | 49 | ann = self.annotation[index] 50 | 51 | vname = ann["video"] 52 | vpath = os.path.join(self.vis_root, vname) 53 | 54 | frms = self.vis_processor(vpath) 55 | question = self.text_processor(ann["question"]) 56 | 57 | return { 58 | "video": frms, 59 | "text_output": question, 60 | "answers": self._get_answer_label(ann["answer"]), 61 | "question_id": ann["question_id"], 62 | "instance_id": ann["instance_id"], 63 | } 64 | -------------------------------------------------------------------------------- /lavis/projects/hermes/qa_moviechat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_vicuna_instruct 8 | model_type: vicuna7b 9 | load_finetuned: False 10 | load_pretrained: True 11 | 12 | # vit encoder 13 | image_size: 224 14 | drop_path_rate: 0 15 | use_grad_checkpoint: False 16 | vit_precision: "fp32" 17 | freeze_vit: False 18 | 19 | # Q-Former 20 | num_query_token: 32 21 | 22 | # path to Vicuna checkpoint 23 | llm_model: "lmsys/vicuna-7b-v1.1" 24 | 25 | # generation configs 26 | prompt: "" 27 | 28 | datasets: 29 | moviechat_qa: # name of the dataset builder 30 | vis_processor: 31 | train: 32 | name: "blip2_video_train" 33 | image_size: 224 34 | eval: 35 | name: "blip2_video_eval" 36 | image_size: 224 37 | text_processor: 38 | train: 39 | name: "blip_question" 40 | prompt: "Question: {} Short answer:" 41 | eval: 42 | name: "blip_question" 43 | prompt: "" 44 | num_frames: 10 45 | trail_percentage: 0.01 46 | # build_info: 47 | # images: 48 | # storage: '/export/share/datasets/vision/coco/images/' 49 | 50 | run: 51 | task: vqa 52 | # optimizer 53 | lr_sched: "linear_warmup_cosine_lr" 54 | init_lr: 1e-5 55 | min_lr: 0 56 | warmup_lr: 1e-8 57 | warmup_steps: 1000 58 | weight_decay: 0.05 59 | max_epoch: 10 60 | batch_size_train: 16 61 | batch_size_eval: 16 62 | num_workers: 12 63 | accum_grad_iters: 1 64 | 65 | max_len: 10 66 | min_len: 1 67 | num_beams: 5 68 | inference_method: "generate" 69 | prompt: "Question: {} Short answer:" 70 | 71 | seed: 42 72 | output_dir: "output" 73 | 74 | amp: True 75 | resume_ckpt_path: null 76 | 77 | evaluate: False 78 | train_splits: ["train"] 79 | valid_splits: ["val"] 80 | test_splits: [] 81 | 82 | device: "cuda" 83 | world_size: 1 84 | dist_url: "env://" 85 | distributed: True 86 | report_metric: True 87 | suffix : null 88 | -------------------------------------------------------------------------------- /lavis/datasets/builders/vqa_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder 10 | from lavis.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset 11 | from lavis.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset 12 | from lavis.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset 13 | from lavis.datasets.datasets.vg_vqa_datasets import VGVQADataset 14 | 15 | 16 | @registry.register_builder("coco_vqa") 17 | class COCOVQABuilder(BaseDatasetBuilder): 18 | train_dataset_cls = COCOVQADataset 19 | eval_dataset_cls = COCOVQAEvalDataset 20 | 21 | DATASET_CONFIG_DICT = { 22 | "default": "configs/datasets/coco/defaults_vqa.yaml", 23 | "eval": "configs/datasets/coco/eval_vqa.yaml", 24 | } 25 | 26 | 27 | @registry.register_builder("vg_vqa") 28 | class VGVQABuilder(BaseDatasetBuilder): 29 | train_dataset_cls = VGVQADataset 30 | DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_vqa.yaml"} 31 | 32 | 33 | @registry.register_builder("ok_vqa") 34 | class OKVQABuilder(COCOVQABuilder): 35 | DATASET_CONFIG_DICT = { 36 | "default": "configs/datasets/okvqa/defaults.yaml", 37 | } 38 | 39 | 40 | @registry.register_builder("aok_vqa") 41 | class AOKVQABuilder(BaseDatasetBuilder): 42 | train_dataset_cls = AOKVQADataset 43 | eval_dataset_cls = AOKVQAEvalDataset 44 | 45 | DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"} 46 | 47 | 48 | @registry.register_builder("gqa") 49 | class GQABuilder(BaseDatasetBuilder): 50 | train_dataset_cls = GQADataset 51 | eval_dataset_cls = GQAEvalDataset 52 | 53 | DATASET_CONFIG_DICT = { 54 | "default": "configs/datasets/gqa/defaults.yaml", 55 | "balanced_val": "configs/datasets/gqa/balanced_val.yaml", 56 | "balanced_testdev": "configs/datasets/gqa/balanced_testdev.yaml", 57 | } 58 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/video_caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | 10 | from lavis.datasets.datasets.base_dataset import BaseDataset 11 | from lavis.datasets.datasets.caption_datasets import CaptionDataset 12 | 13 | 14 | class VideoCaptionDataset(CaptionDataset): 15 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 16 | """ 17 | vis_root (string): Root directory of images (e.g. coco/images/) 18 | ann_root (string): directory to store the annotation file 19 | split (string): val or test 20 | """ 21 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 22 | 23 | def __getitem__(self, index): 24 | 25 | ann = self.annotation[index] 26 | 27 | vname = ann["video"] 28 | video_path = os.path.join(self.vis_root, vname) 29 | 30 | video = self.vis_processor(video_path) 31 | caption = self.text_processor(ann["caption"]) 32 | 33 | # "image_id" is kept to stay compatible with the COCO evaluation format 34 | return { 35 | "video": video, 36 | "text_output": caption, 37 | "image_id": self.img_ids[ann["image_id"]], 38 | } 39 | 40 | 41 | class VideoCaptionEvalDataset(BaseDataset): 42 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 43 | """ 44 | vis_root (string): Root directory of images (e.g. coco/images/) 45 | ann_root (string): directory to store the annotation file 46 | split (string): val or test 47 | """ 48 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 49 | 50 | def __getitem__(self, index): 51 | 52 | ann = self.annotation[index] 53 | 54 | vname = ann["video"] 55 | video_path = os.path.join(self.vis_root, vname) 56 | 57 | video = self.vis_processor(video_path) 58 | 59 | return { 60 | "video": video, 61 | "image_id": ann["image_id"], 62 | "instance_id": ann["instance_id"], 63 | } 64 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | 10 | from lavis.datasets.datasets.base_dataset import BaseDataset 11 | 12 | 13 | class LaionDataset(BaseDataset): 14 | def __init__(self, vis_processor, text_processor, location): 15 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 16 | 17 | self.inner_dataset = wds.DataPipeline( 18 | wds.ResampledShards(location), 19 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 20 | wds.shuffle(1000, handler=wds.warn_and_continue), 21 | wds.decode("pilrgb", handler=wds.warn_and_continue), 22 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 23 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 24 | wds.map(self.to_dict, handler=wds.warn_and_continue), 25 | ) 26 | 27 | def to_dict(self, sample): 28 | return { 29 | "image": sample[0], 30 | "text_input": self.text_processor(sample[1]["caption"]), 31 | } 32 | 33 | 34 | if __name__ == "__main__": 35 | from torchvision import transforms 36 | 37 | def to_image_text_pair(sample): 38 | return sample[0], sample[1]["caption"] 39 | 40 | normalize = transforms.Normalize( 41 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 42 | ) 43 | 44 | transform_train = transforms.Compose( 45 | [ 46 | transforms.RandomResizedCrop(256, scale=(0.2, 1.0)), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | normalize, 50 | ] 51 | ) 52 | 53 | dataset = LaionDataset( 54 | vis_processor=transform_train, 55 | text_processor=lambda x: x, 56 | location="/export/laion/laion2B-multi/part-00000/{00000..01743}.tar", 57 | ) 58 | 59 | import torch 60 | 61 | loader = torch.utils.data.DataLoader(dataset.inner_dataset, batch_size=2) 62 | 63 | print(next(iter(loader))["text_input"]) 64 | -------------------------------------------------------------------------------- /lavis/projects/hermes/qa_moviecore.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: blip2_vicuna_instruct 8 | model_type: vicuna7b 9 | load_finetuned: False 10 | load_pretrained: True 11 | 12 | # vit encoder 13 | image_size: 224 14 | drop_path_rate: 0 15 | use_grad_checkpoint: False 16 | vit_precision: "fp32" 17 | freeze_vit: False 18 | 19 | # Q-Former 20 | num_query_token: 32 21 | 22 | # path to Vicuna checkpoint 23 | llm_model: "lmsys/vicuna-7b-v1.1" 24 | 25 | # generation configs 26 | prompt: "" 27 | 28 | datasets: 29 | moviecore_qa: # name of the dataset builder 30 | vis_processor: 31 | train: 32 | name: "blip2_video_train" 33 | image_size: 224 34 | eval: 35 | name: "blip2_video_eval" 36 | image_size: 224 37 | text_processor: 38 | train: 39 | name: "blip_question" 40 | prompt: "Question: {} Answer:" 41 | max_words: &max_words 100 42 | eval: 43 | name: "blip_question" 44 | prompt: "" 45 | max_words: *max_words 46 | num_frames: 10 47 | trail_percentage: 0.01 48 | # build_info: 49 | # images: 50 | # storage: '/export/share/datasets/vision/coco/images/' 51 | 52 | run: 53 | task: vqa 54 | # optimizer 55 | lr_sched: "linear_warmup_cosine_lr" 56 | init_lr: 1e-5 57 | min_lr: 0 58 | warmup_lr: 1e-8 59 | warmup_steps: 1000 60 | weight_decay: 0.05 61 | max_epoch: 10 62 | batch_size_train: 16 63 | batch_size_eval: 16 64 | num_workers: 12 65 | accum_grad_iters: 1 66 | 67 | max_len: 512 68 | min_len: 10 69 | num_beams: 5 70 | inference_method: "generate" 71 | prompt: "Question: {} Answer:" 72 | max_words: *max_words 73 | 74 | seed: 42 75 | output_dir: "output" 76 | 77 | amp: True 78 | resume_ckpt_path: null 79 | 80 | evaluate: False 81 | train_splits: ["train"] 82 | valid_splits: ["val"] 83 | test_splits: [] 84 | 85 | device: "cuda" 86 | world_size: 1 87 | dist_url: "env://" 88 | distributed: True 89 | report_metric: True 90 | suffix : null 91 | -------------------------------------------------------------------------------- /lavis/models/clip_models/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Based on https://github.com/mlfoundations/open_clip 8 | """ 9 | 10 | from torch import nn as nn 11 | from torchvision.ops.misc import FrozenBatchNorm2d 12 | 13 | 14 | def freeze_batch_norm_2d(module, module_match={}, name=""): 15 | """ 16 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 17 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 18 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 19 | Args: 20 | module (torch.nn.Module): Any PyTorch module. 21 | module_match (dict): Dictionary of full module names to freeze (all if empty) 22 | name (str): Full module name (prefix) 23 | Returns: 24 | torch.nn.Module: Resulting module 25 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 26 | """ 27 | res = module 28 | is_match = True 29 | if module_match: 30 | is_match = name in module_match 31 | if is_match and isinstance( 32 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) 33 | ): 34 | res = FrozenBatchNorm2d(module.num_features) 35 | res.num_features = module.num_features 36 | res.affine = module.affine 37 | if module.affine: 38 | res.weight.data = module.weight.data.clone().detach() 39 | res.bias.data = module.bias.data.clone().detach() 40 | res.running_mean.data = module.running_mean.data 41 | res.running_var.data = module.running_var.data 42 | res.eps = module.eps 43 | else: 44 | for child_name, child in module.named_children(): 45 | full_child_name = ".".join([name, child_name]) if name else child_name 46 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 47 | if new_child is not child: 48 | res.add_module(child_name, new_child) 49 | return res 50 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import ConcatDataset, Dataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/coco_caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | 11 | from PIL import Image, ImageFile 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | from lavis.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset 16 | 17 | COCOCapDataset = CaptionDataset 18 | 19 | 20 | class COCOCapEvalDataset(CaptionEvalDataset): 21 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 22 | """ 23 | vis_root (string): Root directory of images (e.g. coco/images/) 24 | ann_root (string): directory to store the annotation file 25 | split (string): val or test 26 | """ 27 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 28 | 29 | def __getitem__(self, index): 30 | ann = self.annotation[index] 31 | 32 | image_path = os.path.join(self.vis_root, ann["image"]) 33 | image = Image.open(image_path).convert("RGB") 34 | 35 | image = self.vis_processor(image) 36 | 37 | img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] 38 | 39 | return { 40 | "image": image, 41 | "image_id": img_id, 42 | "instance_id": ann["instance_id"], 43 | } 44 | 45 | 46 | class NoCapsEvalDataset(CaptionEvalDataset): 47 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 48 | """ 49 | vis_root (string): Root directory of images (e.g. coco/images/) 50 | ann_root (string): directory to store the annotation file 51 | split (string): val or test 52 | """ 53 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 54 | 55 | def __getitem__(self, index): 56 | ann = self.annotation[index] 57 | 58 | image_path = os.path.join(self.vis_root, ann["image"]) 59 | image = Image.open(image_path).convert("RGB") 60 | 61 | image = self.vis_processor(image) 62 | 63 | img_id = ann["img_id"] 64 | 65 | return { 66 | "image": image, 67 | "image_id": img_id, 68 | "instance_id": ann["instance_id"], 69 | } 70 | -------------------------------------------------------------------------------- /lavis/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | 10 | from lavis.common.registry import registry 11 | from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder 12 | from lavis.datasets.datasets.image_text_pair_datasets import ImageTextPairDataset 13 | from lavis.datasets.datasets.laion_dataset import LaionDataset 14 | 15 | 16 | @registry.register_builder("conceptual_caption_3m") 17 | class ConceptualCaption3MBuilder(BaseDatasetBuilder): 18 | train_dataset_cls = ImageTextPairDataset 19 | 20 | DATASET_CONFIG_DICT = { 21 | "default": "configs/datasets/conceptual_caption/defaults_3m.yaml" 22 | } 23 | 24 | 25 | @registry.register_builder("conceptual_caption_12m") 26 | class ConceptualCaption12MBuilder(BaseDatasetBuilder): 27 | train_dataset_cls = ImageTextPairDataset 28 | 29 | DATASET_CONFIG_DICT = { 30 | "default": "configs/datasets/conceptual_caption/defaults_12m.yaml" 31 | } 32 | 33 | 34 | @registry.register_builder("sbu_caption") 35 | class SBUCaptionBuilder(BaseDatasetBuilder): 36 | train_dataset_cls = ImageTextPairDataset 37 | 38 | DATASET_CONFIG_DICT = {"default": "configs/datasets/sbu_caption/defaults.yaml"} 39 | 40 | 41 | @registry.register_builder("vg_caption") 42 | class VGCaptionBuilder(BaseDatasetBuilder): 43 | train_dataset_cls = ImageTextPairDataset 44 | 45 | DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_caption.yaml"} 46 | 47 | 48 | @registry.register_builder("laion2B_multi") 49 | class Laion2BMultiBuilder(BaseDatasetBuilder): 50 | train_dataset_cls = LaionDataset 51 | 52 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults_2B_multi.yaml"} 53 | 54 | def _download_ann(self): 55 | pass 56 | 57 | def _download_vis(self): 58 | pass 59 | 60 | def build(self): 61 | self.build_processors() 62 | 63 | build_info = self.config.build_info 64 | 65 | datasets = dict() 66 | split = "train" # laion dataset only has train split 67 | 68 | # create datasets 69 | # [NOTE] return inner_datasets (wds.DataPipeline) 70 | dataset_cls = self.train_dataset_cls 71 | datasets[split] = dataset_cls( 72 | vis_processor=self.vis_processors[split], 73 | text_processor=self.text_processors[split], 74 | location=build_info.storage, 75 | ).inner_dataset 76 | 77 | return datasets 78 | -------------------------------------------------------------------------------- /lavis/tasks/multimodal_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import logging 11 | 12 | import numpy as np 13 | import torch 14 | from lavis.common.dist_utils import main_process 15 | from lavis.common.registry import registry 16 | from lavis.tasks.base_task import BaseTask 17 | 18 | 19 | @registry.register_task("multimodal_classification") 20 | class MultimodalClassificationTask(BaseTask): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def valid_step(self, model, samples): 25 | results = [] 26 | 27 | outputs = model.predict(samples) 28 | 29 | predictions = outputs["predictions"] 30 | targets = outputs["targets"] 31 | 32 | predictions = predictions.max(1)[1].cpu().numpy() 33 | targets = targets.cpu().numpy() 34 | 35 | indices = samples[self.inst_id_key] 36 | 37 | for pred, tgt, index in zip(predictions, targets, indices): 38 | if isinstance(index, torch.Tensor): 39 | index = index.item() 40 | 41 | results.append( 42 | { 43 | self.inst_id_key: index, 44 | "prediction": pred.item(), 45 | "target": tgt.item(), 46 | } 47 | ) 48 | 49 | return results 50 | 51 | def after_evaluation(self, val_result, split_name, epoch, **kwargs): 52 | eval_result_file = self.save_result( 53 | result=val_result, 54 | result_dir=registry.get_path("result_dir"), 55 | filename="{}_epoch{}".format(split_name, epoch), 56 | remove_duplicate=self.inst_id_key, 57 | ) 58 | 59 | metrics = self._report_metrics( 60 | eval_result_file=eval_result_file, split_name=split_name 61 | ) 62 | 63 | return metrics 64 | 65 | @main_process 66 | def _report_metrics(self, eval_result_file, split_name): 67 | results = json.load(open(eval_result_file)) 68 | 69 | predictions = np.array([res["prediction"] for res in results]) 70 | targets = np.array([res["target"] for res in results]) 71 | 72 | accuracy = (targets == predictions).sum() / targets.shape[0] 73 | metrics = {"agg_metrics": accuracy, "acc": accuracy} 74 | 75 | log_stats = {split_name: {k: v for k, v in metrics.items()}} 76 | 77 | with open( 78 | os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" 79 | ) as f: 80 | f.write(json.dumps(log_stats) + "\n") 81 | 82 | logging.info(metrics) 83 | return metrics 84 | -------------------------------------------------------------------------------- /lavis/models/blip_models/blip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | from packaging import version 11 | 12 | import torch 13 | from lavis.common.dist_utils import download_cached_file 14 | from lavis.common.utils import is_url 15 | from lavis.models.base_model import BaseModel 16 | from lavis.models.vit import interpolate_pos_embed 17 | from transformers import BertTokenizer 18 | import transformers 19 | 20 | class BlipBase(BaseModel): 21 | def __init__(self): 22 | super().__init__() 23 | transformers_version = version.parse(transformers.__version__) 24 | assert transformers_version < version.parse("4.27"), "BLIP models are not compatible with transformers>=4.27, run pip install transformers==4.25 to downgrade" 25 | 26 | @classmethod 27 | def init_tokenizer(cls): 28 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 29 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 30 | tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) 31 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 32 | return tokenizer 33 | 34 | def load_from_pretrained(self, url_or_filename): 35 | if is_url(url_or_filename): 36 | cached_file = download_cached_file( 37 | url_or_filename, check_hash=False, progress=True 38 | ) 39 | checkpoint = torch.load(cached_file, map_location="cpu") 40 | elif os.path.isfile(url_or_filename): 41 | checkpoint = torch.load(url_or_filename, map_location="cpu") 42 | else: 43 | raise RuntimeError("checkpoint url or path is invalid") 44 | 45 | state_dict = checkpoint["model"] 46 | 47 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 48 | state_dict["visual_encoder.pos_embed"], self.visual_encoder 49 | ) 50 | if "visual_encoder_m.pos_embed" in self.state_dict().keys(): 51 | state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( 52 | state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m 53 | ) 54 | 55 | for key in self.state_dict().keys(): 56 | if key in state_dict.keys(): 57 | if state_dict[key].shape != self.state_dict()[key].shape: 58 | del state_dict[key] 59 | 60 | msg = self.load_state_dict(state_dict, strict=False) 61 | 62 | logging.info("Missing keys {}".format(msg.missing_keys)) 63 | logging.info("load checkpoint from %s" % url_or_filename) 64 | 65 | return msg 66 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from PIL import Image 12 | 13 | from lavis.datasets.datasets.base_dataset import BaseDataset 14 | 15 | 16 | class __DisplMixin: 17 | def displ_item(self, index): 18 | sample, ann = self.__getitem__(index), self.annotation[index] 19 | 20 | return OrderedDict( 21 | { 22 | "file": ann["image"], 23 | "caption": ann["caption"], 24 | "image": sample["image"], 25 | } 26 | ) 27 | 28 | 29 | class CaptionDataset(BaseDataset, __DisplMixin): 30 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 31 | """ 32 | vis_root (string): Root directory of images (e.g. coco/images/) 33 | ann_root (string): directory to store the annotation file 34 | """ 35 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 36 | 37 | self.img_ids = {} 38 | n = 0 39 | for ann in self.annotation: 40 | img_id = ann["image_id"] 41 | if img_id not in self.img_ids.keys(): 42 | self.img_ids[img_id] = n 43 | n += 1 44 | 45 | def __getitem__(self, index): 46 | 47 | # TODO this assumes image input, not general enough 48 | ann = self.annotation[index] 49 | 50 | image_path = os.path.join(self.vis_root, ann["image"]) 51 | image = Image.open(image_path).convert("RGB") 52 | 53 | image = self.vis_processor(image) 54 | caption = self.text_processor(ann["caption"]) 55 | 56 | return { 57 | "image": image, 58 | "text_input": caption, 59 | "image_id": self.img_ids[ann["image_id"]], 60 | } 61 | 62 | 63 | class CaptionEvalDataset(BaseDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | split (string): val or test 69 | """ 70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.vis_root, ann["image"]) 77 | image = Image.open(image_path).convert("RGB") 78 | 79 | image = self.vis_processor(image) 80 | 81 | return { 82 | "image": image, 83 | "image_id": ann["image_id"], 84 | "instance_id": ann["instance_id"], 85 | } 86 | -------------------------------------------------------------------------------- /lavis/processors/clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.processors.blip_processors import BlipImageBaseProcessor 10 | from omegaconf import OmegaConf 11 | from torchvision import transforms 12 | from torchvision.transforms.functional import InterpolationMode 13 | 14 | 15 | def _convert_to_rgb(image): 16 | return image.convert("RGB") 17 | 18 | 19 | @registry.register_processor("clip_image_train") 20 | class ClipImageTrainProcessor(BlipImageBaseProcessor): 21 | def __init__( 22 | self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0 23 | ): 24 | 25 | super().__init__(mean=mean, std=std) 26 | 27 | self.transform = transforms.Compose( 28 | [ 29 | transforms.RandomResizedCrop( 30 | image_size, 31 | scale=(min_scale, max_scale), 32 | interpolation=InterpolationMode.BICUBIC, 33 | ), 34 | _convert_to_rgb, 35 | transforms.ToTensor(), 36 | self.normalize, 37 | ] 38 | ) 39 | 40 | @classmethod 41 | def from_config(cls, cfg=None): 42 | if cfg is None: 43 | cfg = OmegaConf.create() 44 | 45 | image_size = cfg.get("image_size", 224) 46 | 47 | mean = cfg.get("mean", None) 48 | std = cfg.get("std", None) 49 | 50 | min_scale = cfg.get("min_scale", 0.9) 51 | max_scale = cfg.get("max_scale", 1.0) 52 | 53 | return cls( 54 | image_size=image_size, 55 | mean=mean, 56 | std=std, 57 | min_scale=min_scale, 58 | max_scale=max_scale, 59 | ) 60 | 61 | 62 | @registry.register_processor("clip_image_eval") 63 | class ClipImageEvalProcessor(BlipImageBaseProcessor): 64 | def __init__(self, image_size=224, mean=None, std=None): 65 | 66 | super().__init__(mean=mean, std=std) 67 | 68 | self.transform = transforms.Compose( 69 | [ 70 | transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC), 71 | transforms.CenterCrop(image_size), 72 | _convert_to_rgb, 73 | transforms.ToTensor(), 74 | self.normalize, 75 | ] 76 | ) 77 | 78 | @classmethod 79 | def from_config(cls, cfg=None): 80 | if cfg is None: 81 | cfg = OmegaConf.create() 82 | 83 | image_size = cfg.get("image_size", 224) 84 | 85 | mean = cfg.get("mean", None) 86 | std = cfg.get("std", None) 87 | 88 | return cls( 89 | image_size=image_size, 90 | mean=mean, 91 | std=std, 92 | ) 93 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/nlvr_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import random 10 | from collections import OrderedDict 11 | 12 | from PIL import Image 13 | 14 | from lavis.datasets.datasets.multimodal_classification_datasets import ( 15 | MultimodalClassificationDataset, 16 | ) 17 | 18 | 19 | class __DisplMixin: 20 | def displ_item(self, index): 21 | sample, ann = self.__getitem__(index), self.annotation[index] 22 | 23 | return OrderedDict( 24 | { 25 | "file_L": ann["images"][0], 26 | "file_R": ann["images"][1], 27 | "sentence": ann["sentence"], 28 | "label": ann["label"], 29 | "image": [sample["image0"], sample["image1"]], 30 | } 31 | ) 32 | 33 | 34 | class NLVRDataset(MultimodalClassificationDataset, __DisplMixin): 35 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 36 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 37 | 38 | self.class_labels = self._build_class_labels() 39 | 40 | def _build_class_labels(self): 41 | return {"False": 0, "True": 1} 42 | 43 | @staticmethod 44 | def _flip(samples): 45 | sentence = samples["text_input"] 46 | image0, image1 = samples["image0"], samples["image1"] 47 | 48 | if "left" not in sentence and "right" not in sentence: 49 | if random.random() < 0.5: 50 | image0, image1 = image1, image0 51 | else: 52 | if random.random() < 0.5: 53 | sentence = sentence.replace("left", "[TEMP_TOKEN]") 54 | sentence = sentence.replace("right", "left") 55 | sentence = sentence.replace("[TEMP_TOKEN]", "right") 56 | 57 | image0, image1 = image1, image0 58 | 59 | samples["text_input"] = sentence 60 | samples["image0"] = image0 61 | samples["image1"] = image1 62 | 63 | return samples 64 | 65 | def __getitem__(self, index): 66 | ann = self.annotation[index] 67 | 68 | image0_path = os.path.join(self.vis_root, ann["images"][0]) 69 | image0 = Image.open(image0_path).convert("RGB") 70 | image0 = self.vis_processor(image0) 71 | 72 | image1_path = os.path.join(self.vis_root, ann["images"][1]) 73 | image1 = Image.open(image1_path).convert("RGB") 74 | image1 = self.vis_processor(image1) 75 | 76 | sentence = self.text_processor(ann["sentence"]) 77 | label = self.class_labels[ann["label"]] 78 | 79 | return self._flip( 80 | { 81 | "image0": image0, 82 | "image1": image1, 83 | "text_input": sentence, 84 | "label": label, 85 | # "image_id": ann["image_id"], 86 | "instance_id": ann["instance_id"], 87 | } 88 | ) 89 | 90 | 91 | class NLVREvalDataset(NLVRDataset): 92 | @staticmethod 93 | def _flip(samples): 94 | return samples 95 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/gqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | from collections import OrderedDict 11 | 12 | from PIL import Image 13 | 14 | from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset 15 | 16 | 17 | class __DisplMixin: 18 | def displ_item(self, index): 19 | sample, ann = self.__getitem__(index), self.annotation[index] 20 | 21 | return OrderedDict( 22 | { 23 | "file": ann["image"], 24 | "question": ann["question"], 25 | "question_id": ann["question_id"], 26 | "answers": "; ".join(ann["answer"]), 27 | "image": sample["image"], 28 | } 29 | ) 30 | 31 | 32 | class GQADataset(VQADataset, __DisplMixin): 33 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | def __getitem__(self, index): 37 | ann = self.annotation[index] 38 | 39 | image_path = os.path.join(self.vis_root, ann["image"]) 40 | image = Image.open(image_path).convert("RGB") 41 | 42 | image = self.vis_processor(image) 43 | question = self.text_processor(ann["question"]) 44 | 45 | answers = [ann["answer"]] 46 | weights = [1] 47 | 48 | return { 49 | "image": image, 50 | "text_input": question, 51 | "answers": answers, 52 | "weights": weights, 53 | } 54 | 55 | 56 | class GQAEvalDataset(VQAEvalDataset, __DisplMixin): 57 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 58 | """ 59 | vis_root (string): Root directory of images (e.g. gqa/images/) 60 | ann_root (string): directory to store the annotation file 61 | """ 62 | 63 | self.vis_root = vis_root 64 | 65 | self.annotation = json.load(open(ann_paths[0])) 66 | 67 | ## TODO: support inference method == 'ranking' 68 | answer_list_path = ann_paths[1] if len(ann_paths) > 1 else "" 69 | if os.path.exists(answer_list_path): 70 | self.answer_list = json.load(open(answer_list_path)) 71 | else: 72 | self.answer_list = None 73 | 74 | self.vis_processor = vis_processor 75 | self.text_processor = text_processor 76 | 77 | self._add_instance_ids() 78 | 79 | def __getitem__(self, index): 80 | ann = self.annotation[index] 81 | 82 | image_path = os.path.join(self.vis_root, ann["image"]) 83 | image = Image.open(image_path).convert("RGB") 84 | 85 | image = self.vis_processor(image) 86 | question = self.text_processor(ann["question"]) 87 | 88 | if "answer" in ann: 89 | # answer is a string 90 | answer = ann["answer"] 91 | else: 92 | answer = None 93 | 94 | return { 95 | "image": image, 96 | "text_input": question, 97 | "answer": answer, 98 | "question_id": ann["question_id"], 99 | "instance_id": ann["instance_id"], 100 | } 101 | -------------------------------------------------------------------------------- /lavis/models/albef_models/albef_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | import torch 12 | from transformers.modeling_outputs import ( 13 | BaseModelOutputWithPoolingAndCrossAttentions, 14 | CausalLMOutputWithCrossAttentions, 15 | ModelOutput, 16 | ) 17 | 18 | 19 | @dataclass 20 | class AlbefSimilarity(ModelOutput): 21 | sim_i2t: torch.FloatTensor = None 22 | sim_t2i: torch.FloatTensor = None 23 | 24 | sim_i2t_m: Optional[torch.FloatTensor] = None 25 | sim_t2i_m: Optional[torch.FloatTensor] = None 26 | 27 | sim_i2t_targets: Optional[torch.FloatTensor] = None 28 | sim_t2i_targets: Optional[torch.FloatTensor] = None 29 | 30 | 31 | @dataclass 32 | class AlbefIntermediateOutput(ModelOutput): 33 | # uni-modal features 34 | image_embeds: torch.FloatTensor = None 35 | text_embeds: Optional[torch.FloatTensor] = None 36 | 37 | image_embeds_m: Optional[torch.FloatTensor] = None 38 | text_embeds_m: Optional[torch.FloatTensor] = None 39 | 40 | # intermediate outputs of multimodal encoder 41 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 42 | encoder_output_m: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 43 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 44 | 45 | itm_logits: Optional[torch.FloatTensor] = None 46 | itm_labels: Optional[torch.LongTensor] = None 47 | 48 | # intermediate outputs of multimodal decoder 49 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 50 | decoder_labels: Optional[torch.LongTensor] = None 51 | 52 | 53 | @dataclass 54 | class AlbefOutput(ModelOutput): 55 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 56 | sims: Optional[AlbefSimilarity] = None 57 | 58 | intermediate_output: AlbefIntermediateOutput = None 59 | 60 | loss: Optional[torch.FloatTensor] = None 61 | 62 | loss_itc: Optional[torch.FloatTensor] = None 63 | 64 | loss_itm: Optional[torch.FloatTensor] = None 65 | 66 | loss_mlm: Optional[torch.FloatTensor] = None 67 | 68 | 69 | @dataclass 70 | class AlbefOutputWithLogits(AlbefOutput): 71 | logits: torch.FloatTensor = None 72 | logits_m: torch.FloatTensor = None 73 | 74 | 75 | @dataclass 76 | class AlbefOutputFeatures(ModelOutput): 77 | """ 78 | Data class of features from AlbefFeatureExtractor. 79 | 80 | Args: 81 | image_embeds: `torch.FloatTensor` of shape `(batch_size, num_patches+1, embed_dim)`, `optional` 82 | image_features: `torch.FloatTensor` of shape `(batch_size, num_patches+1, feature_dim)`, `optional` 83 | text_embeds: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, embed_dim)`, `optional` 84 | text_features: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, feature_dim)`, `optional` 85 | 86 | The first embedding or feature is for the [CLS] token. 87 | 88 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 89 | """ 90 | 91 | image_embeds: Optional[torch.FloatTensor] = None 92 | image_embeds_proj: Optional[torch.FloatTensor] = None 93 | 94 | text_embeds: Optional[torch.FloatTensor] = None 95 | text_embeds_proj: Optional[torch.FloatTensor] = None 96 | 97 | multimodal_embeds: Optional[torch.FloatTensor] = None 98 | -------------------------------------------------------------------------------- /lavis/tasks/retrieval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | 12 | import numpy as np 13 | import torch 14 | from lavis.common.dist_utils import is_main_process 15 | from lavis.common.registry import registry 16 | from lavis.tasks.base_task import BaseTask 17 | 18 | 19 | @registry.register_task("retrieval") 20 | class RetrievalTask(BaseTask): 21 | def __init__(self, cfg): 22 | super().__init__() 23 | 24 | self.cfg = cfg 25 | 26 | @classmethod 27 | def setup_task(cls, cfg): 28 | run_cfg = cfg.run_cfg 29 | 30 | return cls(cfg=run_cfg) 31 | 32 | def evaluation(self, model, data_loader, **kwargs): 33 | # score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader) 34 | score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg) 35 | 36 | if is_main_process(): 37 | eval_result = self._report_metrics( 38 | score_i2t, 39 | score_t2i, 40 | data_loader.dataset.txt2img, 41 | data_loader.dataset.img2txt, 42 | ) 43 | logging.info(eval_result) 44 | else: 45 | eval_result = None 46 | 47 | return eval_result 48 | 49 | def after_evaluation(self, val_result, **kwargs): 50 | return val_result 51 | 52 | @staticmethod 53 | @torch.no_grad() 54 | def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt): 55 | 56 | # Images->Text 57 | ranks = np.zeros(scores_i2t.shape[0]) 58 | for index, score in enumerate(scores_i2t): 59 | inds = np.argsort(score)[::-1] 60 | # Score 61 | rank = 1e20 62 | for i in img2txt[index]: 63 | tmp = np.where(inds == i)[0][0] 64 | if tmp < rank: 65 | rank = tmp 66 | ranks[index] = rank 67 | 68 | # Compute metrics 69 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 70 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 71 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 72 | 73 | # Text->Images 74 | ranks = np.zeros(scores_t2i.shape[0]) 75 | 76 | for index, score in enumerate(scores_t2i): 77 | inds = np.argsort(score)[::-1] 78 | ranks[index] = np.where(inds == txt2img[index])[0][0] 79 | 80 | # Compute metrics 81 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 82 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 83 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 84 | 85 | tr_mean = (tr1 + tr5 + tr10) / 3 86 | ir_mean = (ir1 + ir5 + ir10) / 3 87 | r_mean = (tr_mean + ir_mean) / 2 88 | 89 | agg_metrics = (tr1 + tr5 + tr10) / 3 90 | 91 | eval_result = { 92 | "txt_r1": tr1, 93 | "txt_r5": tr5, 94 | "txt_r10": tr10, 95 | "txt_r_mean": tr_mean, 96 | "img_r1": ir1, 97 | "img_r5": ir5, 98 | "img_r10": ir10, 99 | "img_r_mean": ir_mean, 100 | "r_mean": r_mean, 101 | "agg_metrics": agg_metrics, 102 | } 103 | with open( 104 | os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" 105 | ) as f: 106 | f.write(json.dumps(eval_result) + "\n") 107 | return eval_result 108 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/coco_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | from collections import OrderedDict 11 | 12 | from PIL import Image 13 | 14 | from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset 15 | 16 | 17 | class __DisplMixin: 18 | def displ_item(self, index): 19 | sample, ann = self.__getitem__(index), self.annotation[index] 20 | 21 | return OrderedDict( 22 | { 23 | "file": ann["image"], 24 | "question": ann["question"], 25 | "question_id": ann["question_id"], 26 | "answers": "; ".join(ann["answer"]), 27 | "image": sample["image"], 28 | } 29 | ) 30 | 31 | 32 | class COCOVQADataset(VQADataset, __DisplMixin): 33 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | def __getitem__(self, index): 37 | ann = self.annotation[index] 38 | 39 | image_path = os.path.join(self.vis_root, ann["image"]) 40 | image = Image.open(image_path).convert("RGB") 41 | 42 | image = self.vis_processor(image) 43 | question = self.text_processor(ann["question"]) 44 | 45 | answer_weight = {} 46 | for answer in ann["answer"]: 47 | if answer in answer_weight.keys(): 48 | answer_weight[answer] += 1 / len(ann["answer"]) 49 | else: 50 | answer_weight[answer] = 1 / len(ann["answer"]) 51 | 52 | answers = list(answer_weight.keys()) 53 | weights = list(answer_weight.values()) 54 | 55 | return { 56 | "image": image, 57 | "text_output": question, 58 | "answers": answers, 59 | "weights": weights, 60 | } 61 | 62 | 63 | class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | """ 69 | 70 | self.vis_root = vis_root 71 | 72 | self.annotation = json.load(open(ann_paths[0])) 73 | 74 | answer_list_path = ann_paths[1] 75 | if os.path.exists(answer_list_path): 76 | self.answer_list = json.load(open(answer_list_path)) 77 | else: 78 | self.answer_list = None 79 | 80 | try: 81 | self.coco_fmt_qust_file = ann_paths[2] 82 | self.coco_fmt_anno_file = ann_paths[3] 83 | except IndexError: 84 | self.coco_fmt_qust_file = None 85 | self.coco_fmt_anno_file = None 86 | 87 | self.vis_processor = vis_processor 88 | self.text_processor = text_processor 89 | 90 | self._add_instance_ids() 91 | 92 | def __getitem__(self, index): 93 | ann = self.annotation[index] 94 | 95 | image_path = os.path.join(self.vis_root, ann["image"]) 96 | image = Image.open(image_path).convert("RGB") 97 | 98 | image = self.vis_processor(image) 99 | question = self.text_processor(ann["question"]) 100 | 101 | return { 102 | "image": image, 103 | "text_output": question, 104 | "question_id": ann["question_id"], 105 | "instance_id": ann["instance_id"], 106 | } 107 | -------------------------------------------------------------------------------- /lavis/models/clip_models/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Based on https://github.com/mlfoundations/open_clip 8 | """ 9 | 10 | from typing import Optional, Sequence, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchvision.transforms.functional as F 15 | 16 | 17 | from torchvision.transforms import ( 18 | Normalize, 19 | Compose, 20 | RandomResizedCrop, 21 | InterpolationMode, 22 | ToTensor, 23 | Resize, 24 | CenterCrop, 25 | ) 26 | 27 | 28 | class ResizeMaxSize(nn.Module): 29 | def __init__( 30 | self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0 31 | ): 32 | super().__init__() 33 | if not isinstance(max_size, int): 34 | raise TypeError(f"Size should be int. Got {type(max_size)}") 35 | self.max_size = max_size 36 | self.interpolation = interpolation 37 | self.fn = min if fn == "min" else min 38 | self.fill = fill 39 | 40 | def forward(self, img): 41 | if isinstance(img, torch.Tensor): 42 | height, width = img.shape[:2] 43 | else: 44 | width, height = img.size 45 | scale = self.max_size / float(max(height, width)) 46 | if scale != 1.0: 47 | new_size = tuple(round(dim * scale) for dim in (height, width)) 48 | img = F.resize(img, new_size, self.interpolation) 49 | pad_h = self.max_size - new_size[0] 50 | pad_w = self.max_size - new_size[1] 51 | img = F.pad( 52 | img, 53 | padding=[ 54 | pad_w // 2, 55 | pad_h // 2, 56 | pad_w - pad_w // 2, 57 | pad_h - pad_h // 2, 58 | ], 59 | fill=self.fill, 60 | ) 61 | return img 62 | 63 | 64 | def _convert_to_rgb(image): 65 | return image.convert("RGB") 66 | 67 | 68 | def image_transform( 69 | image_size: int, 70 | is_train: bool, 71 | mean: Optional[Tuple[float, ...]] = None, 72 | std: Optional[Tuple[float, ...]] = None, 73 | resize_longest_max: bool = False, 74 | fill_color: int = 0, 75 | ): 76 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean 77 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std 78 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 79 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 80 | image_size = image_size[0] 81 | 82 | normalize = Normalize(mean=mean, std=std) 83 | if is_train: 84 | return Compose( 85 | [ 86 | RandomResizedCrop( 87 | image_size, 88 | scale=(0.9, 1.0), 89 | interpolation=InterpolationMode.BICUBIC, 90 | ), 91 | _convert_to_rgb, 92 | ToTensor(), 93 | normalize, 94 | ] 95 | ) 96 | else: 97 | if resize_longest_max: 98 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 99 | else: 100 | transforms = [ 101 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 102 | CenterCrop(image_size), 103 | ] 104 | transforms.extend( 105 | [ 106 | _convert_to_rgb, 107 | ToTensor(), 108 | normalize, 109 | ] 110 | ) 111 | return Compose(transforms) 112 | -------------------------------------------------------------------------------- /lavis/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from lavis.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | min_lr, 63 | init_lr, 64 | warmup_steps=0, 65 | warmup_start_lr=-1, 66 | **kwargs 67 | ): 68 | self.optimizer = optimizer 69 | 70 | self.max_epoch = max_epoch 71 | self.min_lr = min_lr 72 | 73 | self.init_lr = init_lr 74 | self.warmup_steps = warmup_steps 75 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 76 | 77 | def step(self, cur_epoch, cur_step): 78 | # assuming the warmup iters less than one epoch 79 | if cur_epoch == 0: 80 | warmup_lr_schedule( 81 | step=cur_step, 82 | optimizer=self.optimizer, 83 | max_step=self.warmup_steps, 84 | init_lr=self.warmup_start_lr, 85 | max_lr=self.init_lr, 86 | ) 87 | else: 88 | cosine_lr_schedule( 89 | epoch=cur_epoch, 90 | optimizer=self.optimizer, 91 | max_epoch=self.max_epoch, 92 | init_lr=self.init_lr, 93 | min_lr=self.min_lr, 94 | ) 95 | 96 | 97 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 98 | """Decay the learning rate""" 99 | lr = (init_lr - min_lr) * 0.5 * ( 100 | 1.0 + math.cos(math.pi * epoch / max_epoch) 101 | ) + min_lr 102 | for param_group in optimizer.param_groups: 103 | param_group["lr"] = lr 104 | 105 | 106 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 107 | """Warmup the learning rate""" 108 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 109 | for param_group in optimizer.param_groups: 110 | param_group["lr"] = lr 111 | 112 | 113 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 114 | """Decay the learning rate""" 115 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 116 | for param_group in optimizer.param_groups: 117 | param_group["lr"] = lr 118 | -------------------------------------------------------------------------------- /lavis/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from lavis.common.registry import registry 9 | from lavis.datasets.builders.base_dataset_builder import load_dataset_config 10 | from lavis.datasets.builders.caption_builder import ( 11 | COCOCapBuilder, 12 | MSRVTTCapBuilder, 13 | MSVDCapBuilder, 14 | VATEXCapBuilder, 15 | ) 16 | from lavis.datasets.builders.classification_builder import ( 17 | BreakfastCLSBuilder, 18 | COINCLSBuilder, 19 | LVUCLSBuilder, 20 | NLVRBuilder, 21 | SNLIVisualEntailmentBuilder, 22 | ) 23 | from lavis.datasets.builders.dialogue_builder import AVSDDialBuilder 24 | from lavis.datasets.builders.image_text_pair_builder import ( 25 | ConceptualCaption3MBuilder, 26 | ConceptualCaption12MBuilder, 27 | SBUCaptionBuilder, 28 | VGCaptionBuilder, 29 | ) 30 | from lavis.datasets.builders.imagefolder_builder import ImageNetBuilder 31 | from lavis.datasets.builders.retrieval_builder import ( 32 | COCORetrievalBuilder, 33 | DiDeMoRetrievalBuilder, 34 | Flickr30kBuilder, 35 | MSRVTTRetrievalBuilder, 36 | ) 37 | from lavis.datasets.builders.video_qa_builder import ( 38 | ActivityNetQABuilder, 39 | MSRVTTQABuilder, 40 | MSVDQABuilder, 41 | ) 42 | from lavis.datasets.builders.vqa_builder import ( 43 | COCOVQABuilder, 44 | GQABuilder, 45 | OKVQABuilder, 46 | VGVQABuilder, 47 | ) 48 | 49 | __all__ = [ 50 | "COCOCapBuilder", 51 | "COCORetrievalBuilder", 52 | "COCOVQABuilder", 53 | "ConceptualCaption12MBuilder", 54 | "ConceptualCaption3MBuilder", 55 | "DiDeMoRetrievalBuilder", 56 | "Flickr30kBuilder", 57 | "GQABuilder", 58 | "ImageNetBuilder", 59 | "MSRVTTCapBuilder", 60 | "MSRVTTQABuilder", 61 | "MSRVTTRetrievalBuilder", 62 | "MSVDCapBuilder", 63 | "MSVDQABuilder", 64 | "ActivityNetQABuilder", 65 | "NLVRBuilder", 66 | "OKVQABuilder", 67 | "SBUCaptionBuilder", 68 | "SNLIVisualEntailmentBuilder", 69 | "VATEXCapBuilder", 70 | "VGCaptionBuilder", 71 | "VGVQABuilder", 72 | "AVSDDialBuilder", 73 | "LVUCLSBuilder", 74 | "COINCLSBuilder", 75 | "BreakfastCLSBuilder", 76 | ] 77 | 78 | 79 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 80 | """ 81 | Example 82 | 83 | >>> dataset = load_dataset("coco_caption", cfg=None) 84 | >>> splits = dataset.keys() 85 | >>> print([len(dataset[split]) for split in splits]) 86 | 87 | """ 88 | if cfg_path is None: 89 | cfg = None 90 | else: 91 | cfg = load_dataset_config(cfg_path) 92 | 93 | try: 94 | builder = registry.get_builder_class(name)(cfg) 95 | except TypeError: 96 | print( 97 | f"Dataset {name} not found. Available datasets:\n" 98 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 99 | ) 100 | exit(1) 101 | 102 | if vis_path is not None: 103 | if data_type is None: 104 | # use default data type in the config 105 | data_type = builder.config.data_type 106 | 107 | assert ( 108 | data_type in builder.config.build_info 109 | ), f"Invalid data_type {data_type} for {name}." 110 | 111 | builder.config.build_info.get(data_type).storage = vis_path 112 | 113 | dataset = builder.build_datasets() 114 | return dataset 115 | 116 | 117 | class DatasetZoo: 118 | def __init__(self) -> None: 119 | self.dataset_zoo = { 120 | k: list(v.DATASET_CONFIG_DICT.keys()) 121 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 122 | } 123 | 124 | def get_names(self): 125 | return list(self.dataset_zoo.keys()) 126 | 127 | 128 | dataset_zoo = DatasetZoo() 129 | -------------------------------------------------------------------------------- /lavis/models/blip_models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | from typing import List 10 | 11 | from torch import nn 12 | 13 | 14 | def tie_encoder_decoder_weights( 15 | encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str 16 | ): 17 | uninitialized_encoder_weights: List[str] = [] 18 | if decoder.__class__ != encoder.__class__: 19 | logging.info( 20 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 21 | ) 22 | 23 | def tie_encoder_to_decoder_recursively( 24 | decoder_pointer: nn.Module, 25 | encoder_pointer: nn.Module, 26 | module_name: str, 27 | uninitialized_encoder_weights: List[str], 28 | skip_key: str, 29 | depth=0, 30 | ): 31 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 32 | encoder_pointer, nn.Module 33 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 34 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 35 | assert hasattr(encoder_pointer, "weight") 36 | encoder_pointer.weight = decoder_pointer.weight 37 | if hasattr(decoder_pointer, "bias"): 38 | assert hasattr(encoder_pointer, "bias") 39 | encoder_pointer.bias = decoder_pointer.bias 40 | print(module_name + " is tied") 41 | return 42 | 43 | encoder_modules = encoder_pointer._modules 44 | decoder_modules = decoder_pointer._modules 45 | if len(decoder_modules) > 0: 46 | assert ( 47 | len(encoder_modules) > 0 48 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 49 | 50 | all_encoder_weights = set( 51 | [module_name + "/" + sub_name for sub_name in encoder_modules.keys()] 52 | ) 53 | encoder_layer_pos = 0 54 | for name, module in decoder_modules.items(): 55 | if name.isdigit(): 56 | encoder_name = str(int(name) + encoder_layer_pos) 57 | decoder_name = name 58 | if not isinstance( 59 | decoder_modules[decoder_name], 60 | type(encoder_modules[encoder_name]), 61 | ) and len(encoder_modules) != len(decoder_modules): 62 | # this can happen if the name corresponds to the position in a list module list of layers 63 | # in this case the decoder has added a cross-attention that the encoder does not have 64 | # thus skip this step and subtract one layer pos from encoder 65 | encoder_layer_pos -= 1 66 | continue 67 | elif name not in encoder_modules: 68 | continue 69 | elif depth > 500: 70 | raise ValueError( 71 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 72 | ) 73 | else: 74 | decoder_name = encoder_name = name 75 | tie_encoder_to_decoder_recursively( 76 | decoder_modules[decoder_name], 77 | encoder_modules[encoder_name], 78 | module_name + "/" + name, 79 | uninitialized_encoder_weights, 80 | skip_key, 81 | depth=depth + 1, 82 | ) 83 | all_encoder_weights.remove(module_name + "/" + encoder_name) 84 | 85 | uninitialized_encoder_weights += list(all_encoder_weights) 86 | 87 | # tie weights recursively 88 | tie_encoder_to_decoder_recursively( 89 | decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key 90 | ) 91 | -------------------------------------------------------------------------------- /lavis/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import timm.models.hub as timm_hub 13 | import torch 14 | import torch.distributed as dist 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /lavis/processors/functional_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import warnings 9 | 10 | import torch 11 | 12 | 13 | def _is_tensor_video_clip(clip): 14 | if not torch.is_tensor(clip): 15 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 16 | 17 | if not clip.ndimension() == 4: 18 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 19 | 20 | return True 21 | 22 | 23 | def crop(clip, i, j, h, w): 24 | """ 25 | Args: 26 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 27 | """ 28 | if len(clip.size()) != 4: 29 | raise ValueError("clip should be a 4D tensor") 30 | return clip[..., i : i + h, j : j + w] 31 | 32 | 33 | def resize(clip, target_size, interpolation_mode): 34 | if len(target_size) != 2: 35 | raise ValueError( 36 | f"target size should be tuple (height, width), instead got {target_size}" 37 | ) 38 | return torch.nn.functional.interpolate( 39 | clip, size=target_size, mode=interpolation_mode, align_corners=False 40 | ) 41 | 42 | 43 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 44 | """ 45 | Do spatial cropping and resizing to the video clip 46 | Args: 47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 48 | i (int): i in (i,j) i.e coordinates of the upper left corner. 49 | j (int): j in (i,j) i.e coordinates of the upper left corner. 50 | h (int): Height of the cropped region. 51 | w (int): Width of the cropped region. 52 | size (tuple(int, int)): height and width of resized clip 53 | Returns: 54 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) 55 | """ 56 | if not _is_tensor_video_clip(clip): 57 | raise ValueError("clip should be a 4D torch.tensor") 58 | clip = crop(clip, i, j, h, w) 59 | clip = resize(clip, size, interpolation_mode) 60 | return clip 61 | 62 | 63 | def center_crop(clip, crop_size): 64 | if not _is_tensor_video_clip(clip): 65 | raise ValueError("clip should be a 4D torch.tensor") 66 | h, w = clip.size(-2), clip.size(-1) 67 | th, tw = crop_size 68 | if h < th or w < tw: 69 | raise ValueError("height and width must be no smaller than crop_size") 70 | 71 | i = int(round((h - th) / 2.0)) 72 | j = int(round((w - tw) / 2.0)) 73 | return crop(clip, i, j, th, tw) 74 | 75 | 76 | def to_tensor(clip): 77 | """ 78 | Convert tensor data type from uint8 to float, divide value by 255.0 and 79 | permute the dimensions of clip tensor 80 | Args: 81 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 82 | Return: 83 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 84 | """ 85 | _is_tensor_video_clip(clip) 86 | if not clip.dtype == torch.uint8: 87 | raise TypeError( 88 | "clip tensor should have data type uint8. Got %s" % str(clip.dtype) 89 | ) 90 | return clip.float().permute(3, 0, 1, 2) / 255.0 91 | 92 | 93 | def normalize(clip, mean, std, inplace=False): 94 | """ 95 | Args: 96 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 97 | mean (tuple): pixel RGB mean. Size is (3) 98 | std (tuple): pixel standard deviation. Size is (3) 99 | Returns: 100 | normalized clip (torch.tensor): Size is (C, T, H, W) 101 | """ 102 | if not _is_tensor_video_clip(clip): 103 | raise ValueError("clip should be a 4D torch.tensor") 104 | if not inplace: 105 | clip = clip.clone() 106 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 107 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 109 | return clip 110 | 111 | 112 | def hflip(clip): 113 | """ 114 | Args: 115 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 116 | Returns: 117 | flipped clip (torch.tensor): Size is (C, T, H, W) 118 | """ 119 | if not _is_tensor_video_clip(clip): 120 | raise ValueError("clip should be a 4D torch.tensor") 121 | return clip.flip(-1) 122 | -------------------------------------------------------------------------------- /lavis/models/blip2_models/blip2_image_text_matching.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from lavis.common.registry import registry 11 | from lavis.models.blip2_models.blip2_qformer import Blip2Qformer 12 | 13 | 14 | @registry.register_model("blip2_image_text_matching") 15 | class Blip2ITM(Blip2Qformer): 16 | """ 17 | BLIP Image-Text Matching (ITM) model. 18 | Supported model types: 19 | - pretrained: pretrained model 20 | - coco: fintuned model on coco 21 | Usage: 22 | >>> from lavis.models import load_model 23 | >>> model = load_model("blip2_image_text_matching", "pretrained") 24 | >>> model = load_model("blip2_image_text_matching", "coco") 25 | """ 26 | 27 | def __init__( 28 | self, 29 | vit_model="eva_clip_g", 30 | img_size=224, 31 | drop_path_rate=0, 32 | use_grad_checkpoint=False, 33 | vit_precision="fp16", 34 | freeze_vit=True, 35 | num_query_token=32, 36 | cross_attention_freq=2, 37 | embed_dim=256, 38 | max_txt_len=32, 39 | ): 40 | super().__init__( 41 | vit_model=vit_model, 42 | img_size=img_size, 43 | drop_path_rate=drop_path_rate, 44 | use_grad_checkpoint=use_grad_checkpoint, 45 | vit_precision=vit_precision, 46 | freeze_vit=freeze_vit, 47 | num_query_token=num_query_token, 48 | cross_attention_freq=cross_attention_freq, 49 | embed_dim=embed_dim, 50 | max_txt_len=max_txt_len, 51 | ) 52 | 53 | def forward(self, samples, match_head="itm"): 54 | image = samples["image"] 55 | caption = samples["text_input"] 56 | 57 | with self.maybe_autocast(): 58 | image_embeds = self.ln_vision(self.visual_encoder(image)) 59 | image_embeds = image_embeds.float() 60 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 61 | image.device 62 | ) 63 | 64 | text = self.tokenizer( 65 | caption, 66 | truncation=True, 67 | max_length=self.max_txt_len, 68 | return_tensors="pt", 69 | ).to(image.device) 70 | 71 | if match_head == "itm": 72 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) 73 | query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( 74 | image.device 75 | ) 76 | attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) 77 | output_itm = self.Qformer.bert( 78 | text.input_ids, 79 | query_embeds=query_tokens, 80 | attention_mask=attention_mask, 81 | encoder_hidden_states=image_embeds, 82 | encoder_attention_mask=image_atts, 83 | return_dict=True, 84 | ) 85 | itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] 86 | itm_logit = self.itm_head(itm_embeddings) 87 | itm_logit = itm_logit.mean(dim=1) 88 | 89 | return itm_logit 90 | 91 | elif match_head == "itc": 92 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) 93 | 94 | query_output = self.Qformer.bert( 95 | query_embeds=query_tokens, 96 | encoder_hidden_states=image_embeds, 97 | encoder_attention_mask=image_atts, 98 | return_dict=True, 99 | ) 100 | image_feats = F.normalize( 101 | self.vision_proj(query_output.last_hidden_state), dim=-1 102 | ) 103 | 104 | text_output = self.Qformer.bert( 105 | text.input_ids, 106 | attention_mask=text.attention_mask, 107 | return_dict=True, 108 | ) 109 | text_feat = F.normalize( 110 | self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 111 | ) 112 | 113 | sims = torch.bmm(image_feats, text_feat.unsqueeze(-1)) 114 | sim, _ = torch.max(sims, dim=1) 115 | 116 | return sim 117 | -------------------------------------------------------------------------------- /lavis/tasks/dialogue.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | 11 | from lavis.common.dist_utils import main_process 12 | from lavis.common.logger import MetricLogger 13 | from lavis.common.registry import registry 14 | from lavis.tasks.base_task import BaseTask 15 | from lavis.datasets.data_utils import prepare_sample 16 | 17 | import numpy as np 18 | 19 | 20 | @registry.register_task("dialogue") 21 | class DialogueTask(BaseTask): 22 | def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True): 23 | super().__init__() 24 | 25 | self.num_beams = num_beams 26 | self.max_len = max_len 27 | self.min_len = min_len 28 | self.evaluate = evaluate 29 | 30 | self.report_metric = report_metric 31 | 32 | @classmethod 33 | def setup_task(cls, cfg): 34 | run_cfg = cfg.run_cfg 35 | 36 | num_beams = run_cfg.num_beams 37 | max_len = run_cfg.max_len 38 | min_len = run_cfg.min_len 39 | evaluate = run_cfg.evaluate 40 | 41 | report_metric = run_cfg.get("report_metric", True) 42 | 43 | return cls( 44 | num_beams=num_beams, 45 | max_len=max_len, 46 | min_len=min_len, 47 | evaluate=evaluate, 48 | report_metric=report_metric, 49 | ) 50 | 51 | def valid_step(self, model, samples): 52 | results = [] 53 | loss = model(samples)["loss"].item() 54 | 55 | return [loss] 56 | 57 | def after_evaluation(self, val_result, split_name, epoch, **kwargs): 58 | 59 | if self.report_metric: 60 | avg_loss = np.mean(val_result) 61 | metrics = {"agg_metrics": avg_loss} 62 | else: 63 | metrics = {"agg_metrics": 0.0} 64 | 65 | return metrics 66 | 67 | @main_process 68 | def _report_metrics(self, eval_result_file, split_name): 69 | # TODO better way to define this 70 | coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") 71 | coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name) 72 | 73 | agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] 74 | log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} 75 | 76 | with open( 77 | os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" 78 | ) as f: 79 | f.write(json.dumps(log_stats) + "\n") 80 | 81 | coco_res = {k: v for k, v in coco_val.eval.items()} 82 | coco_res["agg_metrics"] = agg_metrics 83 | 84 | return coco_res 85 | 86 | 87 | # TODO better structure for this. 88 | from pycocoevalcap.eval import COCOEvalCap 89 | from pycocotools.coco import COCO 90 | from torchvision.datasets.utils import download_url 91 | 92 | 93 | def coco_dialogue_eval(coco_gt_root, results_file, split): 94 | 95 | urls = { 96 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", 97 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", 98 | } 99 | filenames = { 100 | "val": "coco_karpathy_val_gt.json", 101 | "test": "coco_karpathy_test_gt.json", 102 | } 103 | 104 | download_url(urls[split], coco_gt_root) 105 | annotation_file = os.path.join(coco_gt_root, filenames[split]) 106 | 107 | # create coco object and coco_result object 108 | coco = COCO(annotation_file) 109 | coco_result = coco.loadRes(results_file) 110 | 111 | # create coco_eval object by taking coco and coco_result 112 | coco_eval = COCOEvalCap(coco, coco_result) 113 | 114 | # evaluate on a subset of images by setting 115 | # coco_eval.params['image_id'] = coco_result.getImgIds() 116 | # please remove this line when evaluating the full validation set 117 | # coco_eval.params['image_id'] = coco_result.getImgIds() 118 | 119 | # evaluate results 120 | # SPICE will take a few minutes the first time, but speeds up due to caching 121 | coco_eval.evaluate() 122 | 123 | # print output evaluation scores 124 | for metric, score in coco_eval.eval.items(): 125 | print(f"{metric}: {score:.3f}") 126 | 127 | return coco_eval 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | 6 | # HERMES: temporal-coHERent long-forM understanding with Episodes and Semantics 7 | 8 | ### [Project Page](https://joslefaure.github.io/assets/html/hermes.html) | [Paper](https://arxiv.org/abs/2408.17443) 9 | 10 | 11 | ## :fire: News 12 | * **[2025.06.26]** Our paper **HERMES: temporal-coHERent long-forM understanding with Episodes and Semantics** has been accepted by ***ICCV'2025*** 🚀. 13 | * **[2024.08.24]** :keyboard: Our short paper **BREASE: Bridging Episodes and Semantics, A Novel Framework for Long-Form Video Understanding** has been accepted by the EVAL-FoMo workshop at ***ECCV'24***. 14 | 15 |

16 | teaser 17 |

18 | 19 | 20 | ## Model Overview 21 |

22 | model 23 |

24 | 25 | ## Results 26 | main results 27 | 28 | ### Plug-and-Play Experiments 29 | plug and play 30 | 31 | 32 | 33 | 34 | ## Requirements 35 | 36 | You can install the conda environment by running: 37 | ```bash 38 | git clone https://github.com/joslefaure/HERMES.git 39 | cd HERMES 40 | pip install -e . 41 | ``` 42 | 43 | ## Supported Datasets 44 | - [MovieCORE](https://huggingface.co/datasets/MovieCORE/MovieCORE) 45 | - [LVU](https://github.com/chaoyuaw/lvu) 46 | - [Breakfast](https://serre-lab.clps.brown.edu/resource/breakfast-actions-dataset/) 47 | - [COIN](https://coin-dataset.github.io/) 48 | - [MovieChat-1k](https://github.com/rese1f/MovieChat) 49 | 50 | **Additionally, our modules can be plugged into other VLMs for faster inference and improved memory management.** 51 | 52 | ### Prepare MovieCORE and/or MovieChat-1k 53 | 1. Download the train data (if you want to finetune HERMES) from [here](https://huggingface.co/datasets/Enxin/MovieChat-1K_train) and the test data from [here](https://huggingface.co/datasets/Enxin/MovieChat-1K-test/tree/main) 54 | 55 | 2. Extract the frames at 10FPS and organize it as follows: 56 | ``` 57 | ├── data 58 | └── moviecore 59 | ├── annotation 60 | ├── frames 61 | └── {video_id} 62 | ├── frame000001.jpg 63 | ├── ... 64 | ``` 65 | 66 | ### Pretrained Checkpoints 67 | 68 | | Dataset | Download Link | 69 | |---------|---------------| 70 | | MovieCORE | [GDrive](https://drive.google.com/file/d/16GWbIQ5CpD6un_LJYn04WYf9D0cojWNi/view?usp=sharing) / [HuggingFace](https://huggingface.co/Joslefaure/HERMES/blob/main/moviecore_checkpoint.pth) | 71 | | MovieChat-1k | [GDrive](https://drive.google.com/file/d/15E5f2DyzkA4sjgNk7d7EGoafU8KMDMl8/view?usp=drive_link) / [HuggingFace](https://huggingface.co/Joslefaure/HERMES/blob/main/moviechat_checkpoint.pth) | 72 | | LVU | [GDrive](link) (Coming soon) | 73 | | Breakfast | [GDrive](link) (Coming soon) | 74 | | COIN | [GDrive](link) (Coming soon) | 75 | 76 | 77 | 78 | ### Inference 79 | We inference the model on 4 V100 GPUs (32GB). One GPU will do BTW. 80 | 81 | First add your openai API to the environment variable `export OPENAI_API_KEY='sk-*****` (only for moviechat dataset), as we use GPT3.5 for scoring. For the other datasets, we report top-1 accuracy. 82 | 83 | 84 | ```bash 85 | # Zero-shot 86 | bash run_scripts/moviecore/test.sh 87 | 88 | # Fully-supervised 89 | bash run_scripts/moviecore/test.sh path/to/your/model.pth 90 | ``` 91 | Same for the other datasets. All the scripts are included in `run_scripts`. 92 | 93 | ### Train 94 | We train the model on 8 V100 GPUs (32GB). 95 | 96 | ```bash 97 | bash run_scripts/{dataset}/train.sh 98 | ``` 99 | 100 | ## Citation 101 | If you find our code or our paper useful for your research, please **[★star]** this repo and **[cite]** the following paper: 102 | 103 | ```latex 104 | @misc{faure2024hermestemporalcoherentlongformunderstanding, 105 | title={HERMES: temporal-coHERent long-forM understanding with Episodes and Semantics}, 106 | author={Gueter Josmy Faure and Jia-Fong Yeh and Min-Hung Chen and Hung-Ting Su and Shang-Hong Lai and Winston H. Hsu}, 107 | year={2024}, 108 | eprint={2408.17443}, 109 | archivePrefix={arXiv}, 110 | primaryClass={cs.CV}, 111 | url={https://arxiv.org/abs/2408.17443}, 112 | } 113 | ``` 114 | 115 | 116 | ## Acknowledgement 117 | We thank the authors of the following repositories for open-sourcing their code. 118 | - [LAVIS](https://github.com/salesforce/LAVIS) 119 | - [MA-LMM](https://github.com/boheumd/MA-LMM) 120 | 121 | *Icon made by Freepik from www.flaticon.com* 122 | 123 | -------------------------------------------------------------------------------- /lavis/models/blip_models/blip_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | import torch 12 | from transformers.modeling_outputs import ( 13 | ModelOutput, 14 | BaseModelOutputWithPoolingAndCrossAttentions, 15 | CausalLMOutputWithCrossAttentions, 16 | ) 17 | 18 | 19 | @dataclass 20 | class BlipSimilarity(ModelOutput): 21 | sim_i2t: torch.FloatTensor = None 22 | sim_t2i: torch.FloatTensor = None 23 | 24 | sim_i2t_m: Optional[torch.FloatTensor] = None 25 | sim_t2i_m: Optional[torch.FloatTensor] = None 26 | 27 | sim_i2t_targets: Optional[torch.FloatTensor] = None 28 | sim_t2i_targets: Optional[torch.FloatTensor] = None 29 | 30 | 31 | @dataclass 32 | class BlipIntermediateOutput(ModelOutput): 33 | """ 34 | Data class for intermediate outputs of BLIP models. 35 | 36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 38 | 39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 41 | 42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 44 | 45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 46 | decoder_labels (torch.LongTensor): labels for the captioning loss. 47 | 48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 50 | 51 | """ 52 | 53 | # uni-modal features 54 | image_embeds: torch.FloatTensor = None 55 | text_embeds: Optional[torch.FloatTensor] = None 56 | 57 | image_embeds_m: Optional[torch.FloatTensor] = None 58 | text_embeds_m: Optional[torch.FloatTensor] = None 59 | 60 | # intermediate outputs of multimodal encoder 61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | 64 | itm_logits: Optional[torch.FloatTensor] = None 65 | itm_labels: Optional[torch.LongTensor] = None 66 | 67 | # intermediate outputs of multimodal decoder 68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 69 | decoder_labels: Optional[torch.LongTensor] = None 70 | 71 | 72 | @dataclass 73 | class BlipOutput(ModelOutput): 74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 75 | sims: Optional[BlipSimilarity] = None 76 | 77 | intermediate_output: BlipIntermediateOutput = None 78 | 79 | loss: Optional[torch.FloatTensor] = None 80 | 81 | loss_itc: Optional[torch.FloatTensor] = None 82 | 83 | loss_itm: Optional[torch.FloatTensor] = None 84 | 85 | loss_lm: Optional[torch.FloatTensor] = None 86 | 87 | 88 | @dataclass 89 | class BlipOutputWithLogits(BlipOutput): 90 | logits: torch.FloatTensor = None 91 | logits_m: torch.FloatTensor = None 92 | 93 | 94 | @dataclass 95 | class BlipOutputFeatures(ModelOutput): 96 | """ 97 | Data class of features from BlipFeatureExtractor. 98 | 99 | Args: 100 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 101 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 102 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 103 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 104 | 105 | The first embedding or feature is for the [CLS] token. 106 | 107 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 108 | """ 109 | 110 | image_embeds: Optional[torch.FloatTensor] = None 111 | image_embeds_proj: Optional[torch.FloatTensor] = None 112 | 113 | text_embeds: Optional[torch.FloatTensor] = None 114 | text_embeds_proj: Optional[torch.FloatTensor] = None 115 | 116 | multimodal_embeds: Optional[torch.FloatTensor] = None 117 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/dialogue_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import copy 9 | import json 10 | import os 11 | from collections import OrderedDict 12 | 13 | from PIL import Image 14 | 15 | from lavis.datasets.datasets.base_dataset import BaseDataset 16 | 17 | 18 | class __DisplMixin: 19 | def displ_item(self, index): 20 | sample, ann = self.__getitem__(index), self.annotation[index] 21 | 22 | return OrderedDict( 23 | { 24 | "file": ann["image"], 25 | "dialogue": ann["dialogue"], 26 | "image": sample["image"], 27 | } 28 | ) 29 | 30 | 31 | class DialogueDataset(BaseDataset, __DisplMixin): 32 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 33 | """ 34 | vis_root (string): Root directory of images (e.g. coco/images/) 35 | ann_root (string): directory to store the annotation file 36 | """ 37 | 38 | self.vis_root = vis_root 39 | 40 | self.annotation = [] 41 | for ann_path in ann_paths: 42 | dialogs = json.load(open(ann_path, "r"))["dialogs"] 43 | for dialog in dialogs: 44 | all_turns = dialog["dialog"] 45 | dialogue_context = [] 46 | for turn in all_turns: 47 | dialog_instance = copy.deepcopy(dialog) 48 | question = turn["question"] 49 | answer = turn["answer"] 50 | 51 | dialog_instance["dialog"] = copy.deepcopy(dialogue_context) 52 | dialog_instance["question"] = question 53 | dialog_instance["answer"] = answer 54 | self.annotation.append(dialog_instance) 55 | dialogue_context.append(turn) 56 | 57 | self.vis_processor = vis_processor 58 | self.text_processor = text_processor 59 | 60 | self._add_instance_ids() 61 | 62 | self.img_ids = {} 63 | n = 0 64 | for ann in self.annotation: 65 | img_id = ann["image_id"] 66 | if img_id not in self.img_ids.keys(): 67 | self.img_ids[img_id] = n 68 | n += 1 69 | 70 | def __getitem__(self, index): 71 | 72 | ann = self.annotation[index] 73 | 74 | image_path = os.path.join(self.vis_root, ann["image"]) 75 | image = Image.open(image_path).convert("RGB") 76 | 77 | image = self.vis_processor(image) 78 | caption = self.text_processor(ann["caption"]) 79 | 80 | return { 81 | "image": image, 82 | "text_input": caption, 83 | "image_id": self.img_ids[ann["image_id"]], 84 | } 85 | 86 | 87 | class DialogueEvalDataset(BaseDataset, __DisplMixin): 88 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 89 | """ 90 | vis_root (string): Root directory of images (e.g. coco/images/) 91 | ann_root (string): directory to store the annotation file 92 | split (string): val or test 93 | """ 94 | 95 | self.vis_root = vis_root 96 | 97 | self.annotation = [] 98 | for ann_path in ann_paths: 99 | dialogs = json.load(open(ann_path, "r"))["dialogs"] 100 | for dialog in dialogs: 101 | all_turns = dialog["dialog"] 102 | dialogue_context = all_turns[:-1] 103 | last_turn = all_turns[-1] 104 | 105 | question = last_turn["question"] 106 | answer = last_turn["answer"] 107 | 108 | dialog["dialog"] = dialogue_context 109 | dialog["question"] = question 110 | dialog["answer"] = answer 111 | 112 | self.annotation.append(dialog) 113 | 114 | self.vis_processor = vis_processor 115 | self.text_processor = text_processor 116 | 117 | self._add_instance_ids() 118 | 119 | self.img_ids = {} 120 | n = 0 121 | for ann in self.annotation: 122 | img_id = ann["image_id"] 123 | if img_id not in self.img_ids.keys(): 124 | self.img_ids[img_id] = n 125 | n += 1 126 | 127 | def __getitem__(self, index): 128 | 129 | ann = self.annotation[index] 130 | 131 | image_path = os.path.join(self.vis_root, ann["image"]) 132 | image = Image.open(image_path).convert("RGB") 133 | 134 | image = self.vis_processor(image) 135 | 136 | return { 137 | "image": image, 138 | "image_id": ann["image_id"], 139 | "instance_id": ann["instance_id"], 140 | } 141 | -------------------------------------------------------------------------------- /lavis/tasks/moviechat_gpt_eval.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import multiprocessing 4 | import os 5 | import time 6 | 7 | import openai 8 | from openai.error import APIError, RateLimitError, Timeout 9 | from tqdm import tqdm 10 | 11 | api_key = os.getenv("OPENAI_API_KEY") 12 | 13 | 14 | def evaluate_qa_pair(data, retries=3, delay=10): 15 | question, answer, pred = data["question"], data["answer"], data["pred"] 16 | prompt = ( 17 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 18 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 19 | "------" 20 | "##INSTRUCTIONS: " 21 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 22 | "- Consider synonyms or paraphrases as valid matches.\n" 23 | "- Evaluate the correctness of the prediction compared to the answer." 24 | ) 25 | user_input = ( 26 | f"Please evaluate the following video-based question-answer pair:\n\n" 27 | f"Question: {question}\n" 28 | f"Correct Answer: {answer}\n" 29 | f"Predicted Answer: {pred}\n\n" 30 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 31 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 32 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 33 | "Your response should look strictly follow this format: {'pred': 'yes', 'score': 4}." 34 | ) 35 | 36 | for attempt in range(retries): 37 | try: 38 | response = openai.ChatCompletion.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | {"role": "system", "content": prompt}, 42 | {"role": "user", "content": user_input}, 43 | ], 44 | timeout=30, # Set a timeout for the API call 45 | ) 46 | response_message = response["choices"][0]["message"]["content"] 47 | return ast.literal_eval(response_message) 48 | except (RateLimitError, APIError) as e: 49 | print( 50 | f"Rate limit or API error encountered: {e}. Retrying in {delay} seconds..." 51 | ) 52 | time.sleep(delay) 53 | except Timeout as e: 54 | print(f"Request timed out: {e}. Retrying in {delay} seconds...") 55 | time.sleep(delay) 56 | except Exception as e: 57 | print(f"Unexpected error: {e}") 58 | break 59 | return None 60 | # The rest of the function remains the same as before, just use the variables directly 61 | 62 | 63 | def worker(data): 64 | return evaluate_qa_pair(data) 65 | 66 | 67 | def main(pred_path): 68 | # with open(pred_path, 'r') as pred_file: 69 | pred_contents = pred_path 70 | 71 | data_for_workers = [] 72 | for sample in pred_contents: 73 | id = sample.split(".")[0] 74 | qa_list = pred_contents[sample] 75 | for qa_pair in qa_list: 76 | question = qa_pair["question"] 77 | answer = qa_pair["answer"] 78 | pred = qa_pair["pred"].replace("", "") 79 | data_for_workers.append( 80 | {"question": question, "answer": answer, "pred": pred} 81 | ) 82 | 83 | total_items = len(data_for_workers) 84 | 85 | pool = multiprocessing.Pool(processes=8) 86 | 87 | with tqdm(total=total_items, desc="Processing") as pbar: 88 | results = [] 89 | for result in pool.imap_unordered(worker, data_for_workers): 90 | results.append(result) 91 | pbar.update() 92 | 93 | pool.close() 94 | pool.join() 95 | 96 | yes_count = 0 97 | no_count = 0 98 | score_sum = 0 99 | 100 | for result in results: 101 | if result: 102 | try: 103 | if result["pred"].lower() == "yes": 104 | yes_count += 1 105 | else: 106 | no_count += 1 107 | score_sum += result["score"] 108 | except: 109 | print("Error in result: ", result) 110 | continue 111 | 112 | average_score = score_sum / (yes_count + no_count) if (yes_count + no_count) else 0 113 | accuracy = yes_count / (yes_count + no_count) if (yes_count + no_count) else 0 114 | return {"accuracy": accuracy * 100.0, "average_score": average_score} 115 | 116 | 117 | if __name__ == "__main__": 118 | pred_path = "moviechat_gpt_eval.json" 119 | res = main(pred_path) 120 | print(res) 121 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/msvd_caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import pdb 11 | import re 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from PIL import Image 17 | from torchvision.transforms.functional import pil_to_tensor 18 | 19 | from lavis.datasets.data_utils import load_video 20 | from lavis.datasets.datasets.video_caption_datasets import VideoCaptionDataset 21 | 22 | 23 | class MSVDCapDataset(VideoCaptionDataset): 24 | def __init__( 25 | self, 26 | vis_processor, 27 | text_processor, 28 | vis_root, 29 | ann_paths, 30 | num_frames, 31 | prompt="", 32 | split="train", 33 | ): 34 | self.vis_root = vis_root 35 | 36 | self.annotation = {} 37 | for ann_path in ann_paths: 38 | self.annotation.update(json.load(open(ann_path))) 39 | self.video_id_list = list(self.annotation.keys()) 40 | self.video_id_list.sort() 41 | self.fps = 10 42 | 43 | self.num_frames = num_frames 44 | self.vis_processor = vis_processor 45 | self.text_processor = text_processor 46 | self.prompt = prompt 47 | # self._add_instance_ids() 48 | # pdb.set_trace() 49 | 50 | def __getitem__(self, index): 51 | video_id = self.video_id_list[index] 52 | ann = self.annotation[video_id] 53 | 54 | # Divide the range into num_frames segments and select a random index from each segment 55 | segment_list = np.linspace( 56 | 0, ann["frame_length"], self.num_frames + 1, dtype=int 57 | ) 58 | segment_start_list = segment_list[:-1] 59 | segment_end_list = segment_list[1:] 60 | selected_frame_index = [] 61 | for start, end in zip(segment_start_list, segment_end_list): 62 | if start == end: 63 | selected_frame_index.append(start) 64 | else: 65 | selected_frame_index.append(np.random.randint(start, end)) 66 | 67 | frame_list = [] 68 | for frame_index in selected_frame_index: 69 | frame = Image.open( 70 | os.path.join( 71 | self.vis_root, 72 | ann["video"], 73 | "frame{:06d}.jpg".format(frame_index + 1), 74 | ) 75 | ).convert("RGB") 76 | frame = pil_to_tensor(frame).to(torch.float32) 77 | frame_list.append(frame) 78 | video = torch.stack(frame_list, dim=1) 79 | video = self.vis_processor(video) 80 | # print(selected_frame_index, video.shape) 81 | 82 | text_input = self.prompt 83 | caption = self.text_processor.pre_caption(ann["caption"]) 84 | 85 | return { 86 | "image": video, 87 | "text_input": text_input, 88 | "text_output": caption, 89 | "prompt": self.prompt, 90 | "image_id": ann["video"], 91 | } 92 | 93 | def __len__(self): 94 | return len(self.video_id_list) 95 | 96 | 97 | class MSVDCapEvalDataset(MSVDCapDataset): 98 | def __init__( 99 | self, 100 | vis_processor, 101 | text_processor, 102 | vis_root, 103 | ann_paths, 104 | num_frames, 105 | prompt, 106 | split="val", 107 | ): 108 | super().__init__( 109 | vis_processor, 110 | text_processor, 111 | vis_root, 112 | ann_paths, 113 | num_frames, 114 | prompt, 115 | split="val", 116 | ) 117 | 118 | def __getitem__(self, index): 119 | video_id = self.video_id_list[index] 120 | ann = self.annotation[video_id] 121 | 122 | selected_frame_index = ( 123 | np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames)) 124 | .astype(int) 125 | .tolist() 126 | ) 127 | frame_list = [] 128 | for frame_index in selected_frame_index: 129 | frame = Image.open( 130 | os.path.join( 131 | self.vis_root, 132 | ann["video"], 133 | "frame{:06d}.jpg".format(frame_index + 1), 134 | ) 135 | ).convert("RGB") 136 | frame = pil_to_tensor(frame).to(torch.float32) 137 | frame_list.append(frame) 138 | video = torch.stack(frame_list, dim=1) 139 | video = self.vis_processor(video) 140 | # print(selected_frame_index, video.shape) 141 | 142 | text_input = self.prompt 143 | caption = self.text_processor.pre_caption(ann["caption"]) 144 | 145 | return { 146 | "image": video, 147 | "text_input": text_input, 148 | "text_output": caption, 149 | "prompt": self.prompt, 150 | "image_id": ann["video"], 151 | } 152 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/msrvtt_caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import pdb 11 | import re 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from PIL import Image 17 | from torchvision.transforms.functional import pil_to_tensor 18 | 19 | from lavis.datasets.data_utils import load_video 20 | from lavis.datasets.datasets.video_caption_datasets import VideoCaptionDataset 21 | 22 | 23 | class MSRVTTCapDataset(VideoCaptionDataset): 24 | def __init__( 25 | self, 26 | vis_processor, 27 | text_processor, 28 | vis_root, 29 | ann_paths, 30 | num_frames, 31 | prompt="", 32 | split="train", 33 | ): 34 | self.vis_root = vis_root 35 | 36 | self.annotation = {} 37 | for ann_path in ann_paths: 38 | self.annotation.update(json.load(open(ann_path))) 39 | self.video_id_list = list(self.annotation.keys()) 40 | self.video_id_list.sort() 41 | self.fps = 10 42 | 43 | self.num_frames = num_frames 44 | self.vis_processor = vis_processor 45 | self.text_processor = text_processor 46 | self.prompt = prompt 47 | # self._add_instance_ids() 48 | # pdb.set_trace() 49 | 50 | def __getitem__(self, index): 51 | video_id = self.video_id_list[index] 52 | ann = self.annotation[video_id] 53 | 54 | # Divide the range into num_frames segments and select a random index from each segment 55 | segment_list = np.linspace( 56 | 0, ann["frame_length"], self.num_frames + 1, dtype=int 57 | ) 58 | segment_start_list = segment_list[:-1] 59 | segment_end_list = segment_list[1:] 60 | selected_frame_index = [] 61 | for start, end in zip(segment_start_list, segment_end_list): 62 | if start == end: 63 | selected_frame_index.append(start) 64 | else: 65 | selected_frame_index.append(np.random.randint(start, end)) 66 | 67 | frame_list = [] 68 | for frame_index in selected_frame_index: 69 | frame = Image.open( 70 | os.path.join( 71 | self.vis_root, 72 | ann["video"], 73 | "frame{:06d}.jpg".format(frame_index + 1), 74 | ) 75 | ).convert("RGB") 76 | frame = pil_to_tensor(frame).to(torch.float32) 77 | frame_list.append(frame) 78 | video = torch.stack(frame_list, dim=1) 79 | video = self.vis_processor(video) 80 | # print(selected_frame_index, video.shape) 81 | 82 | text_input = self.prompt 83 | caption = self.text_processor.pre_caption(ann["caption"]) 84 | 85 | return { 86 | "image": video, 87 | "text_input": text_input, 88 | "text_output": caption, 89 | "prompt": self.prompt, 90 | "image_id": ann["video"], 91 | } 92 | 93 | def __len__(self): 94 | return len(self.video_id_list) 95 | 96 | 97 | class MSRVTTCapEvalDataset(MSRVTTCapDataset): 98 | def __init__( 99 | self, 100 | vis_processor, 101 | text_processor, 102 | vis_root, 103 | ann_paths, 104 | num_frames, 105 | prompt, 106 | split="val", 107 | ): 108 | super().__init__( 109 | vis_processor, 110 | text_processor, 111 | vis_root, 112 | ann_paths, 113 | num_frames, 114 | prompt, 115 | split="val", 116 | ) 117 | 118 | def __getitem__(self, index): 119 | video_id = self.video_id_list[index] 120 | ann = self.annotation[video_id] 121 | 122 | selected_frame_index = ( 123 | np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames)) 124 | .astype(int) 125 | .tolist() 126 | ) 127 | frame_list = [] 128 | for frame_index in selected_frame_index: 129 | frame = Image.open( 130 | os.path.join( 131 | self.vis_root, 132 | ann["video"], 133 | "frame{:06d}.jpg".format(frame_index + 1), 134 | ) 135 | ).convert("RGB") 136 | frame = pil_to_tensor(frame).to(torch.float32) 137 | frame_list.append(frame) 138 | video = torch.stack(frame_list, dim=1) 139 | video = self.vis_processor(video) 140 | # print(selected_frame_index, video.shape) 141 | 142 | text_input = self.prompt 143 | caption = self.text_processor.pre_caption(ann["caption"]) 144 | 145 | return { 146 | "image": video, 147 | "text_input": text_input, 148 | "text_output": caption, 149 | "prompt": self.prompt, 150 | "image_id": ann["video"], 151 | } 152 | -------------------------------------------------------------------------------- /lavis/tasks/classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import numpy as np 11 | 12 | import torch.distributed as dist 13 | from lavis.common.dist_utils import main_process, get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 14 | from lavis.common.registry import registry 15 | from lavis.tasks.base_task import BaseTask 16 | from lavis.common.logger import MetricLogger, SmoothedValue 17 | from lavis.datasets.data_utils import prepare_sample 18 | 19 | import pdb 20 | 21 | @registry.register_task("classification") 22 | class ClassificationTask(BaseTask): 23 | def __init__(self, num_beams, max_len, min_len, evaluate, 24 | report_metric=True, verb_only=False, noun_only=False, 25 | dataset_name=None, log_dir=None): 26 | super().__init__() 27 | 28 | self.num_beams = num_beams 29 | self.max_len = max_len 30 | self.min_len = min_len 31 | self.evaluate = evaluate 32 | 33 | self.report_metric = report_metric 34 | 35 | self.verb_only = verb_only 36 | self.noun_only = noun_only 37 | self.dataset_name = dataset_name 38 | self.log_dir = log_dir 39 | 40 | @classmethod 41 | def setup_task(cls, cfg): 42 | run_cfg = cfg.run_cfg 43 | 44 | num_beams = run_cfg.num_beams 45 | max_len = run_cfg.max_len 46 | min_len = run_cfg.min_len 47 | evaluate = run_cfg.evaluate 48 | log_dir = run_cfg.log_dir 49 | 50 | report_metric = run_cfg.get("report_metric", True) 51 | return cls( 52 | num_beams=num_beams, 53 | max_len=max_len, 54 | min_len=min_len, 55 | evaluate=evaluate, 56 | report_metric=report_metric, 57 | dataset_name=list(cfg.datasets_cfg.keys())[0], 58 | log_dir=log_dir, 59 | ) 60 | 61 | def valid_step(self, model, samples): 62 | results = [] 63 | captions = model.generate( 64 | samples, 65 | use_nucleus_sampling=False, 66 | num_beams=self.num_beams, 67 | max_length=self.max_len, 68 | min_length=self.min_len, 69 | num_captions=self.num_beams, 70 | ) 71 | 72 | img_ids = samples["image_id"] 73 | batch_size = len(img_ids) 74 | for i, img_id in enumerate(img_ids): 75 | caption_list = captions[i * self.num_beams : (i + 1) * self.num_beams] 76 | results.append({"caption": caption_list, "image_id": img_id}) 77 | return results 78 | 79 | def evaluation(self, model, data_loader, cuda_enabled=True): 80 | metric_logger = MetricLogger(delimiter=" ") 81 | header = "Evaluation" 82 | print_freq = 10 83 | 84 | results = [] 85 | for samples in metric_logger.log_every(data_loader, print_freq, header): 86 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 87 | eval_output = self.valid_step(model=model, samples=samples) 88 | results.extend(eval_output) 89 | 90 | if is_dist_avail_and_initialized(): 91 | dist.barrier() 92 | return results 93 | 94 | def after_evaluation(self, val_result, split_name, epoch, dataset, **kwargs): 95 | eval_result_file = self.save_result( 96 | result=val_result, 97 | result_dir=registry.get_path("result_dir"), 98 | filename="{}_epoch{}".format(split_name, epoch), 99 | remove_duplicate="image_id", 100 | ) 101 | 102 | if self.report_metric: 103 | metrics = self._report_metrics_cls( 104 | eval_result_file=eval_result_file, split_name=split_name, dataset=dataset 105 | ) 106 | else: 107 | metrics = {"agg_metrics": 0.0} 108 | 109 | return metrics 110 | 111 | @main_process 112 | def _report_metrics_cls(self, eval_result_file, split_name, dataset): 113 | gt_dict = dataset.annotation 114 | 115 | with open(eval_result_file, 'r') as fp: 116 | prediction_list = json.load(fp) 117 | 118 | dataset_size = len(prediction_list) 119 | 120 | acc_sum = 0 121 | image_id_list = [] 122 | 123 | match_video_list = [] 124 | for prediction in prediction_list: 125 | image_id = prediction['image_id'] 126 | caption_list = prediction['caption'] 127 | image_id_list.append(image_id) 128 | label = gt_dict[image_id]['label'] 129 | 130 | match_video = [1 if caption == label else 0 for caption in caption_list] 131 | match_video_list.append(match_video) 132 | match = np.array(match_video_list) 133 | 134 | top_1 = match[:, :1].max(1).mean() * 100 135 | top_5 = match[:, :5].max(1).mean() * 100 136 | 137 | result = { 138 | 'top1': top_1, 'top5': top_5, 139 | } 140 | 141 | print(f"top1: {top_1:.2f} top5: {top_5:.2f}\n") 142 | result['agg_metrics'] = result['top1'] 143 | return result 144 | 145 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/msvd_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import pdb 11 | import re 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from PIL import Image 17 | from torchvision.transforms.functional import pil_to_tensor 18 | 19 | from lavis.datasets.data_utils import load_video 20 | from lavis.datasets.datasets.video_vqa_datasets import VideoQADataset 21 | 22 | 23 | class MSVDVQADataset(VideoQADataset): 24 | def __init__( 25 | self, 26 | vis_processor, 27 | text_processor, 28 | vis_root, 29 | ann_paths, 30 | num_frames, 31 | prompt="", 32 | split="train", 33 | ): 34 | self.vis_root = vis_root 35 | 36 | self.annotation = {} 37 | for ann_path in ann_paths: 38 | self.annotation.update(json.load(open(ann_path))) 39 | self.question_id_list = list(self.annotation.keys()) 40 | self.question_id_list.sort() 41 | self.fps = 10 42 | 43 | self.num_frames = num_frames 44 | self.vis_processor = vis_processor 45 | self.text_processor = text_processor 46 | self.prompt = prompt 47 | # self._add_instance_ids() 48 | # pdb.set_trace() 49 | 50 | def __getitem__(self, index): 51 | assert ( 52 | self.class_labels 53 | ), f"class_labels of {__class__.__name__} is not built yet." 54 | question_id = self.question_id_list[index] 55 | ann = self.annotation[question_id] 56 | 57 | # Divide the range into num_frames segments and select a random index from each segment 58 | segment_list = np.linspace( 59 | 0, ann["frame_length"], self.num_frames + 1, dtype=int 60 | ) 61 | segment_start_list = segment_list[:-1] 62 | segment_end_list = segment_list[1:] 63 | selected_frame_index = [] 64 | for start, end in zip(segment_start_list, segment_end_list): 65 | if start == end: 66 | selected_frame_index.append(start) 67 | else: 68 | selected_frame_index.append(np.random.randint(start, end)) 69 | 70 | frame_list = [] 71 | for frame_index in selected_frame_index: 72 | frame = Image.open( 73 | os.path.join( 74 | self.vis_root, 75 | ann["video"], 76 | "frame{:06d}.jpg".format(frame_index + 1), 77 | ) 78 | ).convert("RGB") 79 | frame = pil_to_tensor(frame).to(torch.float32) 80 | frame_list.append(frame) 81 | video = torch.stack(frame_list, dim=1) 82 | video = self.vis_processor(video) 83 | # print(selected_frame_index, video.shape) 84 | 85 | question = self.text_processor(ann["question"]) 86 | if len(self.prompt) > 0: 87 | question = self.prompt.format(question) 88 | answer = self.text_processor(ann["answer"]) 89 | 90 | return { 91 | "image": video, 92 | "text_input": question, 93 | "text_output": answer, 94 | "question_id": ann["question_id"], 95 | # "instance_id": ann["instance_id"], 96 | } 97 | 98 | def __len__(self): 99 | return len(self.question_id_list) 100 | 101 | 102 | class MSVDVQAEvalDataset(MSVDVQADataset): 103 | def __init__( 104 | self, 105 | vis_processor, 106 | text_processor, 107 | vis_root, 108 | ann_paths, 109 | num_frames, 110 | prompt, 111 | split="test", 112 | ): 113 | super().__init__( 114 | vis_processor, 115 | text_processor, 116 | vis_root, 117 | ann_paths, 118 | num_frames, 119 | prompt, 120 | split="test", 121 | ) 122 | 123 | def __getitem__(self, index): 124 | assert ( 125 | self.class_labels 126 | ), f"class_labels of {__class__.__name__} is not built yet." 127 | question_id = self.question_id_list[index] 128 | ann = self.annotation[question_id] 129 | 130 | selected_frame_index = ( 131 | np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames)) 132 | .astype(int) 133 | .tolist() 134 | ) 135 | frame_list = [] 136 | for frame_index in selected_frame_index: 137 | frame = Image.open( 138 | os.path.join( 139 | self.vis_root, 140 | ann["video"], 141 | "frame{:06d}.jpg".format(frame_index + 1), 142 | ) 143 | ).convert("RGB") 144 | frame = pil_to_tensor(frame).to(torch.float32) 145 | frame_list.append(frame) 146 | video = torch.stack(frame_list, dim=1) 147 | video = self.vis_processor(video) 148 | 149 | question = self.text_processor(ann["question"]) 150 | if len(self.prompt) > 0: 151 | question = self.prompt.format(question) 152 | answer = self.text_processor(ann["answer"]) 153 | 154 | return { 155 | "image": video, 156 | "text_input": question, 157 | "text_output": answer, 158 | "question_id": ann["question_id"], 159 | # "instance_id": ann["instance_id"], 160 | } 161 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/msrvtt_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import pdb 11 | import re 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from PIL import Image 17 | from torchvision.transforms.functional import pil_to_tensor 18 | 19 | from lavis.datasets.data_utils import load_video 20 | from lavis.datasets.datasets.video_vqa_datasets import VideoQADataset 21 | 22 | 23 | class MSRVTTVQADataset(VideoQADataset): 24 | def __init__( 25 | self, 26 | vis_processor, 27 | text_processor, 28 | vis_root, 29 | ann_paths, 30 | num_frames, 31 | prompt="", 32 | split="train", 33 | ): 34 | self.vis_root = vis_root 35 | 36 | self.annotation = {} 37 | for ann_path in ann_paths: 38 | self.annotation.update(json.load(open(ann_path))) 39 | self.question_id_list = list(self.annotation.keys()) 40 | self.question_id_list.sort() 41 | self.fps = 10 42 | 43 | self.num_frames = num_frames 44 | self.vis_processor = vis_processor 45 | self.text_processor = text_processor 46 | self.prompt = prompt 47 | # self._add_instance_ids() 48 | # pdb.set_trace() 49 | 50 | def __getitem__(self, index): 51 | assert ( 52 | self.class_labels 53 | ), f"class_labels of {__class__.__name__} is not built yet." 54 | question_id = self.question_id_list[index] 55 | ann = self.annotation[question_id] 56 | 57 | # Divide the range into num_frames segments and select a random index from each segment 58 | segment_list = np.linspace( 59 | 0, ann["frame_length"], self.num_frames + 1, dtype=int 60 | ) 61 | segment_start_list = segment_list[:-1] 62 | segment_end_list = segment_list[1:] 63 | selected_frame_index = [] 64 | for start, end in zip(segment_start_list, segment_end_list): 65 | if start == end: 66 | selected_frame_index.append(start) 67 | else: 68 | selected_frame_index.append(np.random.randint(start, end)) 69 | 70 | frame_list = [] 71 | for frame_index in selected_frame_index: 72 | frame = Image.open( 73 | os.path.join( 74 | self.vis_root, 75 | ann["video"], 76 | "frame{:06d}.jpg".format(frame_index + 1), 77 | ) 78 | ).convert("RGB") 79 | frame = pil_to_tensor(frame).to(torch.float32) 80 | frame_list.append(frame) 81 | video = torch.stack(frame_list, dim=1) 82 | video = self.vis_processor(video) 83 | # print(selected_frame_index, video.shape) 84 | 85 | question = self.text_processor(ann["question"]) 86 | if len(self.prompt) > 0: 87 | question = self.prompt.format(question) 88 | answer = self.text_processor(ann["answer"]) 89 | 90 | return { 91 | "image": video, 92 | "text_input": question, 93 | "text_output": answer, 94 | "question_id": ann["question_id"], 95 | # "instance_id": ann["instance_id"], 96 | } 97 | 98 | def __len__(self): 99 | return len(self.question_id_list) 100 | 101 | 102 | class MSRVTTVQAEvalDataset(MSRVTTVQADataset): 103 | def __init__( 104 | self, 105 | vis_processor, 106 | text_processor, 107 | vis_root, 108 | ann_paths, 109 | num_frames, 110 | prompt, 111 | split="test", 112 | ): 113 | super().__init__( 114 | vis_processor, 115 | text_processor, 116 | vis_root, 117 | ann_paths, 118 | num_frames, 119 | prompt, 120 | split="test", 121 | ) 122 | 123 | def __getitem__(self, index): 124 | assert ( 125 | self.class_labels 126 | ), f"class_labels of {__class__.__name__} is not built yet." 127 | question_id = self.question_id_list[index] 128 | ann = self.annotation[question_id] 129 | 130 | selected_frame_index = ( 131 | np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames)) 132 | .astype(int) 133 | .tolist() 134 | ) 135 | frame_list = [] 136 | for frame_index in selected_frame_index: 137 | frame = Image.open( 138 | os.path.join( 139 | self.vis_root, 140 | ann["video"], 141 | "frame{:06d}.jpg".format(frame_index + 1), 142 | ) 143 | ).convert("RGB") 144 | frame = pil_to_tensor(frame).to(torch.float32) 145 | frame_list.append(frame) 146 | video = torch.stack(frame_list, dim=1) 147 | video = self.vis_processor(video) 148 | 149 | question = self.text_processor(ann["question"]) 150 | if len(self.prompt) > 0: 151 | question = self.prompt.format(question) 152 | answer = self.text_processor(ann["answer"]) 153 | 154 | return { 155 | "image": video, 156 | "text_input": question, 157 | "text_output": answer, 158 | "question_id": ann["question_id"], 159 | # "instance_id": ann["instance_id"], 160 | } 161 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/activitynet_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import pdb 11 | import re 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from PIL import Image 17 | from torchvision.transforms.functional import pil_to_tensor 18 | 19 | from lavis.datasets.data_utils import load_video 20 | from lavis.datasets.datasets.video_vqa_datasets import VideoQADataset 21 | 22 | 23 | class ActivityNetVQADataset(VideoQADataset): 24 | def __init__( 25 | self, 26 | vis_processor, 27 | text_processor, 28 | vis_root, 29 | ann_paths, 30 | num_frames, 31 | prompt="", 32 | split="train", 33 | ): 34 | self.vis_root = vis_root 35 | 36 | self.annotation = {} 37 | for ann_path in ann_paths: 38 | self.annotation.update(json.load(open(ann_path))) 39 | self.question_id_list = list(self.annotation.keys()) 40 | self.question_id_list.sort() 41 | self.fps = 10 42 | 43 | self.num_frames = num_frames 44 | self.vis_processor = vis_processor 45 | self.text_processor = text_processor 46 | self.prompt = prompt 47 | # self._add_instance_ids() 48 | # pdb.set_trace() 49 | 50 | def __getitem__(self, index): 51 | # assert ( 52 | # self.class_labels 53 | # ), f"class_labels of {__class__.__name__} is not built yet." 54 | question_id = self.question_id_list[index] 55 | ann = self.annotation[question_id] 56 | 57 | # Divide the range into num_frames segments and select a random index from each segment 58 | segment_list = np.linspace( 59 | 0, ann["frame_length"], self.num_frames + 1, dtype=int 60 | ) 61 | segment_start_list = segment_list[:-1] 62 | segment_end_list = segment_list[1:] 63 | selected_frame_index = [] 64 | for start, end in zip(segment_start_list, segment_end_list): 65 | if start == end: 66 | selected_frame_index.append(start) 67 | else: 68 | selected_frame_index.append(np.random.randint(start, end)) 69 | 70 | frame_list = [] 71 | for frame_index in selected_frame_index: 72 | frame = Image.open( 73 | os.path.join( 74 | self.vis_root, 75 | ann["video"], 76 | "frame{:06d}.jpg".format(frame_index + 1), 77 | ) 78 | ).convert("RGB") 79 | frame = pil_to_tensor(frame).to(torch.float32) 80 | frame_list.append(frame) 81 | video = torch.stack(frame_list, dim=1) 82 | video = self.vis_processor(video) 83 | # print(selected_frame_index, video.shape) 84 | 85 | question = self.text_processor(ann["question"]) 86 | if len(self.prompt) > 0: 87 | question = self.prompt.format(question) 88 | answer = self.text_processor(ann["answer"]) 89 | 90 | return { 91 | "image": video, 92 | "text_input": question, 93 | "text_output": answer, 94 | "question_id": ann["question_id"], 95 | # "instance_id": ann["instance_id"], 96 | } 97 | 98 | def __len__(self): 99 | return len(self.question_id_list) 100 | 101 | 102 | class ActivityNetVQAEvalDataset(ActivityNetVQADataset): 103 | def __init__( 104 | self, 105 | vis_processor, 106 | text_processor, 107 | vis_root, 108 | ann_paths, 109 | num_frames, 110 | prompt, 111 | split="test", 112 | ): 113 | super().__init__( 114 | vis_processor, 115 | text_processor, 116 | vis_root, 117 | ann_paths, 118 | num_frames, 119 | prompt, 120 | split="test", 121 | ) 122 | 123 | def __getitem__(self, index): 124 | # assert ( 125 | # self.class_labels 126 | # ), f"class_labels of {__class__.__name__} is not built yet." 127 | question_id = self.question_id_list[index] 128 | ann = self.annotation[question_id] 129 | 130 | selected_frame_index = ( 131 | np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames)) 132 | .astype(int) 133 | .tolist() 134 | ) 135 | frame_list = [] 136 | for frame_index in selected_frame_index: 137 | frame = Image.open( 138 | os.path.join( 139 | self.vis_root, 140 | ann["video"], 141 | "frame{:06d}.jpg".format(frame_index + 1), 142 | ) 143 | ).convert("RGB") 144 | frame = pil_to_tensor(frame).to(torch.float32) 145 | frame_list.append(frame) 146 | video = torch.stack(frame_list, dim=1) 147 | video = self.vis_processor(video) 148 | 149 | question = self.text_processor(ann["question"]) 150 | if len(self.prompt) > 0: 151 | question = self.prompt.format(question) 152 | answer = self.text_processor(ann["answer"]) 153 | 154 | return { 155 | "image": video, 156 | "text_input": question, 157 | "text_output": answer, 158 | "question_id": ann["question_id"], 159 | # "instance_id": ann["instance_id"], 160 | } 161 | -------------------------------------------------------------------------------- /lavis/datasets/datasets/aok_vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | from collections import OrderedDict 11 | 12 | import torch 13 | from PIL import Image 14 | 15 | from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset 16 | 17 | 18 | class __DisplMixin: 19 | def displ_item(self, index): 20 | sample, ann = self.__getitem__(index), self.annotation[index] 21 | return OrderedDict( 22 | { 23 | "file": ann["image"], 24 | "question": ann["question"], 25 | "question_id": ann["question_id"], 26 | "direct_answers": "; ".join(ann["direct_answers"]), 27 | "choices": "; ".join(ann["choices"]), 28 | "correct_choice": ann["choices"][ann["correct_choice_idx"]], 29 | "image": sample["image"], 30 | } 31 | ) 32 | 33 | 34 | class AOKVQADataset(VQADataset, __DisplMixin): 35 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 36 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 37 | 38 | def __getitem__(self, index): 39 | ann = self.annotation[index] 40 | 41 | image_path = os.path.join(self.vis_root, ann["image"]) 42 | image = Image.open(image_path).convert("RGB") 43 | 44 | image = self.vis_processor(image) 45 | question = self.text_processor(ann["question"]) 46 | 47 | answer_key = "direct_answers" 48 | 49 | answer_weight = {} 50 | for answer in ann[answer_key]: 51 | if answer in answer_weight.keys(): 52 | answer_weight[answer] += 1 / len(ann[answer_key]) 53 | else: 54 | answer_weight[answer] = 1 / len(ann[answer_key]) 55 | 56 | answers = list(answer_weight.keys()) 57 | weights = list(answer_weight.values()) 58 | 59 | return { 60 | "image": image, 61 | "text_input": question, 62 | "answers": answers, 63 | "weights": weights, 64 | } 65 | 66 | 67 | class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin): 68 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 69 | """ 70 | vis_root (string): Root directory of images (e.g. coco/images/) 71 | ann_root (string): directory to store the annotation file 72 | """ 73 | 74 | self.vis_root = vis_root 75 | 76 | self.annotation = json.load(open(ann_paths[0])) 77 | 78 | answer_list_path = ann_paths[1] 79 | if os.path.exists(answer_list_path): 80 | self.answer_list = json.load(open(answer_list_path)) 81 | else: 82 | self.answer_list = None 83 | 84 | try: 85 | self.coco_fmt_qust_file = ann_paths[2] 86 | self.coco_fmt_anno_file = ann_paths[3] 87 | except IndexError: 88 | self.coco_fmt_qust_file = None 89 | self.coco_fmt_anno_file = None 90 | 91 | self.vis_processor = vis_processor 92 | self.text_processor = text_processor 93 | 94 | self._add_instance_ids() 95 | 96 | def collater(self, samples): 97 | ( 98 | image_list, 99 | question_list, 100 | question_id_list, 101 | instance_id_list, 102 | choices_list, 103 | correct_choice_idx_list, 104 | direct_answers_list, 105 | ) = ([], [], [], [], [], [], []) 106 | 107 | for sample in samples: 108 | image_list.append(sample["image"]) 109 | question_list.append(sample["text_input"]) 110 | question_id_list.append(sample["question_id"]) 111 | instance_id_list.append(sample["instance_id"]) 112 | choices_list.append(sample["choices"]) 113 | correct_choice_idx_list.append(sample["correct_choice_idx"]) 114 | direct_answers_list.append(sample["direct_answers"]) 115 | 116 | return { 117 | "image": torch.stack(image_list, dim=0), 118 | "text_input": question_list, 119 | "question_id": question_id_list, 120 | "instance_id": instance_id_list, 121 | "choices": choices_list, 122 | "correct_choice_idx": correct_choice_idx_list, 123 | "direct_answers": direct_answers_list, 124 | } 125 | 126 | def __getitem__(self, index): 127 | ann = self.annotation[index] 128 | 129 | image_path = os.path.join(self.vis_root, ann["image"]) 130 | image = Image.open(image_path).convert("RGB") 131 | 132 | image = self.vis_processor(image) 133 | question = self.text_processor(ann["question"]) 134 | 135 | choices = ann["choices"] 136 | if "correct_choice_idx" in ann: 137 | correct_choice_idx = ann["correct_choice_idx"] 138 | else: 139 | correct_choice_idx = None 140 | 141 | if "direct_answers" in ann: 142 | direct_answers = ann["direct_answers"] 143 | else: 144 | direct_answers = None 145 | 146 | return { 147 | "image": image, 148 | "text_input": question, 149 | "question_id": ann["question_id"], 150 | "instance_id": ann["instance_id"], 151 | "choices": choices, 152 | "correct_choice_idx": correct_choice_idx, 153 | "direct_answers": direct_answers, 154 | } 155 | -------------------------------------------------------------------------------- /lavis/models/clip_models/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | import torch.distributed.nn 11 | from torch import distributed as dist, nn as nn 12 | from torch.nn import functional as F 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | 20 | def gather_features( 21 | image_features, 22 | text_features, 23 | local_loss=False, 24 | gather_with_grad=False, 25 | rank=0, 26 | world_size=1, 27 | use_horovod=False, 28 | ): 29 | if use_horovod: 30 | assert hvd is not None, "Please install horovod" 31 | if gather_with_grad: 32 | all_image_features = hvd.allgather(image_features) 33 | all_text_features = hvd.allgather(text_features) 34 | else: 35 | with torch.no_grad(): 36 | all_image_features = hvd.allgather(image_features) 37 | all_text_features = hvd.allgather(text_features) 38 | if not local_loss: 39 | # ensure grads for local rank when all_* features don't have a gradient 40 | gathered_image_features = list( 41 | all_image_features.chunk(world_size, dim=0) 42 | ) 43 | gathered_text_features = list( 44 | all_text_features.chunk(world_size, dim=0) 45 | ) 46 | gathered_image_features[rank] = image_features 47 | gathered_text_features[rank] = text_features 48 | all_image_features = torch.cat(gathered_image_features, dim=0) 49 | all_text_features = torch.cat(gathered_text_features, dim=0) 50 | else: 51 | # We gather tensors from all gpus 52 | if gather_with_grad: 53 | all_image_features = torch.cat( 54 | torch.distributed.nn.all_gather(image_features), dim=0 55 | ) 56 | all_text_features = torch.cat( 57 | torch.distributed.nn.all_gather(text_features), dim=0 58 | ) 59 | else: 60 | gathered_image_features = [ 61 | torch.zeros_like(image_features) for _ in range(world_size) 62 | ] 63 | gathered_text_features = [ 64 | torch.zeros_like(text_features) for _ in range(world_size) 65 | ] 66 | dist.all_gather(gathered_image_features, image_features) 67 | dist.all_gather(gathered_text_features, text_features) 68 | if not local_loss: 69 | # ensure grads for local rank when all_* features don't have a gradient 70 | gathered_image_features[rank] = image_features 71 | gathered_text_features[rank] = text_features 72 | all_image_features = torch.cat(gathered_image_features, dim=0) 73 | all_text_features = torch.cat(gathered_text_features, dim=0) 74 | 75 | return all_image_features, all_text_features 76 | 77 | 78 | class ClipLoss(nn.Module): 79 | def __init__( 80 | self, 81 | local_loss=False, 82 | gather_with_grad=False, 83 | cache_labels=False, 84 | rank=0, 85 | world_size=1, 86 | use_horovod=False, 87 | ): 88 | super().__init__() 89 | self.local_loss = local_loss 90 | self.gather_with_grad = gather_with_grad 91 | self.cache_labels = cache_labels 92 | self.rank = rank 93 | self.world_size = world_size 94 | self.use_horovod = use_horovod 95 | 96 | # cache state 97 | self.prev_num_logits = 0 98 | self.labels = {} 99 | 100 | def forward(self, image_features, text_features, logit_scale): 101 | device = image_features.device 102 | if self.world_size > 1: 103 | all_image_features, all_text_features = gather_features( 104 | image_features, 105 | text_features, 106 | self.local_loss, 107 | self.gather_with_grad, 108 | self.rank, 109 | self.world_size, 110 | self.use_horovod, 111 | ) 112 | 113 | if self.local_loss: 114 | logits_per_image = logit_scale * image_features @ all_text_features.T 115 | logits_per_text = logit_scale * text_features @ all_image_features.T 116 | else: 117 | logits_per_image = ( 118 | logit_scale * all_image_features @ all_text_features.T 119 | ) 120 | logits_per_text = logits_per_image.T 121 | else: 122 | logits_per_image = logit_scale * image_features @ text_features.T 123 | logits_per_text = logit_scale * text_features @ image_features.T 124 | 125 | # calculated ground-truth and cache if enabled 126 | num_logits = logits_per_image.shape[0] 127 | if self.prev_num_logits != num_logits or device not in self.labels: 128 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 129 | if self.world_size > 1 and self.local_loss: 130 | labels = labels + num_logits * self.rank 131 | if self.cache_labels: 132 | self.labels[device] = labels 133 | self.prev_num_logits = num_logits 134 | else: 135 | labels = self.labels[device] 136 | 137 | total_loss = ( 138 | F.cross_entropy(logits_per_image, labels) 139 | + F.cross_entropy(logits_per_text, labels) 140 | ) / 2 141 | return total_loss 142 | --------------------------------------------------------------------------------