├── minigpt4 ├── common │ ├── __init__.py │ ├── gradcam.py │ ├── optims.py │ ├── dist_utils.py │ ├── logger.py │ ├── registry.py │ ├── utils.py │ └── config.py ├── conversation │ ├── __init__.py │ └── conversation.py ├── datasets │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── laion_dataset.py │ │ ├── cc_combine_dataset.py │ │ ├── base_dataset.py │ │ ├── caption_datasets.py │ │ └── dataloader_utils.py │ ├── builders │ │ ├── __init__.py │ │ ├── image_text_pair_builder.py │ │ └── base_dataset_builder.py │ └── data_utils.py ├── runners │ └── __init__.py ├── configs │ ├── default.yaml │ ├── datasets │ │ ├── cc_combine │ │ │ ├── defaults.yaml │ │ │ └── align.yaml │ │ └── laion │ │ │ └── defaults.yaml │ └── models │ │ └── minigpt4.yaml ├── tasks │ ├── image_text_pretrain.py │ ├── __init__.py │ └── base_task.py ├── processors │ ├── base_processor.py │ ├── __init__.py │ ├── blip_processors.py │ └── randaugment.py ├── __init__.py └── models │ ├── blip2_outputs.py │ ├── __init__.py │ ├── blip2.py │ ├── base_model.py │ └── mini_gpt4.py ├── mycaptions └── put your captions here ├── images └── SAMPLE IMAGE DELETE IF U WANT.png ├── CODEOWNERS ├── requirements.txt ├── MANIFEST.in ├── prompts └── alignment.txt ├── extract.py ├── run.bat ├── eval_configs └── minigpt4.yaml ├── train_configs ├── minigpt4_stage1_laion.yaml └── minigpt4_stage2_align.yaml ├── LICENSE.txt ├── .gitattributes ├── setup.bat ├── app.py ├── README.md └── backup_app.py /minigpt4/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/conversation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mycaptions/put your captions here: -------------------------------------------------------------------------------- 1 | pls pls see your captions must be here -------------------------------------------------------------------------------- /images/SAMPLE IMAGE DELETE IF U WANT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipinstallyp/minigpt4-batch/HEAD/images/SAMPLE IMAGE DELETE IF U WANT.png -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.12.1 3 | torchvision==0.13.1 4 | salesforce-lavis 5 | bitsandbytes 6 | accelerate 7 | git+https://github.com/huggingface/transformers.git -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include minigpt4/configs *.yaml *.json 2 | recursive-include minigpt4/projects *.yaml *.json 3 | 4 | recursive-exclude minigpt4/datasets/download_scripts * 5 | recursive-exclude minigpt4/output * 6 | 7 | include requirements.txt 8 | -------------------------------------------------------------------------------- /prompts/alignment.txt: -------------------------------------------------------------------------------- 1 | Describe this image in detail. 2 | Take a look at this image and describe what you notice. 3 | Please provide a detailed description of the picture. 4 | Could you describe the contents of this image for me? -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | def unzip_file(zip_filepath, dest_path): 4 | # Open the zip file in read mode 5 | with zipfile.ZipFile(zip_filepath, 'r') as zip_ref: 6 | # Extract all the contents of the zip file in the destination directory 7 | zip_ref.extractall(dest_path) 8 | 9 | # Call the function 10 | unzip_file('./models.zip', './') -------------------------------------------------------------------------------- /minigpt4/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.runners.runner_base import RunnerBase 9 | 10 | __all__ = ["RunnerBase"] 11 | -------------------------------------------------------------------------------- /minigpt4/configs/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | env: 7 | # For default users 8 | # cache_root: "cache" 9 | # For internal use with persistent storage 10 | cache_root: "/export/home/.cache/minigpt4" 11 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_combine/defaults.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | cc_combine: 8 | data_type: images 9 | build_info: 10 | # Be careful not to append minus sign (-) before split to avoid itemizing 11 | storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar 12 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/laion/defaults.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | laion: 8 | 9 | data_type: images 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar 14 | -------------------------------------------------------------------------------- /minigpt4/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_combine/align.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | cc_align: 8 | data_type: images 9 | build_info: 10 | # Be careful not to append minus sign (-) before split to avoid itemizing 11 | annotations: 12 | train: 13 | url: placeholder 14 | storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json 15 | images: 16 | storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/ 17 | -------------------------------------------------------------------------------- /minigpt4/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON_VER=3.10.11 4 | 5 | :: Check if Python version meets the recommended version 6 | python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul 7 | if errorlevel 1 ( 8 | echo Warning: Python version %PYTHON_VER% is recommended. 9 | ) 10 | 11 | IF NOT EXIST venv ( 12 | echo Error: Virtual environment venv not found. Please run skeleton.bat first to create the environment and install required packages. 13 | exit /b 1 14 | ) 15 | 16 | :: Deactivate the virtual environment 17 | call .\venv\Scripts\deactivate.bat 18 | 19 | :: Activate the virtual environment 20 | call .\venv\Scripts\activate.bat 21 | 22 | echo Running app.py within the virtual environment... 23 | python app.py --image-folder ./images --beam-search-numbers 2 -------------------------------------------------------------------------------- /eval_configs/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: mini_gpt4 8 | model_type: pretrain_vicuna 9 | freeze_vit: True 10 | freeze_qformer: True 11 | max_txt_len: 160 12 | end_sym: "###" 13 | prompt_path: "prompts/alignment.txt" 14 | prompt_template: '###Human: {} ###Assistant: ' 15 | ckpt: 'checkpoint.pth' 16 | 17 | 18 | datasets: 19 | cc_align: 20 | vis_processor: 21 | train: 22 | name: "blip2_image_eval" 23 | image_size: 224 24 | text_processor: 25 | train: 26 | name: "blip_caption" 27 | 28 | run: 29 | task: image_text_pretrain 30 | 31 | -------------------------------------------------------------------------------- /minigpt4/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask 11 | 12 | 13 | def setup_task(cfg): 14 | assert "task" in cfg.run_cfg, "Task name must be provided." 15 | 16 | task_name = cfg.run_cfg.task 17 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 18 | assert task is not None, "Task {} not properly registered.".format(task_name) 19 | 20 | return task 21 | 22 | 23 | __all__ = [ 24 | "BaseTask", 25 | "ImageTextPretrainTask", 26 | ] 27 | -------------------------------------------------------------------------------- /minigpt4/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /minigpt4/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.processors.base_processor import BaseProcessor 9 | from minigpt4.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | 15 | from minigpt4.common.registry import registry 16 | 17 | __all__ = [ 18 | "BaseProcessor", 19 | "Blip2ImageTrainProcessor", 20 | "Blip2ImageEvalProcessor", 21 | "BlipCaptionProcessor", 22 | ] 23 | 24 | 25 | def load_processor(name, cfg=None): 26 | """ 27 | Example 28 | 29 | >>> processor = load_processor("alpro_video_train", cfg=None) 30 | """ 31 | processor = registry.get_processor_class(name).from_config(cfg) 32 | 33 | return processor 34 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: mini_gpt4 8 | 9 | # vit encoder 10 | image_size: 224 11 | drop_path_rate: 0 12 | use_grad_checkpoint: False 13 | vit_precision: "fp16" 14 | freeze_vit: True 15 | freeze_qformer: True 16 | 17 | # Q-Former 18 | num_query_token: 32 19 | 20 | # Vicuna 21 | llama_model: "vicuna" 22 | 23 | # generation configs 24 | prompt: "" 25 | 26 | 27 | preprocess: 28 | vis_processor: 29 | train: 30 | name: "blip2_image_train" 31 | image_size: 224 32 | eval: 33 | name: "blip2_image_eval" 34 | image_size: 224 35 | text_processor: 36 | train: 37 | name: "blip_caption" 38 | eval: 39 | name: "blip_caption" 40 | -------------------------------------------------------------------------------- /minigpt4/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from minigpt4.common.registry import registry 14 | 15 | from minigpt4.datasets.builders import * 16 | from minigpt4.models import * 17 | from minigpt4.processors import * 18 | from minigpt4.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 10 | 11 | 12 | class LaionDataset(BaseDataset): 13 | def __init__(self, vis_processor, text_processor, location): 14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 15 | 16 | self.inner_dataset = wds.DataPipeline( 17 | wds.ResampledShards(location), 18 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 19 | wds.shuffle(1000, handler=wds.warn_and_continue), 20 | wds.decode("pilrgb", handler=wds.warn_and_continue), 21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 23 | wds.map(self.to_dict, handler=wds.warn_and_continue), 24 | ) 25 | 26 | def to_dict(self, sample): 27 | return { 28 | "image": sample[0], 29 | "text_input": self.text_processor(sample[1]["caption"]), 30 | } 31 | 32 | -------------------------------------------------------------------------------- /train_configs/minigpt4_stage1_laion.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: mini_gpt4 8 | model_type: pretrain_vicuna 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | 13 | datasets: 14 | laion: 15 | vis_processor: 16 | train: 17 | name: "blip2_image_train" 18 | image_size: 224 19 | text_processor: 20 | train: 21 | name: "blip_caption" 22 | sample_ratio: 115 23 | cc_combine: 24 | vis_processor: 25 | train: 26 | name: "blip2_image_train" 27 | image_size: 224 28 | text_processor: 29 | train: 30 | name: "blip_caption" 31 | sample_ratio: 14 32 | 33 | 34 | run: 35 | task: image_text_pretrain 36 | # optimizer 37 | lr_sched: "linear_warmup_cosine_lr" 38 | init_lr: 1e-4 39 | min_lr: 3e-5 40 | warmup_lr: 1e-6 41 | 42 | weight_decay: 0.05 43 | max_epoch: 4 44 | batch_size_train: 64 45 | batch_size_eval: 64 46 | num_workers: 4 47 | warmup_steps: 5000 48 | iters_per_epoch: 5000 49 | 50 | seed: 42 51 | output_dir: "/path/to/save/your/model/" 52 | 53 | amp: True 54 | resume_ckpt_path: null 55 | 56 | evaluate: False 57 | train_splits: ["train"] 58 | 59 | device: "cuda" 60 | world_size: 1 61 | dist_url: "env://" 62 | distributed: True -------------------------------------------------------------------------------- /train_configs/minigpt4_stage2_align.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | model: 7 | arch: mini_gpt4 8 | model_type: pretrain_vicuna 9 | freeze_vit: True 10 | freeze_qformer: True 11 | max_txt_len: 160 12 | end_sym: "###" 13 | prompt_path: "prompts/alignment.txt" 14 | prompt_template: '###Human: {} ###Assistant: ' 15 | ckpt: '/ibex/project/c2133/vicuna_jun_checkpoint_wihtout_prompt/20230412162/checkpoint_3.pth' 16 | 17 | 18 | datasets: 19 | cc_align: 20 | vis_processor: 21 | train: 22 | name: "blip2_image_train" 23 | image_size: 224 24 | text_processor: 25 | train: 26 | name: "blip_caption" 27 | 28 | run: 29 | task: image_text_pretrain 30 | # optimizer 31 | lr_sched: "linear_warmup_cosine_lr" 32 | init_lr: 3e-5 33 | min_lr: 1e-5 34 | warmup_lr: 1e-6 35 | 36 | weight_decay: 0.05 37 | max_epoch: 5 38 | iters_per_epoch: 200 39 | batch_size_train: 12 40 | batch_size_eval: 12 41 | num_workers: 4 42 | warmup_steps: 200 43 | 44 | seed: 42 45 | output_dir: "/ibex/project/c2133/vicuna_ckpt_test/minigpt4_stage2_align" 46 | 47 | amp: True 48 | resume_ckpt_path: null 49 | 50 | evaluate: False 51 | train_splits: ["train"] 52 | 53 | device: "cuda" 54 | world_size: 1 55 | dist_url: "env://" 56 | distributed: True -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022 Salesforce, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | MiniGPT_4.pdf filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/cc_combine_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import os 8 | from PIL import Image 9 | import webdataset as wds 10 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 11 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 12 | 13 | 14 | class CCCombineDataset(BaseDataset): 15 | def __init__(self, vis_processor, text_processor, location): 16 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 17 | 18 | self.inner_dataset = wds.DataPipeline( 19 | wds.ResampledShards(location), 20 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 21 | wds.shuffle(1000, handler=wds.warn_and_continue), 22 | wds.decode("pilrgb", handler=wds.warn_and_continue), 23 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 24 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 25 | wds.map(self.to_dict, handler=wds.warn_and_continue), 26 | ) 27 | 28 | def to_dict(self, sample): 29 | return { 30 | "image": sample[0], 31 | "text_input": self.text_processor(sample[1]["caption"]), 32 | } 33 | 34 | 35 | class CCAlignDataset(CaptionDataset): 36 | 37 | def __getitem__(self, index): 38 | 39 | # TODO this assumes image input, not general enough 40 | ann = self.annotation[index] 41 | 42 | img_file = '{}.jpg'.format(ann["image_id"]) 43 | image_path = os.path.join(self.vis_root, img_file) 44 | image = Image.open(image_path).convert("RGB") 45 | 46 | image = self.vis_processor(image) 47 | caption = ann["caption"] 48 | 49 | return { 50 | "image": image, 51 | "text_input": caption, 52 | "image_id": self.img_ids[ann["image_id"]], 53 | } -------------------------------------------------------------------------------- /setup.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON_VER=3.10.11 4 | 5 | :: Check if Python version meets the recommended version 6 | python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul 7 | if errorlevel 1 ( 8 | echo Warning: Python version %PYTHON_VER% is recommended. 9 | ) 10 | 11 | IF NOT EXIST venv ( 12 | echo Creating venv... 13 | python -m venv venv 14 | ) 15 | 16 | :: Deactivate the virtual environment 17 | call .\venv\Scripts\deactivate.bat 18 | 19 | :: Activate the virtual environment 20 | call .\venv\Scripts\activate.bat 21 | 22 | echo Installing Torch and torchvision... 23 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 24 | 25 | echo Cloning and installing bitsandbytes-windows... 26 | git clone https://github.com/Keith-Hon/bitsandbytes-windows.git 27 | cd bitsandbytes-windows 28 | pip3 install -e . 29 | cd .. 30 | 31 | echo Downloading pretrained models... 32 | curl -L -o ./checkpoint.pth https://huggingface.co/ckpt/minigpt4-7B/resolve/main/prerained_minigpt4_7b.pth 33 | curl -L -o ./blip2_pretrained_flant5xxl.pth https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth 34 | curl -L -o ./models.zip https://huggingface.co/pipyp/minigpt4py/resolve/main/models.zip 35 | python extract.py 36 | 37 | echo Installing cmake, lit, salesforce-lavis, accelerate, and transformers... 38 | pip install cmake 39 | pip install lit 40 | pip install -q salesforce-lavis 41 | pip install -q accelerate 42 | pip install -q git+https://github.com/huggingface/transformers.git -U 43 | 44 | :: Adding the extra required libraries... 45 | echo Installing argparse, csv, os, random, glob, time, numpy, Pillow, cv2, tqdm, tensorflow, huggingface-hub, pathlib, copy, and keras... 46 | pip install argparse 47 | pip install csv 48 | pip install os 49 | pip install random 50 | pip install glob 51 | pip install time 52 | pip install numpy 53 | pip install Pillow 54 | pip install opencv-python-headless 55 | pip install tqdm 56 | pip install tensorflow 57 | pip install huggingface-hub 58 | pip install pathlib 59 | pip install copy 60 | pip install keras 61 | 62 | echo Setup complete within virtual environment! -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config 9 | from minigpt4.datasets.builders.image_text_pair_builder import ( 10 | CCCombineBuilder, 11 | LaionBuilder, 12 | CCAlignBuilder 13 | ) 14 | from minigpt4.common.registry import registry 15 | 16 | __all__ = [ 17 | "CCCombineBuilder", 18 | "LaionBuilder", 19 | "CCAlignBuilder" 20 | ] 21 | 22 | 23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 24 | """ 25 | Example 26 | 27 | >>> dataset = load_dataset("coco_caption", cfg=None) 28 | >>> splits = dataset.keys() 29 | >>> print([len(dataset[split]) for split in splits]) 30 | 31 | """ 32 | if cfg_path is None: 33 | cfg = None 34 | else: 35 | cfg = load_dataset_config(cfg_path) 36 | 37 | try: 38 | builder = registry.get_builder_class(name)(cfg) 39 | except TypeError: 40 | print( 41 | f"Dataset {name} not found. Available datasets:\n" 42 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 43 | ) 44 | exit(1) 45 | 46 | if vis_path is not None: 47 | if data_type is None: 48 | # use default data type in the config 49 | data_type = builder.config.data_type 50 | 51 | assert ( 52 | data_type in builder.config.build_info 53 | ), f"Invalid data_type {data_type} for {name}." 54 | 55 | builder.config.build_info.get(data_type).storage = vis_path 56 | 57 | dataset = builder.build_datasets() 58 | return dataset 59 | 60 | 61 | class DatasetZoo: 62 | def __init__(self) -> None: 63 | self.dataset_zoo = { 64 | k: list(v.DATASET_CONFIG_DICT.keys()) 65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 66 | } 67 | 68 | def get_names(self): 69 | return list(self.dataset_zoo.keys()) 70 | 71 | 72 | dataset_zoo = DatasetZoo() 73 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder 12 | from minigpt4.datasets.datasets.laion_dataset import LaionDataset 13 | from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset 14 | 15 | 16 | @registry.register_builder("cc_combine") 17 | class CCCombineBuilder(BaseDatasetBuilder): 18 | train_dataset_cls = CCCombineDataset 19 | 20 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"} 21 | 22 | def _download_ann(self): 23 | pass 24 | 25 | def _download_vis(self): 26 | pass 27 | 28 | def build(self): 29 | self.build_processors() 30 | 31 | build_info = self.config.build_info 32 | 33 | datasets = dict() 34 | split = "train" 35 | 36 | # create datasets 37 | # [NOTE] return inner_datasets (wds.DataPipeline) 38 | dataset_cls = self.train_dataset_cls 39 | datasets[split] = dataset_cls( 40 | vis_processor=self.vis_processors[split], 41 | text_processor=self.text_processors[split], 42 | location=build_info.storage, 43 | ).inner_dataset 44 | 45 | return datasets 46 | 47 | 48 | @registry.register_builder("laion") 49 | class LaionBuilder(BaseDatasetBuilder): 50 | train_dataset_cls = LaionDataset 51 | 52 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} 53 | 54 | def _download_ann(self): 55 | pass 56 | 57 | def _download_vis(self): 58 | pass 59 | 60 | def build(self): 61 | self.build_processors() 62 | 63 | build_info = self.config.build_info 64 | 65 | datasets = dict() 66 | split = "train" 67 | 68 | # create datasets 69 | # [NOTE] return inner_datasets (wds.DataPipeline) 70 | dataset_cls = self.train_dataset_cls 71 | datasets[split] = dataset_cls( 72 | vis_processor=self.vis_processors[split], 73 | text_processor=self.text_processors[split], 74 | location=build_info.storage, 75 | ).inner_dataset 76 | 77 | return datasets 78 | 79 | 80 | @registry.register_builder("cc_align") 81 | class CCAlignBuilder(BaseDatasetBuilder): 82 | train_dataset_cls = CCAlignDataset 83 | 84 | DATASET_CONFIG_DICT = { 85 | "default": "configs/datasets/cc_combine/align.yaml", 86 | } -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 12 | from PIL import Image 13 | 14 | 15 | class __DisplMixin: 16 | def displ_item(self, index): 17 | sample, ann = self.__getitem__(index), self.annotation[index] 18 | 19 | return OrderedDict( 20 | { 21 | "file": ann["image"], 22 | "caption": ann["caption"], 23 | "image": sample["image"], 24 | } 25 | ) 26 | 27 | 28 | class CaptionDataset(BaseDataset, __DisplMixin): 29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 30 | """ 31 | vis_root (string): Root directory of images (e.g. coco/images/) 32 | ann_root (string): directory to store the annotation file 33 | """ 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | self.img_ids = {} 37 | n = 0 38 | for ann in self.annotation: 39 | img_id = ann["image_id"] 40 | if img_id not in self.img_ids.keys(): 41 | self.img_ids[img_id] = n 42 | n += 1 43 | 44 | def __getitem__(self, index): 45 | 46 | # TODO this assumes image input, not general enough 47 | ann = self.annotation[index] 48 | 49 | img_file = '{:0>12}.jpg'.format(ann["image_id"]) 50 | image_path = os.path.join(self.vis_root, img_file) 51 | image = Image.open(image_path).convert("RGB") 52 | 53 | image = self.vis_processor(image) 54 | caption = self.text_processor(ann["caption"]) 55 | 56 | return { 57 | "image": image, 58 | "text_input": caption, 59 | "image_id": self.img_ids[ann["image_id"]], 60 | } 61 | 62 | 63 | class CaptionEvalDataset(BaseDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | split (string): val or test 69 | """ 70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.vis_root, ann["image"]) 77 | image = Image.open(image_path).convert("RGB") 78 | 79 | image = self.vis_processor(image) 80 | 81 | return { 82 | "image": image, 83 | "image_id": ann["image_id"], 84 | "instance_id": ann["instance_id"], 85 | } 86 | -------------------------------------------------------------------------------- /minigpt4/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from minigpt4.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | iters_per_epoch, 63 | min_lr, 64 | init_lr, 65 | warmup_steps=0, 66 | warmup_start_lr=-1, 67 | **kwargs 68 | ): 69 | self.optimizer = optimizer 70 | 71 | self.max_epoch = max_epoch 72 | self.iters_per_epoch = iters_per_epoch 73 | self.min_lr = min_lr 74 | 75 | self.init_lr = init_lr 76 | self.warmup_steps = warmup_steps 77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 78 | 79 | def step(self, cur_epoch, cur_step): 80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 81 | if total_cur_step < self.warmup_steps: 82 | warmup_lr_schedule( 83 | step=cur_step, 84 | optimizer=self.optimizer, 85 | max_step=self.warmup_steps, 86 | init_lr=self.warmup_start_lr, 87 | max_lr=self.init_lr, 88 | ) 89 | else: 90 | cosine_lr_schedule( 91 | epoch=total_cur_step, 92 | optimizer=self.optimizer, 93 | max_epoch=self.max_epoch * self.iters_per_epoch, 94 | init_lr=self.init_lr, 95 | min_lr=self.min_lr, 96 | ) 97 | 98 | 99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 100 | """Decay the learning rate""" 101 | lr = (init_lr - min_lr) * 0.5 * ( 102 | 1.0 + math.cos(math.pi * epoch / max_epoch) 103 | ) + min_lr 104 | for param_group in optimizer.param_groups: 105 | param_group["lr"] = lr 106 | 107 | 108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 109 | """Warmup the learning rate""" 110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 111 | for param_group in optimizer.param_groups: 112 | param_group["lr"] = lr 113 | 114 | 115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 116 | """Decay the learning rate""" 117 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = lr 120 | -------------------------------------------------------------------------------- /minigpt4/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /minigpt4/models/blip2_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | import torch 12 | from transformers.modeling_outputs import ( 13 | ModelOutput, 14 | BaseModelOutputWithPoolingAndCrossAttentions, 15 | CausalLMOutputWithCrossAttentions, 16 | ) 17 | 18 | 19 | @dataclass 20 | class BlipSimilarity(ModelOutput): 21 | sim_i2t: torch.FloatTensor = None 22 | sim_t2i: torch.FloatTensor = None 23 | 24 | sim_i2t_m: Optional[torch.FloatTensor] = None 25 | sim_t2i_m: Optional[torch.FloatTensor] = None 26 | 27 | sim_i2t_targets: Optional[torch.FloatTensor] = None 28 | sim_t2i_targets: Optional[torch.FloatTensor] = None 29 | 30 | 31 | @dataclass 32 | class BlipIntermediateOutput(ModelOutput): 33 | """ 34 | Data class for intermediate outputs of BLIP models. 35 | 36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 38 | 39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 41 | 42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 44 | 45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 46 | decoder_labels (torch.LongTensor): labels for the captioning loss. 47 | 48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 50 | 51 | """ 52 | 53 | # uni-modal features 54 | image_embeds: torch.FloatTensor = None 55 | text_embeds: Optional[torch.FloatTensor] = None 56 | 57 | image_embeds_m: Optional[torch.FloatTensor] = None 58 | text_embeds_m: Optional[torch.FloatTensor] = None 59 | 60 | # intermediate outputs of multimodal encoder 61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | 64 | itm_logits: Optional[torch.FloatTensor] = None 65 | itm_labels: Optional[torch.LongTensor] = None 66 | 67 | # intermediate outputs of multimodal decoder 68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 69 | decoder_labels: Optional[torch.LongTensor] = None 70 | 71 | 72 | @dataclass 73 | class BlipOutput(ModelOutput): 74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 75 | sims: Optional[BlipSimilarity] = None 76 | 77 | intermediate_output: BlipIntermediateOutput = None 78 | 79 | loss: Optional[torch.FloatTensor] = None 80 | 81 | loss_itc: Optional[torch.FloatTensor] = None 82 | 83 | loss_itm: Optional[torch.FloatTensor] = None 84 | 85 | loss_lm: Optional[torch.FloatTensor] = None 86 | 87 | 88 | @dataclass 89 | class BlipOutputFeatures(ModelOutput): 90 | """ 91 | Data class of features from BlipFeatureExtractor. 92 | 93 | Args: 94 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 95 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 96 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 97 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 98 | 99 | The first embedding or feature is for the [CLS] token. 100 | 101 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 102 | """ 103 | 104 | image_embeds: Optional[torch.FloatTensor] = None 105 | image_embeds_proj: Optional[torch.FloatTensor] = None 106 | 107 | text_embeds: Optional[torch.FloatTensor] = None 108 | text_embeds_proj: Optional[torch.FloatTensor] = None 109 | 110 | multimodal_embeds: Optional[torch.FloatTensor] = None 111 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import glob 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from PIL import Image 11 | 12 | from minigpt4.common.config import Config 13 | from minigpt4.common.dist_utils import get_rank 14 | from minigpt4.common.registry import registry 15 | from minigpt4.conversation.conversation import Chat, CONV_VISION 16 | 17 | # imports modules for registration 18 | from minigpt4.datasets.builders import * 19 | from minigpt4.models import * 20 | from minigpt4.processors import * 21 | from minigpt4.runners import * 22 | from minigpt4.tasks import * 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="Demo") 27 | parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.") 28 | parser.add_argument( 29 | "--options", 30 | nargs="+", 31 | help="override some settings in the used config, the key-value pair " 32 | "in xxx=yyy format will be merged into config file (deprecate), " 33 | "change to --cfg-options instead.", 34 | ) 35 | parser.add_argument("--image-folder", type=str, required=True, help="path to the input image folder") 36 | parser.add_argument("--beam-search-numbers", type=int, default=1, help="beam search numbers") 37 | parser.add_argument("--model", type=str, default='llama', help="Model to be used for generation. Options: 'llama' (default), 'llama7b'") 38 | parser.add_argument("--save-in-imgfolder", action="store_true", help="save captions in the input image folder") 39 | options = parser.parse_args() 40 | return options 41 | 42 | 43 | def setup_seeds(config): 44 | seed = config.run_cfg.seed + get_rank() 45 | 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | 50 | cudnn.benchmark = False 51 | cudnn.deterministic = True 52 | 53 | 54 | def describe_image(image_path, chat, chat_state, img, num_beams=1, temperature=1.0): 55 | chat_state = CONV_VISION.copy() 56 | img_list = [] 57 | 58 | gr_img = Image.open(image_path) 59 | llm_message = chat.upload_img(gr_img, chat_state, img_list) 60 | 61 | chat.ask("Describe this image.", chat_state) 62 | generated_caption = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0] 63 | 64 | return generated_caption 65 | 66 | 67 | if __name__ == '__main__': 68 | args = parse_args() 69 | 70 | cfg = Config(args) 71 | 72 | model_config = cfg.model_cfg 73 | if args.model == "llama7b": 74 | model_config.llama_model = "camenduru/MiniGPT4-7B" 75 | 76 | model_cls = registry.get_model_class(model_config.arch) 77 | model = model_cls.from_config(model_config).to('cuda:0') 78 | 79 | vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train 80 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 81 | chat = Chat(model, vis_processor) 82 | 83 | chat_state = CONV_VISION.copy() 84 | img_list = [] 85 | 86 | image_folder = args.image_folder 87 | num_beams = args.beam_search_numbers 88 | temperature = 1.0 # default temperature 89 | 90 | image_extensions = ['jpg', 'jpeg', 'png', 'bmp', "webp"] 91 | image_paths = [] 92 | 93 | for ext in image_extensions: 94 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext}'))) 95 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext.upper()}'))) 96 | 97 | if not args.save_in_imgfolder: 98 | if not os.path.exists("mycaptions"): 99 | os.makedirs("mycaptions") 100 | 101 | for image_path in image_paths: 102 | start_time = time.time() 103 | caption = describe_image(image_path, chat, chat_state, img_list, num_beams, temperature) 104 | 105 | if args.save_in_imgfolder: 106 | output_path = os.path.join(image_folder, "{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0])) 107 | else: 108 | output_path = "mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]) 109 | 110 | with open(output_path, "w") as f: 111 | f.write(caption) 112 | 113 | end_time = time.time() 114 | time_taken = end_time - start_time 115 | print(f"Caption for {os.path.basename(image_path)} saved in '{output_path}'") 116 | print(f"Time taken to process caption for {os.path.basename(image_path)} is: {time_taken:.2f} s") -------------------------------------------------------------------------------- /minigpt4/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.processors.base_processor import BaseProcessor 12 | from minigpt4.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("blip_caption") 29 | class BlipCaptionProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg=None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption 69 | 70 | 71 | @registry.register_processor("blip2_image_train") 72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 74 | super().__init__(mean=mean, std=std) 75 | 76 | self.transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop( 79 | image_size, 80 | scale=(min_scale, max_scale), 81 | interpolation=InterpolationMode.BICUBIC, 82 | ), 83 | transforms.ToTensor(), 84 | self.normalize, 85 | ] 86 | ) 87 | 88 | def __call__(self, item): 89 | return self.transform(item) 90 | 91 | @classmethod 92 | def from_config(cls, cfg=None): 93 | if cfg is None: 94 | cfg = OmegaConf.create() 95 | 96 | image_size = cfg.get("image_size", 224) 97 | 98 | mean = cfg.get("mean", None) 99 | std = cfg.get("std", None) 100 | 101 | min_scale = cfg.get("min_scale", 0.5) 102 | max_scale = cfg.get("max_scale", 1.0) 103 | 104 | return cls( 105 | image_size=image_size, 106 | mean=mean, 107 | std=std, 108 | min_scale=min_scale, 109 | max_scale=max_scale, 110 | ) 111 | 112 | 113 | @registry.register_processor("blip2_image_eval") 114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor): 115 | def __init__(self, image_size=224, mean=None, std=None): 116 | super().__init__(mean=mean, std=std) 117 | 118 | self.transform = transforms.Compose( 119 | [ 120 | transforms.Resize( 121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 122 | ), 123 | transforms.ToTensor(), 124 | self.normalize, 125 | ] 126 | ) 127 | 128 | def __call__(self, item): 129 | return self.transform(item) 130 | 131 | @classmethod 132 | def from_config(cls, cfg=None): 133 | if cfg is None: 134 | cfg = OmegaConf.create() 135 | 136 | image_size = cfg.get("image_size", 224) 137 | 138 | mean = cfg.get("mean", None) 139 | std = cfg.get("std", None) 140 | 141 | return cls(image_size=image_size, mean=mean, std=std) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Welcome to the MiniGPT-4 Batch repo! This repository provides an implementation of MiniGPT-4 to mass caption Stable Diffusion images. It utilizes llama weights that are downloaded automatically if not already present. Please note that this implementation currently works only on Linux systems and runs only on high end machines (not the free colab). 2 | 3 | ## Windows Installation Instructions 4 | 5 | To install and run MiniGPT-4 Batch on Windows, please follow these steps: 6 | 7 | 1. Run the `Setup.bat` script: 8 | 9 | ``` 10 | Setup.bat 11 | ``` 12 | 13 | 2. Check your `/images` and `/mycaptions` folders. In the `/images` folder, one sample image is provided; feel free to delete it. 14 | 15 | 3. If you're just testing it out, simply execute the `run.bat` script. 16 | 17 | **OR** 18 | 19 | 3. If you want to run the script manually, you need to: 20 | 21 | a. Activate the virtual environment: 22 | 23 | ``` 24 | .\venv\Scripts\activate.bat 25 | ``` 26 | 27 | b. Run the `app.py` script with the desired options: 28 | 29 | ``` 30 | python app.py --image-folder ./images --beam-search-numbers 2 31 | ``` 32 | 33 | NEW: We're testing to combine WD tags with minigpt4-batch. If you want to include WD tags along with minigpt4 captions, consider running backup_app.py. In `backup_app.py` WD tagging is mandatory, working to make that optional! 34 | 35 | ``` 36 | python backup_app.py --image-folder ./images --beam-search-numbers 2 --model-dir models/wd14_tagger --undesired-tags '1girl,1boy,solo' 37 | ``` 38 | 39 | Now you're all set to use MiniGPT-4 Batch on Windows! 40 | 41 | ## Getting Started (LINUX) 42 | 43 | If you're installing MiniGPT-4 Batch for the first time, please follow these steps: 44 | 45 | 1. Clone the GitHub repository: 46 | 47 | ```git 48 | git clone https://github.com/pipinstallyp/minigpt4-batch 49 | ``` 50 | Change directory to minigp4-batch 51 | 52 | ``` 53 | cd minigpt4-batch 54 | ``` 55 | 2. Download the necessary files: 56 | 57 | ``` 58 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/minigpt4.pth -O ./checkpoint.pth 59 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth -O ./blip2_pretrained_flant5xxl.pth 60 | ``` 61 | 62 | For 7b, then just use this: 63 | ``` 64 | wget https://huggingface.co/ckpt/minigpt4-7B/resolve/main/prerained_minigpt4_7b.pth -O ./checkpoint.pth 65 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth -O ./blip2_pretrained_flant5xxl.pth 66 | ``` 67 | 68 | To get this right you'd need to replace ./minigpt4/checkpoint.pth with directory your minigpt4 directory + checkpoint.pth, for example. 69 | 70 | 3. Install the required packages: 71 | 72 | ``` 73 | pip install cmake 74 | pip install lit 75 | pip install -q salesforce-lavis 76 | pip install -q bitsandbytes 77 | pip install -q accelerate 78 | pip install -q git+https://github.com/huggingface/transformers.git -U 79 | ``` 80 | 81 | 5. Now, you can run the script: 82 | 83 | ``` 84 | python app.py --image-folder path_to_image_folder --beam-search-numbers value 85 | ``` 86 | 87 | If you want to test llama 7b then use this: 88 | 89 | ``` 90 | python app.py --image-folder path_to_image_folder --beam-search-numbers 2 --model llama7b 91 | ``` 92 | 93 | In your repository directory you can make two folders namely 94 | ``` 95 | images 96 | mycaptions 97 | ``` 98 | 99 | in this case your path_to_image_folder = images 100 | ## Features 101 | 1. Shows timestamp to process each caption 102 | 2. Use --save-in-imgfolder to save captions in your images folder instead. 103 | 3. One click setup (setup.bat) for windows. 104 | 105 | ## To-Do List 106 | 107 | - [x] ~~Make it work on Windows~~ 108 | - [ ] Implement for MiniGPT-4 7B 109 | - [ ] Include inputs from Segment Anything 110 | - [ ] DOCKER SUPPORT COMING TO YAYYYY 111 | 112 | 113 | ## Acknowledgment 114 | 115 | A huge thank you to [Camenduru](https://github.com/camenduru) for developing the awesome MiniGPT-4 Colab, which has served as the foundation for most of this work. Huge thanks to [rafraf](https://www.instagram.com/rafstahelin/) for making the features what they are. This project is primarily aimed at helping people train Stable Diffusion models to mass caption their images. 116 | 117 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference 118 | 119 | Check out https://github.com/gessyoo/minigpt4-batch-tweaked fork with implemented changes which removes trivial words like - "The image shows" and "The image is," etc. and the _caption extension from the text captions. 120 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from minigpt4.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class MultiIterLoader: 16 | """ 17 | A simple wrapper for iterating over multiple iterators. 18 | 19 | Args: 20 | loaders (List[Loader]): List of Iterator loaders. 21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | """ 23 | 24 | def __init__(self, loaders, ratios=None): 25 | # assert all loaders has __next__ method 26 | for loader in loaders: 27 | assert hasattr( 28 | loader, "__next__" 29 | ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | if ratios is None: 32 | ratios = [1.0] * len(loaders) 33 | else: 34 | assert len(ratios) == len(loaders) 35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | self.loaders = loaders 38 | self.ratios = ratios 39 | 40 | def __next__(self): 41 | # random sample from each loader by ratio 42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | return next(self.loaders[loader_idx]) 44 | 45 | 46 | class PrefetchLoader(object): 47 | """ 48 | Modified from https://github.com/ChenRocks/UNITER. 49 | 50 | overlap compute and cuda data transfer 51 | (copied and then modified from nvidia apex) 52 | """ 53 | 54 | def __init__(self, loader): 55 | self.loader = loader 56 | self.stream = torch.cuda.Stream() 57 | 58 | def __iter__(self): 59 | loader_it = iter(self.loader) 60 | self.preload(loader_it) 61 | batch = self.next(loader_it) 62 | while batch is not None: 63 | is_tuple = isinstance(batch, tuple) 64 | if is_tuple: 65 | task, batch = batch 66 | 67 | if is_tuple: 68 | yield task, batch 69 | else: 70 | yield batch 71 | batch = self.next(loader_it) 72 | 73 | def __len__(self): 74 | return len(self.loader) 75 | 76 | def preload(self, it): 77 | try: 78 | self.batch = next(it) 79 | except StopIteration: 80 | self.batch = None 81 | return 82 | # if record_stream() doesn't work, another option is to make sure 83 | # device inputs are created on the main stream. 84 | # self.next_input_gpu = torch.empty_like(self.next_input, 85 | # device='cuda') 86 | # self.next_target_gpu = torch.empty_like(self.next_target, 87 | # device='cuda') 88 | # Need to make sure the memory allocated for next_* is not still in use 89 | # by the main stream at the time we start copying to next_*: 90 | # self.stream.wait_stream(torch.cuda.current_stream()) 91 | with torch.cuda.stream(self.stream): 92 | self.batch = move_to_cuda(self.batch) 93 | # more code for the alternative if record_stream() doesn't work: 94 | # copy_ will record the use of the pinned source tensor in this 95 | # side stream. 96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 98 | # self.next_input = self.next_input_gpu 99 | # self.next_target = self.next_target_gpu 100 | 101 | def next(self, it): 102 | torch.cuda.current_stream().wait_stream(self.stream) 103 | batch = self.batch 104 | if batch is not None: 105 | record_cuda_stream(batch) 106 | self.preload(it) 107 | return batch 108 | 109 | def __getattr__(self, name): 110 | method = self.loader.__getattribute__(name) 111 | return method 112 | 113 | 114 | def record_cuda_stream(batch): 115 | if isinstance(batch, torch.Tensor): 116 | batch.record_stream(torch.cuda.current_stream()) 117 | elif isinstance(batch, list) or isinstance(batch, tuple): 118 | for t in batch: 119 | record_cuda_stream(t) 120 | elif isinstance(batch, dict): 121 | for t in batch.values(): 122 | record_cuda_stream(t) 123 | else: 124 | pass 125 | 126 | 127 | class IterLoader: 128 | """ 129 | A wrapper to convert DataLoader as an infinite iterator. 130 | 131 | Modified from: 132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 133 | """ 134 | 135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 136 | self._dataloader = dataloader 137 | self.iter_loader = iter(self._dataloader) 138 | self._use_distributed = use_distributed 139 | self._epoch = 0 140 | 141 | @property 142 | def epoch(self) -> int: 143 | return self._epoch 144 | 145 | def __next__(self): 146 | try: 147 | data = next(self.iter_loader) 148 | except StopIteration: 149 | self._epoch += 1 150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 151 | self._dataloader.sampler.set_epoch(self._epoch) 152 | time.sleep(2) # Prevent possible deadlock during epoch transition 153 | self.iter_loader = iter(self._dataloader) 154 | data = next(self.iter_loader) 155 | 156 | return data 157 | 158 | def __iter__(self): 159 | return self 160 | 161 | def __len__(self): 162 | return len(self._dataloader) 163 | -------------------------------------------------------------------------------- /minigpt4/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | from omegaconf import OmegaConf 11 | 12 | from minigpt4.common.registry import registry 13 | from minigpt4.models.base_model import BaseModel 14 | from minigpt4.models.blip2 import Blip2Base 15 | from minigpt4.models.mini_gpt4 import MiniGPT4 16 | from minigpt4.processors.base_processor import BaseProcessor 17 | 18 | 19 | __all__ = [ 20 | "load_model", 21 | "BaseModel", 22 | "Blip2Base", 23 | "MiniGPT4", 24 | ] 25 | 26 | 27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 28 | """ 29 | Load supported models. 30 | 31 | To list all available models and types in registry: 32 | >>> from minigpt4.models import model_zoo 33 | >>> print(model_zoo) 34 | 35 | Args: 36 | name (str): name of the model. 37 | model_type (str): type of the model. 38 | is_eval (bool): whether the model is in eval mode. Default: False. 39 | device (str): device to use. Default: "cpu". 40 | checkpoint (str): path or to checkpoint. Default: None. 41 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 42 | 43 | Returns: 44 | model (torch.nn.Module): model. 45 | """ 46 | 47 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 48 | 49 | if checkpoint is not None: 50 | model.load_checkpoint(checkpoint) 51 | 52 | if is_eval: 53 | model.eval() 54 | 55 | if device == "cpu": 56 | model = model.float() 57 | 58 | return model.to(device) 59 | 60 | 61 | def load_preprocess(config): 62 | """ 63 | Load preprocessor configs and construct preprocessors. 64 | 65 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 66 | 67 | Args: 68 | config (dict): preprocessor configs. 69 | 70 | Returns: 71 | vis_processors (dict): preprocessors for visual inputs. 72 | txt_processors (dict): preprocessors for text inputs. 73 | 74 | Key is "train" or "eval" for processors used in training and evaluation respectively. 75 | """ 76 | 77 | def _build_proc_from_cfg(cfg): 78 | return ( 79 | registry.get_processor_class(cfg.name).from_config(cfg) 80 | if cfg is not None 81 | else BaseProcessor() 82 | ) 83 | 84 | vis_processors = dict() 85 | txt_processors = dict() 86 | 87 | vis_proc_cfg = config.get("vis_processor") 88 | txt_proc_cfg = config.get("text_processor") 89 | 90 | if vis_proc_cfg is not None: 91 | vis_train_cfg = vis_proc_cfg.get("train") 92 | vis_eval_cfg = vis_proc_cfg.get("eval") 93 | else: 94 | vis_train_cfg = None 95 | vis_eval_cfg = None 96 | 97 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 98 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 99 | 100 | if txt_proc_cfg is not None: 101 | txt_train_cfg = txt_proc_cfg.get("train") 102 | txt_eval_cfg = txt_proc_cfg.get("eval") 103 | else: 104 | txt_train_cfg = None 105 | txt_eval_cfg = None 106 | 107 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 108 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 109 | 110 | return vis_processors, txt_processors 111 | 112 | 113 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 114 | """ 115 | Load model and its related preprocessors. 116 | 117 | List all available models and types in registry: 118 | >>> from minigpt4.models import model_zoo 119 | >>> print(model_zoo) 120 | 121 | Args: 122 | name (str): name of the model. 123 | model_type (str): type of the model. 124 | is_eval (bool): whether the model is in eval mode. Default: False. 125 | device (str): device to use. Default: "cpu". 126 | 127 | Returns: 128 | model (torch.nn.Module): model. 129 | vis_processors (dict): preprocessors for visual inputs. 130 | txt_processors (dict): preprocessors for text inputs. 131 | """ 132 | model_cls = registry.get_model_class(name) 133 | 134 | # load model 135 | model = model_cls.from_pretrained(model_type=model_type) 136 | 137 | if is_eval: 138 | model.eval() 139 | 140 | # load preprocess 141 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 142 | if cfg is not None: 143 | preprocess_cfg = cfg.preprocess 144 | 145 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 146 | else: 147 | vis_processors, txt_processors = None, None 148 | logging.info( 149 | f"""No default preprocess for model {name} ({model_type}). 150 | This can happen if the model is not finetuned on downstream datasets, 151 | or it is not intended for direct use without finetuning. 152 | """ 153 | ) 154 | 155 | if device == "cpu" or device == torch.device("cpu"): 156 | model = model.float() 157 | 158 | return model.to(device), vis_processors, txt_processors 159 | 160 | 161 | class ModelZoo: 162 | """ 163 | A utility class to create string representation of available model architectures and types. 164 | 165 | >>> from minigpt4.models import model_zoo 166 | >>> # list all available models 167 | >>> print(model_zoo) 168 | >>> # show total number of models 169 | >>> print(len(model_zoo)) 170 | """ 171 | 172 | def __init__(self) -> None: 173 | self.model_zoo = { 174 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 175 | for k, v in registry.mapping["model_name_mapping"].items() 176 | } 177 | 178 | def __str__(self) -> str: 179 | return ( 180 | "=" * 50 181 | + "\n" 182 | + f"{'Architectures':<30} {'Types'}\n" 183 | + "=" * 50 184 | + "\n" 185 | + "\n".join( 186 | [ 187 | f"{name:<30} {', '.join(types)}" 188 | for name, types in self.model_zoo.items() 189 | ] 190 | ) 191 | ) 192 | 193 | def __iter__(self): 194 | return iter(self.model_zoo.items()) 195 | 196 | def __len__(self): 197 | return sum([len(v) for v in self.model_zoo.values()]) 198 | 199 | 200 | model_zoo = ModelZoo() 201 | -------------------------------------------------------------------------------- /minigpt4/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from minigpt4.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def global_avg(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = "" 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt="{avg:.4f}") 129 | data_time = SmoothedValue(fmt="{avg:.4f}") 130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 131 | log_msg = [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append("max mem: {memory:.0f}") 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print( 152 | log_msg.format( 153 | i, 154 | len(iterable), 155 | eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), 158 | data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB, 160 | ) 161 | ) 162 | else: 163 | print( 164 | log_msg.format( 165 | i, 166 | len(iterable), 167 | eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), 170 | data=str(data_time), 171 | ) 172 | ) 173 | i += 1 174 | end = time.time() 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print( 178 | "{} Total time: {} ({:.4f} s / it)".format( 179 | header, total_time_str, total_time / len(iterable) 180 | ) 181 | ) 182 | 183 | 184 | class AttrDict(dict): 185 | def __init__(self, *args, **kwargs): 186 | super(AttrDict, self).__init__(*args, **kwargs) 187 | self.__dict__ = self 188 | 189 | 190 | def setup_logger(): 191 | logging.basicConfig( 192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 193 | format="%(asctime)s [%(levelname)s] %(message)s", 194 | handlers=[logging.StreamHandler()], 195 | ) 196 | -------------------------------------------------------------------------------- /minigpt4/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from minigpt4.common.registry import registry 26 | from minigpt4.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | 114 | # reorganize by split 115 | for _, dataset in datasets.items(): 116 | for split_name, dataset_split in dataset.items(): 117 | if split_name not in reorg_datasets: 118 | reorg_datasets[split_name] = [dataset_split] 119 | else: 120 | reorg_datasets[split_name].append(dataset_split) 121 | 122 | return reorg_datasets 123 | 124 | 125 | def concat_datasets(datasets): 126 | """ 127 | Concatenates multiple datasets into a single dataset. 128 | 129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 130 | generic IterableDataset because it requires creating separate samplers. 131 | 132 | Now only supports conctenating training datasets and assuming validation and testing 133 | have only a single dataset. This is because metrics should not be computed on the concatenated 134 | datasets. 135 | 136 | Args: 137 | datasets: dict of torch.utils.data.Dataset objects by split. 138 | 139 | Returns: 140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 141 | "val" and "test" remain the same. 142 | 143 | If the input training datasets contain both map-style and DataPipeline datasets, returns 144 | a tuple, where the first element is a concatenated map-style dataset and the second 145 | element is a chained DataPipeline dataset. 146 | 147 | """ 148 | # concatenate datasets in the same split 149 | for split_name in datasets: 150 | if split_name != "train": 151 | assert ( 152 | len(datasets[split_name]) == 1 153 | ), "Do not support multiple {} datasets.".format(split_name) 154 | datasets[split_name] = datasets[split_name][0] 155 | else: 156 | iterable_datasets, map_datasets = [], [] 157 | for dataset in datasets[split_name]: 158 | if isinstance(dataset, wds.DataPipeline): 159 | logging.info( 160 | "Dataset {} is IterableDataset, can't be concatenated.".format( 161 | dataset 162 | ) 163 | ) 164 | iterable_datasets.append(dataset) 165 | elif isinstance(dataset, IterableDataset): 166 | raise NotImplementedError( 167 | "Do not support concatenation of generic IterableDataset." 168 | ) 169 | else: 170 | map_datasets.append(dataset) 171 | 172 | # if len(iterable_datasets) > 0: 173 | # concatenate map-style datasets and iterable-style datasets separately 174 | if len(iterable_datasets) > 1: 175 | chained_datasets = ( 176 | ChainDataset(iterable_datasets) 177 | ) 178 | elif len(iterable_datasets) == 1: 179 | chained_datasets = iterable_datasets[0] 180 | else: 181 | chained_datasets = None 182 | 183 | concat_datasets = ( 184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 185 | ) 186 | 187 | train_datasets = concat_datasets, chained_datasets 188 | train_datasets = tuple([x for x in train_datasets if x is not None]) 189 | train_datasets = ( 190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 191 | ) 192 | 193 | datasets[split_name] = train_datasets 194 | 195 | return datasets 196 | 197 | -------------------------------------------------------------------------------- /minigpt4/conversation/conversation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from PIL import Image 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 7 | from transformers import StoppingCriteria, StoppingCriteriaList 8 | 9 | import dataclasses 10 | from enum import auto, Enum 11 | from typing import List, Tuple, Any 12 | 13 | from minigpt4.common.registry import registry 14 | 15 | 16 | class SeparatorStyle(Enum): 17 | """Different separator style.""" 18 | SINGLE = auto() 19 | TWO = auto() 20 | 21 | 22 | @dataclasses.dataclass 23 | class Conversation: 24 | """A class that keeps all conversation history.""" 25 | system: str 26 | roles: List[str] 27 | messages: List[List[str]] 28 | offset: int 29 | # system_img: List[Image.Image] = [] 30 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 31 | sep: str = "###" 32 | sep2: str = None 33 | 34 | skip_next: bool = False 35 | conv_id: Any = None 36 | 37 | def get_prompt(self): 38 | if self.sep_style == SeparatorStyle.SINGLE: 39 | ret = self.system + self.sep 40 | for role, message in self.messages: 41 | if message: 42 | ret += role + ": " + message + self.sep 43 | else: 44 | ret += role + ":" 45 | return ret 46 | elif self.sep_style == SeparatorStyle.TWO: 47 | seps = [self.sep, self.sep2] 48 | ret = self.system + seps[0] 49 | for i, (role, message) in enumerate(self.messages): 50 | if message: 51 | ret += role + ": " + message + seps[i % 2] 52 | else: 53 | ret += role + ":" 54 | return ret 55 | else: 56 | raise ValueError(f"Invalid style: {self.sep_style}") 57 | 58 | def append_message(self, role, message): 59 | self.messages.append([role, message]) 60 | 61 | def to_gradio_chatbot(self): 62 | ret = [] 63 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 64 | if i % 2 == 0: 65 | ret.append([msg, None]) 66 | else: 67 | ret[-1][-1] = msg 68 | return ret 69 | 70 | def copy(self): 71 | return Conversation( 72 | system=self.system, 73 | # system_img=self.system_img, 74 | roles=self.roles, 75 | messages=[[x, y] for x, y in self.messages], 76 | offset=self.offset, 77 | sep_style=self.sep_style, 78 | sep=self.sep, 79 | sep2=self.sep2, 80 | conv_id=self.conv_id) 81 | 82 | def dict(self): 83 | return { 84 | "system": self.system, 85 | # "system_img": self.system_img, 86 | "roles": self.roles, 87 | "messages": self.messages, 88 | "offset": self.offset, 89 | "sep": self.sep, 90 | "sep2": self.sep2, 91 | "conv_id": self.conv_id, 92 | } 93 | 94 | 95 | class StoppingCriteriaSub(StoppingCriteria): 96 | 97 | def __init__(self, stops=[], encounters=1): 98 | super().__init__() 99 | self.stops = stops 100 | 101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 102 | for stop in self.stops: 103 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 104 | return True 105 | 106 | return False 107 | 108 | 109 | CONV_VISION = Conversation( 110 | system="Give the following image: ImageContent. " 111 | "You will be able to see the image once I provide it to you. Please answer my questions.", 112 | roles=("Human", "Assistant"), 113 | messages=[], 114 | offset=2, 115 | sep_style=SeparatorStyle.SINGLE, 116 | sep="###", 117 | ) 118 | 119 | 120 | 121 | class Chat: 122 | def __init__(self, model, vis_processor, device='cuda:0'): 123 | self.device = device 124 | self.model = model 125 | self.vis_processor = vis_processor 126 | stop_words_ids = [torch.tensor([835]).to(self.device), 127 | torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. 128 | self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 129 | 130 | def ask(self, text, conv): 131 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ 132 | and conv.messages[-1][1][-6:] == '': # last message is image. 133 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) 134 | else: 135 | conv.append_message(conv.roles[0], text) 136 | 137 | def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9, 138 | repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000): 139 | conv.append_message(conv.roles[1], None) 140 | embs = self.get_context_emb(conv, img_list) 141 | 142 | # current_max_len = embs.shape[1] + max_new_tokens + 100 143 | # begin_idx = max(0, current_max_len - max_length) 144 | # embs = embs[:, begin_idx:] 145 | outputs = self.model.llama_model.generate( 146 | inputs_embeds=embs, 147 | max_new_tokens=max_new_tokens, 148 | stopping_criteria=self.stopping_criteria, 149 | num_beams=num_beams, 150 | min_length=min_length, 151 | top_p=top_p, 152 | repetition_penalty=repetition_penalty, 153 | length_penalty=length_penalty, 154 | temperature=temperature, 155 | ) 156 | output_token = outputs[0] 157 | if output_token[0] == 0: 158 | output_token = output_token[1:] 159 | output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) 160 | output_text = output_text.split('###')[0] # remove the stop sign '###' 161 | output_text = output_text.split('Assistant:')[-1].strip() 162 | conv.messages[-1][1] = output_text 163 | return output_text, output_token.cpu().numpy() 164 | 165 | def upload_img(self, image, conv, img_list): 166 | if isinstance(image, str): # is a image path 167 | raw_image = Image.open(image).convert('RGB') 168 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 169 | elif isinstance(image, Image.Image): 170 | raw_image = image 171 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 172 | elif isinstance(image, torch.Tensor): 173 | if len(image.shape) == 3: 174 | image = image.unsqueeze(0) 175 | image = image.to(self.device) 176 | 177 | image_emb, _ = self.model.encode_img(image) 178 | img_list.append(image_emb) 179 | conv.append_message(conv.roles[0], "") 180 | msg = "Received." 181 | # self.conv.append_message(self.conv.roles[1], msg) 182 | return msg 183 | 184 | def get_context_emb(self, conv, img_list): 185 | prompt = conv.get_prompt() 186 | prompt_segs = prompt.split('') 187 | assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." 188 | seg_tokens = [ 189 | self.model.llama_tokenizer( 190 | seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids 191 | # only add bos to the first seg 192 | for i, seg in enumerate(prompt_segs) 193 | ] 194 | seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] 195 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] 196 | mixed_embs = torch.cat(mixed_embs, dim=1) 197 | return mixed_embs 198 | 199 | 200 | -------------------------------------------------------------------------------- /minigpt4/models/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import logging 9 | import os 10 | import time 11 | import datetime 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.distributed as dist 16 | import torch.nn.functional as F 17 | 18 | import minigpt4.common.dist_utils as dist_utils 19 | from minigpt4.common.dist_utils import download_cached_file 20 | from minigpt4.common.utils import is_url 21 | from minigpt4.common.logger import MetricLogger 22 | from minigpt4.models.base_model import BaseModel 23 | from minigpt4.models.Qformer import BertConfig, BertLMHeadModel 24 | from minigpt4.models.eva_vit import create_eva_vit_g 25 | from transformers import BertTokenizer 26 | 27 | 28 | class Blip2Base(BaseModel): 29 | @classmethod 30 | def init_tokenizer(cls): 31 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 32 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 33 | return tokenizer 34 | 35 | def maybe_autocast(self, dtype=torch.float16): 36 | # if on cpu, don't use autocast 37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 38 | enable_autocast = self.device != torch.device("cpu") 39 | 40 | if enable_autocast: 41 | return torch.cuda.amp.autocast(dtype=dtype) 42 | else: 43 | return contextlib.nullcontext() 44 | 45 | @classmethod 46 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 47 | encoder_config = BertConfig.from_pretrained("bert-base-uncased") 48 | encoder_config.encoder_width = vision_width 49 | # insert cross-attention layer every other block 50 | encoder_config.add_cross_attention = True 51 | encoder_config.cross_attention_freq = cross_attention_freq 52 | encoder_config.query_length = num_query_token 53 | Qformer = BertLMHeadModel(config=encoder_config) 54 | query_tokens = nn.Parameter( 55 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 56 | ) 57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 58 | return Qformer, query_tokens 59 | 60 | @classmethod 61 | def init_vision_encoder( 62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision 63 | ): 64 | assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" 65 | visual_encoder = create_eva_vit_g( 66 | img_size, drop_path_rate, use_grad_checkpoint, precision 67 | ) 68 | 69 | ln_vision = LayerNorm(visual_encoder.num_features) 70 | return visual_encoder, ln_vision 71 | 72 | def load_from_pretrained(self, url_or_filename): 73 | if is_url(url_or_filename): 74 | cached_file = download_cached_file( 75 | url_or_filename, check_hash=False, progress=True 76 | ) 77 | checkpoint = torch.load(cached_file, map_location="cpu") 78 | elif os.path.isfile(url_or_filename): 79 | checkpoint = torch.load(url_or_filename, map_location="cpu") 80 | else: 81 | raise RuntimeError("checkpoint url or path is invalid") 82 | 83 | state_dict = checkpoint["model"] 84 | 85 | msg = self.load_state_dict(state_dict, strict=False) 86 | 87 | # logging.info("Missing keys {}".format(msg.missing_keys)) 88 | logging.info("load checkpoint from %s" % url_or_filename) 89 | 90 | return msg 91 | 92 | 93 | def disabled_train(self, mode=True): 94 | """Overwrite model.train with this function to make sure train/eval mode 95 | does not change anymore.""" 96 | return self 97 | 98 | 99 | class LayerNorm(nn.LayerNorm): 100 | """Subclass torch's LayerNorm to handle fp16.""" 101 | 102 | def forward(self, x: torch.Tensor): 103 | orig_type = x.dtype 104 | ret = super().forward(x.type(torch.float32)) 105 | return ret.type(orig_type) 106 | 107 | 108 | def compute_sim_matrix(model, data_loader, **kwargs): 109 | k_test = kwargs.pop("k_test") 110 | 111 | metric_logger = MetricLogger(delimiter=" ") 112 | header = "Evaluation:" 113 | 114 | logging.info("Computing features for evaluation...") 115 | start_time = time.time() 116 | 117 | texts = data_loader.dataset.text 118 | num_text = len(texts) 119 | text_bs = 256 120 | text_ids = [] 121 | text_embeds = [] 122 | text_atts = [] 123 | for i in range(0, num_text, text_bs): 124 | text = texts[i : min(num_text, i + text_bs)] 125 | text_input = model.tokenizer( 126 | text, 127 | padding="max_length", 128 | truncation=True, 129 | max_length=35, 130 | return_tensors="pt", 131 | ).to(model.device) 132 | text_feat = model.forward_text(text_input) 133 | text_embed = F.normalize(model.text_proj(text_feat)) 134 | text_embeds.append(text_embed) 135 | text_ids.append(text_input.input_ids) 136 | text_atts.append(text_input.attention_mask) 137 | 138 | text_embeds = torch.cat(text_embeds, dim=0) 139 | text_ids = torch.cat(text_ids, dim=0) 140 | text_atts = torch.cat(text_atts, dim=0) 141 | 142 | vit_feats = [] 143 | image_embeds = [] 144 | for samples in data_loader: 145 | image = samples["image"] 146 | 147 | image = image.to(model.device) 148 | image_feat, vit_feat = model.forward_image(image) 149 | image_embed = model.vision_proj(image_feat) 150 | image_embed = F.normalize(image_embed, dim=-1) 151 | 152 | vit_feats.append(vit_feat.cpu()) 153 | image_embeds.append(image_embed) 154 | 155 | vit_feats = torch.cat(vit_feats, dim=0) 156 | image_embeds = torch.cat(image_embeds, dim=0) 157 | 158 | sims_matrix = [] 159 | for image_embed in image_embeds: 160 | sim_q2t = image_embed @ text_embeds.t() 161 | sim_i2t, _ = sim_q2t.max(0) 162 | sims_matrix.append(sim_i2t) 163 | sims_matrix = torch.stack(sims_matrix, dim=0) 164 | 165 | score_matrix_i2t = torch.full( 166 | (len(data_loader.dataset.image), len(texts)), -100.0 167 | ).to(model.device) 168 | 169 | num_tasks = dist_utils.get_world_size() 170 | rank = dist_utils.get_rank() 171 | step = sims_matrix.size(0) // num_tasks + 1 172 | start = rank * step 173 | end = min(sims_matrix.size(0), start + step) 174 | 175 | for i, sims in enumerate( 176 | metric_logger.log_every(sims_matrix[start:end], 50, header) 177 | ): 178 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 179 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 180 | score = model.compute_itm( 181 | image_inputs=image_inputs, 182 | text_ids=text_ids[topk_idx], 183 | text_atts=text_atts[topk_idx], 184 | ).float() 185 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 186 | 187 | sims_matrix = sims_matrix.t() 188 | score_matrix_t2i = torch.full( 189 | (len(texts), len(data_loader.dataset.image)), -100.0 190 | ).to(model.device) 191 | 192 | step = sims_matrix.size(0) // num_tasks + 1 193 | start = rank * step 194 | end = min(sims_matrix.size(0), start + step) 195 | 196 | for i, sims in enumerate( 197 | metric_logger.log_every(sims_matrix[start:end], 50, header) 198 | ): 199 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 200 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 201 | score = model.compute_itm( 202 | image_inputs=image_inputs, 203 | text_ids=text_ids[start + i].repeat(k_test, 1), 204 | text_atts=text_atts[start + i].repeat(k_test, 1), 205 | ).float() 206 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 207 | 208 | if dist_utils.is_dist_avail_and_initialized(): 209 | dist.barrier() 210 | torch.distributed.all_reduce( 211 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 212 | ) 213 | torch.distributed.all_reduce( 214 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 215 | ) 216 | 217 | total_time = time.time() - start_time 218 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 219 | logging.info("Evaluation time {}".format(total_time_str)) 220 | 221 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 222 | -------------------------------------------------------------------------------- /minigpt4/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized 15 | from minigpt4.common.utils import get_abs_path, is_url 16 | from omegaconf import OmegaConf 17 | 18 | 19 | class BaseModel(nn.Module): 20 | """Base class for models.""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @property 26 | def device(self): 27 | return list(self.parameters())[0].device 28 | 29 | def load_checkpoint(self, url_or_filename): 30 | """ 31 | Load from a finetuned checkpoint. 32 | 33 | This should expect no mismatch in the model keys and the checkpoint keys. 34 | """ 35 | 36 | if is_url(url_or_filename): 37 | cached_file = download_cached_file( 38 | url_or_filename, check_hash=False, progress=True 39 | ) 40 | checkpoint = torch.load(cached_file, map_location="cpu") 41 | elif os.path.isfile(url_or_filename): 42 | checkpoint = torch.load(url_or_filename, map_location="cpu") 43 | else: 44 | raise RuntimeError("checkpoint url or path is invalid") 45 | 46 | if "model" in checkpoint.keys(): 47 | state_dict = checkpoint["model"] 48 | else: 49 | state_dict = checkpoint 50 | 51 | msg = self.load_state_dict(state_dict, strict=False) 52 | 53 | logging.info("Missing keys {}".format(msg.missing_keys)) 54 | logging.info("load checkpoint from %s" % url_or_filename) 55 | 56 | return msg 57 | 58 | @classmethod 59 | def from_pretrained(cls, model_type): 60 | """ 61 | Build a pretrained model from default configuration file, specified by model_type. 62 | 63 | Args: 64 | - model_type (str): model type, specifying architecture and checkpoints. 65 | 66 | Returns: 67 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 68 | """ 69 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 70 | model = cls.from_config(model_cfg) 71 | 72 | return model 73 | 74 | @classmethod 75 | def default_config_path(cls, model_type): 76 | assert ( 77 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 78 | ), "Unknown model type {}".format(model_type) 79 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 80 | 81 | def load_checkpoint_from_config(self, cfg, **kwargs): 82 | """ 83 | Load checkpoint as specified in the config file. 84 | 85 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 86 | When loading the pretrained model, each task-specific architecture may define their 87 | own load_from_pretrained() method. 88 | """ 89 | load_finetuned = cfg.get("load_finetuned", True) 90 | if load_finetuned: 91 | finetune_path = cfg.get("finetuned", None) 92 | assert ( 93 | finetune_path is not None 94 | ), "Found load_finetuned is True, but finetune_path is None." 95 | self.load_checkpoint(url_or_filename=finetune_path) 96 | else: 97 | # load pre-trained weights 98 | pretrain_path = cfg.get("pretrained", None) 99 | assert "Found load_finetuned is False, but pretrain_path is None." 100 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 101 | 102 | def before_evaluation(self, **kwargs): 103 | pass 104 | 105 | def show_n_params(self, return_str=True): 106 | tot = 0 107 | for p in self.parameters(): 108 | w = 1 109 | for x in p.shape: 110 | w *= x 111 | tot += w 112 | if return_str: 113 | if tot >= 1e6: 114 | return "{:.1f}M".format(tot / 1e6) 115 | else: 116 | return "{:.1f}K".format(tot / 1e3) 117 | else: 118 | return tot 119 | 120 | 121 | class BaseEncoder(nn.Module): 122 | """ 123 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 124 | """ 125 | 126 | def __init__(self): 127 | super().__init__() 128 | 129 | def forward_features(self, samples, **kwargs): 130 | raise NotImplementedError 131 | 132 | @property 133 | def device(self): 134 | return list(self.parameters())[0].device 135 | 136 | 137 | class SharedQueueMixin: 138 | @torch.no_grad() 139 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 140 | # gather keys before updating queue 141 | image_feats = concat_all_gather(image_feat) 142 | text_feats = concat_all_gather(text_feat) 143 | 144 | batch_size = image_feats.shape[0] 145 | 146 | ptr = int(self.queue_ptr) 147 | assert self.queue_size % batch_size == 0 # for simplicity 148 | 149 | # replace the keys at ptr (dequeue and enqueue) 150 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 151 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 152 | 153 | if idxs is not None: 154 | idxs = concat_all_gather(idxs) 155 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 156 | 157 | ptr = (ptr + batch_size) % self.queue_size # move pointer 158 | self.queue_ptr[0] = ptr 159 | 160 | 161 | class MomentumDistilationMixin: 162 | @torch.no_grad() 163 | def copy_params(self): 164 | for model_pair in self.model_pairs: 165 | for param, param_m in zip( 166 | model_pair[0].parameters(), model_pair[1].parameters() 167 | ): 168 | param_m.data.copy_(param.data) # initialize 169 | param_m.requires_grad = False # not update by gradient 170 | 171 | @torch.no_grad() 172 | def _momentum_update(self): 173 | for model_pair in self.model_pairs: 174 | for param, param_m in zip( 175 | model_pair[0].parameters(), model_pair[1].parameters() 176 | ): 177 | param_m.data = param_m.data * self.momentum + param.data * ( 178 | 1.0 - self.momentum 179 | ) 180 | 181 | 182 | class GatherLayer(torch.autograd.Function): 183 | """ 184 | Gather tensors from all workers with support for backward propagation: 185 | This implementation does not cut the gradients as torch.distributed.all_gather does. 186 | """ 187 | 188 | @staticmethod 189 | def forward(ctx, x): 190 | output = [ 191 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 192 | ] 193 | torch.distributed.all_gather(output, x) 194 | return tuple(output) 195 | 196 | @staticmethod 197 | def backward(ctx, *grads): 198 | all_gradients = torch.stack(grads) 199 | torch.distributed.all_reduce(all_gradients) 200 | return all_gradients[torch.distributed.get_rank()] 201 | 202 | 203 | def all_gather_with_grad(tensors): 204 | """ 205 | Performs all_gather operation on the provided tensors. 206 | Graph remains connected for backward grad computation. 207 | """ 208 | # Queue the gathered tensors 209 | world_size = torch.distributed.get_world_size() 210 | # There is no need for reduction in the single-proc case 211 | if world_size == 1: 212 | return tensors 213 | 214 | # tensor_all = GatherLayer.apply(tensors) 215 | tensor_all = GatherLayer.apply(tensors) 216 | 217 | return torch.cat(tensor_all, dim=0) 218 | 219 | 220 | @torch.no_grad() 221 | def concat_all_gather(tensor): 222 | """ 223 | Performs all_gather operation on the provided tensors. 224 | *** Warning ***: torch.distributed.all_gather has no gradient. 225 | """ 226 | # if use distributed training 227 | if not is_dist_avail_and_initialized(): 228 | return tensor 229 | 230 | tensors_gather = [ 231 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 232 | ] 233 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 234 | 235 | output = torch.cat(tensors_gather, dim=0) 236 | return output 237 | 238 | 239 | def tile(x, dim, n_tile): 240 | init_dim = x.size(dim) 241 | repeat_idx = [1] * x.dim() 242 | repeat_idx[dim] = n_tile 243 | x = x.repeat(*(repeat_idx)) 244 | order_index = torch.LongTensor( 245 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 246 | ) 247 | return torch.index_select(x, dim, order_index.to(x.device)) 248 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | import shutil 11 | import warnings 12 | 13 | from omegaconf import OmegaConf 14 | import torch.distributed as dist 15 | from torchvision.datasets.utils import download_url 16 | 17 | import minigpt4.common.utils as utils 18 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process 19 | from minigpt4.common.registry import registry 20 | from minigpt4.processors.base_processor import BaseProcessor 21 | 22 | 23 | 24 | class BaseDatasetBuilder: 25 | train_dataset_cls, eval_dataset_cls = None, None 26 | 27 | def __init__(self, cfg=None): 28 | super().__init__() 29 | 30 | if cfg is None: 31 | # help to create datasets from default config. 32 | self.config = load_dataset_config(self.default_config_path()) 33 | elif isinstance(cfg, str): 34 | self.config = load_dataset_config(cfg) 35 | else: 36 | # when called from task.build_dataset() 37 | self.config = cfg 38 | 39 | self.data_type = self.config.data_type 40 | 41 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 42 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | 44 | def build_datasets(self): 45 | # download, split, etc... 46 | # only called on 1 GPU/TPU in distributed 47 | 48 | if is_main_process(): 49 | self._download_data() 50 | 51 | if is_dist_avail_and_initialized(): 52 | dist.barrier() 53 | 54 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 55 | logging.info("Building datasets...") 56 | datasets = self.build() # dataset['train'/'val'/'test'] 57 | 58 | return datasets 59 | 60 | def build_processors(self): 61 | vis_proc_cfg = self.config.get("vis_processor") 62 | txt_proc_cfg = self.config.get("text_processor") 63 | 64 | if vis_proc_cfg is not None: 65 | vis_train_cfg = vis_proc_cfg.get("train") 66 | vis_eval_cfg = vis_proc_cfg.get("eval") 67 | 68 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 69 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 70 | 71 | if txt_proc_cfg is not None: 72 | txt_train_cfg = txt_proc_cfg.get("train") 73 | txt_eval_cfg = txt_proc_cfg.get("eval") 74 | 75 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 76 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 77 | 78 | @staticmethod 79 | def _build_proc_from_cfg(cfg): 80 | return ( 81 | registry.get_processor_class(cfg.name).from_config(cfg) 82 | if cfg is not None 83 | else None 84 | ) 85 | 86 | @classmethod 87 | def default_config_path(cls, type="default"): 88 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 89 | 90 | def _download_data(self): 91 | self._download_ann() 92 | self._download_vis() 93 | 94 | def _download_ann(self): 95 | """ 96 | Download annotation files if necessary. 97 | All the vision-language datasets should have annotations of unified format. 98 | 99 | storage_path can be: 100 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 101 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 102 | 103 | Local annotation paths should be relative. 104 | """ 105 | anns = self.config.build_info.annotations 106 | 107 | splits = anns.keys() 108 | 109 | cache_root = registry.get_path("cache_root") 110 | 111 | for split in splits: 112 | info = anns[split] 113 | 114 | urls, storage_paths = info.get("url", None), info.storage 115 | 116 | if isinstance(urls, str): 117 | urls = [urls] 118 | if isinstance(storage_paths, str): 119 | storage_paths = [storage_paths] 120 | 121 | assert len(urls) == len(storage_paths) 122 | 123 | for url_or_filename, storage_path in zip(urls, storage_paths): 124 | # if storage_path is relative, make it full by prefixing with cache_root. 125 | if not os.path.isabs(storage_path): 126 | storage_path = os.path.join(cache_root, storage_path) 127 | 128 | dirname = os.path.dirname(storage_path) 129 | if not os.path.exists(dirname): 130 | os.makedirs(dirname) 131 | 132 | if os.path.isfile(url_or_filename): 133 | src, dst = url_or_filename, storage_path 134 | if not os.path.exists(dst): 135 | shutil.copyfile(src=src, dst=dst) 136 | else: 137 | logging.info("Using existing file {}.".format(dst)) 138 | else: 139 | if os.path.isdir(storage_path): 140 | # if only dirname is provided, suffix with basename of URL. 141 | raise ValueError( 142 | "Expecting storage_path to be a file path, got directory {}".format( 143 | storage_path 144 | ) 145 | ) 146 | else: 147 | filename = os.path.basename(storage_path) 148 | 149 | download_url(url=url_or_filename, root=dirname, filename=filename) 150 | 151 | def _download_vis(self): 152 | 153 | storage_path = self.config.build_info.get(self.data_type).storage 154 | storage_path = utils.get_cache_path(storage_path) 155 | 156 | if not os.path.exists(storage_path): 157 | warnings.warn( 158 | f""" 159 | The specified path {storage_path} for visual inputs does not exist. 160 | Please provide a correct path to the visual inputs or 161 | refer to datasets/download_scripts/README.md for downloading instructions. 162 | """ 163 | ) 164 | 165 | def build(self): 166 | """ 167 | Create by split datasets inheriting torch.utils.data.Datasets. 168 | 169 | # build() can be dataset-specific. Overwrite to customize. 170 | """ 171 | self.build_processors() 172 | 173 | build_info = self.config.build_info 174 | 175 | ann_info = build_info.annotations 176 | vis_info = build_info.get(self.data_type) 177 | 178 | datasets = dict() 179 | for split in ann_info.keys(): 180 | if split not in ["train", "val", "test"]: 181 | continue 182 | 183 | is_train = split == "train" 184 | 185 | # processors 186 | vis_processor = ( 187 | self.vis_processors["train"] 188 | if is_train 189 | else self.vis_processors["eval"] 190 | ) 191 | text_processor = ( 192 | self.text_processors["train"] 193 | if is_train 194 | else self.text_processors["eval"] 195 | ) 196 | 197 | # annotation path 198 | ann_paths = ann_info.get(split).storage 199 | if isinstance(ann_paths, str): 200 | ann_paths = [ann_paths] 201 | 202 | abs_ann_paths = [] 203 | for ann_path in ann_paths: 204 | if not os.path.isabs(ann_path): 205 | ann_path = utils.get_cache_path(ann_path) 206 | abs_ann_paths.append(ann_path) 207 | ann_paths = abs_ann_paths 208 | 209 | # visual data storage path 210 | vis_path = os.path.join(vis_info.storage, split) 211 | 212 | if not os.path.isabs(vis_path): 213 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 214 | vis_path = utils.get_cache_path(vis_path) 215 | 216 | if not os.path.exists(vis_path): 217 | warnings.warn("storage path {} does not exist.".format(vis_path)) 218 | 219 | # create datasets 220 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 221 | datasets[split] = dataset_cls( 222 | vis_processor=vis_processor, 223 | text_processor=text_processor, 224 | ann_paths=ann_paths, 225 | vis_root=vis_path, 226 | ) 227 | 228 | return datasets 229 | 230 | 231 | def load_dataset_config(cfg_path): 232 | cfg = OmegaConf.load(cfg_path).datasets 233 | cfg = cfg[list(cfg.keys())[0]] 234 | 235 | return cfg 236 | -------------------------------------------------------------------------------- /minigpt4/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 14 | from minigpt4.common.logger import MetricLogger, SmoothedValue 15 | from minigpt4.common.registry import registry 16 | from minigpt4.datasets.data_utils import prepare_sample 17 | 18 | 19 | class BaseTask: 20 | def __init__(self, **kwargs): 21 | super().__init__() 22 | 23 | self.inst_id_key = "instance_id" 24 | 25 | @classmethod 26 | def setup_task(cls, **kwargs): 27 | return cls() 28 | 29 | def build_model(self, cfg): 30 | model_config = cfg.model_cfg 31 | 32 | model_cls = registry.get_model_class(model_config.arch) 33 | return model_cls.from_config(model_config) 34 | 35 | def build_datasets(self, cfg): 36 | """ 37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. 38 | Download dataset and annotations automatically if not exist. 39 | 40 | Args: 41 | cfg (common.config.Config): _description_ 42 | 43 | Returns: 44 | dict: Dictionary of torch.utils.data.Dataset objects by split. 45 | """ 46 | 47 | datasets = dict() 48 | 49 | datasets_config = cfg.datasets_cfg 50 | 51 | assert len(datasets_config) > 0, "At least one dataset has to be specified." 52 | 53 | for name in datasets_config: 54 | dataset_config = datasets_config[name] 55 | 56 | builder = registry.get_builder_class(name)(dataset_config) 57 | dataset = builder.build_datasets() 58 | 59 | dataset['train'].name = name 60 | if 'sample_ratio' in dataset_config: 61 | dataset['train'].sample_ratio = dataset_config.sample_ratio 62 | 63 | datasets[name] = dataset 64 | 65 | return datasets 66 | 67 | def train_step(self, model, samples): 68 | loss = model(samples)["loss"] 69 | return loss 70 | 71 | def valid_step(self, model, samples): 72 | raise NotImplementedError 73 | 74 | def before_evaluation(self, model, dataset, **kwargs): 75 | model.before_evaluation(dataset=dataset, task_type=type(self)) 76 | 77 | def after_evaluation(self, **kwargs): 78 | pass 79 | 80 | def inference_step(self): 81 | raise NotImplementedError 82 | 83 | def evaluation(self, model, data_loader, cuda_enabled=True): 84 | metric_logger = MetricLogger(delimiter=" ") 85 | header = "Evaluation" 86 | # TODO make it configurable 87 | print_freq = 10 88 | 89 | results = [] 90 | 91 | for samples in metric_logger.log_every(data_loader, print_freq, header): 92 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 93 | 94 | eval_output = self.valid_step(model=model, samples=samples) 95 | results.extend(eval_output) 96 | 97 | if is_dist_avail_and_initialized(): 98 | dist.barrier() 99 | 100 | return results 101 | 102 | def train_epoch( 103 | self, 104 | epoch, 105 | model, 106 | data_loader, 107 | optimizer, 108 | lr_scheduler, 109 | scaler=None, 110 | cuda_enabled=False, 111 | log_freq=50, 112 | accum_grad_iters=1, 113 | ): 114 | return self._train_inner_loop( 115 | epoch=epoch, 116 | iters_per_epoch=lr_scheduler.iters_per_epoch, 117 | model=model, 118 | data_loader=data_loader, 119 | optimizer=optimizer, 120 | scaler=scaler, 121 | lr_scheduler=lr_scheduler, 122 | log_freq=log_freq, 123 | cuda_enabled=cuda_enabled, 124 | accum_grad_iters=accum_grad_iters, 125 | ) 126 | 127 | def train_iters( 128 | self, 129 | epoch, 130 | start_iters, 131 | iters_per_inner_epoch, 132 | model, 133 | data_loader, 134 | optimizer, 135 | lr_scheduler, 136 | scaler=None, 137 | cuda_enabled=False, 138 | log_freq=50, 139 | accum_grad_iters=1, 140 | ): 141 | return self._train_inner_loop( 142 | epoch=epoch, 143 | start_iters=start_iters, 144 | iters_per_epoch=iters_per_inner_epoch, 145 | model=model, 146 | data_loader=data_loader, 147 | optimizer=optimizer, 148 | scaler=scaler, 149 | lr_scheduler=lr_scheduler, 150 | log_freq=log_freq, 151 | cuda_enabled=cuda_enabled, 152 | accum_grad_iters=accum_grad_iters, 153 | ) 154 | 155 | def _train_inner_loop( 156 | self, 157 | epoch, 158 | iters_per_epoch, 159 | model, 160 | data_loader, 161 | optimizer, 162 | lr_scheduler, 163 | scaler=None, 164 | start_iters=None, 165 | log_freq=50, 166 | cuda_enabled=False, 167 | accum_grad_iters=1, 168 | ): 169 | """ 170 | An inner training loop compatible with both epoch-based and iter-based training. 171 | 172 | When using epoch-based, training stops after one epoch; when using iter-based, 173 | training stops after #iters_per_epoch iterations. 174 | """ 175 | use_amp = scaler is not None 176 | 177 | if not hasattr(data_loader, "__next__"): 178 | # convert to iterator if not already 179 | data_loader = iter(data_loader) 180 | 181 | metric_logger = MetricLogger(delimiter=" ") 182 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 183 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 184 | 185 | # if iter-based runner, schedule lr based on inner epoch. 186 | logging.info( 187 | "Start training epoch {}, {} iters per inner epoch.".format( 188 | epoch, iters_per_epoch 189 | ) 190 | ) 191 | header = "Train: data epoch: [{}]".format(epoch) 192 | if start_iters is None: 193 | # epoch-based runner 194 | inner_epoch = epoch 195 | else: 196 | # In iter-based runner, we schedule the learning rate based on iterations. 197 | inner_epoch = start_iters // iters_per_epoch 198 | header = header + "; inner epoch [{}]".format(inner_epoch) 199 | 200 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 201 | # if using iter-based runner, we stop after iters_per_epoch iterations. 202 | if i >= iters_per_epoch: 203 | break 204 | 205 | samples = next(data_loader) 206 | 207 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 208 | samples.update( 209 | { 210 | "epoch": inner_epoch, 211 | "num_iters_per_epoch": iters_per_epoch, 212 | "iters": i, 213 | } 214 | ) 215 | 216 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 217 | 218 | with torch.cuda.amp.autocast(enabled=use_amp): 219 | loss = self.train_step(model=model, samples=samples) 220 | 221 | # after_train_step() 222 | if use_amp: 223 | scaler.scale(loss).backward() 224 | else: 225 | loss.backward() 226 | 227 | # update gradients every accum_grad_iters iterations 228 | if (i + 1) % accum_grad_iters == 0: 229 | if use_amp: 230 | scaler.step(optimizer) 231 | scaler.update() 232 | else: 233 | optimizer.step() 234 | optimizer.zero_grad() 235 | 236 | metric_logger.update(loss=loss.item()) 237 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 238 | 239 | # after train_epoch() 240 | # gather the stats from all processes 241 | metric_logger.synchronize_between_processes() 242 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 243 | return { 244 | k: "{:.3f}".format(meter.global_avg) 245 | for k, meter in metric_logger.meters.items() 246 | } 247 | 248 | @staticmethod 249 | def save_result(result, result_dir, filename, remove_duplicate=""): 250 | import json 251 | 252 | result_file = os.path.join( 253 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 254 | ) 255 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 256 | 257 | json.dump(result, open(result_file, "w")) 258 | 259 | if is_dist_avail_and_initialized(): 260 | dist.barrier() 261 | 262 | if is_main_process(): 263 | logging.warning("rank %d starts merging results." % get_rank()) 264 | # combine results from all processes 265 | result = [] 266 | 267 | for rank in range(get_world_size()): 268 | result_file = os.path.join( 269 | result_dir, "%s_rank%d.json" % (filename, rank) 270 | ) 271 | res = json.load(open(result_file, "r")) 272 | result += res 273 | 274 | if remove_duplicate: 275 | result_new = [] 276 | id_list = [] 277 | for res in result: 278 | if res[remove_duplicate] not in id_list: 279 | id_list.append(res[remove_duplicate]) 280 | result_new.append(res) 281 | result = result_new 282 | 283 | json.dump(result, open(final_result_file, "w")) 284 | print("result file saved to %s" % final_result_file) 285 | 286 | return final_result_file 287 | -------------------------------------------------------------------------------- /minigpt4/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from minigpt4.common.registry import registry 31 | from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | assert issubclass( 38 | builder_cls, BaseDatasetBuilder 39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | builder_cls 41 | ) 42 | if name in cls.mapping["builder_name_mapping"]: 43 | raise KeyError( 44 | "Name '{}' already registered for {}.".format( 45 | name, cls.mapping["builder_name_mapping"][name] 46 | ) 47 | ) 48 | cls.mapping["builder_name_mapping"][name] = builder_cls 49 | return builder_cls 50 | 51 | return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from minigpt4.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from minigpt4.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from minigpt4.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from minigpt4.models import BaseModel 96 | 97 | assert issubclass( 98 | model_cls, BaseModel 99 | ), "All models must inherit BaseModel class" 100 | if name in cls.mapping["model_name_mapping"]: 101 | raise KeyError( 102 | "Name '{}' already registered for {}.".format( 103 | name, cls.mapping["model_name_mapping"][name] 104 | ) 105 | ) 106 | cls.mapping["model_name_mapping"][name] = model_cls 107 | return model_cls 108 | 109 | return wrap 110 | 111 | @classmethod 112 | def register_processor(cls, name): 113 | r"""Register a processor to registry with key 'name' 114 | 115 | Args: 116 | name: Key with which the task will be registered. 117 | 118 | Usage: 119 | 120 | from minigpt4.common.registry import registry 121 | """ 122 | 123 | def wrap(processor_cls): 124 | from minigpt4.processors import BaseProcessor 125 | 126 | assert issubclass( 127 | processor_cls, BaseProcessor 128 | ), "All processors must inherit BaseProcessor class" 129 | if name in cls.mapping["processor_name_mapping"]: 130 | raise KeyError( 131 | "Name '{}' already registered for {}.".format( 132 | name, cls.mapping["processor_name_mapping"][name] 133 | ) 134 | ) 135 | cls.mapping["processor_name_mapping"][name] = processor_cls 136 | return processor_cls 137 | 138 | return wrap 139 | 140 | @classmethod 141 | def register_lr_scheduler(cls, name): 142 | r"""Register a model to registry with key 'name' 143 | 144 | Args: 145 | name: Key with which the task will be registered. 146 | 147 | Usage: 148 | 149 | from minigpt4.common.registry import registry 150 | """ 151 | 152 | def wrap(lr_sched_cls): 153 | if name in cls.mapping["lr_scheduler_name_mapping"]: 154 | raise KeyError( 155 | "Name '{}' already registered for {}.".format( 156 | name, cls.mapping["lr_scheduler_name_mapping"][name] 157 | ) 158 | ) 159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 160 | return lr_sched_cls 161 | 162 | return wrap 163 | 164 | @classmethod 165 | def register_runner(cls, name): 166 | r"""Register a model to registry with key 'name' 167 | 168 | Args: 169 | name: Key with which the task will be registered. 170 | 171 | Usage: 172 | 173 | from minigpt4.common.registry import registry 174 | """ 175 | 176 | def wrap(runner_cls): 177 | if name in cls.mapping["runner_name_mapping"]: 178 | raise KeyError( 179 | "Name '{}' already registered for {}.".format( 180 | name, cls.mapping["runner_name_mapping"][name] 181 | ) 182 | ) 183 | cls.mapping["runner_name_mapping"][name] = runner_cls 184 | return runner_cls 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def register_path(cls, name, path): 190 | r"""Register a path to registry with key 'name' 191 | 192 | Args: 193 | name: Key with which the path will be registered. 194 | 195 | Usage: 196 | 197 | from minigpt4.common.registry import registry 198 | """ 199 | assert isinstance(path, str), "All path must be str." 200 | if name in cls.mapping["paths"]: 201 | raise KeyError("Name '{}' already registered.".format(name)) 202 | cls.mapping["paths"][name] = path 203 | 204 | @classmethod 205 | def register(cls, name, obj): 206 | r"""Register an item to registry with key 'name' 207 | 208 | Args: 209 | name: Key with which the item will be registered. 210 | 211 | Usage:: 212 | 213 | from minigpt4.common.registry import registry 214 | 215 | registry.register("config", {}) 216 | """ 217 | path = name.split(".") 218 | current = cls.mapping["state"] 219 | 220 | for part in path[:-1]: 221 | if part not in current: 222 | current[part] = {} 223 | current = current[part] 224 | 225 | current[path[-1]] = obj 226 | 227 | # @classmethod 228 | # def get_trainer_class(cls, name): 229 | # return cls.mapping["trainer_name_mapping"].get(name, None) 230 | 231 | @classmethod 232 | def get_builder_class(cls, name): 233 | return cls.mapping["builder_name_mapping"].get(name, None) 234 | 235 | @classmethod 236 | def get_model_class(cls, name): 237 | return cls.mapping["model_name_mapping"].get(name, None) 238 | 239 | @classmethod 240 | def get_task_class(cls, name): 241 | return cls.mapping["task_name_mapping"].get(name, None) 242 | 243 | @classmethod 244 | def get_processor_class(cls, name): 245 | return cls.mapping["processor_name_mapping"].get(name, None) 246 | 247 | @classmethod 248 | def get_lr_scheduler_class(cls, name): 249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 250 | 251 | @classmethod 252 | def get_runner_class(cls, name): 253 | return cls.mapping["runner_name_mapping"].get(name, None) 254 | 255 | @classmethod 256 | def list_runners(cls): 257 | return sorted(cls.mapping["runner_name_mapping"].keys()) 258 | 259 | @classmethod 260 | def list_models(cls): 261 | return sorted(cls.mapping["model_name_mapping"].keys()) 262 | 263 | @classmethod 264 | def list_tasks(cls): 265 | return sorted(cls.mapping["task_name_mapping"].keys()) 266 | 267 | @classmethod 268 | def list_processors(cls): 269 | return sorted(cls.mapping["processor_name_mapping"].keys()) 270 | 271 | @classmethod 272 | def list_lr_schedulers(cls): 273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 274 | 275 | @classmethod 276 | def list_datasets(cls): 277 | return sorted(cls.mapping["builder_name_mapping"].keys()) 278 | 279 | @classmethod 280 | def get_path(cls, name): 281 | return cls.mapping["paths"].get(name, None) 282 | 283 | @classmethod 284 | def get(cls, name, default=None, no_warning=False): 285 | r"""Get an item from registry with key 'name' 286 | 287 | Args: 288 | name (string): Key whose value needs to be retrieved. 289 | default: If passed and key is not in registry, default value will 290 | be returned with a warning. Default: None 291 | no_warning (bool): If passed as True, warning when key doesn't exist 292 | will not be generated. Useful for MMF's 293 | internal operations. Default: False 294 | """ 295 | original_name = name 296 | name = name.split(".") 297 | value = cls.mapping["state"] 298 | for subname in name: 299 | value = value.get(subname, default) 300 | if value is default: 301 | break 302 | 303 | if ( 304 | "writer" in cls.mapping["state"] 305 | and value == default 306 | and no_warning is False 307 | ): 308 | cls.mapping["state"]["writer"].warning( 309 | "Key {} is not present in registry, returning default value " 310 | "of {}".format(original_name, default) 311 | ) 312 | return value 313 | 314 | @classmethod 315 | def unregister(cls, name): 316 | r"""Remove an item from registry with key 'name' 317 | 318 | Args: 319 | name: Key which needs to be removed. 320 | Usage:: 321 | 322 | from mmf.common.registry import registry 323 | 324 | config = registry.unregister("config") 325 | """ 326 | return cls.mapping["state"].pop(name, None) 327 | 328 | 329 | registry = Registry() 330 | -------------------------------------------------------------------------------- /backup_app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | import glob 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from PIL import Image 12 | 13 | from minigpt4.common.config import Config 14 | from minigpt4.common.dist_utils import get_rank 15 | from minigpt4.common.registry import registry 16 | from minigpt4.conversation.conversation import Chat, CONV_VISION 17 | 18 | # imports modules for registration 19 | from minigpt4.datasets.builders import * 20 | from minigpt4.models import * 21 | from minigpt4.processors import * 22 | from minigpt4.runners import * 23 | from minigpt4.tasks import * 24 | 25 | import cv2 26 | from tqdm import tqdm 27 | from tensorflow.keras.models import load_model 28 | from huggingface_hub import hf_hub_download 29 | from pathlib import Path 30 | from copy import deepcopy 31 | import keras 32 | 33 | 34 | #parsing the arguments 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description="Combined Demo") 38 | parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.") 39 | parser.add_argument( 40 | "--options", 41 | nargs="+", 42 | help="override some settings in the used config, the key-value pair " 43 | "in xxx=yyy format will be merged into config file (deprecate), " 44 | "change to --cfg-options instead.", 45 | ) 46 | parser.add_argument("--image-folder", type=str, required=True, help="Path to the input image folder.") 47 | parser.add_argument("--model", type=str, default='llama', help="Model to be used for generation. Options: 'llama' (default), 'llama7b'") 48 | parser.add_argument("--beam-search-numbers", type=int, default=1, help="beam search numbers") 49 | parser.add_argument("--model-dir", type=str, required=True, help="Path to the model directory.") 50 | parser.add_argument("--repo-id", type=str, default=DEFAULT_WD14_TAGGER_REPO, help="Hugging Face model repository ID.") 51 | parser.add_argument("--force-download", action="store_true", help="Force download the model.") 52 | parser.add_argument("--general-threshold", type=float, default=0.5, help="Threshold for general tags.") 53 | parser.add_argument("--character-threshold", type=float, default=0.5, help="Threshold for character tags.") 54 | parser.add_argument("--remove-underscore", action="store_true", help="Remove underscores from captions.") 55 | parser.add_argument("--undesired-tags", type=str, default="", help="Comma separated list of undesired tags.") 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | # these are functions taken from minigpt4 app.py 61 | def setup_seeds(config): 62 | seed = config.run_cfg.seed + get_rank() 63 | 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | 68 | cudnn.benchmark = False 69 | cudnn.deterministic = True 70 | 71 | 72 | def describe_image(image_path, chat, chat_state, img, num_beams=1, temperature=1.0): 73 | chat_state = CONV_VISION.copy() 74 | img_list = [] 75 | 76 | gr_img = Image.open(image_path) 77 | llm_message = chat.upload_img(gr_img, chat_state, img_list) 78 | 79 | chat.ask("Describe this image.", chat_state) 80 | generated_caption = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0] 81 | 82 | return generated_caption 83 | 84 | 85 | #these are functions taken from wd_tags.py 86 | 87 | IMAGE_SIZE = 448 88 | IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp",".webp"] 89 | 90 | # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 91 | DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" 92 | FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] 93 | SUB_DIR = "variables" 94 | SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] 95 | CSV_FILE = FILES[-1] 96 | 97 | def glob_images_pathlib(dir_path, recursive): 98 | image_paths = [] 99 | if recursive: 100 | for ext in IMAGE_EXTENSIONS: 101 | image_paths += list(dir_path.rglob("*" + ext)) 102 | else: 103 | for ext in IMAGE_EXTENSIONS: 104 | image_paths += list(dir_path.glob("*" + ext)) 105 | image_paths = list(set(image_paths)) # Remove duplicates 106 | image_paths.sort() 107 | return image_paths 108 | 109 | 110 | def preprocess_image(image): 111 | image = np.array(image) 112 | image = image[:, :, ::-1] # RGB->BGR 113 | 114 | # pad to square 115 | size = max(image.shape[0:2]) 116 | pad_x = size - image.shape[1] 117 | pad_y = size - image.shape[0] 118 | pad_l = pad_x // 2 119 | pad_t = pad_y // 2 120 | image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) 121 | 122 | interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 123 | image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) 124 | 125 | image = image.astype(np.float32) 126 | return image 127 | 128 | 129 | class ImageLoadingPrepDataset(torch.utils.data.Dataset): 130 | def __init__(self, image_paths): 131 | self.images = image_paths 132 | 133 | def __len__(self): 134 | return len(self.images) 135 | 136 | def __getitem__(self, idx): 137 | img_path = str(self.images[idx]) 138 | 139 | try: 140 | image = Image.open(img_path).convert("RGB") 141 | image = preprocess_image(image) 142 | tensor = torch.tensor(image) 143 | except Exception as e: 144 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 145 | return None 146 | 147 | return (tensor, img_path) 148 | 149 | 150 | def collate_fn_remove_corrupted(batch): 151 | """Collate function that allows to remove corrupted examples in the 152 | dataloader. It expects that the dataloader returns 'None' when that occurs. 153 | The 'None's in the batch are removed. 154 | """ 155 | # Filter out all the Nones (corrupted examples) 156 | batch = list(filter(lambda x: x is not None, batch)) 157 | return batch 158 | 159 | def run_batch(images, model, args): 160 | # define the tags 161 | with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: 162 | reader = csv.reader(f) 163 | l = [row for row in reader] 164 | header = l[0] # tag_id, name, category, count 165 | rows = l[1:] 166 | assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" 167 | general_tags = [row[1] for row in rows[1:] if row[2] == "0"] 168 | character_tags = [row[1] for row in rows[1:] if row[2] == "4"] 169 | 170 | undesired_tags = set(args.undesired_tags.split(",")) 171 | 172 | # Process images to generate captions 173 | probs = model(np.array(images), training=False) 174 | captions = [] 175 | for batch_probs in probs.numpy(): 176 | tag_text = "" 177 | for i, p in enumerate(batch_probs[4:]): 178 | if i < len(general_tags) and p >= args.general_threshold: 179 | tag_name = general_tags[i] 180 | tag_name = tag_name if not args.remove_underscore or len(tag_name) <= 3 else tag_name.replace("_", " ") 181 | if tag_name not in undesired_tags: 182 | tag_text += ", " + tag_name 183 | elif i >= len(general_tags) and p >= args.character_threshold: 184 | tag_name = character_tags[i - len(general_tags)] 185 | tag_name = tag_name if not args.remove_underscore or len(tag_name) <= 3 else tag_name.replace("_", " ") 186 | if tag_name not in undesired_tags: 187 | tag_text += ", " + tag_name 188 | tag_text = tag_text[2:] if len(tag_text) > 0 else '' 189 | captions.append(tag_text) 190 | return captions 191 | 192 | def wd_pass(image_paths, model, args): 193 | # Preprocess the image 194 | captions = [] 195 | for image_path in image_paths: 196 | image = Image.open(image_path) 197 | image = preprocess_image(image) 198 | captions.append(run_batch([image], model, args)) 199 | return captions 200 | 201 | def main(): 202 | args = parse_args() 203 | 204 | # check for the model 205 | if not os.path.exists(args.model_dir) or args.force_download: 206 | print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") 207 | for file in FILES: 208 | hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True) 209 | 210 | for file in SUB_DIR_FILES: 211 | hf_hub_download( 212 | args.repo_id, 213 | file, 214 | subfolder=SUB_DIR, 215 | cache_dir=os.path.join(args.model_dir, SUB_DIR), 216 | force_download=True, 217 | ) 218 | 219 | cfg = Config(args) 220 | model_config = cfg.model_cfg 221 | 222 | model_cls = registry.get_model_class(model_config.arch) 223 | model = model_cls.from_config(model_config) 224 | 225 | model = model.to(torch.device('cuda')) 226 | 227 | vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train 228 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 229 | chat = Chat(model, vis_processor) 230 | 231 | chat_state = deepcopy(CONV_VISION) 232 | img_list = [] 233 | 234 | image_folder = args.image_folder 235 | num_beams = args.beam_search_numbers 236 | temperature = 1.0 # default temperature 237 | 238 | image_extensions = ['jpg', 'jpeg', 'png', 'bmp'] 239 | image_paths = [] 240 | 241 | for ext in image_extensions: 242 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext}'))) 243 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext.upper()}'))) 244 | 245 | if not os.path.exists("mycaptions"): 246 | os.makedirs("mycaptions") 247 | 248 | for image_path in image_paths: 249 | start_time = time.time() 250 | caption = describe_image(image_path, chat, chat_state, img_list, num_beams, temperature) 251 | 252 | with open("mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]), "w") as f: 253 | f.write(caption) 254 | 255 | end_time = time.time() 256 | time_taken = end_time - start_time 257 | print(f"Caption for {os.path.basename(image_path)} saved in 'mycaptions' folder") 258 | print(f"Time taken to process caption for {os.path.basename(image_path)} is: {time_taken:.2f} s") 259 | 260 | del model # Unload pytorch model from memory 261 | torch.cuda.empty_cache() 262 | 263 | # Load Keras model 264 | keras.backend.clear_session() 265 | model = load_model(args.model_dir) 266 | 267 | wd_captions = wd_pass(image_paths, model, args) 268 | 269 | for image_path, wd_caption in zip(image_paths, wd_captions): 270 | wd_caption = wd_caption[0] 271 | with open("mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]), "a") as f: 272 | f.write(str(wd_caption)) 273 | 274 | del model # Unload keras model from memory 275 | keras.backend.clear_session() 276 | 277 | if __name__ == '__main__': 278 | main() -------------------------------------------------------------------------------- /minigpt4/models/mini_gpt4.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import logging 8 | import random 9 | import os 10 | import torch 11 | from torch.cuda.amp import autocast as autocast 12 | import torch.nn as nn 13 | 14 | from minigpt4.common.registry import registry 15 | from minigpt4.models.blip2 import Blip2Base, disabled_train 16 | from minigpt4.models.modeling_llama import LlamaForCausalLM 17 | from transformers import LlamaTokenizer 18 | 19 | 20 | @registry.register_model("mini_gpt4") 21 | class MiniGPT4(Blip2Base): 22 | """ 23 | BLIP2 GPT-LLAMA model. 24 | """ 25 | 26 | PRETRAINED_MODEL_CONFIG_DICT = { 27 | "pretrain_vicuna": "configs/models/minigpt4.yaml", 28 | } 29 | 30 | def __init__( 31 | self, 32 | vit_model="eva_clip_g", 33 | q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", 34 | img_size=224, 35 | drop_path_rate=0, 36 | use_grad_checkpoint=False, 37 | vit_precision="fp16", 38 | freeze_vit=True, 39 | freeze_qformer=True, 40 | num_query_token=32, 41 | llama_model="", 42 | llama_cache_dir='', 43 | prompt_path="", 44 | prompt_template="", 45 | max_txt_len=32, 46 | end_sym='\n', 47 | ): 48 | super().__init__() 49 | 50 | self.tokenizer = self.init_tokenizer() 51 | 52 | print('Loading VIT') 53 | self.visual_encoder, self.ln_vision = self.init_vision_encoder( 54 | vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision 55 | ) 56 | if freeze_vit: 57 | for name, param in self.visual_encoder.named_parameters(): 58 | param.requires_grad = False 59 | self.visual_encoder = self.visual_encoder.eval() 60 | self.visual_encoder.train = disabled_train 61 | for name, param in self.ln_vision.named_parameters(): 62 | param.requires_grad = False 63 | self.ln_vision = self.ln_vision.eval() 64 | self.ln_vision.train = disabled_train 65 | logging.info("freeze vision encoder") 66 | print('Loading VIT Done') 67 | 68 | print('Loading Q-Former') 69 | self.Qformer, self.query_tokens = self.init_Qformer( 70 | num_query_token, self.visual_encoder.num_features 71 | ) 72 | self.Qformer.cls = None 73 | self.Qformer.bert.embeddings.word_embeddings = None 74 | self.Qformer.bert.embeddings.position_embeddings = None 75 | for layer in self.Qformer.bert.encoder.layer: 76 | layer.output = None 77 | layer.intermediate = None 78 | self.load_from_pretrained(url_or_filename=q_former_model) 79 | 80 | if freeze_qformer: 81 | for name, param in self.Qformer.named_parameters(): 82 | param.requires_grad = False 83 | self.Qformer = self.Qformer.eval() 84 | self.Qformer.train = disabled_train 85 | self.query_tokens.requires_grad = False 86 | logging.info("freeze Qformer") 87 | print('Loading Q-Former Done') 88 | 89 | print('Loading LLAMA') 90 | self.llama_tokenizer = LlamaTokenizer.from_pretrained('camenduru/MiniGPT4-7B', use_fast=False) 91 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 92 | 93 | if llama_cache_dir: 94 | self.llama_model = LlamaForCausalLM.from_pretrained( 95 | 'camenduru/MiniGPT4-7B', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto" 96 | ) 97 | else: 98 | self.llama_model = LlamaForCausalLM.from_pretrained( 99 | 'camenduru/MiniGPT4-7B', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto" 100 | ) 101 | for name, param in self.llama_model.named_parameters(): 102 | param.requires_grad = False 103 | print('Loading LLAMA Done') 104 | 105 | self.llama_proj = nn.Linear( 106 | self.Qformer.config.hidden_size, self.llama_model.config.hidden_size 107 | ) 108 | self.max_txt_len = max_txt_len 109 | self.end_sym = end_sym 110 | 111 | if prompt_path: 112 | with open(prompt_path, 'r') as f: 113 | raw_prompts = f.read().splitlines() 114 | filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] 115 | self.prompt_list = [prompt_template.format(p) for p in filted_prompts] 116 | print('Load {} training prompts'.format(len(self.prompt_list))) 117 | print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) 118 | else: 119 | self.prompt_list = [] 120 | 121 | def vit_to_cpu(self): 122 | self.ln_vision.to("cpu") 123 | self.ln_vision.float() 124 | self.visual_encoder.to("cpu") 125 | self.visual_encoder.float() 126 | 127 | def encode_img(self, image): 128 | device = image.device 129 | self.vit_to_cpu() 130 | image = image.to("cpu") 131 | with self.maybe_autocast(): 132 | image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) 133 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) 134 | 135 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) 136 | query_output = self.Qformer.bert( 137 | query_embeds=query_tokens, 138 | encoder_hidden_states=image_embeds, 139 | encoder_attention_mask=image_atts, 140 | return_dict=True, 141 | ) 142 | 143 | inputs_llama = self.llama_proj(query_output.last_hidden_state) 144 | atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) 145 | return inputs_llama, atts_llama 146 | 147 | def prompt_wrap(self, img_embeds, atts_img, prompt): 148 | if prompt: 149 | batch_size = img_embeds.shape[0] 150 | p_before, p_after = prompt.split('') 151 | p_before_tokens = self.llama_tokenizer( 152 | p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) 153 | p_after_tokens = self.llama_tokenizer( 154 | p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) 155 | p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) 156 | p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) 157 | wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) 158 | wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) 159 | return wrapped_img_embeds, wrapped_atts_img 160 | else: 161 | return img_embeds, atts_img 162 | 163 | def forward(self, samples): 164 | image = samples["image"] 165 | img_embeds, atts_img = self.encode_img(image) 166 | if hasattr(samples, 'question_split'): # VQA dataset 167 | print('VQA Batch') 168 | vqa_prompt = '###Human: ' 169 | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) 170 | elif self.prompt_list: 171 | prompt = random.choice(self.prompt_list) 172 | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) 173 | 174 | self.llama_tokenizer.padding_side = "right" 175 | 176 | text = [t + self.end_sym for t in samples["text_input"]] 177 | 178 | to_regress_tokens = self.llama_tokenizer( 179 | text, 180 | return_tensors="pt", 181 | padding="longest", 182 | truncation=True, 183 | max_length=self.max_txt_len, 184 | add_special_tokens=False 185 | ).to(image.device) 186 | 187 | targets = to_regress_tokens.input_ids.masked_fill( 188 | to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 189 | ) 190 | 191 | empty_targets = ( 192 | torch.ones([atts_img.shape[0], atts_img.shape[1]+1], 193 | dtype=torch.long).to(image.device).fill_(-100) # plus one for bos 194 | ) 195 | targets = torch.cat([empty_targets, targets], dim=1) 196 | 197 | batch_size = img_embeds.shape[0] 198 | bos = torch.ones([batch_size, 1], 199 | dtype=to_regress_tokens.input_ids.dtype, 200 | device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id 201 | bos_embeds = self.llama_model.model.embed_tokens(bos) 202 | atts_bos = atts_img[:, :1] 203 | 204 | to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) 205 | inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) 206 | attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) 207 | 208 | with self.maybe_autocast(): 209 | outputs = self.llama_model( 210 | inputs_embeds=inputs_embeds, 211 | attention_mask=attention_mask, 212 | return_dict=True, 213 | labels=targets, 214 | ) 215 | loss = outputs.loss 216 | 217 | return {"loss": loss} 218 | 219 | @classmethod 220 | def from_config(cls, cfg): 221 | vit_model = cfg.get("vit_model", "eva_clip_g") 222 | q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") 223 | img_size = cfg.get("image_size") 224 | num_query_token = cfg.get("num_query_token") 225 | llama_model = cfg.get("llama_model") 226 | 227 | drop_path_rate = cfg.get("drop_path_rate", 0) 228 | use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) 229 | vit_precision = cfg.get("vit_precision", "fp16") 230 | freeze_vit = cfg.get("freeze_vit", True) 231 | freeze_qformer = cfg.get("freeze_qformer", True) 232 | llama_cache_dir = cfg.get("llama_cache_dir", "") 233 | 234 | prompt_path = cfg.get("prompt_path", "") 235 | prompt_template = cfg.get("prompt_template", "") 236 | max_txt_len = cfg.get("max_txt_len", 32) 237 | end_sym = cfg.get("end_sym", '\n') 238 | 239 | model = cls( 240 | vit_model=vit_model, 241 | q_former_model=q_former_model, 242 | img_size=img_size, 243 | drop_path_rate=drop_path_rate, 244 | use_grad_checkpoint=use_grad_checkpoint, 245 | vit_precision=vit_precision, 246 | freeze_vit=freeze_vit, 247 | freeze_qformer=freeze_qformer, 248 | llama_cache_dir=llama_cache_dir, 249 | num_query_token=num_query_token, 250 | llama_model=llama_model, 251 | prompt_path=prompt_path, 252 | prompt_template=prompt_template, 253 | max_txt_len=max_txt_len, 254 | end_sym=end_sym 255 | ) 256 | 257 | ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 258 | if ckpt_path: 259 | print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) 260 | ckpt = torch.load(ckpt_path, map_location="cpu") 261 | msg = model.load_state_dict(ckpt['model'], strict=False) 262 | 263 | return model 264 | -------------------------------------------------------------------------------- /minigpt4/processors/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | import torch 12 | 13 | 14 | ## aug functions 15 | def identity_func(img): 16 | return img 17 | 18 | 19 | def autocontrast_func(img, cutoff=0): 20 | """ 21 | same output as PIL.ImageOps.autocontrast 22 | """ 23 | n_bins = 256 24 | 25 | def tune_channel(ch): 26 | n = ch.size 27 | cut = cutoff * n // 100 28 | if cut == 0: 29 | high, low = ch.max(), ch.min() 30 | else: 31 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 32 | low = np.argwhere(np.cumsum(hist) > cut) 33 | low = 0 if low.shape[0] == 0 else low[0] 34 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 35 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 36 | if high <= low: 37 | table = np.arange(n_bins) 38 | else: 39 | scale = (n_bins - 1) / (high - low) 40 | offset = -low * scale 41 | table = np.arange(n_bins) * scale + offset 42 | table[table < 0] = 0 43 | table[table > n_bins - 1] = n_bins - 1 44 | table = table.clip(0, 255).astype(np.uint8) 45 | return table[ch] 46 | 47 | channels = [tune_channel(ch) for ch in cv2.split(img)] 48 | out = cv2.merge(channels) 49 | return out 50 | 51 | 52 | def equalize_func(img): 53 | """ 54 | same output as PIL.ImageOps.equalize 55 | PIL's implementation is different from cv2.equalize 56 | """ 57 | n_bins = 256 58 | 59 | def tune_channel(ch): 60 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 61 | non_zero_hist = hist[hist != 0].reshape(-1) 62 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 63 | if step == 0: 64 | return ch 65 | n = np.empty_like(hist) 66 | n[0] = step // 2 67 | n[1:] = hist[:-1] 68 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 69 | return table[ch] 70 | 71 | channels = [tune_channel(ch) for ch in cv2.split(img)] 72 | out = cv2.merge(channels) 73 | return out 74 | 75 | 76 | def rotate_func(img, degree, fill=(0, 0, 0)): 77 | """ 78 | like PIL, rotate by degree, not radians 79 | """ 80 | H, W = img.shape[0], img.shape[1] 81 | center = W / 2, H / 2 82 | M = cv2.getRotationMatrix2D(center, degree, 1) 83 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 84 | return out 85 | 86 | 87 | def solarize_func(img, thresh=128): 88 | """ 89 | same output as PIL.ImageOps.posterize 90 | """ 91 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 92 | table = table.clip(0, 255).astype(np.uint8) 93 | out = table[img] 94 | return out 95 | 96 | 97 | def color_func(img, factor): 98 | """ 99 | same output as PIL.ImageEnhance.Color 100 | """ 101 | ## implementation according to PIL definition, quite slow 102 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 103 | # out = blend(degenerate, img, factor) 104 | # M = ( 105 | # np.eye(3) * factor 106 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 107 | # )[np.newaxis, np.newaxis, :] 108 | M = np.float32( 109 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] 110 | ) * factor + np.float32([[0.114], [0.587], [0.299]]) 111 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 112 | return out 113 | 114 | 115 | def contrast_func(img, factor): 116 | """ 117 | same output as PIL.ImageEnhance.Contrast 118 | """ 119 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 120 | table = ( 121 | np.array([(el - mean) * factor + mean for el in range(256)]) 122 | .clip(0, 255) 123 | .astype(np.uint8) 124 | ) 125 | out = table[img] 126 | return out 127 | 128 | 129 | def brightness_func(img, factor): 130 | """ 131 | same output as PIL.ImageEnhance.Contrast 132 | """ 133 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 134 | out = table[img] 135 | return out 136 | 137 | 138 | def sharpness_func(img, factor): 139 | """ 140 | The differences the this result and PIL are all on the 4 boundaries, the center 141 | areas are same 142 | """ 143 | kernel = np.ones((3, 3), dtype=np.float32) 144 | kernel[1][1] = 5 145 | kernel /= 13 146 | degenerate = cv2.filter2D(img, -1, kernel) 147 | if factor == 0.0: 148 | out = degenerate 149 | elif factor == 1.0: 150 | out = img 151 | else: 152 | out = img.astype(np.float32) 153 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 154 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 155 | out = out.astype(np.uint8) 156 | return out 157 | 158 | 159 | def shear_x_func(img, factor, fill=(0, 0, 0)): 160 | H, W = img.shape[0], img.shape[1] 161 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 162 | out = cv2.warpAffine( 163 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 164 | ).astype(np.uint8) 165 | return out 166 | 167 | 168 | def translate_x_func(img, offset, fill=(0, 0, 0)): 169 | """ 170 | same output as PIL.Image.transform 171 | """ 172 | H, W = img.shape[0], img.shape[1] 173 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 174 | out = cv2.warpAffine( 175 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 176 | ).astype(np.uint8) 177 | return out 178 | 179 | 180 | def translate_y_func(img, offset, fill=(0, 0, 0)): 181 | """ 182 | same output as PIL.Image.transform 183 | """ 184 | H, W = img.shape[0], img.shape[1] 185 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 186 | out = cv2.warpAffine( 187 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 188 | ).astype(np.uint8) 189 | return out 190 | 191 | 192 | def posterize_func(img, bits): 193 | """ 194 | same output as PIL.ImageOps.posterize 195 | """ 196 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 197 | return out 198 | 199 | 200 | def shear_y_func(img, factor, fill=(0, 0, 0)): 201 | H, W = img.shape[0], img.shape[1] 202 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 203 | out = cv2.warpAffine( 204 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 205 | ).astype(np.uint8) 206 | return out 207 | 208 | 209 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 210 | replace = np.array(replace, dtype=np.uint8) 211 | H, W = img.shape[0], img.shape[1] 212 | rh, rw = np.random.random(2) 213 | pad_size = pad_size // 2 214 | ch, cw = int(rh * H), int(rw * W) 215 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 216 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 217 | out = img.copy() 218 | out[x1:x2, y1:y2, :] = replace 219 | return out 220 | 221 | 222 | ### level to args 223 | def enhance_level_to_args(MAX_LEVEL): 224 | def level_to_args(level): 225 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 226 | 227 | return level_to_args 228 | 229 | 230 | def shear_level_to_args(MAX_LEVEL, replace_value): 231 | def level_to_args(level): 232 | level = (level / MAX_LEVEL) * 0.3 233 | if np.random.random() > 0.5: 234 | level = -level 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 241 | def level_to_args(level): 242 | level = (level / MAX_LEVEL) * float(translate_const) 243 | if np.random.random() > 0.5: 244 | level = -level 245 | return (level, replace_value) 246 | 247 | return level_to_args 248 | 249 | 250 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 251 | def level_to_args(level): 252 | level = int((level / MAX_LEVEL) * cutout_const) 253 | return (level, replace_value) 254 | 255 | return level_to_args 256 | 257 | 258 | def solarize_level_to_args(MAX_LEVEL): 259 | def level_to_args(level): 260 | level = int((level / MAX_LEVEL) * 256) 261 | return (level,) 262 | 263 | return level_to_args 264 | 265 | 266 | def none_level_to_args(level): 267 | return () 268 | 269 | 270 | def posterize_level_to_args(MAX_LEVEL): 271 | def level_to_args(level): 272 | level = int((level / MAX_LEVEL) * 4) 273 | return (level,) 274 | 275 | return level_to_args 276 | 277 | 278 | def rotate_level_to_args(MAX_LEVEL, replace_value): 279 | def level_to_args(level): 280 | level = (level / MAX_LEVEL) * 30 281 | if np.random.random() < 0.5: 282 | level = -level 283 | return (level, replace_value) 284 | 285 | return level_to_args 286 | 287 | 288 | func_dict = { 289 | "Identity": identity_func, 290 | "AutoContrast": autocontrast_func, 291 | "Equalize": equalize_func, 292 | "Rotate": rotate_func, 293 | "Solarize": solarize_func, 294 | "Color": color_func, 295 | "Contrast": contrast_func, 296 | "Brightness": brightness_func, 297 | "Sharpness": sharpness_func, 298 | "ShearX": shear_x_func, 299 | "TranslateX": translate_x_func, 300 | "TranslateY": translate_y_func, 301 | "Posterize": posterize_func, 302 | "ShearY": shear_y_func, 303 | } 304 | 305 | translate_const = 10 306 | MAX_LEVEL = 10 307 | replace_value = (128, 128, 128) 308 | arg_dict = { 309 | "Identity": none_level_to_args, 310 | "AutoContrast": none_level_to_args, 311 | "Equalize": none_level_to_args, 312 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), 313 | "Solarize": solarize_level_to_args(MAX_LEVEL), 314 | "Color": enhance_level_to_args(MAX_LEVEL), 315 | "Contrast": enhance_level_to_args(MAX_LEVEL), 316 | "Brightness": enhance_level_to_args(MAX_LEVEL), 317 | "Sharpness": enhance_level_to_args(MAX_LEVEL), 318 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), 319 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 320 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 321 | "Posterize": posterize_level_to_args(MAX_LEVEL), 322 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), 323 | } 324 | 325 | 326 | class RandomAugment(object): 327 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 328 | self.N = N 329 | self.M = M 330 | self.isPIL = isPIL 331 | if augs: 332 | self.augs = augs 333 | else: 334 | self.augs = list(arg_dict.keys()) 335 | 336 | def get_random_ops(self): 337 | sampled_ops = np.random.choice(self.augs, self.N) 338 | return [(op, 0.5, self.M) for op in sampled_ops] 339 | 340 | def __call__(self, img): 341 | if self.isPIL: 342 | img = np.array(img) 343 | ops = self.get_random_ops() 344 | for name, prob, level in ops: 345 | if np.random.random() > prob: 346 | continue 347 | args = arg_dict[name](level) 348 | img = func_dict[name](img, *args) 349 | return img 350 | 351 | 352 | class VideoRandomAugment(object): 353 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 354 | self.N = N 355 | self.M = M 356 | self.p = p 357 | self.tensor_in_tensor_out = tensor_in_tensor_out 358 | if augs: 359 | self.augs = augs 360 | else: 361 | self.augs = list(arg_dict.keys()) 362 | 363 | def get_random_ops(self): 364 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 365 | return [(op, self.M) for op in sampled_ops] 366 | 367 | def __call__(self, frames): 368 | assert ( 369 | frames.shape[-1] == 3 370 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." 371 | 372 | if self.tensor_in_tensor_out: 373 | frames = frames.numpy().astype(np.uint8) 374 | 375 | num_frames = frames.shape[0] 376 | 377 | ops = num_frames * [self.get_random_ops()] 378 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 379 | 380 | frames = torch.stack( 381 | list(map(self._aug, frames, ops, apply_or_not)), dim=0 382 | ).float() 383 | 384 | return frames 385 | 386 | def _aug(self, img, ops, apply_or_not): 387 | for i, (name, level) in enumerate(ops): 388 | if not apply_or_not[i]: 389 | continue 390 | args = arg_dict[name](level) 391 | img = func_dict[name](img, *args) 392 | return torch.from_numpy(img) 393 | 394 | 395 | if __name__ == "__main__": 396 | a = RandomAugment() 397 | img = np.random.randn(32, 32, 3) 398 | a(img) 399 | -------------------------------------------------------------------------------- /minigpt4/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import io 9 | import json 10 | import logging 11 | import os 12 | import pickle 13 | import re 14 | import shutil 15 | import urllib 16 | import urllib.error 17 | import urllib.request 18 | from typing import Optional 19 | from urllib.parse import urlparse 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import yaml 24 | from iopath.common.download import download 25 | from iopath.common.file_io import file_lock, g_pathmgr 26 | from minigpt4.common.registry import registry 27 | from torch.utils.model_zoo import tqdm 28 | from torchvision.datasets.utils import ( 29 | check_integrity, 30 | download_file_from_google_drive, 31 | extract_archive, 32 | ) 33 | 34 | 35 | def now(): 36 | from datetime import datetime 37 | 38 | return datetime.now().strftime("%Y%m%d%H%M")[:-1] 39 | 40 | 41 | def is_url(url_or_filename): 42 | parsed = urlparse(url_or_filename) 43 | return parsed.scheme in ("http", "https") 44 | 45 | 46 | def get_cache_path(rel_path): 47 | return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) 48 | 49 | 50 | def get_abs_path(rel_path): 51 | return os.path.join(registry.get_path("library_root"), rel_path) 52 | 53 | 54 | def load_json(filename): 55 | with open(filename, "r") as f: 56 | return json.load(f) 57 | 58 | 59 | # The following are adapted from torchvision and vissl 60 | # torchvision: https://github.com/pytorch/vision 61 | # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py 62 | 63 | 64 | def makedir(dir_path): 65 | """ 66 | Create the directory if it does not exist. 67 | """ 68 | is_success = False 69 | try: 70 | if not g_pathmgr.exists(dir_path): 71 | g_pathmgr.mkdirs(dir_path) 72 | is_success = True 73 | except BaseException: 74 | print(f"Error creating directory: {dir_path}") 75 | return is_success 76 | 77 | 78 | def get_redirected_url(url: str): 79 | """ 80 | Given a URL, returns the URL it redirects to or the 81 | original URL in case of no indirection 82 | """ 83 | import requests 84 | 85 | with requests.Session() as session: 86 | with session.get(url, stream=True, allow_redirects=True) as response: 87 | if response.history: 88 | return response.url 89 | else: 90 | return url 91 | 92 | 93 | def to_google_drive_download_url(view_url: str) -> str: 94 | """ 95 | Utility function to transform a view URL of google drive 96 | to a download URL for google drive 97 | Example input: 98 | https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view 99 | Example output: 100 | https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp 101 | """ 102 | splits = view_url.split("/") 103 | assert splits[-1] == "view" 104 | file_id = splits[-2] 105 | return f"https://drive.google.com/uc?export=download&id={file_id}" 106 | 107 | 108 | def download_google_drive_url(url: str, output_path: str, output_file_name: str): 109 | """ 110 | Download a file from google drive 111 | Downloading an URL from google drive requires confirmation when 112 | the file of the size is too big (google drive notifies that 113 | anti-viral checks cannot be performed on such files) 114 | """ 115 | import requests 116 | 117 | with requests.Session() as session: 118 | 119 | # First get the confirmation token and append it to the URL 120 | with session.get(url, stream=True, allow_redirects=True) as response: 121 | for k, v in response.cookies.items(): 122 | if k.startswith("download_warning"): 123 | url = url + "&confirm=" + v 124 | 125 | # Then download the content of the file 126 | with session.get(url, stream=True, verify=True) as response: 127 | makedir(output_path) 128 | path = os.path.join(output_path, output_file_name) 129 | total_size = int(response.headers.get("Content-length", 0)) 130 | with open(path, "wb") as file: 131 | from tqdm import tqdm 132 | 133 | with tqdm(total=total_size) as progress_bar: 134 | for block in response.iter_content( 135 | chunk_size=io.DEFAULT_BUFFER_SIZE 136 | ): 137 | file.write(block) 138 | progress_bar.update(len(block)) 139 | 140 | 141 | def _get_google_drive_file_id(url: str) -> Optional[str]: 142 | parts = urlparse(url) 143 | 144 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: 145 | return None 146 | 147 | match = re.match(r"/file/d/(?P[^/]*)", parts.path) 148 | if match is None: 149 | return None 150 | 151 | return match.group("id") 152 | 153 | 154 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: 155 | with open(filename, "wb") as fh: 156 | with urllib.request.urlopen( 157 | urllib.request.Request(url, headers={"User-Agent": "vissl"}) 158 | ) as response: 159 | with tqdm(total=response.length) as pbar: 160 | for chunk in iter(lambda: response.read(chunk_size), ""): 161 | if not chunk: 162 | break 163 | pbar.update(chunk_size) 164 | fh.write(chunk) 165 | 166 | 167 | def download_url( 168 | url: str, 169 | root: str, 170 | filename: Optional[str] = None, 171 | md5: Optional[str] = None, 172 | ) -> None: 173 | """Download a file from a url and place it in root. 174 | Args: 175 | url (str): URL to download file from 176 | root (str): Directory to place downloaded file in 177 | filename (str, optional): Name to save the file under. 178 | If None, use the basename of the URL. 179 | md5 (str, optional): MD5 checksum of the download. If None, do not check 180 | """ 181 | root = os.path.expanduser(root) 182 | if not filename: 183 | filename = os.path.basename(url) 184 | fpath = os.path.join(root, filename) 185 | 186 | makedir(root) 187 | 188 | # check if file is already present locally 189 | if check_integrity(fpath, md5): 190 | print("Using downloaded and verified file: " + fpath) 191 | return 192 | 193 | # expand redirect chain if needed 194 | url = get_redirected_url(url) 195 | 196 | # check if file is located on Google Drive 197 | file_id = _get_google_drive_file_id(url) 198 | if file_id is not None: 199 | return download_file_from_google_drive(file_id, root, filename, md5) 200 | 201 | # download the file 202 | try: 203 | print("Downloading " + url + " to " + fpath) 204 | _urlretrieve(url, fpath) 205 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] 206 | if url[:5] == "https": 207 | url = url.replace("https:", "http:") 208 | print( 209 | "Failed download. Trying https -> http instead." 210 | " Downloading " + url + " to " + fpath 211 | ) 212 | _urlretrieve(url, fpath) 213 | else: 214 | raise e 215 | 216 | # check integrity of downloaded file 217 | if not check_integrity(fpath, md5): 218 | raise RuntimeError("File not found or corrupted.") 219 | 220 | 221 | def download_and_extract_archive( 222 | url: str, 223 | download_root: str, 224 | extract_root: Optional[str] = None, 225 | filename: Optional[str] = None, 226 | md5: Optional[str] = None, 227 | remove_finished: bool = False, 228 | ) -> None: 229 | download_root = os.path.expanduser(download_root) 230 | if extract_root is None: 231 | extract_root = download_root 232 | if not filename: 233 | filename = os.path.basename(url) 234 | 235 | download_url(url, download_root, filename, md5) 236 | 237 | archive = os.path.join(download_root, filename) 238 | print("Extracting {} to {}".format(archive, extract_root)) 239 | extract_archive(archive, extract_root, remove_finished) 240 | 241 | 242 | def cache_url(url: str, cache_dir: str) -> str: 243 | """ 244 | This implementation downloads the remote resource and caches it locally. 245 | The resource will only be downloaded if not previously requested. 246 | """ 247 | parsed_url = urlparse(url) 248 | dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) 249 | makedir(dirname) 250 | filename = url.split("/")[-1] 251 | cached = os.path.join(dirname, filename) 252 | with file_lock(cached): 253 | if not os.path.isfile(cached): 254 | logging.info(f"Downloading {url} to {cached} ...") 255 | cached = download(url, dirname, filename=filename) 256 | logging.info(f"URL {url} cached in {cached}") 257 | return cached 258 | 259 | 260 | # TODO (prigoyal): convert this into RAII-style API 261 | def create_file_symlink(file1, file2): 262 | """ 263 | Simply create the symlinks for a given file1 to file2. 264 | Useful during model checkpointing to symlinks to the 265 | latest successful checkpoint. 266 | """ 267 | try: 268 | if g_pathmgr.exists(file2): 269 | g_pathmgr.rm(file2) 270 | g_pathmgr.symlink(file1, file2) 271 | except Exception as e: 272 | logging.info(f"Could NOT create symlink. Error: {e}") 273 | 274 | 275 | def save_file(data, filename, append_to_json=True, verbose=True): 276 | """ 277 | Common i/o utility to handle saving data to various file formats. 278 | Supported: 279 | .pkl, .pickle, .npy, .json 280 | Specifically for .json, users have the option to either append (default) 281 | or rewrite by passing in Boolean value to append_to_json. 282 | """ 283 | if verbose: 284 | logging.info(f"Saving data to file: {filename}") 285 | file_ext = os.path.splitext(filename)[1] 286 | if file_ext in [".pkl", ".pickle"]: 287 | with g_pathmgr.open(filename, "wb") as fopen: 288 | pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) 289 | elif file_ext == ".npy": 290 | with g_pathmgr.open(filename, "wb") as fopen: 291 | np.save(fopen, data) 292 | elif file_ext == ".json": 293 | if append_to_json: 294 | with g_pathmgr.open(filename, "a") as fopen: 295 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 296 | fopen.flush() 297 | else: 298 | with g_pathmgr.open(filename, "w") as fopen: 299 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 300 | fopen.flush() 301 | elif file_ext == ".yaml": 302 | with g_pathmgr.open(filename, "w") as fopen: 303 | dump = yaml.dump(data) 304 | fopen.write(dump) 305 | fopen.flush() 306 | else: 307 | raise Exception(f"Saving {file_ext} is not supported yet") 308 | 309 | if verbose: 310 | logging.info(f"Saved data to file: {filename}") 311 | 312 | 313 | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): 314 | """ 315 | Common i/o utility to handle loading data from various file formats. 316 | Supported: 317 | .pkl, .pickle, .npy, .json 318 | For the npy files, we support reading the files in mmap_mode. 319 | If the mmap_mode of reading is not successful, we load data without the 320 | mmap_mode. 321 | """ 322 | if verbose: 323 | logging.info(f"Loading data from file: {filename}") 324 | 325 | file_ext = os.path.splitext(filename)[1] 326 | if file_ext == ".txt": 327 | with g_pathmgr.open(filename, "r") as fopen: 328 | data = fopen.readlines() 329 | elif file_ext in [".pkl", ".pickle"]: 330 | with g_pathmgr.open(filename, "rb") as fopen: 331 | data = pickle.load(fopen, encoding="latin1") 332 | elif file_ext == ".npy": 333 | if mmap_mode: 334 | try: 335 | with g_pathmgr.open(filename, "rb") as fopen: 336 | data = np.load( 337 | fopen, 338 | allow_pickle=allow_pickle, 339 | encoding="latin1", 340 | mmap_mode=mmap_mode, 341 | ) 342 | except ValueError as e: 343 | logging.info( 344 | f"Could not mmap {filename}: {e}. Trying without g_pathmgr" 345 | ) 346 | data = np.load( 347 | filename, 348 | allow_pickle=allow_pickle, 349 | encoding="latin1", 350 | mmap_mode=mmap_mode, 351 | ) 352 | logging.info("Successfully loaded without g_pathmgr") 353 | except Exception: 354 | logging.info("Could not mmap without g_pathmgr. Trying without mmap") 355 | with g_pathmgr.open(filename, "rb") as fopen: 356 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 357 | else: 358 | with g_pathmgr.open(filename, "rb") as fopen: 359 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 360 | elif file_ext == ".json": 361 | with g_pathmgr.open(filename, "r") as fopen: 362 | data = json.load(fopen) 363 | elif file_ext == ".yaml": 364 | with g_pathmgr.open(filename, "r") as fopen: 365 | data = yaml.load(fopen, Loader=yaml.FullLoader) 366 | elif file_ext == ".csv": 367 | with g_pathmgr.open(filename, "r") as fopen: 368 | data = pd.read_csv(fopen) 369 | else: 370 | raise Exception(f"Reading from {file_ext} is not supported yet") 371 | return data 372 | 373 | 374 | def abspath(resource_path: str): 375 | """ 376 | Make a path absolute, but take into account prefixes like 377 | "http://" or "manifold://" 378 | """ 379 | regex = re.compile(r"^\w+://") 380 | if regex.match(resource_path) is None: 381 | return os.path.abspath(resource_path) 382 | else: 383 | return resource_path 384 | 385 | 386 | def makedir(dir_path): 387 | """ 388 | Create the directory if it does not exist. 389 | """ 390 | is_success = False 391 | try: 392 | if not g_pathmgr.exists(dir_path): 393 | g_pathmgr.mkdirs(dir_path) 394 | is_success = True 395 | except BaseException: 396 | logging.info(f"Error creating directory: {dir_path}") 397 | return is_success 398 | 399 | 400 | def is_url(input_url): 401 | """ 402 | Check if an input string is a url. look for http(s):// and ignoring the case 403 | """ 404 | is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None 405 | return is_url 406 | 407 | 408 | def cleanup_dir(dir): 409 | """ 410 | Utility for deleting a directory. Useful for cleaning the storage space 411 | that contains various training artifacts like checkpoints, data etc. 412 | """ 413 | if os.path.exists(dir): 414 | logging.info(f"Deleting directory: {dir}") 415 | shutil.rmtree(dir) 416 | logging.info(f"Deleted contents of directory: {dir}") 417 | 418 | 419 | def get_file_size(filename): 420 | """ 421 | Given a file, get the size of file in MB 422 | """ 423 | size_in_mb = os.path.getsize(filename) / float(1024**2) 424 | return size_in_mb 425 | -------------------------------------------------------------------------------- /minigpt4/common/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import json 10 | from typing import Dict 11 | 12 | from omegaconf import OmegaConf 13 | from minigpt4.common.registry import registry 14 | 15 | 16 | class Config: 17 | def __init__(self, args): 18 | self.config = {} 19 | 20 | self.args = args 21 | 22 | # Register the config and configuration for setup 23 | registry.register("configuration", self) 24 | 25 | user_config = self._build_opt_list(self.args.options) 26 | 27 | config = OmegaConf.load(self.args.cfg_path) 28 | 29 | runner_config = self.build_runner_config(config) 30 | model_config = self.build_model_config(config, **user_config) 31 | dataset_config = self.build_dataset_config(config) 32 | 33 | # Validate the user-provided runner configuration 34 | # model and dataset configuration are supposed to be validated by the respective classes 35 | # [TODO] validate the model/dataset configuration 36 | # self._validate_runner_config(runner_config) 37 | 38 | # Override the default configuration with user options. 39 | self.config = OmegaConf.merge( 40 | runner_config, model_config, dataset_config, user_config 41 | ) 42 | 43 | def _validate_runner_config(self, runner_config): 44 | """ 45 | This method validates the configuration, such that 46 | 1) all the user specified options are valid; 47 | 2) no type mismatches between the user specified options and the config. 48 | """ 49 | runner_config_validator = create_runner_config_validator() 50 | runner_config_validator.validate(runner_config) 51 | 52 | def _build_opt_list(self, opts): 53 | opts_dot_list = self._convert_to_dot_list(opts) 54 | return OmegaConf.from_dotlist(opts_dot_list) 55 | 56 | @staticmethod 57 | def build_model_config(config, **kwargs): 58 | model = config.get("model", None) 59 | assert model is not None, "Missing model configuration file." 60 | 61 | model_cls = registry.get_model_class(model.arch) 62 | assert model_cls is not None, f"Model '{model.arch}' has not been registered." 63 | 64 | model_type = kwargs.get("model.model_type", None) 65 | if not model_type: 66 | model_type = model.get("model_type", None) 67 | # else use the model type selected by user. 68 | 69 | assert model_type is not None, "Missing model_type." 70 | 71 | model_config_path = model_cls.default_config_path(model_type=model_type) 72 | 73 | model_config = OmegaConf.create() 74 | # hiararchy override, customized config > default config 75 | model_config = OmegaConf.merge( 76 | model_config, 77 | OmegaConf.load(model_config_path), 78 | {"model": config["model"]}, 79 | ) 80 | 81 | return model_config 82 | 83 | @staticmethod 84 | def build_runner_config(config): 85 | return {"run": config.run} 86 | 87 | @staticmethod 88 | def build_dataset_config(config): 89 | datasets = config.get("datasets", None) 90 | if datasets is None: 91 | raise KeyError( 92 | "Expecting 'datasets' as the root key for dataset configuration." 93 | ) 94 | 95 | dataset_config = OmegaConf.create() 96 | 97 | for dataset_name in datasets: 98 | builder_cls = registry.get_builder_class(dataset_name) 99 | 100 | dataset_config_type = datasets[dataset_name].get("type", "default") 101 | dataset_config_path = builder_cls.default_config_path( 102 | type=dataset_config_type 103 | ) 104 | 105 | # hiararchy override, customized config > default config 106 | dataset_config = OmegaConf.merge( 107 | dataset_config, 108 | OmegaConf.load(dataset_config_path), 109 | {"datasets": {dataset_name: config["datasets"][dataset_name]}}, 110 | ) 111 | 112 | return dataset_config 113 | 114 | def _convert_to_dot_list(self, opts): 115 | if opts is None: 116 | opts = [] 117 | 118 | if len(opts) == 0: 119 | return opts 120 | 121 | has_equal = opts[0].find("=") != -1 122 | 123 | if has_equal: 124 | return opts 125 | 126 | return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] 127 | 128 | def get_config(self): 129 | return self.config 130 | 131 | @property 132 | def run_cfg(self): 133 | return self.config.run 134 | 135 | @property 136 | def datasets_cfg(self): 137 | return self.config.datasets 138 | 139 | @property 140 | def model_cfg(self): 141 | return self.config.model 142 | 143 | def pretty_print(self): 144 | logging.info("\n===== Running Parameters =====") 145 | logging.info(self._convert_node_to_json(self.config.run)) 146 | 147 | logging.info("\n====== Dataset Attributes ======") 148 | datasets = self.config.datasets 149 | 150 | for dataset in datasets: 151 | if dataset in self.config.datasets: 152 | logging.info(f"\n======== {dataset} =======") 153 | dataset_config = self.config.datasets[dataset] 154 | logging.info(self._convert_node_to_json(dataset_config)) 155 | else: 156 | logging.warning(f"No dataset named '{dataset}' in config. Skipping") 157 | 158 | logging.info(f"\n====== Model Attributes ======") 159 | logging.info(self._convert_node_to_json(self.config.model)) 160 | 161 | def _convert_node_to_json(self, node): 162 | container = OmegaConf.to_container(node, resolve=True) 163 | return json.dumps(container, indent=4, sort_keys=True) 164 | 165 | def to_dict(self): 166 | return OmegaConf.to_container(self.config) 167 | 168 | 169 | def node_to_dict(node): 170 | return OmegaConf.to_container(node) 171 | 172 | 173 | class ConfigValidator: 174 | """ 175 | This is a preliminary implementation to centralize and validate the configuration. 176 | May be altered in the future. 177 | 178 | A helper class to validate configurations from yaml file. 179 | 180 | This serves the following purposes: 181 | 1. Ensure all the options in the yaml are defined, raise error if not. 182 | 2. when type mismatches are found, the validator will raise an error. 183 | 3. a central place to store and display helpful messages for supported configurations. 184 | 185 | """ 186 | 187 | class _Argument: 188 | def __init__(self, name, choices=None, type=None, help=None): 189 | self.name = name 190 | self.val = None 191 | self.choices = choices 192 | self.type = type 193 | self.help = help 194 | 195 | def __str__(self): 196 | s = f"{self.name}={self.val}" 197 | if self.type is not None: 198 | s += f", ({self.type})" 199 | if self.choices is not None: 200 | s += f", choices: {self.choices}" 201 | if self.help is not None: 202 | s += f", ({self.help})" 203 | return s 204 | 205 | def __init__(self, description): 206 | self.description = description 207 | 208 | self.arguments = dict() 209 | 210 | self.parsed_args = None 211 | 212 | def __getitem__(self, key): 213 | assert self.parsed_args is not None, "No arguments parsed yet." 214 | 215 | return self.parsed_args[key] 216 | 217 | def __str__(self) -> str: 218 | return self.format_help() 219 | 220 | def add_argument(self, *args, **kwargs): 221 | """ 222 | Assume the first argument is the name of the argument. 223 | """ 224 | self.arguments[args[0]] = self._Argument(*args, **kwargs) 225 | 226 | def validate(self, config=None): 227 | """ 228 | Convert yaml config (dict-like) to list, required by argparse. 229 | """ 230 | for k, v in config.items(): 231 | assert ( 232 | k in self.arguments 233 | ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" 234 | 235 | if self.arguments[k].type is not None: 236 | try: 237 | self.arguments[k].val = self.arguments[k].type(v) 238 | except ValueError: 239 | raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") 240 | 241 | if self.arguments[k].choices is not None: 242 | assert ( 243 | v in self.arguments[k].choices 244 | ), f"""{k} must be one of {self.arguments[k].choices}.""" 245 | 246 | return config 247 | 248 | def format_arguments(self): 249 | return str([f"{k}" for k in sorted(self.arguments.keys())]) 250 | 251 | def format_help(self): 252 | # description + key-value pair string for each argument 253 | help_msg = str(self.description) 254 | return help_msg + ", available arguments: " + self.format_arguments() 255 | 256 | def print_help(self): 257 | # display help message 258 | print(self.format_help()) 259 | 260 | 261 | def create_runner_config_validator(): 262 | validator = ConfigValidator(description="Runner configurations") 263 | 264 | validator.add_argument( 265 | "runner", 266 | type=str, 267 | choices=["runner_base", "runner_iter"], 268 | help="""Runner to use. The "runner_base" uses epoch-based training while iter-based 269 | runner runs based on iters. Default: runner_base""", 270 | ) 271 | # add argumetns for training dataset ratios 272 | validator.add_argument( 273 | "train_dataset_ratios", 274 | type=Dict[str, float], 275 | help="""Ratios of training dataset. This is used in iteration-based runner. 276 | Do not support for epoch-based runner because how to define an epoch becomes tricky. 277 | Default: None""", 278 | ) 279 | validator.add_argument( 280 | "max_iters", 281 | type=float, 282 | help="Maximum number of iterations to run.", 283 | ) 284 | validator.add_argument( 285 | "max_epoch", 286 | type=int, 287 | help="Maximum number of epochs to run.", 288 | ) 289 | # add arguments for iters_per_inner_epoch 290 | validator.add_argument( 291 | "iters_per_inner_epoch", 292 | type=float, 293 | help="Number of iterations per inner epoch. This is required when runner is runner_iter.", 294 | ) 295 | lr_scheds_choices = registry.list_lr_schedulers() 296 | validator.add_argument( 297 | "lr_sched", 298 | type=str, 299 | choices=lr_scheds_choices, 300 | help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), 301 | ) 302 | task_choices = registry.list_tasks() 303 | validator.add_argument( 304 | "task", 305 | type=str, 306 | choices=task_choices, 307 | help="Task to use, from {}".format(task_choices), 308 | ) 309 | # add arguments for init_lr 310 | validator.add_argument( 311 | "init_lr", 312 | type=float, 313 | help="Initial learning rate. This will be the learning rate after warmup and before decay.", 314 | ) 315 | # add arguments for min_lr 316 | validator.add_argument( 317 | "min_lr", 318 | type=float, 319 | help="Minimum learning rate (after decay).", 320 | ) 321 | # add arguments for warmup_lr 322 | validator.add_argument( 323 | "warmup_lr", 324 | type=float, 325 | help="Starting learning rate for warmup.", 326 | ) 327 | # add arguments for learning rate decay rate 328 | validator.add_argument( 329 | "lr_decay_rate", 330 | type=float, 331 | help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", 332 | ) 333 | # add arguments for weight decay 334 | validator.add_argument( 335 | "weight_decay", 336 | type=float, 337 | help="Weight decay rate.", 338 | ) 339 | # add arguments for training batch size 340 | validator.add_argument( 341 | "batch_size_train", 342 | type=int, 343 | help="Training batch size.", 344 | ) 345 | # add arguments for evaluation batch size 346 | validator.add_argument( 347 | "batch_size_eval", 348 | type=int, 349 | help="Evaluation batch size, including validation and testing.", 350 | ) 351 | # add arguments for number of workers for data loading 352 | validator.add_argument( 353 | "num_workers", 354 | help="Number of workers for data loading.", 355 | ) 356 | # add arguments for warm up steps 357 | validator.add_argument( 358 | "warmup_steps", 359 | type=int, 360 | help="Number of warmup steps. Required if a warmup schedule is used.", 361 | ) 362 | # add arguments for random seed 363 | validator.add_argument( 364 | "seed", 365 | type=int, 366 | help="Random seed.", 367 | ) 368 | # add arguments for output directory 369 | validator.add_argument( 370 | "output_dir", 371 | type=str, 372 | help="Output directory to save checkpoints and logs.", 373 | ) 374 | # add arguments for whether only use evaluation 375 | validator.add_argument( 376 | "evaluate", 377 | help="Whether to only evaluate the model. If true, training will not be performed.", 378 | ) 379 | # add arguments for splits used for training, e.g. ["train", "val"] 380 | validator.add_argument( 381 | "train_splits", 382 | type=list, 383 | help="Splits to use for training.", 384 | ) 385 | # add arguments for splits used for validation, e.g. ["val"] 386 | validator.add_argument( 387 | "valid_splits", 388 | type=list, 389 | help="Splits to use for validation. If not provided, will skip the validation.", 390 | ) 391 | # add arguments for splits used for testing, e.g. ["test"] 392 | validator.add_argument( 393 | "test_splits", 394 | type=list, 395 | help="Splits to use for testing. If not provided, will skip the testing.", 396 | ) 397 | # add arguments for accumulating gradient for iterations 398 | validator.add_argument( 399 | "accum_grad_iters", 400 | type=int, 401 | help="Number of iterations to accumulate gradient for.", 402 | ) 403 | 404 | # ====== distributed training ====== 405 | validator.add_argument( 406 | "device", 407 | type=str, 408 | choices=["cpu", "cuda"], 409 | help="Device to use. Support 'cuda' or 'cpu' as for now.", 410 | ) 411 | validator.add_argument( 412 | "world_size", 413 | type=int, 414 | help="Number of processes participating in the job.", 415 | ) 416 | validator.add_argument("dist_url", type=str) 417 | validator.add_argument("distributed", type=bool) 418 | # add arguments to opt using distributed sampler during evaluation or not 419 | validator.add_argument( 420 | "use_dist_eval_sampler", 421 | type=bool, 422 | help="Whether to use distributed sampler during evaluation or not.", 423 | ) 424 | 425 | # ====== task specific ====== 426 | # generation task specific arguments 427 | # add arguments for maximal length of text output 428 | validator.add_argument( 429 | "max_len", 430 | type=int, 431 | help="Maximal length of text output.", 432 | ) 433 | # add arguments for minimal length of text output 434 | validator.add_argument( 435 | "min_len", 436 | type=int, 437 | help="Minimal length of text output.", 438 | ) 439 | # add arguments number of beams 440 | validator.add_argument( 441 | "num_beams", 442 | type=int, 443 | help="Number of beams used for beam search.", 444 | ) 445 | 446 | # vqa task specific arguments 447 | # add arguments for number of answer candidates 448 | validator.add_argument( 449 | "num_ans_candidates", 450 | type=int, 451 | help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", 452 | ) 453 | # add arguments for inference method 454 | validator.add_argument( 455 | "inference_method", 456 | type=str, 457 | choices=["genearte", "rank"], 458 | help="""Inference method to use for question answering. If rank, requires a answer list.""", 459 | ) 460 | 461 | # ====== model specific ====== 462 | validator.add_argument( 463 | "k_test", 464 | type=int, 465 | help="Number of top k most similar samples from ITC/VTC selection to be tested.", 466 | ) 467 | 468 | return validator 469 | --------------------------------------------------------------------------------