├── minigpt4
├── common
│ ├── __init__.py
│ ├── gradcam.py
│ ├── optims.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── registry.py
│ ├── utils.py
│ └── config.py
├── conversation
│ ├── __init__.py
│ └── conversation.py
├── datasets
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── laion_dataset.py
│ │ ├── cc_combine_dataset.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ └── dataloader_utils.py
│ ├── builders
│ │ ├── __init__.py
│ │ ├── image_text_pair_builder.py
│ │ └── base_dataset_builder.py
│ └── data_utils.py
├── runners
│ └── __init__.py
├── configs
│ ├── default.yaml
│ ├── datasets
│ │ ├── cc_combine
│ │ │ ├── defaults.yaml
│ │ │ └── align.yaml
│ │ └── laion
│ │ │ └── defaults.yaml
│ └── models
│ │ └── minigpt4.yaml
├── tasks
│ ├── image_text_pretrain.py
│ ├── __init__.py
│ └── base_task.py
├── processors
│ ├── base_processor.py
│ ├── __init__.py
│ ├── blip_processors.py
│ └── randaugment.py
├── __init__.py
└── models
│ ├── blip2_outputs.py
│ ├── __init__.py
│ ├── blip2.py
│ ├── base_model.py
│ └── mini_gpt4.py
├── mycaptions
└── put your captions here
├── images
└── SAMPLE IMAGE DELETE IF U WANT.png
├── CODEOWNERS
├── requirements.txt
├── MANIFEST.in
├── prompts
└── alignment.txt
├── extract.py
├── run.bat
├── eval_configs
└── minigpt4.yaml
├── train_configs
├── minigpt4_stage1_laion.yaml
└── minigpt4_stage2_align.yaml
├── LICENSE.txt
├── .gitattributes
├── setup.bat
├── app.py
├── README.md
└── backup_app.py
/minigpt4/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/minigpt4/conversation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/minigpt4/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mycaptions/put your captions here:
--------------------------------------------------------------------------------
1 | pls pls see your captions must be here
--------------------------------------------------------------------------------
/images/SAMPLE IMAGE DELETE IF U WANT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pipinstallyp/minigpt4-batch/HEAD/images/SAMPLE IMAGE DELETE IF U WANT.png
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu113
2 | torch==1.12.1
3 | torchvision==0.13.1
4 | salesforce-lavis
5 | bitsandbytes
6 | accelerate
7 | git+https://github.com/huggingface/transformers.git
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include minigpt4/configs *.yaml *.json
2 | recursive-include minigpt4/projects *.yaml *.json
3 |
4 | recursive-exclude minigpt4/datasets/download_scripts *
5 | recursive-exclude minigpt4/output *
6 |
7 | include requirements.txt
8 |
--------------------------------------------------------------------------------
/prompts/alignment.txt:
--------------------------------------------------------------------------------
1 |
Describe this image in detail.
2 |
Take a look at this image and describe what you notice.
3 |
Please provide a detailed description of the picture.
4 |
Could you describe the contents of this image for me?
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 |
3 | def unzip_file(zip_filepath, dest_path):
4 | # Open the zip file in read mode
5 | with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
6 | # Extract all the contents of the zip file in the destination directory
7 | zip_ref.extractall(dest_path)
8 |
9 | # Call the function
10 | unzip_file('./models.zip', './')
--------------------------------------------------------------------------------
/minigpt4/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.runners.runner_base import RunnerBase
9 |
10 | __all__ = ["RunnerBase"]
11 |
--------------------------------------------------------------------------------
/minigpt4/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | env:
7 | # For default users
8 | # cache_root: "cache"
9 | # For internal use with persistent storage
10 | cache_root: "/export/home/.cache/minigpt4"
11 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_combine/defaults.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | cc_combine:
8 | data_type: images
9 | build_info:
10 | # Be careful not to append minus sign (-) before split to avoid itemizing
11 | storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar
12 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/laion/defaults.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | laion:
8 |
9 | data_type: images
10 |
11 | build_info:
12 | # Be careful not to append minus sign (-) before split to avoid itemizing
13 | storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar
14 |
--------------------------------------------------------------------------------
/minigpt4/tasks/image_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("image_text_pretrain")
13 | class ImageTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_combine/align.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | cc_align:
8 | data_type: images
9 | build_info:
10 | # Be careful not to append minus sign (-) before split to avoid itemizing
11 | annotations:
12 | train:
13 | url: placeholder
14 | storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json
15 | images:
16 | storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/
17 |
--------------------------------------------------------------------------------
/minigpt4/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/run.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set PYTHON_VER=3.10.11
4 |
5 | :: Check if Python version meets the recommended version
6 | python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
7 | if errorlevel 1 (
8 | echo Warning: Python version %PYTHON_VER% is recommended.
9 | )
10 |
11 | IF NOT EXIST venv (
12 | echo Error: Virtual environment venv not found. Please run skeleton.bat first to create the environment and install required packages.
13 | exit /b 1
14 | )
15 |
16 | :: Deactivate the virtual environment
17 | call .\venv\Scripts\deactivate.bat
18 |
19 | :: Activate the virtual environment
20 | call .\venv\Scripts\activate.bat
21 |
22 | echo Running app.py within the virtual environment...
23 | python app.py --image-folder ./images --beam-search-numbers 2
--------------------------------------------------------------------------------
/eval_configs/minigpt4.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | model:
7 | arch: mini_gpt4
8 | model_type: pretrain_vicuna
9 | freeze_vit: True
10 | freeze_qformer: True
11 | max_txt_len: 160
12 | end_sym: "###"
13 | prompt_path: "prompts/alignment.txt"
14 | prompt_template: '###Human: {} ###Assistant: '
15 | ckpt: 'checkpoint.pth'
16 |
17 |
18 | datasets:
19 | cc_align:
20 | vis_processor:
21 | train:
22 | name: "blip2_image_eval"
23 | image_size: 224
24 | text_processor:
25 | train:
26 | name: "blip_caption"
27 |
28 | run:
29 | task: image_text_pretrain
30 |
31 |
--------------------------------------------------------------------------------
/minigpt4/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
11 |
12 |
13 | def setup_task(cfg):
14 | assert "task" in cfg.run_cfg, "Task name must be provided."
15 |
16 | task_name = cfg.run_cfg.task
17 | task = registry.get_task_class(task_name).setup_task(cfg=cfg)
18 | assert task is not None, "Task {} not properly registered.".format(task_name)
19 |
20 | return task
21 |
22 |
23 | __all__ = [
24 | "BaseTask",
25 | "ImageTextPretrainTask",
26 | ]
27 |
--------------------------------------------------------------------------------
/minigpt4/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/minigpt4/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.processors.base_processor import BaseProcessor
9 | from minigpt4.processors.blip_processors import (
10 | Blip2ImageTrainProcessor,
11 | Blip2ImageEvalProcessor,
12 | BlipCaptionProcessor,
13 | )
14 |
15 | from minigpt4.common.registry import registry
16 |
17 | __all__ = [
18 | "BaseProcessor",
19 | "Blip2ImageTrainProcessor",
20 | "Blip2ImageEvalProcessor",
21 | "BlipCaptionProcessor",
22 | ]
23 |
24 |
25 | def load_processor(name, cfg=None):
26 | """
27 | Example
28 |
29 | >>> processor = load_processor("alpro_video_train", cfg=None)
30 | """
31 | processor = registry.get_processor_class(name).from_config(cfg)
32 |
33 | return processor
34 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt4.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | model:
7 | arch: mini_gpt4
8 |
9 | # vit encoder
10 | image_size: 224
11 | drop_path_rate: 0
12 | use_grad_checkpoint: False
13 | vit_precision: "fp16"
14 | freeze_vit: True
15 | freeze_qformer: True
16 |
17 | # Q-Former
18 | num_query_token: 32
19 |
20 | # Vicuna
21 | llama_model: "vicuna"
22 |
23 | # generation configs
24 | prompt: ""
25 |
26 |
27 | preprocess:
28 | vis_processor:
29 | train:
30 | name: "blip2_image_train"
31 | image_size: 224
32 | eval:
33 | name: "blip2_image_eval"
34 | image_size: 224
35 | text_processor:
36 | train:
37 | name: "blip_caption"
38 | eval:
39 | name: "blip_caption"
40 |
--------------------------------------------------------------------------------
/minigpt4/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | import sys
10 |
11 | from omegaconf import OmegaConf
12 |
13 | from minigpt4.common.registry import registry
14 |
15 | from minigpt4.datasets.builders import *
16 | from minigpt4.models import *
17 | from minigpt4.processors import *
18 | from minigpt4.tasks import *
19 |
20 |
21 | root_dir = os.path.dirname(os.path.abspath(__file__))
22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23 |
24 | registry.register_path("library_root", root_dir)
25 | repo_root = os.path.join(root_dir, "..")
26 | registry.register_path("repo_root", repo_root)
27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28 | registry.register_path("cache_root", cache_root)
29 |
30 | registry.register("MAX_INT", sys.maxsize)
31 | registry.register("SPLIT_NAMES", ["train", "val", "test"])
32 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import webdataset as wds
9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
10 |
11 |
12 | class LaionDataset(BaseDataset):
13 | def __init__(self, vis_processor, text_processor, location):
14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15 |
16 | self.inner_dataset = wds.DataPipeline(
17 | wds.ResampledShards(location),
18 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
19 | wds.shuffle(1000, handler=wds.warn_and_continue),
20 | wds.decode("pilrgb", handler=wds.warn_and_continue),
21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23 | wds.map(self.to_dict, handler=wds.warn_and_continue),
24 | )
25 |
26 | def to_dict(self, sample):
27 | return {
28 | "image": sample[0],
29 | "text_input": self.text_processor(sample[1]["caption"]),
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/train_configs/minigpt4_stage1_laion.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | model:
7 | arch: mini_gpt4
8 | model_type: pretrain_vicuna
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 |
13 | datasets:
14 | laion:
15 | vis_processor:
16 | train:
17 | name: "blip2_image_train"
18 | image_size: 224
19 | text_processor:
20 | train:
21 | name: "blip_caption"
22 | sample_ratio: 115
23 | cc_combine:
24 | vis_processor:
25 | train:
26 | name: "blip2_image_train"
27 | image_size: 224
28 | text_processor:
29 | train:
30 | name: "blip_caption"
31 | sample_ratio: 14
32 |
33 |
34 | run:
35 | task: image_text_pretrain
36 | # optimizer
37 | lr_sched: "linear_warmup_cosine_lr"
38 | init_lr: 1e-4
39 | min_lr: 3e-5
40 | warmup_lr: 1e-6
41 |
42 | weight_decay: 0.05
43 | max_epoch: 4
44 | batch_size_train: 64
45 | batch_size_eval: 64
46 | num_workers: 4
47 | warmup_steps: 5000
48 | iters_per_epoch: 5000
49 |
50 | seed: 42
51 | output_dir: "/path/to/save/your/model/"
52 |
53 | amp: True
54 | resume_ckpt_path: null
55 |
56 | evaluate: False
57 | train_splits: ["train"]
58 |
59 | device: "cuda"
60 | world_size: 1
61 | dist_url: "env://"
62 | distributed: True
--------------------------------------------------------------------------------
/train_configs/minigpt4_stage2_align.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | model:
7 | arch: mini_gpt4
8 | model_type: pretrain_vicuna
9 | freeze_vit: True
10 | freeze_qformer: True
11 | max_txt_len: 160
12 | end_sym: "###"
13 | prompt_path: "prompts/alignment.txt"
14 | prompt_template: '###Human: {} ###Assistant: '
15 | ckpt: '/ibex/project/c2133/vicuna_jun_checkpoint_wihtout_prompt/20230412162/checkpoint_3.pth'
16 |
17 |
18 | datasets:
19 | cc_align:
20 | vis_processor:
21 | train:
22 | name: "blip2_image_train"
23 | image_size: 224
24 | text_processor:
25 | train:
26 | name: "blip_caption"
27 |
28 | run:
29 | task: image_text_pretrain
30 | # optimizer
31 | lr_sched: "linear_warmup_cosine_lr"
32 | init_lr: 3e-5
33 | min_lr: 1e-5
34 | warmup_lr: 1e-6
35 |
36 | weight_decay: 0.05
37 | max_epoch: 5
38 | iters_per_epoch: 200
39 | batch_size_train: 12
40 | batch_size_eval: 12
41 | num_workers: 4
42 | warmup_steps: 200
43 |
44 | seed: 42
45 | output_dir: "/ibex/project/c2133/vicuna_ckpt_test/minigpt4_stage2_align"
46 |
47 | amp: True
48 | resume_ckpt_path: null
49 |
50 | evaluate: False
51 | train_splits: ["train"]
52 |
53 | device: "cuda"
54 | world_size: 1
55 | dist_url: "env://"
56 | distributed: True
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022 Salesforce, Inc.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11 |
12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tflite filter=lfs diff=lfs merge=lfs -text
29 | *.tgz filter=lfs diff=lfs merge=lfs -text
30 | *.wasm filter=lfs diff=lfs merge=lfs -text
31 | *.xz filter=lfs diff=lfs merge=lfs -text
32 | *.zip filter=lfs diff=lfs merge=lfs -text
33 | *.zst filter=lfs diff=lfs merge=lfs -text
34 | *tfevents* filter=lfs diff=lfs merge=lfs -text
35 | MiniGPT_4.pdf filter=lfs diff=lfs merge=lfs -text
36 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/cc_combine_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import os
8 | from PIL import Image
9 | import webdataset as wds
10 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
11 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
12 |
13 |
14 | class CCCombineDataset(BaseDataset):
15 | def __init__(self, vis_processor, text_processor, location):
16 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
17 |
18 | self.inner_dataset = wds.DataPipeline(
19 | wds.ResampledShards(location),
20 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
21 | wds.shuffle(1000, handler=wds.warn_and_continue),
22 | wds.decode("pilrgb", handler=wds.warn_and_continue),
23 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
24 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
25 | wds.map(self.to_dict, handler=wds.warn_and_continue),
26 | )
27 |
28 | def to_dict(self, sample):
29 | return {
30 | "image": sample[0],
31 | "text_input": self.text_processor(sample[1]["caption"]),
32 | }
33 |
34 |
35 | class CCAlignDataset(CaptionDataset):
36 |
37 | def __getitem__(self, index):
38 |
39 | # TODO this assumes image input, not general enough
40 | ann = self.annotation[index]
41 |
42 | img_file = '{}.jpg'.format(ann["image_id"])
43 | image_path = os.path.join(self.vis_root, img_file)
44 | image = Image.open(image_path).convert("RGB")
45 |
46 | image = self.vis_processor(image)
47 | caption = ann["caption"]
48 |
49 | return {
50 | "image": image,
51 | "text_input": caption,
52 | "image_id": self.img_ids[ann["image_id"]],
53 | }
--------------------------------------------------------------------------------
/setup.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set PYTHON_VER=3.10.11
4 |
5 | :: Check if Python version meets the recommended version
6 | python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
7 | if errorlevel 1 (
8 | echo Warning: Python version %PYTHON_VER% is recommended.
9 | )
10 |
11 | IF NOT EXIST venv (
12 | echo Creating venv...
13 | python -m venv venv
14 | )
15 |
16 | :: Deactivate the virtual environment
17 | call .\venv\Scripts\deactivate.bat
18 |
19 | :: Activate the virtual environment
20 | call .\venv\Scripts\activate.bat
21 |
22 | echo Installing Torch and torchvision...
23 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
24 |
25 | echo Cloning and installing bitsandbytes-windows...
26 | git clone https://github.com/Keith-Hon/bitsandbytes-windows.git
27 | cd bitsandbytes-windows
28 | pip3 install -e .
29 | cd ..
30 |
31 | echo Downloading pretrained models...
32 | curl -L -o ./checkpoint.pth https://huggingface.co/ckpt/minigpt4-7B/resolve/main/prerained_minigpt4_7b.pth
33 | curl -L -o ./blip2_pretrained_flant5xxl.pth https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth
34 | curl -L -o ./models.zip https://huggingface.co/pipyp/minigpt4py/resolve/main/models.zip
35 | python extract.py
36 |
37 | echo Installing cmake, lit, salesforce-lavis, accelerate, and transformers...
38 | pip install cmake
39 | pip install lit
40 | pip install -q salesforce-lavis
41 | pip install -q accelerate
42 | pip install -q git+https://github.com/huggingface/transformers.git -U
43 |
44 | :: Adding the extra required libraries...
45 | echo Installing argparse, csv, os, random, glob, time, numpy, Pillow, cv2, tqdm, tensorflow, huggingface-hub, pathlib, copy, and keras...
46 | pip install argparse
47 | pip install csv
48 | pip install os
49 | pip install random
50 | pip install glob
51 | pip install time
52 | pip install numpy
53 | pip install Pillow
54 | pip install opencv-python-headless
55 | pip install tqdm
56 | pip install tensorflow
57 | pip install huggingface-hub
58 | pip install pathlib
59 | pip install copy
60 | pip install keras
61 |
62 | echo Setup complete within virtual environment!
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9 | from minigpt4.datasets.builders.image_text_pair_builder import (
10 | CCCombineBuilder,
11 | LaionBuilder,
12 | CCAlignBuilder
13 | )
14 | from minigpt4.common.registry import registry
15 |
16 | __all__ = [
17 | "CCCombineBuilder",
18 | "LaionBuilder",
19 | "CCAlignBuilder"
20 | ]
21 |
22 |
23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24 | """
25 | Example
26 |
27 | >>> dataset = load_dataset("coco_caption", cfg=None)
28 | >>> splits = dataset.keys()
29 | >>> print([len(dataset[split]) for split in splits])
30 |
31 | """
32 | if cfg_path is None:
33 | cfg = None
34 | else:
35 | cfg = load_dataset_config(cfg_path)
36 |
37 | try:
38 | builder = registry.get_builder_class(name)(cfg)
39 | except TypeError:
40 | print(
41 | f"Dataset {name} not found. Available datasets:\n"
42 | + ", ".join([str(k) for k in dataset_zoo.get_names()])
43 | )
44 | exit(1)
45 |
46 | if vis_path is not None:
47 | if data_type is None:
48 | # use default data type in the config
49 | data_type = builder.config.data_type
50 |
51 | assert (
52 | data_type in builder.config.build_info
53 | ), f"Invalid data_type {data_type} for {name}."
54 |
55 | builder.config.build_info.get(data_type).storage = vis_path
56 |
57 | dataset = builder.build_datasets()
58 | return dataset
59 |
60 |
61 | class DatasetZoo:
62 | def __init__(self) -> None:
63 | self.dataset_zoo = {
64 | k: list(v.DATASET_CONFIG_DICT.keys())
65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66 | }
67 |
68 | def get_names(self):
69 | return list(self.dataset_zoo.keys())
70 |
71 |
72 | dataset_zoo = DatasetZoo()
73 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import json
9 | from typing import Iterable
10 |
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 | class BaseDataset(Dataset):
16 | def __init__(
17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18 | ):
19 | """
20 | vis_root (string): Root directory of images (e.g. coco/images/)
21 | ann_root (string): directory to store the annotation file
22 | """
23 | self.vis_root = vis_root
24 |
25 | self.annotation = []
26 | for ann_path in ann_paths:
27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28 |
29 | self.vis_processor = vis_processor
30 | self.text_processor = text_processor
31 |
32 | self._add_instance_ids()
33 |
34 | def __len__(self):
35 | return len(self.annotation)
36 |
37 | def collater(self, samples):
38 | return default_collate(samples)
39 |
40 | def set_processors(self, vis_processor, text_processor):
41 | self.vis_processor = vis_processor
42 | self.text_processor = text_processor
43 |
44 | def _add_instance_ids(self, key="instance_id"):
45 | for idx, ann in enumerate(self.annotation):
46 | ann[key] = str(idx)
47 |
48 |
49 | class ConcatDataset(ConcatDataset):
50 | def __init__(self, datasets: Iterable[Dataset]) -> None:
51 | super().__init__(datasets)
52 |
53 | def collater(self, samples):
54 | # TODO For now only supports datasets with same underlying collater implementations
55 |
56 | all_keys = set()
57 | for s in samples:
58 | all_keys.update(s)
59 |
60 | shared_keys = all_keys
61 | for s in samples:
62 | shared_keys = shared_keys & set(s.keys())
63 |
64 | samples_shared_keys = []
65 | for s in samples:
66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67 |
68 | return self.datasets[0].collater(samples_shared_keys)
69 |
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/image_text_pair_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 |
10 | from minigpt4.common.registry import registry
11 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
12 | from minigpt4.datasets.datasets.laion_dataset import LaionDataset
13 | from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset
14 |
15 |
16 | @registry.register_builder("cc_combine")
17 | class CCCombineBuilder(BaseDatasetBuilder):
18 | train_dataset_cls = CCCombineDataset
19 |
20 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"}
21 |
22 | def _download_ann(self):
23 | pass
24 |
25 | def _download_vis(self):
26 | pass
27 |
28 | def build(self):
29 | self.build_processors()
30 |
31 | build_info = self.config.build_info
32 |
33 | datasets = dict()
34 | split = "train"
35 |
36 | # create datasets
37 | # [NOTE] return inner_datasets (wds.DataPipeline)
38 | dataset_cls = self.train_dataset_cls
39 | datasets[split] = dataset_cls(
40 | vis_processor=self.vis_processors[split],
41 | text_processor=self.text_processors[split],
42 | location=build_info.storage,
43 | ).inner_dataset
44 |
45 | return datasets
46 |
47 |
48 | @registry.register_builder("laion")
49 | class LaionBuilder(BaseDatasetBuilder):
50 | train_dataset_cls = LaionDataset
51 |
52 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
53 |
54 | def _download_ann(self):
55 | pass
56 |
57 | def _download_vis(self):
58 | pass
59 |
60 | def build(self):
61 | self.build_processors()
62 |
63 | build_info = self.config.build_info
64 |
65 | datasets = dict()
66 | split = "train"
67 |
68 | # create datasets
69 | # [NOTE] return inner_datasets (wds.DataPipeline)
70 | dataset_cls = self.train_dataset_cls
71 | datasets[split] = dataset_cls(
72 | vis_processor=self.vis_processors[split],
73 | text_processor=self.text_processors[split],
74 | location=build_info.storage,
75 | ).inner_dataset
76 |
77 | return datasets
78 |
79 |
80 | @registry.register_builder("cc_align")
81 | class CCAlignBuilder(BaseDatasetBuilder):
82 | train_dataset_cls = CCAlignDataset
83 |
84 | DATASET_CONFIG_DICT = {
85 | "default": "configs/datasets/cc_combine/align.yaml",
86 | }
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/caption_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from collections import OrderedDict
10 |
11 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
12 | from PIL import Image
13 |
14 |
15 | class __DisplMixin:
16 | def displ_item(self, index):
17 | sample, ann = self.__getitem__(index), self.annotation[index]
18 |
19 | return OrderedDict(
20 | {
21 | "file": ann["image"],
22 | "caption": ann["caption"],
23 | "image": sample["image"],
24 | }
25 | )
26 |
27 |
28 | class CaptionDataset(BaseDataset, __DisplMixin):
29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30 | """
31 | vis_root (string): Root directory of images (e.g. coco/images/)
32 | ann_root (string): directory to store the annotation file
33 | """
34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35 |
36 | self.img_ids = {}
37 | n = 0
38 | for ann in self.annotation:
39 | img_id = ann["image_id"]
40 | if img_id not in self.img_ids.keys():
41 | self.img_ids[img_id] = n
42 | n += 1
43 |
44 | def __getitem__(self, index):
45 |
46 | # TODO this assumes image input, not general enough
47 | ann = self.annotation[index]
48 |
49 | img_file = '{:0>12}.jpg'.format(ann["image_id"])
50 | image_path = os.path.join(self.vis_root, img_file)
51 | image = Image.open(image_path).convert("RGB")
52 |
53 | image = self.vis_processor(image)
54 | caption = self.text_processor(ann["caption"])
55 |
56 | return {
57 | "image": image,
58 | "text_input": caption,
59 | "image_id": self.img_ids[ann["image_id"]],
60 | }
61 |
62 |
63 | class CaptionEvalDataset(BaseDataset, __DisplMixin):
64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65 | """
66 | vis_root (string): Root directory of images (e.g. coco/images/)
67 | ann_root (string): directory to store the annotation file
68 | split (string): val or test
69 | """
70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71 |
72 | def __getitem__(self, index):
73 |
74 | ann = self.annotation[index]
75 |
76 | image_path = os.path.join(self.vis_root, ann["image"])
77 | image = Image.open(image_path).convert("RGB")
78 |
79 | image = self.vis_processor(image)
80 |
81 | return {
82 | "image": image,
83 | "image_id": ann["image_id"],
84 | "instance_id": ann["instance_id"],
85 | }
86 |
--------------------------------------------------------------------------------
/minigpt4/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from minigpt4.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | iters_per_epoch,
63 | min_lr,
64 | init_lr,
65 | warmup_steps=0,
66 | warmup_start_lr=-1,
67 | **kwargs
68 | ):
69 | self.optimizer = optimizer
70 |
71 | self.max_epoch = max_epoch
72 | self.iters_per_epoch = iters_per_epoch
73 | self.min_lr = min_lr
74 |
75 | self.init_lr = init_lr
76 | self.warmup_steps = warmup_steps
77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78 |
79 | def step(self, cur_epoch, cur_step):
80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81 | if total_cur_step < self.warmup_steps:
82 | warmup_lr_schedule(
83 | step=cur_step,
84 | optimizer=self.optimizer,
85 | max_step=self.warmup_steps,
86 | init_lr=self.warmup_start_lr,
87 | max_lr=self.init_lr,
88 | )
89 | else:
90 | cosine_lr_schedule(
91 | epoch=total_cur_step,
92 | optimizer=self.optimizer,
93 | max_epoch=self.max_epoch * self.iters_per_epoch,
94 | init_lr=self.init_lr,
95 | min_lr=self.min_lr,
96 | )
97 |
98 |
99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100 | """Decay the learning rate"""
101 | lr = (init_lr - min_lr) * 0.5 * (
102 | 1.0 + math.cos(math.pi * epoch / max_epoch)
103 | ) + min_lr
104 | for param_group in optimizer.param_groups:
105 | param_group["lr"] = lr
106 |
107 |
108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109 | """Warmup the learning rate"""
110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111 | for param_group in optimizer.param_groups:
112 | param_group["lr"] = lr
113 |
114 |
115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116 | """Decay the learning rate"""
117 | lr = max(min_lr, init_lr * (decay_rate**epoch))
118 | for param_group in optimizer.param_groups:
119 | param_group["lr"] = lr
120 |
--------------------------------------------------------------------------------
/minigpt4/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import functools
10 | import os
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import timm.models.hub as timm_hub
15 |
16 |
17 | def setup_for_distributed(is_master):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | import builtins as __builtin__
22 |
23 | builtin_print = __builtin__.print
24 |
25 | def print(*args, **kwargs):
26 | force = kwargs.pop("force", False)
27 | if is_master or force:
28 | builtin_print(*args, **kwargs)
29 |
30 | __builtin__.print = print
31 |
32 |
33 | def is_dist_avail_and_initialized():
34 | if not dist.is_available():
35 | return False
36 | if not dist.is_initialized():
37 | return False
38 | return True
39 |
40 |
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 |
47 | def get_rank():
48 | if not is_dist_avail_and_initialized():
49 | return 0
50 | return dist.get_rank()
51 |
52 |
53 | def is_main_process():
54 | return get_rank() == 0
55 |
56 |
57 | def init_distributed_mode(args):
58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59 | args.rank = int(os.environ["RANK"])
60 | args.world_size = int(os.environ["WORLD_SIZE"])
61 | args.gpu = int(os.environ["LOCAL_RANK"])
62 | elif "SLURM_PROCID" in os.environ:
63 | args.rank = int(os.environ["SLURM_PROCID"])
64 | args.gpu = args.rank % torch.cuda.device_count()
65 | else:
66 | print("Not using distributed mode")
67 | args.distributed = False
68 | return
69 |
70 | args.distributed = True
71 |
72 | torch.cuda.set_device(args.gpu)
73 | args.dist_backend = "nccl"
74 | print(
75 | "| distributed init (rank {}, world {}): {}".format(
76 | args.rank, args.world_size, args.dist_url
77 | ),
78 | flush=True,
79 | )
80 | torch.distributed.init_process_group(
81 | backend=args.dist_backend,
82 | init_method=args.dist_url,
83 | world_size=args.world_size,
84 | rank=args.rank,
85 | timeout=datetime.timedelta(
86 | days=365
87 | ), # allow auto-downloading and de-compressing
88 | )
89 | torch.distributed.barrier()
90 | setup_for_distributed(args.rank == 0)
91 |
92 |
93 | def get_dist_info():
94 | if torch.__version__ < "1.0":
95 | initialized = dist._initialized
96 | else:
97 | initialized = dist.is_initialized()
98 | if initialized:
99 | rank = dist.get_rank()
100 | world_size = dist.get_world_size()
101 | else: # non-distributed training
102 | rank = 0
103 | world_size = 1
104 | return rank, world_size
105 |
106 |
107 | def main_process(func):
108 | @functools.wraps(func)
109 | def wrapper(*args, **kwargs):
110 | rank, _ = get_dist_info()
111 | if rank == 0:
112 | return func(*args, **kwargs)
113 |
114 | return wrapper
115 |
116 |
117 | def download_cached_file(url, check_hash=True, progress=False):
118 | """
119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121 | """
122 |
123 | def get_cached_file_path():
124 | # a hack to sync the file path across processes
125 | parts = torch.hub.urlparse(url)
126 | filename = os.path.basename(parts.path)
127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128 |
129 | return cached_file
130 |
131 | if is_main_process():
132 | timm_hub.download_cached_file(url, check_hash, progress)
133 |
134 | if is_dist_avail_and_initialized():
135 | dist.barrier()
136 |
137 | return get_cached_file_path()
138 |
--------------------------------------------------------------------------------
/minigpt4/models/blip2_outputs.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Optional
10 |
11 | import torch
12 | from transformers.modeling_outputs import (
13 | ModelOutput,
14 | BaseModelOutputWithPoolingAndCrossAttentions,
15 | CausalLMOutputWithCrossAttentions,
16 | )
17 |
18 |
19 | @dataclass
20 | class BlipSimilarity(ModelOutput):
21 | sim_i2t: torch.FloatTensor = None
22 | sim_t2i: torch.FloatTensor = None
23 |
24 | sim_i2t_m: Optional[torch.FloatTensor] = None
25 | sim_t2i_m: Optional[torch.FloatTensor] = None
26 |
27 | sim_i2t_targets: Optional[torch.FloatTensor] = None
28 | sim_t2i_targets: Optional[torch.FloatTensor] = None
29 |
30 |
31 | @dataclass
32 | class BlipIntermediateOutput(ModelOutput):
33 | """
34 | Data class for intermediate outputs of BLIP models.
35 |
36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38 |
39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41 |
42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44 |
45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46 | decoder_labels (torch.LongTensor): labels for the captioning loss.
47 |
48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50 |
51 | """
52 |
53 | # uni-modal features
54 | image_embeds: torch.FloatTensor = None
55 | text_embeds: Optional[torch.FloatTensor] = None
56 |
57 | image_embeds_m: Optional[torch.FloatTensor] = None
58 | text_embeds_m: Optional[torch.FloatTensor] = None
59 |
60 | # intermediate outputs of multimodal encoder
61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63 |
64 | itm_logits: Optional[torch.FloatTensor] = None
65 | itm_labels: Optional[torch.LongTensor] = None
66 |
67 | # intermediate outputs of multimodal decoder
68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69 | decoder_labels: Optional[torch.LongTensor] = None
70 |
71 |
72 | @dataclass
73 | class BlipOutput(ModelOutput):
74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75 | sims: Optional[BlipSimilarity] = None
76 |
77 | intermediate_output: BlipIntermediateOutput = None
78 |
79 | loss: Optional[torch.FloatTensor] = None
80 |
81 | loss_itc: Optional[torch.FloatTensor] = None
82 |
83 | loss_itm: Optional[torch.FloatTensor] = None
84 |
85 | loss_lm: Optional[torch.FloatTensor] = None
86 |
87 |
88 | @dataclass
89 | class BlipOutputFeatures(ModelOutput):
90 | """
91 | Data class of features from BlipFeatureExtractor.
92 |
93 | Args:
94 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98 |
99 | The first embedding or feature is for the [CLS] token.
100 |
101 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102 | """
103 |
104 | image_embeds: Optional[torch.FloatTensor] = None
105 | image_embeds_proj: Optional[torch.FloatTensor] = None
106 |
107 | text_embeds: Optional[torch.FloatTensor] = None
108 | text_embeds_proj: Optional[torch.FloatTensor] = None
109 |
110 | multimodal_embeds: Optional[torch.FloatTensor] = None
111 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import glob
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | from PIL import Image
11 |
12 | from minigpt4.common.config import Config
13 | from minigpt4.common.dist_utils import get_rank
14 | from minigpt4.common.registry import registry
15 | from minigpt4.conversation.conversation import Chat, CONV_VISION
16 |
17 | # imports modules for registration
18 | from minigpt4.datasets.builders import *
19 | from minigpt4.models import *
20 | from minigpt4.processors import *
21 | from minigpt4.runners import *
22 | from minigpt4.tasks import *
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description="Demo")
27 | parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
28 | parser.add_argument(
29 | "--options",
30 | nargs="+",
31 | help="override some settings in the used config, the key-value pair "
32 | "in xxx=yyy format will be merged into config file (deprecate), "
33 | "change to --cfg-options instead.",
34 | )
35 | parser.add_argument("--image-folder", type=str, required=True, help="path to the input image folder")
36 | parser.add_argument("--beam-search-numbers", type=int, default=1, help="beam search numbers")
37 | parser.add_argument("--model", type=str, default='llama', help="Model to be used for generation. Options: 'llama' (default), 'llama7b'")
38 | parser.add_argument("--save-in-imgfolder", action="store_true", help="save captions in the input image folder")
39 | options = parser.parse_args()
40 | return options
41 |
42 |
43 | def setup_seeds(config):
44 | seed = config.run_cfg.seed + get_rank()
45 |
46 | random.seed(seed)
47 | np.random.seed(seed)
48 | torch.manual_seed(seed)
49 |
50 | cudnn.benchmark = False
51 | cudnn.deterministic = True
52 |
53 |
54 | def describe_image(image_path, chat, chat_state, img, num_beams=1, temperature=1.0):
55 | chat_state = CONV_VISION.copy()
56 | img_list = []
57 |
58 | gr_img = Image.open(image_path)
59 | llm_message = chat.upload_img(gr_img, chat_state, img_list)
60 |
61 | chat.ask("Describe this image.", chat_state)
62 | generated_caption = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
63 |
64 | return generated_caption
65 |
66 |
67 | if __name__ == '__main__':
68 | args = parse_args()
69 |
70 | cfg = Config(args)
71 |
72 | model_config = cfg.model_cfg
73 | if args.model == "llama7b":
74 | model_config.llama_model = "camenduru/MiniGPT4-7B"
75 |
76 | model_cls = registry.get_model_class(model_config.arch)
77 | model = model_cls.from_config(model_config).to('cuda:0')
78 |
79 | vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
80 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
81 | chat = Chat(model, vis_processor)
82 |
83 | chat_state = CONV_VISION.copy()
84 | img_list = []
85 |
86 | image_folder = args.image_folder
87 | num_beams = args.beam_search_numbers
88 | temperature = 1.0 # default temperature
89 |
90 | image_extensions = ['jpg', 'jpeg', 'png', 'bmp', "webp"]
91 | image_paths = []
92 |
93 | for ext in image_extensions:
94 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext}')))
95 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext.upper()}')))
96 |
97 | if not args.save_in_imgfolder:
98 | if not os.path.exists("mycaptions"):
99 | os.makedirs("mycaptions")
100 |
101 | for image_path in image_paths:
102 | start_time = time.time()
103 | caption = describe_image(image_path, chat, chat_state, img_list, num_beams, temperature)
104 |
105 | if args.save_in_imgfolder:
106 | output_path = os.path.join(image_folder, "{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]))
107 | else:
108 | output_path = "mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0])
109 |
110 | with open(output_path, "w") as f:
111 | f.write(caption)
112 |
113 | end_time = time.time()
114 | time_taken = end_time - start_time
115 | print(f"Caption for {os.path.basename(image_path)} saved in '{output_path}'")
116 | print(f"Time taken to process caption for {os.path.basename(image_path)} is: {time_taken:.2f} s")
--------------------------------------------------------------------------------
/minigpt4/processors/blip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import re
9 |
10 | from minigpt4.common.registry import registry
11 | from minigpt4.processors.base_processor import BaseProcessor
12 | from minigpt4.processors.randaugment import RandomAugment
13 | from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 |
17 |
18 | class BlipImageBaseProcessor(BaseProcessor):
19 | def __init__(self, mean=None, std=None):
20 | if mean is None:
21 | mean = (0.48145466, 0.4578275, 0.40821073)
22 | if std is None:
23 | std = (0.26862954, 0.26130258, 0.27577711)
24 |
25 | self.normalize = transforms.Normalize(mean, std)
26 |
27 |
28 | @registry.register_processor("blip_caption")
29 | class BlipCaptionProcessor(BaseProcessor):
30 | def __init__(self, prompt="", max_words=50):
31 | self.prompt = prompt
32 | self.max_words = max_words
33 |
34 | def __call__(self, caption):
35 | caption = self.prompt + self.pre_caption(caption)
36 |
37 | return caption
38 |
39 | @classmethod
40 | def from_config(cls, cfg=None):
41 | if cfg is None:
42 | cfg = OmegaConf.create()
43 |
44 | prompt = cfg.get("prompt", "")
45 | max_words = cfg.get("max_words", 50)
46 |
47 | return cls(prompt=prompt, max_words=max_words)
48 |
49 | def pre_caption(self, caption):
50 | caption = re.sub(
51 | r"([.!\"()*#:;~])",
52 | " ",
53 | caption.lower(),
54 | )
55 | caption = re.sub(
56 | r"\s{2,}",
57 | " ",
58 | caption,
59 | )
60 | caption = caption.rstrip("\n")
61 | caption = caption.strip(" ")
62 |
63 | # truncate caption
64 | caption_words = caption.split(" ")
65 | if len(caption_words) > self.max_words:
66 | caption = " ".join(caption_words[: self.max_words])
67 |
68 | return caption
69 |
70 |
71 | @registry.register_processor("blip2_image_train")
72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74 | super().__init__(mean=mean, std=std)
75 |
76 | self.transform = transforms.Compose(
77 | [
78 | transforms.RandomResizedCrop(
79 | image_size,
80 | scale=(min_scale, max_scale),
81 | interpolation=InterpolationMode.BICUBIC,
82 | ),
83 | transforms.ToTensor(),
84 | self.normalize,
85 | ]
86 | )
87 |
88 | def __call__(self, item):
89 | return self.transform(item)
90 |
91 | @classmethod
92 | def from_config(cls, cfg=None):
93 | if cfg is None:
94 | cfg = OmegaConf.create()
95 |
96 | image_size = cfg.get("image_size", 224)
97 |
98 | mean = cfg.get("mean", None)
99 | std = cfg.get("std", None)
100 |
101 | min_scale = cfg.get("min_scale", 0.5)
102 | max_scale = cfg.get("max_scale", 1.0)
103 |
104 | return cls(
105 | image_size=image_size,
106 | mean=mean,
107 | std=std,
108 | min_scale=min_scale,
109 | max_scale=max_scale,
110 | )
111 |
112 |
113 | @registry.register_processor("blip2_image_eval")
114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115 | def __init__(self, image_size=224, mean=None, std=None):
116 | super().__init__(mean=mean, std=std)
117 |
118 | self.transform = transforms.Compose(
119 | [
120 | transforms.Resize(
121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122 | ),
123 | transforms.ToTensor(),
124 | self.normalize,
125 | ]
126 | )
127 |
128 | def __call__(self, item):
129 | return self.transform(item)
130 |
131 | @classmethod
132 | def from_config(cls, cfg=None):
133 | if cfg is None:
134 | cfg = OmegaConf.create()
135 |
136 | image_size = cfg.get("image_size", 224)
137 |
138 | mean = cfg.get("mean", None)
139 | std = cfg.get("std", None)
140 |
141 | return cls(image_size=image_size, mean=mean, std=std)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Welcome to the MiniGPT-4 Batch repo! This repository provides an implementation of MiniGPT-4 to mass caption Stable Diffusion images. It utilizes llama weights that are downloaded automatically if not already present. Please note that this implementation currently works only on Linux systems and runs only on high end machines (not the free colab).
2 |
3 | ## Windows Installation Instructions
4 |
5 | To install and run MiniGPT-4 Batch on Windows, please follow these steps:
6 |
7 | 1. Run the `Setup.bat` script:
8 |
9 | ```
10 | Setup.bat
11 | ```
12 |
13 | 2. Check your `/images` and `/mycaptions` folders. In the `/images` folder, one sample image is provided; feel free to delete it.
14 |
15 | 3. If you're just testing it out, simply execute the `run.bat` script.
16 |
17 | **OR**
18 |
19 | 3. If you want to run the script manually, you need to:
20 |
21 | a. Activate the virtual environment:
22 |
23 | ```
24 | .\venv\Scripts\activate.bat
25 | ```
26 |
27 | b. Run the `app.py` script with the desired options:
28 |
29 | ```
30 | python app.py --image-folder ./images --beam-search-numbers 2
31 | ```
32 |
33 | NEW: We're testing to combine WD tags with minigpt4-batch. If you want to include WD tags along with minigpt4 captions, consider running backup_app.py. In `backup_app.py` WD tagging is mandatory, working to make that optional!
34 |
35 | ```
36 | python backup_app.py --image-folder ./images --beam-search-numbers 2 --model-dir models/wd14_tagger --undesired-tags '1girl,1boy,solo'
37 | ```
38 |
39 | Now you're all set to use MiniGPT-4 Batch on Windows!
40 |
41 | ## Getting Started (LINUX)
42 |
43 | If you're installing MiniGPT-4 Batch for the first time, please follow these steps:
44 |
45 | 1. Clone the GitHub repository:
46 |
47 | ```git
48 | git clone https://github.com/pipinstallyp/minigpt4-batch
49 | ```
50 | Change directory to minigp4-batch
51 |
52 | ```
53 | cd minigpt4-batch
54 | ```
55 | 2. Download the necessary files:
56 |
57 | ```
58 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/minigpt4.pth -O ./checkpoint.pth
59 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth -O ./blip2_pretrained_flant5xxl.pth
60 | ```
61 |
62 | For 7b, then just use this:
63 | ```
64 | wget https://huggingface.co/ckpt/minigpt4-7B/resolve/main/prerained_minigpt4_7b.pth -O ./checkpoint.pth
65 | wget https://huggingface.co/ckpt/minigpt4/resolve/main/blip2_pretrained_flant5xxl.pth -O ./blip2_pretrained_flant5xxl.pth
66 | ```
67 |
68 | To get this right you'd need to replace ./minigpt4/checkpoint.pth with directory your minigpt4 directory + checkpoint.pth, for example.
69 |
70 | 3. Install the required packages:
71 |
72 | ```
73 | pip install cmake
74 | pip install lit
75 | pip install -q salesforce-lavis
76 | pip install -q bitsandbytes
77 | pip install -q accelerate
78 | pip install -q git+https://github.com/huggingface/transformers.git -U
79 | ```
80 |
81 | 5. Now, you can run the script:
82 |
83 | ```
84 | python app.py --image-folder path_to_image_folder --beam-search-numbers value
85 | ```
86 |
87 | If you want to test llama 7b then use this:
88 |
89 | ```
90 | python app.py --image-folder path_to_image_folder --beam-search-numbers 2 --model llama7b
91 | ```
92 |
93 | In your repository directory you can make two folders namely
94 | ```
95 | images
96 | mycaptions
97 | ```
98 |
99 | in this case your path_to_image_folder = images
100 | ## Features
101 | 1. Shows timestamp to process each caption
102 | 2. Use --save-in-imgfolder to save captions in your images folder instead.
103 | 3. One click setup (setup.bat) for windows.
104 |
105 | ## To-Do List
106 |
107 | - [x] ~~Make it work on Windows~~
108 | - [ ] Implement for MiniGPT-4 7B
109 | - [ ] Include inputs from Segment Anything
110 | - [ ] DOCKER SUPPORT COMING TO YAYYYY
111 |
112 |
113 | ## Acknowledgment
114 |
115 | A huge thank you to [Camenduru](https://github.com/camenduru) for developing the awesome MiniGPT-4 Colab, which has served as the foundation for most of this work. Huge thanks to [rafraf](https://www.instagram.com/rafstahelin/) for making the features what they are. This project is primarily aimed at helping people train Stable Diffusion models to mass caption their images.
116 |
117 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
118 |
119 | Check out https://github.com/gessyoo/minigpt4-batch-tweaked fork with implemented changes which removes trivial words like - "The image shows" and "The image is," etc. and the _caption extension from the text captions.
120 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import time
9 | import random
10 | import torch
11 | from minigpt4.datasets.data_utils import move_to_cuda
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class MultiIterLoader:
16 | """
17 | A simple wrapper for iterating over multiple iterators.
18 |
19 | Args:
20 | loaders (List[Loader]): List of Iterator loaders.
21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22 | """
23 |
24 | def __init__(self, loaders, ratios=None):
25 | # assert all loaders has __next__ method
26 | for loader in loaders:
27 | assert hasattr(
28 | loader, "__next__"
29 | ), "Loader {} has no __next__ method.".format(loader)
30 |
31 | if ratios is None:
32 | ratios = [1.0] * len(loaders)
33 | else:
34 | assert len(ratios) == len(loaders)
35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36 |
37 | self.loaders = loaders
38 | self.ratios = ratios
39 |
40 | def __next__(self):
41 | # random sample from each loader by ratio
42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43 | return next(self.loaders[loader_idx])
44 |
45 |
46 | class PrefetchLoader(object):
47 | """
48 | Modified from https://github.com/ChenRocks/UNITER.
49 |
50 | overlap compute and cuda data transfer
51 | (copied and then modified from nvidia apex)
52 | """
53 |
54 | def __init__(self, loader):
55 | self.loader = loader
56 | self.stream = torch.cuda.Stream()
57 |
58 | def __iter__(self):
59 | loader_it = iter(self.loader)
60 | self.preload(loader_it)
61 | batch = self.next(loader_it)
62 | while batch is not None:
63 | is_tuple = isinstance(batch, tuple)
64 | if is_tuple:
65 | task, batch = batch
66 |
67 | if is_tuple:
68 | yield task, batch
69 | else:
70 | yield batch
71 | batch = self.next(loader_it)
72 |
73 | def __len__(self):
74 | return len(self.loader)
75 |
76 | def preload(self, it):
77 | try:
78 | self.batch = next(it)
79 | except StopIteration:
80 | self.batch = None
81 | return
82 | # if record_stream() doesn't work, another option is to make sure
83 | # device inputs are created on the main stream.
84 | # self.next_input_gpu = torch.empty_like(self.next_input,
85 | # device='cuda')
86 | # self.next_target_gpu = torch.empty_like(self.next_target,
87 | # device='cuda')
88 | # Need to make sure the memory allocated for next_* is not still in use
89 | # by the main stream at the time we start copying to next_*:
90 | # self.stream.wait_stream(torch.cuda.current_stream())
91 | with torch.cuda.stream(self.stream):
92 | self.batch = move_to_cuda(self.batch)
93 | # more code for the alternative if record_stream() doesn't work:
94 | # copy_ will record the use of the pinned source tensor in this
95 | # side stream.
96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98 | # self.next_input = self.next_input_gpu
99 | # self.next_target = self.next_target_gpu
100 |
101 | def next(self, it):
102 | torch.cuda.current_stream().wait_stream(self.stream)
103 | batch = self.batch
104 | if batch is not None:
105 | record_cuda_stream(batch)
106 | self.preload(it)
107 | return batch
108 |
109 | def __getattr__(self, name):
110 | method = self.loader.__getattribute__(name)
111 | return method
112 |
113 |
114 | def record_cuda_stream(batch):
115 | if isinstance(batch, torch.Tensor):
116 | batch.record_stream(torch.cuda.current_stream())
117 | elif isinstance(batch, list) or isinstance(batch, tuple):
118 | for t in batch:
119 | record_cuda_stream(t)
120 | elif isinstance(batch, dict):
121 | for t in batch.values():
122 | record_cuda_stream(t)
123 | else:
124 | pass
125 |
126 |
127 | class IterLoader:
128 | """
129 | A wrapper to convert DataLoader as an infinite iterator.
130 |
131 | Modified from:
132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133 | """
134 |
135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136 | self._dataloader = dataloader
137 | self.iter_loader = iter(self._dataloader)
138 | self._use_distributed = use_distributed
139 | self._epoch = 0
140 |
141 | @property
142 | def epoch(self) -> int:
143 | return self._epoch
144 |
145 | def __next__(self):
146 | try:
147 | data = next(self.iter_loader)
148 | except StopIteration:
149 | self._epoch += 1
150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151 | self._dataloader.sampler.set_epoch(self._epoch)
152 | time.sleep(2) # Prevent possible deadlock during epoch transition
153 | self.iter_loader = iter(self._dataloader)
154 | data = next(self.iter_loader)
155 |
156 | return data
157 |
158 | def __iter__(self):
159 | return self
160 |
161 | def __len__(self):
162 | return len(self._dataloader)
163 |
--------------------------------------------------------------------------------
/minigpt4/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import torch
10 | from omegaconf import OmegaConf
11 |
12 | from minigpt4.common.registry import registry
13 | from minigpt4.models.base_model import BaseModel
14 | from minigpt4.models.blip2 import Blip2Base
15 | from minigpt4.models.mini_gpt4 import MiniGPT4
16 | from minigpt4.processors.base_processor import BaseProcessor
17 |
18 |
19 | __all__ = [
20 | "load_model",
21 | "BaseModel",
22 | "Blip2Base",
23 | "MiniGPT4",
24 | ]
25 |
26 |
27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28 | """
29 | Load supported models.
30 |
31 | To list all available models and types in registry:
32 | >>> from minigpt4.models import model_zoo
33 | >>> print(model_zoo)
34 |
35 | Args:
36 | name (str): name of the model.
37 | model_type (str): type of the model.
38 | is_eval (bool): whether the model is in eval mode. Default: False.
39 | device (str): device to use. Default: "cpu".
40 | checkpoint (str): path or to checkpoint. Default: None.
41 | Note that expecting the checkpoint to have the same keys in state_dict as the model.
42 |
43 | Returns:
44 | model (torch.nn.Module): model.
45 | """
46 |
47 | model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48 |
49 | if checkpoint is not None:
50 | model.load_checkpoint(checkpoint)
51 |
52 | if is_eval:
53 | model.eval()
54 |
55 | if device == "cpu":
56 | model = model.float()
57 |
58 | return model.to(device)
59 |
60 |
61 | def load_preprocess(config):
62 | """
63 | Load preprocessor configs and construct preprocessors.
64 |
65 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66 |
67 | Args:
68 | config (dict): preprocessor configs.
69 |
70 | Returns:
71 | vis_processors (dict): preprocessors for visual inputs.
72 | txt_processors (dict): preprocessors for text inputs.
73 |
74 | Key is "train" or "eval" for processors used in training and evaluation respectively.
75 | """
76 |
77 | def _build_proc_from_cfg(cfg):
78 | return (
79 | registry.get_processor_class(cfg.name).from_config(cfg)
80 | if cfg is not None
81 | else BaseProcessor()
82 | )
83 |
84 | vis_processors = dict()
85 | txt_processors = dict()
86 |
87 | vis_proc_cfg = config.get("vis_processor")
88 | txt_proc_cfg = config.get("text_processor")
89 |
90 | if vis_proc_cfg is not None:
91 | vis_train_cfg = vis_proc_cfg.get("train")
92 | vis_eval_cfg = vis_proc_cfg.get("eval")
93 | else:
94 | vis_train_cfg = None
95 | vis_eval_cfg = None
96 |
97 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99 |
100 | if txt_proc_cfg is not None:
101 | txt_train_cfg = txt_proc_cfg.get("train")
102 | txt_eval_cfg = txt_proc_cfg.get("eval")
103 | else:
104 | txt_train_cfg = None
105 | txt_eval_cfg = None
106 |
107 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109 |
110 | return vis_processors, txt_processors
111 |
112 |
113 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114 | """
115 | Load model and its related preprocessors.
116 |
117 | List all available models and types in registry:
118 | >>> from minigpt4.models import model_zoo
119 | >>> print(model_zoo)
120 |
121 | Args:
122 | name (str): name of the model.
123 | model_type (str): type of the model.
124 | is_eval (bool): whether the model is in eval mode. Default: False.
125 | device (str): device to use. Default: "cpu".
126 |
127 | Returns:
128 | model (torch.nn.Module): model.
129 | vis_processors (dict): preprocessors for visual inputs.
130 | txt_processors (dict): preprocessors for text inputs.
131 | """
132 | model_cls = registry.get_model_class(name)
133 |
134 | # load model
135 | model = model_cls.from_pretrained(model_type=model_type)
136 |
137 | if is_eval:
138 | model.eval()
139 |
140 | # load preprocess
141 | cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142 | if cfg is not None:
143 | preprocess_cfg = cfg.preprocess
144 |
145 | vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146 | else:
147 | vis_processors, txt_processors = None, None
148 | logging.info(
149 | f"""No default preprocess for model {name} ({model_type}).
150 | This can happen if the model is not finetuned on downstream datasets,
151 | or it is not intended for direct use without finetuning.
152 | """
153 | )
154 |
155 | if device == "cpu" or device == torch.device("cpu"):
156 | model = model.float()
157 |
158 | return model.to(device), vis_processors, txt_processors
159 |
160 |
161 | class ModelZoo:
162 | """
163 | A utility class to create string representation of available model architectures and types.
164 |
165 | >>> from minigpt4.models import model_zoo
166 | >>> # list all available models
167 | >>> print(model_zoo)
168 | >>> # show total number of models
169 | >>> print(len(model_zoo))
170 | """
171 |
172 | def __init__(self) -> None:
173 | self.model_zoo = {
174 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175 | for k, v in registry.mapping["model_name_mapping"].items()
176 | }
177 |
178 | def __str__(self) -> str:
179 | return (
180 | "=" * 50
181 | + "\n"
182 | + f"{'Architectures':<30} {'Types'}\n"
183 | + "=" * 50
184 | + "\n"
185 | + "\n".join(
186 | [
187 | f"{name:<30} {', '.join(types)}"
188 | for name, types in self.model_zoo.items()
189 | ]
190 | )
191 | )
192 |
193 | def __iter__(self):
194 | return iter(self.model_zoo.items())
195 |
196 | def __len__(self):
197 | return sum([len(v) for v in self.model_zoo.values()])
198 |
199 |
200 | model_zoo = ModelZoo()
201 |
--------------------------------------------------------------------------------
/minigpt4/common/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import logging
10 | import time
11 | from collections import defaultdict, deque
12 |
13 | import torch
14 | import torch.distributed as dist
15 |
16 | from minigpt4.common import dist_utils
17 |
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not dist_utils.is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value,
79 | )
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError(
100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101 | )
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def global_avg(self):
110 | loss_str = []
111 | for name, meter in self.meters.items():
112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113 | return self.delimiter.join(loss_str)
114 |
115 | def synchronize_between_processes(self):
116 | for meter in self.meters.values():
117 | meter.synchronize_between_processes()
118 |
119 | def add_meter(self, name, meter):
120 | self.meters[name] = meter
121 |
122 | def log_every(self, iterable, print_freq, header=None):
123 | i = 0
124 | if not header:
125 | header = ""
126 | start_time = time.time()
127 | end = time.time()
128 | iter_time = SmoothedValue(fmt="{avg:.4f}")
129 | data_time = SmoothedValue(fmt="{avg:.4f}")
130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131 | log_msg = [
132 | header,
133 | "[{0" + space_fmt + "}/{1}]",
134 | "eta: {eta}",
135 | "{meters}",
136 | "time: {time}",
137 | "data: {data}",
138 | ]
139 | if torch.cuda.is_available():
140 | log_msg.append("max mem: {memory:.0f}")
141 | log_msg = self.delimiter.join(log_msg)
142 | MB = 1024.0 * 1024.0
143 | for obj in iterable:
144 | data_time.update(time.time() - end)
145 | yield obj
146 | iter_time.update(time.time() - end)
147 | if i % print_freq == 0 or i == len(iterable) - 1:
148 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150 | if torch.cuda.is_available():
151 | print(
152 | log_msg.format(
153 | i,
154 | len(iterable),
155 | eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time),
158 | data=str(data_time),
159 | memory=torch.cuda.max_memory_allocated() / MB,
160 | )
161 | )
162 | else:
163 | print(
164 | log_msg.format(
165 | i,
166 | len(iterable),
167 | eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time),
170 | data=str(data_time),
171 | )
172 | )
173 | i += 1
174 | end = time.time()
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print(
178 | "{} Total time: {} ({:.4f} s / it)".format(
179 | header, total_time_str, total_time / len(iterable)
180 | )
181 | )
182 |
183 |
184 | class AttrDict(dict):
185 | def __init__(self, *args, **kwargs):
186 | super(AttrDict, self).__init__(*args, **kwargs)
187 | self.__dict__ = self
188 |
189 |
190 | def setup_logger():
191 | logging.basicConfig(
192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193 | format="%(asctime)s [%(levelname)s] %(message)s",
194 | handlers=[logging.StreamHandler()],
195 | )
196 |
--------------------------------------------------------------------------------
/minigpt4/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import gzip
9 | import logging
10 | import os
11 | import random as rnd
12 | import tarfile
13 | import zipfile
14 | import random
15 | from typing import List
16 | from tqdm import tqdm
17 |
18 | import decord
19 | from decord import VideoReader
20 | import webdataset as wds
21 | import numpy as np
22 | import torch
23 | from torch.utils.data.dataset import IterableDataset
24 |
25 | from minigpt4.common.registry import registry
26 | from minigpt4.datasets.datasets.base_dataset import ConcatDataset
27 |
28 |
29 | decord.bridge.set_bridge("torch")
30 | MAX_INT = registry.get("MAX_INT")
31 |
32 |
33 | class ChainDataset(wds.DataPipeline):
34 | r"""Dataset for chaining multiple :class:`DataPipeline` s.
35 |
36 | This class is useful to assemble different existing dataset streams. The
37 | chaining operation is done on-the-fly, so concatenating large-scale
38 | datasets with this class will be efficient.
39 |
40 | Args:
41 | datasets (iterable of IterableDataset): datasets to be chained together
42 | """
43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44 | super().__init__()
45 | self.datasets = datasets
46 | self.prob = []
47 | self.names = []
48 | for dataset in self.datasets:
49 | if hasattr(dataset, 'name'):
50 | self.names.append(dataset.name)
51 | else:
52 | self.names.append('Unknown')
53 | if hasattr(dataset, 'sample_ratio'):
54 | self.prob.append(dataset.sample_ratio)
55 | else:
56 | self.prob.append(1)
57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58 |
59 | def __iter__(self):
60 | datastreams = [iter(dataset) for dataset in self.datasets]
61 | while True:
62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63 | yield next(select_datastream)
64 |
65 |
66 | def apply_to_sample(f, sample):
67 | if len(sample) == 0:
68 | return {}
69 |
70 | def _apply(x):
71 | if torch.is_tensor(x):
72 | return f(x)
73 | elif isinstance(x, dict):
74 | return {key: _apply(value) for key, value in x.items()}
75 | elif isinstance(x, list):
76 | return [_apply(x) for x in x]
77 | else:
78 | return x
79 |
80 | return _apply(sample)
81 |
82 |
83 | def move_to_cuda(sample):
84 | def _move_to_cuda(tensor):
85 | return tensor.cuda()
86 |
87 | return apply_to_sample(_move_to_cuda, sample)
88 |
89 |
90 | def prepare_sample(samples, cuda_enabled=True):
91 | if cuda_enabled:
92 | samples = move_to_cuda(samples)
93 |
94 | # TODO fp16 support
95 |
96 | return samples
97 |
98 |
99 | def reorg_datasets_by_split(datasets):
100 | """
101 | Organizes datasets by split.
102 |
103 | Args:
104 | datasets: dict of torch.utils.data.Dataset objects by name.
105 |
106 | Returns:
107 | Dict of datasets by split {split_name: List[Datasets]}.
108 | """
109 | # if len(datasets) == 1:
110 | # return datasets[list(datasets.keys())[0]]
111 | # else:
112 | reorg_datasets = dict()
113 |
114 | # reorganize by split
115 | for _, dataset in datasets.items():
116 | for split_name, dataset_split in dataset.items():
117 | if split_name not in reorg_datasets:
118 | reorg_datasets[split_name] = [dataset_split]
119 | else:
120 | reorg_datasets[split_name].append(dataset_split)
121 |
122 | return reorg_datasets
123 |
124 |
125 | def concat_datasets(datasets):
126 | """
127 | Concatenates multiple datasets into a single dataset.
128 |
129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130 | generic IterableDataset because it requires creating separate samplers.
131 |
132 | Now only supports conctenating training datasets and assuming validation and testing
133 | have only a single dataset. This is because metrics should not be computed on the concatenated
134 | datasets.
135 |
136 | Args:
137 | datasets: dict of torch.utils.data.Dataset objects by split.
138 |
139 | Returns:
140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141 | "val" and "test" remain the same.
142 |
143 | If the input training datasets contain both map-style and DataPipeline datasets, returns
144 | a tuple, where the first element is a concatenated map-style dataset and the second
145 | element is a chained DataPipeline dataset.
146 |
147 | """
148 | # concatenate datasets in the same split
149 | for split_name in datasets:
150 | if split_name != "train":
151 | assert (
152 | len(datasets[split_name]) == 1
153 | ), "Do not support multiple {} datasets.".format(split_name)
154 | datasets[split_name] = datasets[split_name][0]
155 | else:
156 | iterable_datasets, map_datasets = [], []
157 | for dataset in datasets[split_name]:
158 | if isinstance(dataset, wds.DataPipeline):
159 | logging.info(
160 | "Dataset {} is IterableDataset, can't be concatenated.".format(
161 | dataset
162 | )
163 | )
164 | iterable_datasets.append(dataset)
165 | elif isinstance(dataset, IterableDataset):
166 | raise NotImplementedError(
167 | "Do not support concatenation of generic IterableDataset."
168 | )
169 | else:
170 | map_datasets.append(dataset)
171 |
172 | # if len(iterable_datasets) > 0:
173 | # concatenate map-style datasets and iterable-style datasets separately
174 | if len(iterable_datasets) > 1:
175 | chained_datasets = (
176 | ChainDataset(iterable_datasets)
177 | )
178 | elif len(iterable_datasets) == 1:
179 | chained_datasets = iterable_datasets[0]
180 | else:
181 | chained_datasets = None
182 |
183 | concat_datasets = (
184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185 | )
186 |
187 | train_datasets = concat_datasets, chained_datasets
188 | train_datasets = tuple([x for x in train_datasets if x is not None])
189 | train_datasets = (
190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets
191 | )
192 |
193 | datasets[split_name] = train_datasets
194 |
195 | return datasets
196 |
197 |
--------------------------------------------------------------------------------
/minigpt4/conversation/conversation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from PIL import Image
4 |
5 | import torch
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7 | from transformers import StoppingCriteria, StoppingCriteriaList
8 |
9 | import dataclasses
10 | from enum import auto, Enum
11 | from typing import List, Tuple, Any
12 |
13 | from minigpt4.common.registry import registry
14 |
15 |
16 | class SeparatorStyle(Enum):
17 | """Different separator style."""
18 | SINGLE = auto()
19 | TWO = auto()
20 |
21 |
22 | @dataclasses.dataclass
23 | class Conversation:
24 | """A class that keeps all conversation history."""
25 | system: str
26 | roles: List[str]
27 | messages: List[List[str]]
28 | offset: int
29 | # system_img: List[Image.Image] = []
30 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31 | sep: str = "###"
32 | sep2: str = None
33 |
34 | skip_next: bool = False
35 | conv_id: Any = None
36 |
37 | def get_prompt(self):
38 | if self.sep_style == SeparatorStyle.SINGLE:
39 | ret = self.system + self.sep
40 | for role, message in self.messages:
41 | if message:
42 | ret += role + ": " + message + self.sep
43 | else:
44 | ret += role + ":"
45 | return ret
46 | elif self.sep_style == SeparatorStyle.TWO:
47 | seps = [self.sep, self.sep2]
48 | ret = self.system + seps[0]
49 | for i, (role, message) in enumerate(self.messages):
50 | if message:
51 | ret += role + ": " + message + seps[i % 2]
52 | else:
53 | ret += role + ":"
54 | return ret
55 | else:
56 | raise ValueError(f"Invalid style: {self.sep_style}")
57 |
58 | def append_message(self, role, message):
59 | self.messages.append([role, message])
60 |
61 | def to_gradio_chatbot(self):
62 | ret = []
63 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
64 | if i % 2 == 0:
65 | ret.append([msg, None])
66 | else:
67 | ret[-1][-1] = msg
68 | return ret
69 |
70 | def copy(self):
71 | return Conversation(
72 | system=self.system,
73 | # system_img=self.system_img,
74 | roles=self.roles,
75 | messages=[[x, y] for x, y in self.messages],
76 | offset=self.offset,
77 | sep_style=self.sep_style,
78 | sep=self.sep,
79 | sep2=self.sep2,
80 | conv_id=self.conv_id)
81 |
82 | def dict(self):
83 | return {
84 | "system": self.system,
85 | # "system_img": self.system_img,
86 | "roles": self.roles,
87 | "messages": self.messages,
88 | "offset": self.offset,
89 | "sep": self.sep,
90 | "sep2": self.sep2,
91 | "conv_id": self.conv_id,
92 | }
93 |
94 |
95 | class StoppingCriteriaSub(StoppingCriteria):
96 |
97 | def __init__(self, stops=[], encounters=1):
98 | super().__init__()
99 | self.stops = stops
100 |
101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102 | for stop in self.stops:
103 | if torch.all((stop == input_ids[0][-len(stop):])).item():
104 | return True
105 |
106 | return False
107 |
108 |
109 | CONV_VISION = Conversation(
110 | system="Give the following image:
ImageContent. "
111 | "You will be able to see the image once I provide it to you. Please answer my questions.",
112 | roles=("Human", "Assistant"),
113 | messages=[],
114 | offset=2,
115 | sep_style=SeparatorStyle.SINGLE,
116 | sep="###",
117 | )
118 |
119 |
120 |
121 | class Chat:
122 | def __init__(self, model, vis_processor, device='cuda:0'):
123 | self.device = device
124 | self.model = model
125 | self.vis_processor = vis_processor
126 | stop_words_ids = [torch.tensor([835]).to(self.device),
127 | torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128 | self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129 |
130 | def ask(self, text, conv):
131 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132 | and conv.messages[-1][1][-6:] == '': # last message is image.
133 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134 | else:
135 | conv.append_message(conv.roles[0], text)
136 |
137 | def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
138 | repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
139 | conv.append_message(conv.roles[1], None)
140 | embs = self.get_context_emb(conv, img_list)
141 |
142 | # current_max_len = embs.shape[1] + max_new_tokens + 100
143 | # begin_idx = max(0, current_max_len - max_length)
144 | # embs = embs[:, begin_idx:]
145 | outputs = self.model.llama_model.generate(
146 | inputs_embeds=embs,
147 | max_new_tokens=max_new_tokens,
148 | stopping_criteria=self.stopping_criteria,
149 | num_beams=num_beams,
150 | min_length=min_length,
151 | top_p=top_p,
152 | repetition_penalty=repetition_penalty,
153 | length_penalty=length_penalty,
154 | temperature=temperature,
155 | )
156 | output_token = outputs[0]
157 | if output_token[0] == 0:
158 | output_token = output_token[1:]
159 | output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
160 | output_text = output_text.split('###')[0] # remove the stop sign '###'
161 | output_text = output_text.split('Assistant:')[-1].strip()
162 | conv.messages[-1][1] = output_text
163 | return output_text, output_token.cpu().numpy()
164 |
165 | def upload_img(self, image, conv, img_list):
166 | if isinstance(image, str): # is a image path
167 | raw_image = Image.open(image).convert('RGB')
168 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
169 | elif isinstance(image, Image.Image):
170 | raw_image = image
171 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
172 | elif isinstance(image, torch.Tensor):
173 | if len(image.shape) == 3:
174 | image = image.unsqueeze(0)
175 | image = image.to(self.device)
176 |
177 | image_emb, _ = self.model.encode_img(image)
178 | img_list.append(image_emb)
179 | conv.append_message(conv.roles[0], "
")
180 | msg = "Received."
181 | # self.conv.append_message(self.conv.roles[1], msg)
182 | return msg
183 |
184 | def get_context_emb(self, conv, img_list):
185 | prompt = conv.get_prompt()
186 | prompt_segs = prompt.split('')
187 | assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
188 | seg_tokens = [
189 | self.model.llama_tokenizer(
190 | seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
191 | # only add bos to the first seg
192 | for i, seg in enumerate(prompt_segs)
193 | ]
194 | seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
195 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
196 | mixed_embs = torch.cat(mixed_embs, dim=1)
197 | return mixed_embs
198 |
199 |
200 |
--------------------------------------------------------------------------------
/minigpt4/models/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import contextlib
8 | import logging
9 | import os
10 | import time
11 | import datetime
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.distributed as dist
16 | import torch.nn.functional as F
17 |
18 | import minigpt4.common.dist_utils as dist_utils
19 | from minigpt4.common.dist_utils import download_cached_file
20 | from minigpt4.common.utils import is_url
21 | from minigpt4.common.logger import MetricLogger
22 | from minigpt4.models.base_model import BaseModel
23 | from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
24 | from minigpt4.models.eva_vit import create_eva_vit_g
25 | from transformers import BertTokenizer
26 |
27 |
28 | class Blip2Base(BaseModel):
29 | @classmethod
30 | def init_tokenizer(cls):
31 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33 | return tokenizer
34 |
35 | def maybe_autocast(self, dtype=torch.float16):
36 | # if on cpu, don't use autocast
37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38 | enable_autocast = self.device != torch.device("cpu")
39 |
40 | if enable_autocast:
41 | return torch.cuda.amp.autocast(dtype=dtype)
42 | else:
43 | return contextlib.nullcontext()
44 |
45 | @classmethod
46 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48 | encoder_config.encoder_width = vision_width
49 | # insert cross-attention layer every other block
50 | encoder_config.add_cross_attention = True
51 | encoder_config.cross_attention_freq = cross_attention_freq
52 | encoder_config.query_length = num_query_token
53 | Qformer = BertLMHeadModel(config=encoder_config)
54 | query_tokens = nn.Parameter(
55 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
56 | )
57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58 | return Qformer, query_tokens
59 |
60 | @classmethod
61 | def init_vision_encoder(
62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63 | ):
64 | assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
65 | visual_encoder = create_eva_vit_g(
66 | img_size, drop_path_rate, use_grad_checkpoint, precision
67 | )
68 |
69 | ln_vision = LayerNorm(visual_encoder.num_features)
70 | return visual_encoder, ln_vision
71 |
72 | def load_from_pretrained(self, url_or_filename):
73 | if is_url(url_or_filename):
74 | cached_file = download_cached_file(
75 | url_or_filename, check_hash=False, progress=True
76 | )
77 | checkpoint = torch.load(cached_file, map_location="cpu")
78 | elif os.path.isfile(url_or_filename):
79 | checkpoint = torch.load(url_or_filename, map_location="cpu")
80 | else:
81 | raise RuntimeError("checkpoint url or path is invalid")
82 |
83 | state_dict = checkpoint["model"]
84 |
85 | msg = self.load_state_dict(state_dict, strict=False)
86 |
87 | # logging.info("Missing keys {}".format(msg.missing_keys))
88 | logging.info("load checkpoint from %s" % url_or_filename)
89 |
90 | return msg
91 |
92 |
93 | def disabled_train(self, mode=True):
94 | """Overwrite model.train with this function to make sure train/eval mode
95 | does not change anymore."""
96 | return self
97 |
98 |
99 | class LayerNorm(nn.LayerNorm):
100 | """Subclass torch's LayerNorm to handle fp16."""
101 |
102 | def forward(self, x: torch.Tensor):
103 | orig_type = x.dtype
104 | ret = super().forward(x.type(torch.float32))
105 | return ret.type(orig_type)
106 |
107 |
108 | def compute_sim_matrix(model, data_loader, **kwargs):
109 | k_test = kwargs.pop("k_test")
110 |
111 | metric_logger = MetricLogger(delimiter=" ")
112 | header = "Evaluation:"
113 |
114 | logging.info("Computing features for evaluation...")
115 | start_time = time.time()
116 |
117 | texts = data_loader.dataset.text
118 | num_text = len(texts)
119 | text_bs = 256
120 | text_ids = []
121 | text_embeds = []
122 | text_atts = []
123 | for i in range(0, num_text, text_bs):
124 | text = texts[i : min(num_text, i + text_bs)]
125 | text_input = model.tokenizer(
126 | text,
127 | padding="max_length",
128 | truncation=True,
129 | max_length=35,
130 | return_tensors="pt",
131 | ).to(model.device)
132 | text_feat = model.forward_text(text_input)
133 | text_embed = F.normalize(model.text_proj(text_feat))
134 | text_embeds.append(text_embed)
135 | text_ids.append(text_input.input_ids)
136 | text_atts.append(text_input.attention_mask)
137 |
138 | text_embeds = torch.cat(text_embeds, dim=0)
139 | text_ids = torch.cat(text_ids, dim=0)
140 | text_atts = torch.cat(text_atts, dim=0)
141 |
142 | vit_feats = []
143 | image_embeds = []
144 | for samples in data_loader:
145 | image = samples["image"]
146 |
147 | image = image.to(model.device)
148 | image_feat, vit_feat = model.forward_image(image)
149 | image_embed = model.vision_proj(image_feat)
150 | image_embed = F.normalize(image_embed, dim=-1)
151 |
152 | vit_feats.append(vit_feat.cpu())
153 | image_embeds.append(image_embed)
154 |
155 | vit_feats = torch.cat(vit_feats, dim=0)
156 | image_embeds = torch.cat(image_embeds, dim=0)
157 |
158 | sims_matrix = []
159 | for image_embed in image_embeds:
160 | sim_q2t = image_embed @ text_embeds.t()
161 | sim_i2t, _ = sim_q2t.max(0)
162 | sims_matrix.append(sim_i2t)
163 | sims_matrix = torch.stack(sims_matrix, dim=0)
164 |
165 | score_matrix_i2t = torch.full(
166 | (len(data_loader.dataset.image), len(texts)), -100.0
167 | ).to(model.device)
168 |
169 | num_tasks = dist_utils.get_world_size()
170 | rank = dist_utils.get_rank()
171 | step = sims_matrix.size(0) // num_tasks + 1
172 | start = rank * step
173 | end = min(sims_matrix.size(0), start + step)
174 |
175 | for i, sims in enumerate(
176 | metric_logger.log_every(sims_matrix[start:end], 50, header)
177 | ):
178 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180 | score = model.compute_itm(
181 | image_inputs=image_inputs,
182 | text_ids=text_ids[topk_idx],
183 | text_atts=text_atts[topk_idx],
184 | ).float()
185 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186 |
187 | sims_matrix = sims_matrix.t()
188 | score_matrix_t2i = torch.full(
189 | (len(texts), len(data_loader.dataset.image)), -100.0
190 | ).to(model.device)
191 |
192 | step = sims_matrix.size(0) // num_tasks + 1
193 | start = rank * step
194 | end = min(sims_matrix.size(0), start + step)
195 |
196 | for i, sims in enumerate(
197 | metric_logger.log_every(sims_matrix[start:end], 50, header)
198 | ):
199 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201 | score = model.compute_itm(
202 | image_inputs=image_inputs,
203 | text_ids=text_ids[start + i].repeat(k_test, 1),
204 | text_atts=text_atts[start + i].repeat(k_test, 1),
205 | ).float()
206 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207 |
208 | if dist_utils.is_dist_avail_and_initialized():
209 | dist.barrier()
210 | torch.distributed.all_reduce(
211 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212 | )
213 | torch.distributed.all_reduce(
214 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215 | )
216 |
217 | total_time = time.time() - start_time
218 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219 | logging.info("Evaluation time {}".format(total_time_str))
220 |
221 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
222 |
--------------------------------------------------------------------------------
/minigpt4/models/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15 | from minigpt4.common.utils import get_abs_path, is_url
16 | from omegaconf import OmegaConf
17 |
18 |
19 | class BaseModel(nn.Module):
20 | """Base class for models."""
21 |
22 | def __init__(self):
23 | super().__init__()
24 |
25 | @property
26 | def device(self):
27 | return list(self.parameters())[0].device
28 |
29 | def load_checkpoint(self, url_or_filename):
30 | """
31 | Load from a finetuned checkpoint.
32 |
33 | This should expect no mismatch in the model keys and the checkpoint keys.
34 | """
35 |
36 | if is_url(url_or_filename):
37 | cached_file = download_cached_file(
38 | url_or_filename, check_hash=False, progress=True
39 | )
40 | checkpoint = torch.load(cached_file, map_location="cpu")
41 | elif os.path.isfile(url_or_filename):
42 | checkpoint = torch.load(url_or_filename, map_location="cpu")
43 | else:
44 | raise RuntimeError("checkpoint url or path is invalid")
45 |
46 | if "model" in checkpoint.keys():
47 | state_dict = checkpoint["model"]
48 | else:
49 | state_dict = checkpoint
50 |
51 | msg = self.load_state_dict(state_dict, strict=False)
52 |
53 | logging.info("Missing keys {}".format(msg.missing_keys))
54 | logging.info("load checkpoint from %s" % url_or_filename)
55 |
56 | return msg
57 |
58 | @classmethod
59 | def from_pretrained(cls, model_type):
60 | """
61 | Build a pretrained model from default configuration file, specified by model_type.
62 |
63 | Args:
64 | - model_type (str): model type, specifying architecture and checkpoints.
65 |
66 | Returns:
67 | - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68 | """
69 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70 | model = cls.from_config(model_cfg)
71 |
72 | return model
73 |
74 | @classmethod
75 | def default_config_path(cls, model_type):
76 | assert (
77 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78 | ), "Unknown model type {}".format(model_type)
79 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80 |
81 | def load_checkpoint_from_config(self, cfg, **kwargs):
82 | """
83 | Load checkpoint as specified in the config file.
84 |
85 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86 | When loading the pretrained model, each task-specific architecture may define their
87 | own load_from_pretrained() method.
88 | """
89 | load_finetuned = cfg.get("load_finetuned", True)
90 | if load_finetuned:
91 | finetune_path = cfg.get("finetuned", None)
92 | assert (
93 | finetune_path is not None
94 | ), "Found load_finetuned is True, but finetune_path is None."
95 | self.load_checkpoint(url_or_filename=finetune_path)
96 | else:
97 | # load pre-trained weights
98 | pretrain_path = cfg.get("pretrained", None)
99 | assert "Found load_finetuned is False, but pretrain_path is None."
100 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101 |
102 | def before_evaluation(self, **kwargs):
103 | pass
104 |
105 | def show_n_params(self, return_str=True):
106 | tot = 0
107 | for p in self.parameters():
108 | w = 1
109 | for x in p.shape:
110 | w *= x
111 | tot += w
112 | if return_str:
113 | if tot >= 1e6:
114 | return "{:.1f}M".format(tot / 1e6)
115 | else:
116 | return "{:.1f}K".format(tot / 1e3)
117 | else:
118 | return tot
119 |
120 |
121 | class BaseEncoder(nn.Module):
122 | """
123 | Base class for primitive encoders, such as ViT, TimeSformer, etc.
124 | """
125 |
126 | def __init__(self):
127 | super().__init__()
128 |
129 | def forward_features(self, samples, **kwargs):
130 | raise NotImplementedError
131 |
132 | @property
133 | def device(self):
134 | return list(self.parameters())[0].device
135 |
136 |
137 | class SharedQueueMixin:
138 | @torch.no_grad()
139 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140 | # gather keys before updating queue
141 | image_feats = concat_all_gather(image_feat)
142 | text_feats = concat_all_gather(text_feat)
143 |
144 | batch_size = image_feats.shape[0]
145 |
146 | ptr = int(self.queue_ptr)
147 | assert self.queue_size % batch_size == 0 # for simplicity
148 |
149 | # replace the keys at ptr (dequeue and enqueue)
150 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152 |
153 | if idxs is not None:
154 | idxs = concat_all_gather(idxs)
155 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156 |
157 | ptr = (ptr + batch_size) % self.queue_size # move pointer
158 | self.queue_ptr[0] = ptr
159 |
160 |
161 | class MomentumDistilationMixin:
162 | @torch.no_grad()
163 | def copy_params(self):
164 | for model_pair in self.model_pairs:
165 | for param, param_m in zip(
166 | model_pair[0].parameters(), model_pair[1].parameters()
167 | ):
168 | param_m.data.copy_(param.data) # initialize
169 | param_m.requires_grad = False # not update by gradient
170 |
171 | @torch.no_grad()
172 | def _momentum_update(self):
173 | for model_pair in self.model_pairs:
174 | for param, param_m in zip(
175 | model_pair[0].parameters(), model_pair[1].parameters()
176 | ):
177 | param_m.data = param_m.data * self.momentum + param.data * (
178 | 1.0 - self.momentum
179 | )
180 |
181 |
182 | class GatherLayer(torch.autograd.Function):
183 | """
184 | Gather tensors from all workers with support for backward propagation:
185 | This implementation does not cut the gradients as torch.distributed.all_gather does.
186 | """
187 |
188 | @staticmethod
189 | def forward(ctx, x):
190 | output = [
191 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192 | ]
193 | torch.distributed.all_gather(output, x)
194 | return tuple(output)
195 |
196 | @staticmethod
197 | def backward(ctx, *grads):
198 | all_gradients = torch.stack(grads)
199 | torch.distributed.all_reduce(all_gradients)
200 | return all_gradients[torch.distributed.get_rank()]
201 |
202 |
203 | def all_gather_with_grad(tensors):
204 | """
205 | Performs all_gather operation on the provided tensors.
206 | Graph remains connected for backward grad computation.
207 | """
208 | # Queue the gathered tensors
209 | world_size = torch.distributed.get_world_size()
210 | # There is no need for reduction in the single-proc case
211 | if world_size == 1:
212 | return tensors
213 |
214 | # tensor_all = GatherLayer.apply(tensors)
215 | tensor_all = GatherLayer.apply(tensors)
216 |
217 | return torch.cat(tensor_all, dim=0)
218 |
219 |
220 | @torch.no_grad()
221 | def concat_all_gather(tensor):
222 | """
223 | Performs all_gather operation on the provided tensors.
224 | *** Warning ***: torch.distributed.all_gather has no gradient.
225 | """
226 | # if use distributed training
227 | if not is_dist_avail_and_initialized():
228 | return tensor
229 |
230 | tensors_gather = [
231 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232 | ]
233 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234 |
235 | output = torch.cat(tensors_gather, dim=0)
236 | return output
237 |
238 |
239 | def tile(x, dim, n_tile):
240 | init_dim = x.size(dim)
241 | repeat_idx = [1] * x.dim()
242 | repeat_idx[dim] = n_tile
243 | x = x.repeat(*(repeat_idx))
244 | order_index = torch.LongTensor(
245 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246 | )
247 | return torch.index_select(x, dim, order_index.to(x.device))
248 |
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/base_dataset_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 | import shutil
11 | import warnings
12 |
13 | from omegaconf import OmegaConf
14 | import torch.distributed as dist
15 | from torchvision.datasets.utils import download_url
16 |
17 | import minigpt4.common.utils as utils
18 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
19 | from minigpt4.common.registry import registry
20 | from minigpt4.processors.base_processor import BaseProcessor
21 |
22 |
23 |
24 | class BaseDatasetBuilder:
25 | train_dataset_cls, eval_dataset_cls = None, None
26 |
27 | def __init__(self, cfg=None):
28 | super().__init__()
29 |
30 | if cfg is None:
31 | # help to create datasets from default config.
32 | self.config = load_dataset_config(self.default_config_path())
33 | elif isinstance(cfg, str):
34 | self.config = load_dataset_config(cfg)
35 | else:
36 | # when called from task.build_dataset()
37 | self.config = cfg
38 |
39 | self.data_type = self.config.data_type
40 |
41 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
42 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43 |
44 | def build_datasets(self):
45 | # download, split, etc...
46 | # only called on 1 GPU/TPU in distributed
47 |
48 | if is_main_process():
49 | self._download_data()
50 |
51 | if is_dist_avail_and_initialized():
52 | dist.barrier()
53 |
54 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
55 | logging.info("Building datasets...")
56 | datasets = self.build() # dataset['train'/'val'/'test']
57 |
58 | return datasets
59 |
60 | def build_processors(self):
61 | vis_proc_cfg = self.config.get("vis_processor")
62 | txt_proc_cfg = self.config.get("text_processor")
63 |
64 | if vis_proc_cfg is not None:
65 | vis_train_cfg = vis_proc_cfg.get("train")
66 | vis_eval_cfg = vis_proc_cfg.get("eval")
67 |
68 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
69 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
70 |
71 | if txt_proc_cfg is not None:
72 | txt_train_cfg = txt_proc_cfg.get("train")
73 | txt_eval_cfg = txt_proc_cfg.get("eval")
74 |
75 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
76 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
77 |
78 | @staticmethod
79 | def _build_proc_from_cfg(cfg):
80 | return (
81 | registry.get_processor_class(cfg.name).from_config(cfg)
82 | if cfg is not None
83 | else None
84 | )
85 |
86 | @classmethod
87 | def default_config_path(cls, type="default"):
88 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
89 |
90 | def _download_data(self):
91 | self._download_ann()
92 | self._download_vis()
93 |
94 | def _download_ann(self):
95 | """
96 | Download annotation files if necessary.
97 | All the vision-language datasets should have annotations of unified format.
98 |
99 | storage_path can be:
100 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
101 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
102 |
103 | Local annotation paths should be relative.
104 | """
105 | anns = self.config.build_info.annotations
106 |
107 | splits = anns.keys()
108 |
109 | cache_root = registry.get_path("cache_root")
110 |
111 | for split in splits:
112 | info = anns[split]
113 |
114 | urls, storage_paths = info.get("url", None), info.storage
115 |
116 | if isinstance(urls, str):
117 | urls = [urls]
118 | if isinstance(storage_paths, str):
119 | storage_paths = [storage_paths]
120 |
121 | assert len(urls) == len(storage_paths)
122 |
123 | for url_or_filename, storage_path in zip(urls, storage_paths):
124 | # if storage_path is relative, make it full by prefixing with cache_root.
125 | if not os.path.isabs(storage_path):
126 | storage_path = os.path.join(cache_root, storage_path)
127 |
128 | dirname = os.path.dirname(storage_path)
129 | if not os.path.exists(dirname):
130 | os.makedirs(dirname)
131 |
132 | if os.path.isfile(url_or_filename):
133 | src, dst = url_or_filename, storage_path
134 | if not os.path.exists(dst):
135 | shutil.copyfile(src=src, dst=dst)
136 | else:
137 | logging.info("Using existing file {}.".format(dst))
138 | else:
139 | if os.path.isdir(storage_path):
140 | # if only dirname is provided, suffix with basename of URL.
141 | raise ValueError(
142 | "Expecting storage_path to be a file path, got directory {}".format(
143 | storage_path
144 | )
145 | )
146 | else:
147 | filename = os.path.basename(storage_path)
148 |
149 | download_url(url=url_or_filename, root=dirname, filename=filename)
150 |
151 | def _download_vis(self):
152 |
153 | storage_path = self.config.build_info.get(self.data_type).storage
154 | storage_path = utils.get_cache_path(storage_path)
155 |
156 | if not os.path.exists(storage_path):
157 | warnings.warn(
158 | f"""
159 | The specified path {storage_path} for visual inputs does not exist.
160 | Please provide a correct path to the visual inputs or
161 | refer to datasets/download_scripts/README.md for downloading instructions.
162 | """
163 | )
164 |
165 | def build(self):
166 | """
167 | Create by split datasets inheriting torch.utils.data.Datasets.
168 |
169 | # build() can be dataset-specific. Overwrite to customize.
170 | """
171 | self.build_processors()
172 |
173 | build_info = self.config.build_info
174 |
175 | ann_info = build_info.annotations
176 | vis_info = build_info.get(self.data_type)
177 |
178 | datasets = dict()
179 | for split in ann_info.keys():
180 | if split not in ["train", "val", "test"]:
181 | continue
182 |
183 | is_train = split == "train"
184 |
185 | # processors
186 | vis_processor = (
187 | self.vis_processors["train"]
188 | if is_train
189 | else self.vis_processors["eval"]
190 | )
191 | text_processor = (
192 | self.text_processors["train"]
193 | if is_train
194 | else self.text_processors["eval"]
195 | )
196 |
197 | # annotation path
198 | ann_paths = ann_info.get(split).storage
199 | if isinstance(ann_paths, str):
200 | ann_paths = [ann_paths]
201 |
202 | abs_ann_paths = []
203 | for ann_path in ann_paths:
204 | if not os.path.isabs(ann_path):
205 | ann_path = utils.get_cache_path(ann_path)
206 | abs_ann_paths.append(ann_path)
207 | ann_paths = abs_ann_paths
208 |
209 | # visual data storage path
210 | vis_path = os.path.join(vis_info.storage, split)
211 |
212 | if not os.path.isabs(vis_path):
213 | # vis_path = os.path.join(utils.get_cache_path(), vis_path)
214 | vis_path = utils.get_cache_path(vis_path)
215 |
216 | if not os.path.exists(vis_path):
217 | warnings.warn("storage path {} does not exist.".format(vis_path))
218 |
219 | # create datasets
220 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
221 | datasets[split] = dataset_cls(
222 | vis_processor=vis_processor,
223 | text_processor=text_processor,
224 | ann_paths=ann_paths,
225 | vis_root=vis_path,
226 | )
227 |
228 | return datasets
229 |
230 |
231 | def load_dataset_config(cfg_path):
232 | cfg = OmegaConf.load(cfg_path).datasets
233 | cfg = cfg[list(cfg.keys())[0]]
234 |
235 | return cfg
236 |
--------------------------------------------------------------------------------
/minigpt4/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
14 | from minigpt4.common.logger import MetricLogger, SmoothedValue
15 | from minigpt4.common.registry import registry
16 | from minigpt4.datasets.data_utils import prepare_sample
17 |
18 |
19 | class BaseTask:
20 | def __init__(self, **kwargs):
21 | super().__init__()
22 |
23 | self.inst_id_key = "instance_id"
24 |
25 | @classmethod
26 | def setup_task(cls, **kwargs):
27 | return cls()
28 |
29 | def build_model(self, cfg):
30 | model_config = cfg.model_cfg
31 |
32 | model_cls = registry.get_model_class(model_config.arch)
33 | return model_cls.from_config(model_config)
34 |
35 | def build_datasets(self, cfg):
36 | """
37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
38 | Download dataset and annotations automatically if not exist.
39 |
40 | Args:
41 | cfg (common.config.Config): _description_
42 |
43 | Returns:
44 | dict: Dictionary of torch.utils.data.Dataset objects by split.
45 | """
46 |
47 | datasets = dict()
48 |
49 | datasets_config = cfg.datasets_cfg
50 |
51 | assert len(datasets_config) > 0, "At least one dataset has to be specified."
52 |
53 | for name in datasets_config:
54 | dataset_config = datasets_config[name]
55 |
56 | builder = registry.get_builder_class(name)(dataset_config)
57 | dataset = builder.build_datasets()
58 |
59 | dataset['train'].name = name
60 | if 'sample_ratio' in dataset_config:
61 | dataset['train'].sample_ratio = dataset_config.sample_ratio
62 |
63 | datasets[name] = dataset
64 |
65 | return datasets
66 |
67 | def train_step(self, model, samples):
68 | loss = model(samples)["loss"]
69 | return loss
70 |
71 | def valid_step(self, model, samples):
72 | raise NotImplementedError
73 |
74 | def before_evaluation(self, model, dataset, **kwargs):
75 | model.before_evaluation(dataset=dataset, task_type=type(self))
76 |
77 | def after_evaluation(self, **kwargs):
78 | pass
79 |
80 | def inference_step(self):
81 | raise NotImplementedError
82 |
83 | def evaluation(self, model, data_loader, cuda_enabled=True):
84 | metric_logger = MetricLogger(delimiter=" ")
85 | header = "Evaluation"
86 | # TODO make it configurable
87 | print_freq = 10
88 |
89 | results = []
90 |
91 | for samples in metric_logger.log_every(data_loader, print_freq, header):
92 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
93 |
94 | eval_output = self.valid_step(model=model, samples=samples)
95 | results.extend(eval_output)
96 |
97 | if is_dist_avail_and_initialized():
98 | dist.barrier()
99 |
100 | return results
101 |
102 | def train_epoch(
103 | self,
104 | epoch,
105 | model,
106 | data_loader,
107 | optimizer,
108 | lr_scheduler,
109 | scaler=None,
110 | cuda_enabled=False,
111 | log_freq=50,
112 | accum_grad_iters=1,
113 | ):
114 | return self._train_inner_loop(
115 | epoch=epoch,
116 | iters_per_epoch=lr_scheduler.iters_per_epoch,
117 | model=model,
118 | data_loader=data_loader,
119 | optimizer=optimizer,
120 | scaler=scaler,
121 | lr_scheduler=lr_scheduler,
122 | log_freq=log_freq,
123 | cuda_enabled=cuda_enabled,
124 | accum_grad_iters=accum_grad_iters,
125 | )
126 |
127 | def train_iters(
128 | self,
129 | epoch,
130 | start_iters,
131 | iters_per_inner_epoch,
132 | model,
133 | data_loader,
134 | optimizer,
135 | lr_scheduler,
136 | scaler=None,
137 | cuda_enabled=False,
138 | log_freq=50,
139 | accum_grad_iters=1,
140 | ):
141 | return self._train_inner_loop(
142 | epoch=epoch,
143 | start_iters=start_iters,
144 | iters_per_epoch=iters_per_inner_epoch,
145 | model=model,
146 | data_loader=data_loader,
147 | optimizer=optimizer,
148 | scaler=scaler,
149 | lr_scheduler=lr_scheduler,
150 | log_freq=log_freq,
151 | cuda_enabled=cuda_enabled,
152 | accum_grad_iters=accum_grad_iters,
153 | )
154 |
155 | def _train_inner_loop(
156 | self,
157 | epoch,
158 | iters_per_epoch,
159 | model,
160 | data_loader,
161 | optimizer,
162 | lr_scheduler,
163 | scaler=None,
164 | start_iters=None,
165 | log_freq=50,
166 | cuda_enabled=False,
167 | accum_grad_iters=1,
168 | ):
169 | """
170 | An inner training loop compatible with both epoch-based and iter-based training.
171 |
172 | When using epoch-based, training stops after one epoch; when using iter-based,
173 | training stops after #iters_per_epoch iterations.
174 | """
175 | use_amp = scaler is not None
176 |
177 | if not hasattr(data_loader, "__next__"):
178 | # convert to iterator if not already
179 | data_loader = iter(data_loader)
180 |
181 | metric_logger = MetricLogger(delimiter=" ")
182 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
183 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
184 |
185 | # if iter-based runner, schedule lr based on inner epoch.
186 | logging.info(
187 | "Start training epoch {}, {} iters per inner epoch.".format(
188 | epoch, iters_per_epoch
189 | )
190 | )
191 | header = "Train: data epoch: [{}]".format(epoch)
192 | if start_iters is None:
193 | # epoch-based runner
194 | inner_epoch = epoch
195 | else:
196 | # In iter-based runner, we schedule the learning rate based on iterations.
197 | inner_epoch = start_iters // iters_per_epoch
198 | header = header + "; inner epoch [{}]".format(inner_epoch)
199 |
200 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
201 | # if using iter-based runner, we stop after iters_per_epoch iterations.
202 | if i >= iters_per_epoch:
203 | break
204 |
205 | samples = next(data_loader)
206 |
207 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
208 | samples.update(
209 | {
210 | "epoch": inner_epoch,
211 | "num_iters_per_epoch": iters_per_epoch,
212 | "iters": i,
213 | }
214 | )
215 |
216 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
217 |
218 | with torch.cuda.amp.autocast(enabled=use_amp):
219 | loss = self.train_step(model=model, samples=samples)
220 |
221 | # after_train_step()
222 | if use_amp:
223 | scaler.scale(loss).backward()
224 | else:
225 | loss.backward()
226 |
227 | # update gradients every accum_grad_iters iterations
228 | if (i + 1) % accum_grad_iters == 0:
229 | if use_amp:
230 | scaler.step(optimizer)
231 | scaler.update()
232 | else:
233 | optimizer.step()
234 | optimizer.zero_grad()
235 |
236 | metric_logger.update(loss=loss.item())
237 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
238 |
239 | # after train_epoch()
240 | # gather the stats from all processes
241 | metric_logger.synchronize_between_processes()
242 | logging.info("Averaged stats: " + str(metric_logger.global_avg()))
243 | return {
244 | k: "{:.3f}".format(meter.global_avg)
245 | for k, meter in metric_logger.meters.items()
246 | }
247 |
248 | @staticmethod
249 | def save_result(result, result_dir, filename, remove_duplicate=""):
250 | import json
251 |
252 | result_file = os.path.join(
253 | result_dir, "%s_rank%d.json" % (filename, get_rank())
254 | )
255 | final_result_file = os.path.join(result_dir, "%s.json" % filename)
256 |
257 | json.dump(result, open(result_file, "w"))
258 |
259 | if is_dist_avail_and_initialized():
260 | dist.barrier()
261 |
262 | if is_main_process():
263 | logging.warning("rank %d starts merging results." % get_rank())
264 | # combine results from all processes
265 | result = []
266 |
267 | for rank in range(get_world_size()):
268 | result_file = os.path.join(
269 | result_dir, "%s_rank%d.json" % (filename, rank)
270 | )
271 | res = json.load(open(result_file, "r"))
272 | result += res
273 |
274 | if remove_duplicate:
275 | result_new = []
276 | id_list = []
277 | for res in result:
278 | if res[remove_duplicate] not in id_list:
279 | id_list.append(res[remove_duplicate])
280 | result_new.append(res)
281 | result = result_new
282 |
283 | json.dump(result, open(final_result_file, "w"))
284 | print("result file saved to %s" % final_result_file)
285 |
286 | return final_result_file
287 |
--------------------------------------------------------------------------------
/minigpt4/common/registry.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 |
9 | class Registry:
10 | mapping = {
11 | "builder_name_mapping": {},
12 | "task_name_mapping": {},
13 | "processor_name_mapping": {},
14 | "model_name_mapping": {},
15 | "lr_scheduler_name_mapping": {},
16 | "runner_name_mapping": {},
17 | "state": {},
18 | "paths": {},
19 | }
20 |
21 | @classmethod
22 | def register_builder(cls, name):
23 | r"""Register a dataset builder to registry with key 'name'
24 |
25 | Args:
26 | name: Key with which the builder will be registered.
27 |
28 | Usage:
29 |
30 | from minigpt4.common.registry import registry
31 | from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32 | """
33 |
34 | def wrap(builder_cls):
35 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36 |
37 | assert issubclass(
38 | builder_cls, BaseDatasetBuilder
39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40 | builder_cls
41 | )
42 | if name in cls.mapping["builder_name_mapping"]:
43 | raise KeyError(
44 | "Name '{}' already registered for {}.".format(
45 | name, cls.mapping["builder_name_mapping"][name]
46 | )
47 | )
48 | cls.mapping["builder_name_mapping"][name] = builder_cls
49 | return builder_cls
50 |
51 | return wrap
52 |
53 | @classmethod
54 | def register_task(cls, name):
55 | r"""Register a task to registry with key 'name'
56 |
57 | Args:
58 | name: Key with which the task will be registered.
59 |
60 | Usage:
61 |
62 | from minigpt4.common.registry import registry
63 | """
64 |
65 | def wrap(task_cls):
66 | from minigpt4.tasks.base_task import BaseTask
67 |
68 | assert issubclass(
69 | task_cls, BaseTask
70 | ), "All tasks must inherit BaseTask class"
71 | if name in cls.mapping["task_name_mapping"]:
72 | raise KeyError(
73 | "Name '{}' already registered for {}.".format(
74 | name, cls.mapping["task_name_mapping"][name]
75 | )
76 | )
77 | cls.mapping["task_name_mapping"][name] = task_cls
78 | return task_cls
79 |
80 | return wrap
81 |
82 | @classmethod
83 | def register_model(cls, name):
84 | r"""Register a task to registry with key 'name'
85 |
86 | Args:
87 | name: Key with which the task will be registered.
88 |
89 | Usage:
90 |
91 | from minigpt4.common.registry import registry
92 | """
93 |
94 | def wrap(model_cls):
95 | from minigpt4.models import BaseModel
96 |
97 | assert issubclass(
98 | model_cls, BaseModel
99 | ), "All models must inherit BaseModel class"
100 | if name in cls.mapping["model_name_mapping"]:
101 | raise KeyError(
102 | "Name '{}' already registered for {}.".format(
103 | name, cls.mapping["model_name_mapping"][name]
104 | )
105 | )
106 | cls.mapping["model_name_mapping"][name] = model_cls
107 | return model_cls
108 |
109 | return wrap
110 |
111 | @classmethod
112 | def register_processor(cls, name):
113 | r"""Register a processor to registry with key 'name'
114 |
115 | Args:
116 | name: Key with which the task will be registered.
117 |
118 | Usage:
119 |
120 | from minigpt4.common.registry import registry
121 | """
122 |
123 | def wrap(processor_cls):
124 | from minigpt4.processors import BaseProcessor
125 |
126 | assert issubclass(
127 | processor_cls, BaseProcessor
128 | ), "All processors must inherit BaseProcessor class"
129 | if name in cls.mapping["processor_name_mapping"]:
130 | raise KeyError(
131 | "Name '{}' already registered for {}.".format(
132 | name, cls.mapping["processor_name_mapping"][name]
133 | )
134 | )
135 | cls.mapping["processor_name_mapping"][name] = processor_cls
136 | return processor_cls
137 |
138 | return wrap
139 |
140 | @classmethod
141 | def register_lr_scheduler(cls, name):
142 | r"""Register a model to registry with key 'name'
143 |
144 | Args:
145 | name: Key with which the task will be registered.
146 |
147 | Usage:
148 |
149 | from minigpt4.common.registry import registry
150 | """
151 |
152 | def wrap(lr_sched_cls):
153 | if name in cls.mapping["lr_scheduler_name_mapping"]:
154 | raise KeyError(
155 | "Name '{}' already registered for {}.".format(
156 | name, cls.mapping["lr_scheduler_name_mapping"][name]
157 | )
158 | )
159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160 | return lr_sched_cls
161 |
162 | return wrap
163 |
164 | @classmethod
165 | def register_runner(cls, name):
166 | r"""Register a model to registry with key 'name'
167 |
168 | Args:
169 | name: Key with which the task will be registered.
170 |
171 | Usage:
172 |
173 | from minigpt4.common.registry import registry
174 | """
175 |
176 | def wrap(runner_cls):
177 | if name in cls.mapping["runner_name_mapping"]:
178 | raise KeyError(
179 | "Name '{}' already registered for {}.".format(
180 | name, cls.mapping["runner_name_mapping"][name]
181 | )
182 | )
183 | cls.mapping["runner_name_mapping"][name] = runner_cls
184 | return runner_cls
185 |
186 | return wrap
187 |
188 | @classmethod
189 | def register_path(cls, name, path):
190 | r"""Register a path to registry with key 'name'
191 |
192 | Args:
193 | name: Key with which the path will be registered.
194 |
195 | Usage:
196 |
197 | from minigpt4.common.registry import registry
198 | """
199 | assert isinstance(path, str), "All path must be str."
200 | if name in cls.mapping["paths"]:
201 | raise KeyError("Name '{}' already registered.".format(name))
202 | cls.mapping["paths"][name] = path
203 |
204 | @classmethod
205 | def register(cls, name, obj):
206 | r"""Register an item to registry with key 'name'
207 |
208 | Args:
209 | name: Key with which the item will be registered.
210 |
211 | Usage::
212 |
213 | from minigpt4.common.registry import registry
214 |
215 | registry.register("config", {})
216 | """
217 | path = name.split(".")
218 | current = cls.mapping["state"]
219 |
220 | for part in path[:-1]:
221 | if part not in current:
222 | current[part] = {}
223 | current = current[part]
224 |
225 | current[path[-1]] = obj
226 |
227 | # @classmethod
228 | # def get_trainer_class(cls, name):
229 | # return cls.mapping["trainer_name_mapping"].get(name, None)
230 |
231 | @classmethod
232 | def get_builder_class(cls, name):
233 | return cls.mapping["builder_name_mapping"].get(name, None)
234 |
235 | @classmethod
236 | def get_model_class(cls, name):
237 | return cls.mapping["model_name_mapping"].get(name, None)
238 |
239 | @classmethod
240 | def get_task_class(cls, name):
241 | return cls.mapping["task_name_mapping"].get(name, None)
242 |
243 | @classmethod
244 | def get_processor_class(cls, name):
245 | return cls.mapping["processor_name_mapping"].get(name, None)
246 |
247 | @classmethod
248 | def get_lr_scheduler_class(cls, name):
249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250 |
251 | @classmethod
252 | def get_runner_class(cls, name):
253 | return cls.mapping["runner_name_mapping"].get(name, None)
254 |
255 | @classmethod
256 | def list_runners(cls):
257 | return sorted(cls.mapping["runner_name_mapping"].keys())
258 |
259 | @classmethod
260 | def list_models(cls):
261 | return sorted(cls.mapping["model_name_mapping"].keys())
262 |
263 | @classmethod
264 | def list_tasks(cls):
265 | return sorted(cls.mapping["task_name_mapping"].keys())
266 |
267 | @classmethod
268 | def list_processors(cls):
269 | return sorted(cls.mapping["processor_name_mapping"].keys())
270 |
271 | @classmethod
272 | def list_lr_schedulers(cls):
273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274 |
275 | @classmethod
276 | def list_datasets(cls):
277 | return sorted(cls.mapping["builder_name_mapping"].keys())
278 |
279 | @classmethod
280 | def get_path(cls, name):
281 | return cls.mapping["paths"].get(name, None)
282 |
283 | @classmethod
284 | def get(cls, name, default=None, no_warning=False):
285 | r"""Get an item from registry with key 'name'
286 |
287 | Args:
288 | name (string): Key whose value needs to be retrieved.
289 | default: If passed and key is not in registry, default value will
290 | be returned with a warning. Default: None
291 | no_warning (bool): If passed as True, warning when key doesn't exist
292 | will not be generated. Useful for MMF's
293 | internal operations. Default: False
294 | """
295 | original_name = name
296 | name = name.split(".")
297 | value = cls.mapping["state"]
298 | for subname in name:
299 | value = value.get(subname, default)
300 | if value is default:
301 | break
302 |
303 | if (
304 | "writer" in cls.mapping["state"]
305 | and value == default
306 | and no_warning is False
307 | ):
308 | cls.mapping["state"]["writer"].warning(
309 | "Key {} is not present in registry, returning default value "
310 | "of {}".format(original_name, default)
311 | )
312 | return value
313 |
314 | @classmethod
315 | def unregister(cls, name):
316 | r"""Remove an item from registry with key 'name'
317 |
318 | Args:
319 | name: Key which needs to be removed.
320 | Usage::
321 |
322 | from mmf.common.registry import registry
323 |
324 | config = registry.unregister("config")
325 | """
326 | return cls.mapping["state"].pop(name, None)
327 |
328 |
329 | registry = Registry()
330 |
--------------------------------------------------------------------------------
/backup_app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | import random
5 | import glob
6 | import time
7 |
8 | import numpy as np
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from PIL import Image
12 |
13 | from minigpt4.common.config import Config
14 | from minigpt4.common.dist_utils import get_rank
15 | from minigpt4.common.registry import registry
16 | from minigpt4.conversation.conversation import Chat, CONV_VISION
17 |
18 | # imports modules for registration
19 | from minigpt4.datasets.builders import *
20 | from minigpt4.models import *
21 | from minigpt4.processors import *
22 | from minigpt4.runners import *
23 | from minigpt4.tasks import *
24 |
25 | import cv2
26 | from tqdm import tqdm
27 | from tensorflow.keras.models import load_model
28 | from huggingface_hub import hf_hub_download
29 | from pathlib import Path
30 | from copy import deepcopy
31 | import keras
32 |
33 |
34 | #parsing the arguments
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description="Combined Demo")
38 | parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
39 | parser.add_argument(
40 | "--options",
41 | nargs="+",
42 | help="override some settings in the used config, the key-value pair "
43 | "in xxx=yyy format will be merged into config file (deprecate), "
44 | "change to --cfg-options instead.",
45 | )
46 | parser.add_argument("--image-folder", type=str, required=True, help="Path to the input image folder.")
47 | parser.add_argument("--model", type=str, default='llama', help="Model to be used for generation. Options: 'llama' (default), 'llama7b'")
48 | parser.add_argument("--beam-search-numbers", type=int, default=1, help="beam search numbers")
49 | parser.add_argument("--model-dir", type=str, required=True, help="Path to the model directory.")
50 | parser.add_argument("--repo-id", type=str, default=DEFAULT_WD14_TAGGER_REPO, help="Hugging Face model repository ID.")
51 | parser.add_argument("--force-download", action="store_true", help="Force download the model.")
52 | parser.add_argument("--general-threshold", type=float, default=0.5, help="Threshold for general tags.")
53 | parser.add_argument("--character-threshold", type=float, default=0.5, help="Threshold for character tags.")
54 | parser.add_argument("--remove-underscore", action="store_true", help="Remove underscores from captions.")
55 | parser.add_argument("--undesired-tags", type=str, default="", help="Comma separated list of undesired tags.")
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | # these are functions taken from minigpt4 app.py
61 | def setup_seeds(config):
62 | seed = config.run_cfg.seed + get_rank()
63 |
64 | random.seed(seed)
65 | np.random.seed(seed)
66 | torch.manual_seed(seed)
67 |
68 | cudnn.benchmark = False
69 | cudnn.deterministic = True
70 |
71 |
72 | def describe_image(image_path, chat, chat_state, img, num_beams=1, temperature=1.0):
73 | chat_state = CONV_VISION.copy()
74 | img_list = []
75 |
76 | gr_img = Image.open(image_path)
77 | llm_message = chat.upload_img(gr_img, chat_state, img_list)
78 |
79 | chat.ask("Describe this image.", chat_state)
80 | generated_caption = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
81 |
82 | return generated_caption
83 |
84 |
85 | #these are functions taken from wd_tags.py
86 |
87 | IMAGE_SIZE = 448
88 | IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp",".webp"]
89 |
90 | # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
91 | DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
92 | FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
93 | SUB_DIR = "variables"
94 | SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
95 | CSV_FILE = FILES[-1]
96 |
97 | def glob_images_pathlib(dir_path, recursive):
98 | image_paths = []
99 | if recursive:
100 | for ext in IMAGE_EXTENSIONS:
101 | image_paths += list(dir_path.rglob("*" + ext))
102 | else:
103 | for ext in IMAGE_EXTENSIONS:
104 | image_paths += list(dir_path.glob("*" + ext))
105 | image_paths = list(set(image_paths)) # Remove duplicates
106 | image_paths.sort()
107 | return image_paths
108 |
109 |
110 | def preprocess_image(image):
111 | image = np.array(image)
112 | image = image[:, :, ::-1] # RGB->BGR
113 |
114 | # pad to square
115 | size = max(image.shape[0:2])
116 | pad_x = size - image.shape[1]
117 | pad_y = size - image.shape[0]
118 | pad_l = pad_x // 2
119 | pad_t = pad_y // 2
120 | image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
121 |
122 | interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
123 | image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
124 |
125 | image = image.astype(np.float32)
126 | return image
127 |
128 |
129 | class ImageLoadingPrepDataset(torch.utils.data.Dataset):
130 | def __init__(self, image_paths):
131 | self.images = image_paths
132 |
133 | def __len__(self):
134 | return len(self.images)
135 |
136 | def __getitem__(self, idx):
137 | img_path = str(self.images[idx])
138 |
139 | try:
140 | image = Image.open(img_path).convert("RGB")
141 | image = preprocess_image(image)
142 | tensor = torch.tensor(image)
143 | except Exception as e:
144 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
145 | return None
146 |
147 | return (tensor, img_path)
148 |
149 |
150 | def collate_fn_remove_corrupted(batch):
151 | """Collate function that allows to remove corrupted examples in the
152 | dataloader. It expects that the dataloader returns 'None' when that occurs.
153 | The 'None's in the batch are removed.
154 | """
155 | # Filter out all the Nones (corrupted examples)
156 | batch = list(filter(lambda x: x is not None, batch))
157 | return batch
158 |
159 | def run_batch(images, model, args):
160 | # define the tags
161 | with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
162 | reader = csv.reader(f)
163 | l = [row for row in reader]
164 | header = l[0] # tag_id, name, category, count
165 | rows = l[1:]
166 | assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
167 | general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
168 | character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
169 |
170 | undesired_tags = set(args.undesired_tags.split(","))
171 |
172 | # Process images to generate captions
173 | probs = model(np.array(images), training=False)
174 | captions = []
175 | for batch_probs in probs.numpy():
176 | tag_text = ""
177 | for i, p in enumerate(batch_probs[4:]):
178 | if i < len(general_tags) and p >= args.general_threshold:
179 | tag_name = general_tags[i]
180 | tag_name = tag_name if not args.remove_underscore or len(tag_name) <= 3 else tag_name.replace("_", " ")
181 | if tag_name not in undesired_tags:
182 | tag_text += ", " + tag_name
183 | elif i >= len(general_tags) and p >= args.character_threshold:
184 | tag_name = character_tags[i - len(general_tags)]
185 | tag_name = tag_name if not args.remove_underscore or len(tag_name) <= 3 else tag_name.replace("_", " ")
186 | if tag_name not in undesired_tags:
187 | tag_text += ", " + tag_name
188 | tag_text = tag_text[2:] if len(tag_text) > 0 else ''
189 | captions.append(tag_text)
190 | return captions
191 |
192 | def wd_pass(image_paths, model, args):
193 | # Preprocess the image
194 | captions = []
195 | for image_path in image_paths:
196 | image = Image.open(image_path)
197 | image = preprocess_image(image)
198 | captions.append(run_batch([image], model, args))
199 | return captions
200 |
201 | def main():
202 | args = parse_args()
203 |
204 | # check for the model
205 | if not os.path.exists(args.model_dir) or args.force_download:
206 | print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
207 | for file in FILES:
208 | hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True)
209 |
210 | for file in SUB_DIR_FILES:
211 | hf_hub_download(
212 | args.repo_id,
213 | file,
214 | subfolder=SUB_DIR,
215 | cache_dir=os.path.join(args.model_dir, SUB_DIR),
216 | force_download=True,
217 | )
218 |
219 | cfg = Config(args)
220 | model_config = cfg.model_cfg
221 |
222 | model_cls = registry.get_model_class(model_config.arch)
223 | model = model_cls.from_config(model_config)
224 |
225 | model = model.to(torch.device('cuda'))
226 |
227 | vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
228 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
229 | chat = Chat(model, vis_processor)
230 |
231 | chat_state = deepcopy(CONV_VISION)
232 | img_list = []
233 |
234 | image_folder = args.image_folder
235 | num_beams = args.beam_search_numbers
236 | temperature = 1.0 # default temperature
237 |
238 | image_extensions = ['jpg', 'jpeg', 'png', 'bmp']
239 | image_paths = []
240 |
241 | for ext in image_extensions:
242 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext}')))
243 | image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext.upper()}')))
244 |
245 | if not os.path.exists("mycaptions"):
246 | os.makedirs("mycaptions")
247 |
248 | for image_path in image_paths:
249 | start_time = time.time()
250 | caption = describe_image(image_path, chat, chat_state, img_list, num_beams, temperature)
251 |
252 | with open("mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]), "w") as f:
253 | f.write(caption)
254 |
255 | end_time = time.time()
256 | time_taken = end_time - start_time
257 | print(f"Caption for {os.path.basename(image_path)} saved in 'mycaptions' folder")
258 | print(f"Time taken to process caption for {os.path.basename(image_path)} is: {time_taken:.2f} s")
259 |
260 | del model # Unload pytorch model from memory
261 | torch.cuda.empty_cache()
262 |
263 | # Load Keras model
264 | keras.backend.clear_session()
265 | model = load_model(args.model_dir)
266 |
267 | wd_captions = wd_pass(image_paths, model, args)
268 |
269 | for image_path, wd_caption in zip(image_paths, wd_captions):
270 | wd_caption = wd_caption[0]
271 | with open("mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]), "a") as f:
272 | f.write(str(wd_caption))
273 |
274 | del model # Unload keras model from memory
275 | keras.backend.clear_session()
276 |
277 | if __name__ == '__main__':
278 | main()
--------------------------------------------------------------------------------
/minigpt4/models/mini_gpt4.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import logging
8 | import random
9 | import os
10 | import torch
11 | from torch.cuda.amp import autocast as autocast
12 | import torch.nn as nn
13 |
14 | from minigpt4.common.registry import registry
15 | from minigpt4.models.blip2 import Blip2Base, disabled_train
16 | from minigpt4.models.modeling_llama import LlamaForCausalLM
17 | from transformers import LlamaTokenizer
18 |
19 |
20 | @registry.register_model("mini_gpt4")
21 | class MiniGPT4(Blip2Base):
22 | """
23 | BLIP2 GPT-LLAMA model.
24 | """
25 |
26 | PRETRAINED_MODEL_CONFIG_DICT = {
27 | "pretrain_vicuna": "configs/models/minigpt4.yaml",
28 | }
29 |
30 | def __init__(
31 | self,
32 | vit_model="eva_clip_g",
33 | q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
34 | img_size=224,
35 | drop_path_rate=0,
36 | use_grad_checkpoint=False,
37 | vit_precision="fp16",
38 | freeze_vit=True,
39 | freeze_qformer=True,
40 | num_query_token=32,
41 | llama_model="",
42 | llama_cache_dir='',
43 | prompt_path="",
44 | prompt_template="",
45 | max_txt_len=32,
46 | end_sym='\n',
47 | ):
48 | super().__init__()
49 |
50 | self.tokenizer = self.init_tokenizer()
51 |
52 | print('Loading VIT')
53 | self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54 | vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55 | )
56 | if freeze_vit:
57 | for name, param in self.visual_encoder.named_parameters():
58 | param.requires_grad = False
59 | self.visual_encoder = self.visual_encoder.eval()
60 | self.visual_encoder.train = disabled_train
61 | for name, param in self.ln_vision.named_parameters():
62 | param.requires_grad = False
63 | self.ln_vision = self.ln_vision.eval()
64 | self.ln_vision.train = disabled_train
65 | logging.info("freeze vision encoder")
66 | print('Loading VIT Done')
67 |
68 | print('Loading Q-Former')
69 | self.Qformer, self.query_tokens = self.init_Qformer(
70 | num_query_token, self.visual_encoder.num_features
71 | )
72 | self.Qformer.cls = None
73 | self.Qformer.bert.embeddings.word_embeddings = None
74 | self.Qformer.bert.embeddings.position_embeddings = None
75 | for layer in self.Qformer.bert.encoder.layer:
76 | layer.output = None
77 | layer.intermediate = None
78 | self.load_from_pretrained(url_or_filename=q_former_model)
79 |
80 | if freeze_qformer:
81 | for name, param in self.Qformer.named_parameters():
82 | param.requires_grad = False
83 | self.Qformer = self.Qformer.eval()
84 | self.Qformer.train = disabled_train
85 | self.query_tokens.requires_grad = False
86 | logging.info("freeze Qformer")
87 | print('Loading Q-Former Done')
88 |
89 | print('Loading LLAMA')
90 | self.llama_tokenizer = LlamaTokenizer.from_pretrained('camenduru/MiniGPT4-7B', use_fast=False)
91 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92 |
93 | if llama_cache_dir:
94 | self.llama_model = LlamaForCausalLM.from_pretrained(
95 | 'camenduru/MiniGPT4-7B', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto"
96 | )
97 | else:
98 | self.llama_model = LlamaForCausalLM.from_pretrained(
99 | 'camenduru/MiniGPT4-7B', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto"
100 | )
101 | for name, param in self.llama_model.named_parameters():
102 | param.requires_grad = False
103 | print('Loading LLAMA Done')
104 |
105 | self.llama_proj = nn.Linear(
106 | self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
107 | )
108 | self.max_txt_len = max_txt_len
109 | self.end_sym = end_sym
110 |
111 | if prompt_path:
112 | with open(prompt_path, 'r') as f:
113 | raw_prompts = f.read().splitlines()
114 | filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt]
115 | self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
116 | print('Load {} training prompts'.format(len(self.prompt_list)))
117 | print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
118 | else:
119 | self.prompt_list = []
120 |
121 | def vit_to_cpu(self):
122 | self.ln_vision.to("cpu")
123 | self.ln_vision.float()
124 | self.visual_encoder.to("cpu")
125 | self.visual_encoder.float()
126 |
127 | def encode_img(self, image):
128 | device = image.device
129 | self.vit_to_cpu()
130 | image = image.to("cpu")
131 | with self.maybe_autocast():
132 | image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
133 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
134 |
135 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
136 | query_output = self.Qformer.bert(
137 | query_embeds=query_tokens,
138 | encoder_hidden_states=image_embeds,
139 | encoder_attention_mask=image_atts,
140 | return_dict=True,
141 | )
142 |
143 | inputs_llama = self.llama_proj(query_output.last_hidden_state)
144 | atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
145 | return inputs_llama, atts_llama
146 |
147 | def prompt_wrap(self, img_embeds, atts_img, prompt):
148 | if prompt:
149 | batch_size = img_embeds.shape[0]
150 | p_before, p_after = prompt.split('')
151 | p_before_tokens = self.llama_tokenizer(
152 | p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
153 | p_after_tokens = self.llama_tokenizer(
154 | p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
155 | p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
156 | p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
157 | wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
158 | wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
159 | return wrapped_img_embeds, wrapped_atts_img
160 | else:
161 | return img_embeds, atts_img
162 |
163 | def forward(self, samples):
164 | image = samples["image"]
165 | img_embeds, atts_img = self.encode_img(image)
166 | if hasattr(samples, 'question_split'): # VQA dataset
167 | print('VQA Batch')
168 | vqa_prompt = '###Human:
'
169 | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
170 | elif self.prompt_list:
171 | prompt = random.choice(self.prompt_list)
172 | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
173 |
174 | self.llama_tokenizer.padding_side = "right"
175 |
176 | text = [t + self.end_sym for t in samples["text_input"]]
177 |
178 | to_regress_tokens = self.llama_tokenizer(
179 | text,
180 | return_tensors="pt",
181 | padding="longest",
182 | truncation=True,
183 | max_length=self.max_txt_len,
184 | add_special_tokens=False
185 | ).to(image.device)
186 |
187 | targets = to_regress_tokens.input_ids.masked_fill(
188 | to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
189 | )
190 |
191 | empty_targets = (
192 | torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
193 | dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
194 | )
195 | targets = torch.cat([empty_targets, targets], dim=1)
196 |
197 | batch_size = img_embeds.shape[0]
198 | bos = torch.ones([batch_size, 1],
199 | dtype=to_regress_tokens.input_ids.dtype,
200 | device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
201 | bos_embeds = self.llama_model.model.embed_tokens(bos)
202 | atts_bos = atts_img[:, :1]
203 |
204 | to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
205 | inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
206 | attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
207 |
208 | with self.maybe_autocast():
209 | outputs = self.llama_model(
210 | inputs_embeds=inputs_embeds,
211 | attention_mask=attention_mask,
212 | return_dict=True,
213 | labels=targets,
214 | )
215 | loss = outputs.loss
216 |
217 | return {"loss": loss}
218 |
219 | @classmethod
220 | def from_config(cls, cfg):
221 | vit_model = cfg.get("vit_model", "eva_clip_g")
222 | q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
223 | img_size = cfg.get("image_size")
224 | num_query_token = cfg.get("num_query_token")
225 | llama_model = cfg.get("llama_model")
226 |
227 | drop_path_rate = cfg.get("drop_path_rate", 0)
228 | use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
229 | vit_precision = cfg.get("vit_precision", "fp16")
230 | freeze_vit = cfg.get("freeze_vit", True)
231 | freeze_qformer = cfg.get("freeze_qformer", True)
232 | llama_cache_dir = cfg.get("llama_cache_dir", "")
233 |
234 | prompt_path = cfg.get("prompt_path", "")
235 | prompt_template = cfg.get("prompt_template", "")
236 | max_txt_len = cfg.get("max_txt_len", 32)
237 | end_sym = cfg.get("end_sym", '\n')
238 |
239 | model = cls(
240 | vit_model=vit_model,
241 | q_former_model=q_former_model,
242 | img_size=img_size,
243 | drop_path_rate=drop_path_rate,
244 | use_grad_checkpoint=use_grad_checkpoint,
245 | vit_precision=vit_precision,
246 | freeze_vit=freeze_vit,
247 | freeze_qformer=freeze_qformer,
248 | llama_cache_dir=llama_cache_dir,
249 | num_query_token=num_query_token,
250 | llama_model=llama_model,
251 | prompt_path=prompt_path,
252 | prompt_template=prompt_template,
253 | max_txt_len=max_txt_len,
254 | end_sym=end_sym
255 | )
256 |
257 | ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
258 | if ckpt_path:
259 | print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
260 | ckpt = torch.load(ckpt_path, map_location="cpu")
261 | msg = model.load_state_dict(ckpt['model'], strict=False)
262 |
263 | return model
264 |
--------------------------------------------------------------------------------
/minigpt4/processors/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 |
11 | import torch
12 |
13 |
14 | ## aug functions
15 | def identity_func(img):
16 | return img
17 |
18 |
19 | def autocontrast_func(img, cutoff=0):
20 | """
21 | same output as PIL.ImageOps.autocontrast
22 | """
23 | n_bins = 256
24 |
25 | def tune_channel(ch):
26 | n = ch.size
27 | cut = cutoff * n // 100
28 | if cut == 0:
29 | high, low = ch.max(), ch.min()
30 | else:
31 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32 | low = np.argwhere(np.cumsum(hist) > cut)
33 | low = 0 if low.shape[0] == 0 else low[0]
34 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36 | if high <= low:
37 | table = np.arange(n_bins)
38 | else:
39 | scale = (n_bins - 1) / (high - low)
40 | offset = -low * scale
41 | table = np.arange(n_bins) * scale + offset
42 | table[table < 0] = 0
43 | table[table > n_bins - 1] = n_bins - 1
44 | table = table.clip(0, 255).astype(np.uint8)
45 | return table[ch]
46 |
47 | channels = [tune_channel(ch) for ch in cv2.split(img)]
48 | out = cv2.merge(channels)
49 | return out
50 |
51 |
52 | def equalize_func(img):
53 | """
54 | same output as PIL.ImageOps.equalize
55 | PIL's implementation is different from cv2.equalize
56 | """
57 | n_bins = 256
58 |
59 | def tune_channel(ch):
60 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61 | non_zero_hist = hist[hist != 0].reshape(-1)
62 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63 | if step == 0:
64 | return ch
65 | n = np.empty_like(hist)
66 | n[0] = step // 2
67 | n[1:] = hist[:-1]
68 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69 | return table[ch]
70 |
71 | channels = [tune_channel(ch) for ch in cv2.split(img)]
72 | out = cv2.merge(channels)
73 | return out
74 |
75 |
76 | def rotate_func(img, degree, fill=(0, 0, 0)):
77 | """
78 | like PIL, rotate by degree, not radians
79 | """
80 | H, W = img.shape[0], img.shape[1]
81 | center = W / 2, H / 2
82 | M = cv2.getRotationMatrix2D(center, degree, 1)
83 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84 | return out
85 |
86 |
87 | def solarize_func(img, thresh=128):
88 | """
89 | same output as PIL.ImageOps.posterize
90 | """
91 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
92 | table = table.clip(0, 255).astype(np.uint8)
93 | out = table[img]
94 | return out
95 |
96 |
97 | def color_func(img, factor):
98 | """
99 | same output as PIL.ImageEnhance.Color
100 | """
101 | ## implementation according to PIL definition, quite slow
102 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103 | # out = blend(degenerate, img, factor)
104 | # M = (
105 | # np.eye(3) * factor
106 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107 | # )[np.newaxis, np.newaxis, :]
108 | M = np.float32(
109 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110 | ) * factor + np.float32([[0.114], [0.587], [0.299]])
111 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112 | return out
113 |
114 |
115 | def contrast_func(img, factor):
116 | """
117 | same output as PIL.ImageEnhance.Contrast
118 | """
119 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120 | table = (
121 | np.array([(el - mean) * factor + mean for el in range(256)])
122 | .clip(0, 255)
123 | .astype(np.uint8)
124 | )
125 | out = table[img]
126 | return out
127 |
128 |
129 | def brightness_func(img, factor):
130 | """
131 | same output as PIL.ImageEnhance.Contrast
132 | """
133 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134 | out = table[img]
135 | return out
136 |
137 |
138 | def sharpness_func(img, factor):
139 | """
140 | The differences the this result and PIL are all on the 4 boundaries, the center
141 | areas are same
142 | """
143 | kernel = np.ones((3, 3), dtype=np.float32)
144 | kernel[1][1] = 5
145 | kernel /= 13
146 | degenerate = cv2.filter2D(img, -1, kernel)
147 | if factor == 0.0:
148 | out = degenerate
149 | elif factor == 1.0:
150 | out = img
151 | else:
152 | out = img.astype(np.float32)
153 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155 | out = out.astype(np.uint8)
156 | return out
157 |
158 |
159 | def shear_x_func(img, factor, fill=(0, 0, 0)):
160 | H, W = img.shape[0], img.shape[1]
161 | M = np.float32([[1, factor, 0], [0, 1, 0]])
162 | out = cv2.warpAffine(
163 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164 | ).astype(np.uint8)
165 | return out
166 |
167 |
168 | def translate_x_func(img, offset, fill=(0, 0, 0)):
169 | """
170 | same output as PIL.Image.transform
171 | """
172 | H, W = img.shape[0], img.shape[1]
173 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
174 | out = cv2.warpAffine(
175 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176 | ).astype(np.uint8)
177 | return out
178 |
179 |
180 | def translate_y_func(img, offset, fill=(0, 0, 0)):
181 | """
182 | same output as PIL.Image.transform
183 | """
184 | H, W = img.shape[0], img.shape[1]
185 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
186 | out = cv2.warpAffine(
187 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188 | ).astype(np.uint8)
189 | return out
190 |
191 |
192 | def posterize_func(img, bits):
193 | """
194 | same output as PIL.ImageOps.posterize
195 | """
196 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197 | return out
198 |
199 |
200 | def shear_y_func(img, factor, fill=(0, 0, 0)):
201 | H, W = img.shape[0], img.shape[1]
202 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
203 | out = cv2.warpAffine(
204 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205 | ).astype(np.uint8)
206 | return out
207 |
208 |
209 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
210 | replace = np.array(replace, dtype=np.uint8)
211 | H, W = img.shape[0], img.shape[1]
212 | rh, rw = np.random.random(2)
213 | pad_size = pad_size // 2
214 | ch, cw = int(rh * H), int(rw * W)
215 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217 | out = img.copy()
218 | out[x1:x2, y1:y2, :] = replace
219 | return out
220 |
221 |
222 | ### level to args
223 | def enhance_level_to_args(MAX_LEVEL):
224 | def level_to_args(level):
225 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226 |
227 | return level_to_args
228 |
229 |
230 | def shear_level_to_args(MAX_LEVEL, replace_value):
231 | def level_to_args(level):
232 | level = (level / MAX_LEVEL) * 0.3
233 | if np.random.random() > 0.5:
234 | level = -level
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241 | def level_to_args(level):
242 | level = (level / MAX_LEVEL) * float(translate_const)
243 | if np.random.random() > 0.5:
244 | level = -level
245 | return (level, replace_value)
246 |
247 | return level_to_args
248 |
249 |
250 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251 | def level_to_args(level):
252 | level = int((level / MAX_LEVEL) * cutout_const)
253 | return (level, replace_value)
254 |
255 | return level_to_args
256 |
257 |
258 | def solarize_level_to_args(MAX_LEVEL):
259 | def level_to_args(level):
260 | level = int((level / MAX_LEVEL) * 256)
261 | return (level,)
262 |
263 | return level_to_args
264 |
265 |
266 | def none_level_to_args(level):
267 | return ()
268 |
269 |
270 | def posterize_level_to_args(MAX_LEVEL):
271 | def level_to_args(level):
272 | level = int((level / MAX_LEVEL) * 4)
273 | return (level,)
274 |
275 | return level_to_args
276 |
277 |
278 | def rotate_level_to_args(MAX_LEVEL, replace_value):
279 | def level_to_args(level):
280 | level = (level / MAX_LEVEL) * 30
281 | if np.random.random() < 0.5:
282 | level = -level
283 | return (level, replace_value)
284 |
285 | return level_to_args
286 |
287 |
288 | func_dict = {
289 | "Identity": identity_func,
290 | "AutoContrast": autocontrast_func,
291 | "Equalize": equalize_func,
292 | "Rotate": rotate_func,
293 | "Solarize": solarize_func,
294 | "Color": color_func,
295 | "Contrast": contrast_func,
296 | "Brightness": brightness_func,
297 | "Sharpness": sharpness_func,
298 | "ShearX": shear_x_func,
299 | "TranslateX": translate_x_func,
300 | "TranslateY": translate_y_func,
301 | "Posterize": posterize_func,
302 | "ShearY": shear_y_func,
303 | }
304 |
305 | translate_const = 10
306 | MAX_LEVEL = 10
307 | replace_value = (128, 128, 128)
308 | arg_dict = {
309 | "Identity": none_level_to_args,
310 | "AutoContrast": none_level_to_args,
311 | "Equalize": none_level_to_args,
312 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313 | "Solarize": solarize_level_to_args(MAX_LEVEL),
314 | "Color": enhance_level_to_args(MAX_LEVEL),
315 | "Contrast": enhance_level_to_args(MAX_LEVEL),
316 | "Brightness": enhance_level_to_args(MAX_LEVEL),
317 | "Sharpness": enhance_level_to_args(MAX_LEVEL),
318 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321 | "Posterize": posterize_level_to_args(MAX_LEVEL),
322 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323 | }
324 |
325 |
326 | class RandomAugment(object):
327 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328 | self.N = N
329 | self.M = M
330 | self.isPIL = isPIL
331 | if augs:
332 | self.augs = augs
333 | else:
334 | self.augs = list(arg_dict.keys())
335 |
336 | def get_random_ops(self):
337 | sampled_ops = np.random.choice(self.augs, self.N)
338 | return [(op, 0.5, self.M) for op in sampled_ops]
339 |
340 | def __call__(self, img):
341 | if self.isPIL:
342 | img = np.array(img)
343 | ops = self.get_random_ops()
344 | for name, prob, level in ops:
345 | if np.random.random() > prob:
346 | continue
347 | args = arg_dict[name](level)
348 | img = func_dict[name](img, *args)
349 | return img
350 |
351 |
352 | class VideoRandomAugment(object):
353 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354 | self.N = N
355 | self.M = M
356 | self.p = p
357 | self.tensor_in_tensor_out = tensor_in_tensor_out
358 | if augs:
359 | self.augs = augs
360 | else:
361 | self.augs = list(arg_dict.keys())
362 |
363 | def get_random_ops(self):
364 | sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365 | return [(op, self.M) for op in sampled_ops]
366 |
367 | def __call__(self, frames):
368 | assert (
369 | frames.shape[-1] == 3
370 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371 |
372 | if self.tensor_in_tensor_out:
373 | frames = frames.numpy().astype(np.uint8)
374 |
375 | num_frames = frames.shape[0]
376 |
377 | ops = num_frames * [self.get_random_ops()]
378 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379 |
380 | frames = torch.stack(
381 | list(map(self._aug, frames, ops, apply_or_not)), dim=0
382 | ).float()
383 |
384 | return frames
385 |
386 | def _aug(self, img, ops, apply_or_not):
387 | for i, (name, level) in enumerate(ops):
388 | if not apply_or_not[i]:
389 | continue
390 | args = arg_dict[name](level)
391 | img = func_dict[name](img, *args)
392 | return torch.from_numpy(img)
393 |
394 |
395 | if __name__ == "__main__":
396 | a = RandomAugment()
397 | img = np.random.randn(32, 32, 3)
398 | a(img)
399 |
--------------------------------------------------------------------------------
/minigpt4/common/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import io
9 | import json
10 | import logging
11 | import os
12 | import pickle
13 | import re
14 | import shutil
15 | import urllib
16 | import urllib.error
17 | import urllib.request
18 | from typing import Optional
19 | from urllib.parse import urlparse
20 |
21 | import numpy as np
22 | import pandas as pd
23 | import yaml
24 | from iopath.common.download import download
25 | from iopath.common.file_io import file_lock, g_pathmgr
26 | from minigpt4.common.registry import registry
27 | from torch.utils.model_zoo import tqdm
28 | from torchvision.datasets.utils import (
29 | check_integrity,
30 | download_file_from_google_drive,
31 | extract_archive,
32 | )
33 |
34 |
35 | def now():
36 | from datetime import datetime
37 |
38 | return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39 |
40 |
41 | def is_url(url_or_filename):
42 | parsed = urlparse(url_or_filename)
43 | return parsed.scheme in ("http", "https")
44 |
45 |
46 | def get_cache_path(rel_path):
47 | return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48 |
49 |
50 | def get_abs_path(rel_path):
51 | return os.path.join(registry.get_path("library_root"), rel_path)
52 |
53 |
54 | def load_json(filename):
55 | with open(filename, "r") as f:
56 | return json.load(f)
57 |
58 |
59 | # The following are adapted from torchvision and vissl
60 | # torchvision: https://github.com/pytorch/vision
61 | # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62 |
63 |
64 | def makedir(dir_path):
65 | """
66 | Create the directory if it does not exist.
67 | """
68 | is_success = False
69 | try:
70 | if not g_pathmgr.exists(dir_path):
71 | g_pathmgr.mkdirs(dir_path)
72 | is_success = True
73 | except BaseException:
74 | print(f"Error creating directory: {dir_path}")
75 | return is_success
76 |
77 |
78 | def get_redirected_url(url: str):
79 | """
80 | Given a URL, returns the URL it redirects to or the
81 | original URL in case of no indirection
82 | """
83 | import requests
84 |
85 | with requests.Session() as session:
86 | with session.get(url, stream=True, allow_redirects=True) as response:
87 | if response.history:
88 | return response.url
89 | else:
90 | return url
91 |
92 |
93 | def to_google_drive_download_url(view_url: str) -> str:
94 | """
95 | Utility function to transform a view URL of google drive
96 | to a download URL for google drive
97 | Example input:
98 | https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99 | Example output:
100 | https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101 | """
102 | splits = view_url.split("/")
103 | assert splits[-1] == "view"
104 | file_id = splits[-2]
105 | return f"https://drive.google.com/uc?export=download&id={file_id}"
106 |
107 |
108 | def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109 | """
110 | Download a file from google drive
111 | Downloading an URL from google drive requires confirmation when
112 | the file of the size is too big (google drive notifies that
113 | anti-viral checks cannot be performed on such files)
114 | """
115 | import requests
116 |
117 | with requests.Session() as session:
118 |
119 | # First get the confirmation token and append it to the URL
120 | with session.get(url, stream=True, allow_redirects=True) as response:
121 | for k, v in response.cookies.items():
122 | if k.startswith("download_warning"):
123 | url = url + "&confirm=" + v
124 |
125 | # Then download the content of the file
126 | with session.get(url, stream=True, verify=True) as response:
127 | makedir(output_path)
128 | path = os.path.join(output_path, output_file_name)
129 | total_size = int(response.headers.get("Content-length", 0))
130 | with open(path, "wb") as file:
131 | from tqdm import tqdm
132 |
133 | with tqdm(total=total_size) as progress_bar:
134 | for block in response.iter_content(
135 | chunk_size=io.DEFAULT_BUFFER_SIZE
136 | ):
137 | file.write(block)
138 | progress_bar.update(len(block))
139 |
140 |
141 | def _get_google_drive_file_id(url: str) -> Optional[str]:
142 | parts = urlparse(url)
143 |
144 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145 | return None
146 |
147 | match = re.match(r"/file/d/(?P[^/]*)", parts.path)
148 | if match is None:
149 | return None
150 |
151 | return match.group("id")
152 |
153 |
154 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155 | with open(filename, "wb") as fh:
156 | with urllib.request.urlopen(
157 | urllib.request.Request(url, headers={"User-Agent": "vissl"})
158 | ) as response:
159 | with tqdm(total=response.length) as pbar:
160 | for chunk in iter(lambda: response.read(chunk_size), ""):
161 | if not chunk:
162 | break
163 | pbar.update(chunk_size)
164 | fh.write(chunk)
165 |
166 |
167 | def download_url(
168 | url: str,
169 | root: str,
170 | filename: Optional[str] = None,
171 | md5: Optional[str] = None,
172 | ) -> None:
173 | """Download a file from a url and place it in root.
174 | Args:
175 | url (str): URL to download file from
176 | root (str): Directory to place downloaded file in
177 | filename (str, optional): Name to save the file under.
178 | If None, use the basename of the URL.
179 | md5 (str, optional): MD5 checksum of the download. If None, do not check
180 | """
181 | root = os.path.expanduser(root)
182 | if not filename:
183 | filename = os.path.basename(url)
184 | fpath = os.path.join(root, filename)
185 |
186 | makedir(root)
187 |
188 | # check if file is already present locally
189 | if check_integrity(fpath, md5):
190 | print("Using downloaded and verified file: " + fpath)
191 | return
192 |
193 | # expand redirect chain if needed
194 | url = get_redirected_url(url)
195 |
196 | # check if file is located on Google Drive
197 | file_id = _get_google_drive_file_id(url)
198 | if file_id is not None:
199 | return download_file_from_google_drive(file_id, root, filename, md5)
200 |
201 | # download the file
202 | try:
203 | print("Downloading " + url + " to " + fpath)
204 | _urlretrieve(url, fpath)
205 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206 | if url[:5] == "https":
207 | url = url.replace("https:", "http:")
208 | print(
209 | "Failed download. Trying https -> http instead."
210 | " Downloading " + url + " to " + fpath
211 | )
212 | _urlretrieve(url, fpath)
213 | else:
214 | raise e
215 |
216 | # check integrity of downloaded file
217 | if not check_integrity(fpath, md5):
218 | raise RuntimeError("File not found or corrupted.")
219 |
220 |
221 | def download_and_extract_archive(
222 | url: str,
223 | download_root: str,
224 | extract_root: Optional[str] = None,
225 | filename: Optional[str] = None,
226 | md5: Optional[str] = None,
227 | remove_finished: bool = False,
228 | ) -> None:
229 | download_root = os.path.expanduser(download_root)
230 | if extract_root is None:
231 | extract_root = download_root
232 | if not filename:
233 | filename = os.path.basename(url)
234 |
235 | download_url(url, download_root, filename, md5)
236 |
237 | archive = os.path.join(download_root, filename)
238 | print("Extracting {} to {}".format(archive, extract_root))
239 | extract_archive(archive, extract_root, remove_finished)
240 |
241 |
242 | def cache_url(url: str, cache_dir: str) -> str:
243 | """
244 | This implementation downloads the remote resource and caches it locally.
245 | The resource will only be downloaded if not previously requested.
246 | """
247 | parsed_url = urlparse(url)
248 | dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249 | makedir(dirname)
250 | filename = url.split("/")[-1]
251 | cached = os.path.join(dirname, filename)
252 | with file_lock(cached):
253 | if not os.path.isfile(cached):
254 | logging.info(f"Downloading {url} to {cached} ...")
255 | cached = download(url, dirname, filename=filename)
256 | logging.info(f"URL {url} cached in {cached}")
257 | return cached
258 |
259 |
260 | # TODO (prigoyal): convert this into RAII-style API
261 | def create_file_symlink(file1, file2):
262 | """
263 | Simply create the symlinks for a given file1 to file2.
264 | Useful during model checkpointing to symlinks to the
265 | latest successful checkpoint.
266 | """
267 | try:
268 | if g_pathmgr.exists(file2):
269 | g_pathmgr.rm(file2)
270 | g_pathmgr.symlink(file1, file2)
271 | except Exception as e:
272 | logging.info(f"Could NOT create symlink. Error: {e}")
273 |
274 |
275 | def save_file(data, filename, append_to_json=True, verbose=True):
276 | """
277 | Common i/o utility to handle saving data to various file formats.
278 | Supported:
279 | .pkl, .pickle, .npy, .json
280 | Specifically for .json, users have the option to either append (default)
281 | or rewrite by passing in Boolean value to append_to_json.
282 | """
283 | if verbose:
284 | logging.info(f"Saving data to file: {filename}")
285 | file_ext = os.path.splitext(filename)[1]
286 | if file_ext in [".pkl", ".pickle"]:
287 | with g_pathmgr.open(filename, "wb") as fopen:
288 | pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289 | elif file_ext == ".npy":
290 | with g_pathmgr.open(filename, "wb") as fopen:
291 | np.save(fopen, data)
292 | elif file_ext == ".json":
293 | if append_to_json:
294 | with g_pathmgr.open(filename, "a") as fopen:
295 | fopen.write(json.dumps(data, sort_keys=True) + "\n")
296 | fopen.flush()
297 | else:
298 | with g_pathmgr.open(filename, "w") as fopen:
299 | fopen.write(json.dumps(data, sort_keys=True) + "\n")
300 | fopen.flush()
301 | elif file_ext == ".yaml":
302 | with g_pathmgr.open(filename, "w") as fopen:
303 | dump = yaml.dump(data)
304 | fopen.write(dump)
305 | fopen.flush()
306 | else:
307 | raise Exception(f"Saving {file_ext} is not supported yet")
308 |
309 | if verbose:
310 | logging.info(f"Saved data to file: {filename}")
311 |
312 |
313 | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314 | """
315 | Common i/o utility to handle loading data from various file formats.
316 | Supported:
317 | .pkl, .pickle, .npy, .json
318 | For the npy files, we support reading the files in mmap_mode.
319 | If the mmap_mode of reading is not successful, we load data without the
320 | mmap_mode.
321 | """
322 | if verbose:
323 | logging.info(f"Loading data from file: {filename}")
324 |
325 | file_ext = os.path.splitext(filename)[1]
326 | if file_ext == ".txt":
327 | with g_pathmgr.open(filename, "r") as fopen:
328 | data = fopen.readlines()
329 | elif file_ext in [".pkl", ".pickle"]:
330 | with g_pathmgr.open(filename, "rb") as fopen:
331 | data = pickle.load(fopen, encoding="latin1")
332 | elif file_ext == ".npy":
333 | if mmap_mode:
334 | try:
335 | with g_pathmgr.open(filename, "rb") as fopen:
336 | data = np.load(
337 | fopen,
338 | allow_pickle=allow_pickle,
339 | encoding="latin1",
340 | mmap_mode=mmap_mode,
341 | )
342 | except ValueError as e:
343 | logging.info(
344 | f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345 | )
346 | data = np.load(
347 | filename,
348 | allow_pickle=allow_pickle,
349 | encoding="latin1",
350 | mmap_mode=mmap_mode,
351 | )
352 | logging.info("Successfully loaded without g_pathmgr")
353 | except Exception:
354 | logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355 | with g_pathmgr.open(filename, "rb") as fopen:
356 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357 | else:
358 | with g_pathmgr.open(filename, "rb") as fopen:
359 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360 | elif file_ext == ".json":
361 | with g_pathmgr.open(filename, "r") as fopen:
362 | data = json.load(fopen)
363 | elif file_ext == ".yaml":
364 | with g_pathmgr.open(filename, "r") as fopen:
365 | data = yaml.load(fopen, Loader=yaml.FullLoader)
366 | elif file_ext == ".csv":
367 | with g_pathmgr.open(filename, "r") as fopen:
368 | data = pd.read_csv(fopen)
369 | else:
370 | raise Exception(f"Reading from {file_ext} is not supported yet")
371 | return data
372 |
373 |
374 | def abspath(resource_path: str):
375 | """
376 | Make a path absolute, but take into account prefixes like
377 | "http://" or "manifold://"
378 | """
379 | regex = re.compile(r"^\w+://")
380 | if regex.match(resource_path) is None:
381 | return os.path.abspath(resource_path)
382 | else:
383 | return resource_path
384 |
385 |
386 | def makedir(dir_path):
387 | """
388 | Create the directory if it does not exist.
389 | """
390 | is_success = False
391 | try:
392 | if not g_pathmgr.exists(dir_path):
393 | g_pathmgr.mkdirs(dir_path)
394 | is_success = True
395 | except BaseException:
396 | logging.info(f"Error creating directory: {dir_path}")
397 | return is_success
398 |
399 |
400 | def is_url(input_url):
401 | """
402 | Check if an input string is a url. look for http(s):// and ignoring the case
403 | """
404 | is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405 | return is_url
406 |
407 |
408 | def cleanup_dir(dir):
409 | """
410 | Utility for deleting a directory. Useful for cleaning the storage space
411 | that contains various training artifacts like checkpoints, data etc.
412 | """
413 | if os.path.exists(dir):
414 | logging.info(f"Deleting directory: {dir}")
415 | shutil.rmtree(dir)
416 | logging.info(f"Deleted contents of directory: {dir}")
417 |
418 |
419 | def get_file_size(filename):
420 | """
421 | Given a file, get the size of file in MB
422 | """
423 | size_in_mb = os.path.getsize(filename) / float(1024**2)
424 | return size_in_mb
425 |
--------------------------------------------------------------------------------
/minigpt4/common/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import json
10 | from typing import Dict
11 |
12 | from omegaconf import OmegaConf
13 | from minigpt4.common.registry import registry
14 |
15 |
16 | class Config:
17 | def __init__(self, args):
18 | self.config = {}
19 |
20 | self.args = args
21 |
22 | # Register the config and configuration for setup
23 | registry.register("configuration", self)
24 |
25 | user_config = self._build_opt_list(self.args.options)
26 |
27 | config = OmegaConf.load(self.args.cfg_path)
28 |
29 | runner_config = self.build_runner_config(config)
30 | model_config = self.build_model_config(config, **user_config)
31 | dataset_config = self.build_dataset_config(config)
32 |
33 | # Validate the user-provided runner configuration
34 | # model and dataset configuration are supposed to be validated by the respective classes
35 | # [TODO] validate the model/dataset configuration
36 | # self._validate_runner_config(runner_config)
37 |
38 | # Override the default configuration with user options.
39 | self.config = OmegaConf.merge(
40 | runner_config, model_config, dataset_config, user_config
41 | )
42 |
43 | def _validate_runner_config(self, runner_config):
44 | """
45 | This method validates the configuration, such that
46 | 1) all the user specified options are valid;
47 | 2) no type mismatches between the user specified options and the config.
48 | """
49 | runner_config_validator = create_runner_config_validator()
50 | runner_config_validator.validate(runner_config)
51 |
52 | def _build_opt_list(self, opts):
53 | opts_dot_list = self._convert_to_dot_list(opts)
54 | return OmegaConf.from_dotlist(opts_dot_list)
55 |
56 | @staticmethod
57 | def build_model_config(config, **kwargs):
58 | model = config.get("model", None)
59 | assert model is not None, "Missing model configuration file."
60 |
61 | model_cls = registry.get_model_class(model.arch)
62 | assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63 |
64 | model_type = kwargs.get("model.model_type", None)
65 | if not model_type:
66 | model_type = model.get("model_type", None)
67 | # else use the model type selected by user.
68 |
69 | assert model_type is not None, "Missing model_type."
70 |
71 | model_config_path = model_cls.default_config_path(model_type=model_type)
72 |
73 | model_config = OmegaConf.create()
74 | # hiararchy override, customized config > default config
75 | model_config = OmegaConf.merge(
76 | model_config,
77 | OmegaConf.load(model_config_path),
78 | {"model": config["model"]},
79 | )
80 |
81 | return model_config
82 |
83 | @staticmethod
84 | def build_runner_config(config):
85 | return {"run": config.run}
86 |
87 | @staticmethod
88 | def build_dataset_config(config):
89 | datasets = config.get("datasets", None)
90 | if datasets is None:
91 | raise KeyError(
92 | "Expecting 'datasets' as the root key for dataset configuration."
93 | )
94 |
95 | dataset_config = OmegaConf.create()
96 |
97 | for dataset_name in datasets:
98 | builder_cls = registry.get_builder_class(dataset_name)
99 |
100 | dataset_config_type = datasets[dataset_name].get("type", "default")
101 | dataset_config_path = builder_cls.default_config_path(
102 | type=dataset_config_type
103 | )
104 |
105 | # hiararchy override, customized config > default config
106 | dataset_config = OmegaConf.merge(
107 | dataset_config,
108 | OmegaConf.load(dataset_config_path),
109 | {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110 | )
111 |
112 | return dataset_config
113 |
114 | def _convert_to_dot_list(self, opts):
115 | if opts is None:
116 | opts = []
117 |
118 | if len(opts) == 0:
119 | return opts
120 |
121 | has_equal = opts[0].find("=") != -1
122 |
123 | if has_equal:
124 | return opts
125 |
126 | return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127 |
128 | def get_config(self):
129 | return self.config
130 |
131 | @property
132 | def run_cfg(self):
133 | return self.config.run
134 |
135 | @property
136 | def datasets_cfg(self):
137 | return self.config.datasets
138 |
139 | @property
140 | def model_cfg(self):
141 | return self.config.model
142 |
143 | def pretty_print(self):
144 | logging.info("\n===== Running Parameters =====")
145 | logging.info(self._convert_node_to_json(self.config.run))
146 |
147 | logging.info("\n====== Dataset Attributes ======")
148 | datasets = self.config.datasets
149 |
150 | for dataset in datasets:
151 | if dataset in self.config.datasets:
152 | logging.info(f"\n======== {dataset} =======")
153 | dataset_config = self.config.datasets[dataset]
154 | logging.info(self._convert_node_to_json(dataset_config))
155 | else:
156 | logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157 |
158 | logging.info(f"\n====== Model Attributes ======")
159 | logging.info(self._convert_node_to_json(self.config.model))
160 |
161 | def _convert_node_to_json(self, node):
162 | container = OmegaConf.to_container(node, resolve=True)
163 | return json.dumps(container, indent=4, sort_keys=True)
164 |
165 | def to_dict(self):
166 | return OmegaConf.to_container(self.config)
167 |
168 |
169 | def node_to_dict(node):
170 | return OmegaConf.to_container(node)
171 |
172 |
173 | class ConfigValidator:
174 | """
175 | This is a preliminary implementation to centralize and validate the configuration.
176 | May be altered in the future.
177 |
178 | A helper class to validate configurations from yaml file.
179 |
180 | This serves the following purposes:
181 | 1. Ensure all the options in the yaml are defined, raise error if not.
182 | 2. when type mismatches are found, the validator will raise an error.
183 | 3. a central place to store and display helpful messages for supported configurations.
184 |
185 | """
186 |
187 | class _Argument:
188 | def __init__(self, name, choices=None, type=None, help=None):
189 | self.name = name
190 | self.val = None
191 | self.choices = choices
192 | self.type = type
193 | self.help = help
194 |
195 | def __str__(self):
196 | s = f"{self.name}={self.val}"
197 | if self.type is not None:
198 | s += f", ({self.type})"
199 | if self.choices is not None:
200 | s += f", choices: {self.choices}"
201 | if self.help is not None:
202 | s += f", ({self.help})"
203 | return s
204 |
205 | def __init__(self, description):
206 | self.description = description
207 |
208 | self.arguments = dict()
209 |
210 | self.parsed_args = None
211 |
212 | def __getitem__(self, key):
213 | assert self.parsed_args is not None, "No arguments parsed yet."
214 |
215 | return self.parsed_args[key]
216 |
217 | def __str__(self) -> str:
218 | return self.format_help()
219 |
220 | def add_argument(self, *args, **kwargs):
221 | """
222 | Assume the first argument is the name of the argument.
223 | """
224 | self.arguments[args[0]] = self._Argument(*args, **kwargs)
225 |
226 | def validate(self, config=None):
227 | """
228 | Convert yaml config (dict-like) to list, required by argparse.
229 | """
230 | for k, v in config.items():
231 | assert (
232 | k in self.arguments
233 | ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234 |
235 | if self.arguments[k].type is not None:
236 | try:
237 | self.arguments[k].val = self.arguments[k].type(v)
238 | except ValueError:
239 | raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240 |
241 | if self.arguments[k].choices is not None:
242 | assert (
243 | v in self.arguments[k].choices
244 | ), f"""{k} must be one of {self.arguments[k].choices}."""
245 |
246 | return config
247 |
248 | def format_arguments(self):
249 | return str([f"{k}" for k in sorted(self.arguments.keys())])
250 |
251 | def format_help(self):
252 | # description + key-value pair string for each argument
253 | help_msg = str(self.description)
254 | return help_msg + ", available arguments: " + self.format_arguments()
255 |
256 | def print_help(self):
257 | # display help message
258 | print(self.format_help())
259 |
260 |
261 | def create_runner_config_validator():
262 | validator = ConfigValidator(description="Runner configurations")
263 |
264 | validator.add_argument(
265 | "runner",
266 | type=str,
267 | choices=["runner_base", "runner_iter"],
268 | help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269 | runner runs based on iters. Default: runner_base""",
270 | )
271 | # add argumetns for training dataset ratios
272 | validator.add_argument(
273 | "train_dataset_ratios",
274 | type=Dict[str, float],
275 | help="""Ratios of training dataset. This is used in iteration-based runner.
276 | Do not support for epoch-based runner because how to define an epoch becomes tricky.
277 | Default: None""",
278 | )
279 | validator.add_argument(
280 | "max_iters",
281 | type=float,
282 | help="Maximum number of iterations to run.",
283 | )
284 | validator.add_argument(
285 | "max_epoch",
286 | type=int,
287 | help="Maximum number of epochs to run.",
288 | )
289 | # add arguments for iters_per_inner_epoch
290 | validator.add_argument(
291 | "iters_per_inner_epoch",
292 | type=float,
293 | help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294 | )
295 | lr_scheds_choices = registry.list_lr_schedulers()
296 | validator.add_argument(
297 | "lr_sched",
298 | type=str,
299 | choices=lr_scheds_choices,
300 | help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301 | )
302 | task_choices = registry.list_tasks()
303 | validator.add_argument(
304 | "task",
305 | type=str,
306 | choices=task_choices,
307 | help="Task to use, from {}".format(task_choices),
308 | )
309 | # add arguments for init_lr
310 | validator.add_argument(
311 | "init_lr",
312 | type=float,
313 | help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314 | )
315 | # add arguments for min_lr
316 | validator.add_argument(
317 | "min_lr",
318 | type=float,
319 | help="Minimum learning rate (after decay).",
320 | )
321 | # add arguments for warmup_lr
322 | validator.add_argument(
323 | "warmup_lr",
324 | type=float,
325 | help="Starting learning rate for warmup.",
326 | )
327 | # add arguments for learning rate decay rate
328 | validator.add_argument(
329 | "lr_decay_rate",
330 | type=float,
331 | help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332 | )
333 | # add arguments for weight decay
334 | validator.add_argument(
335 | "weight_decay",
336 | type=float,
337 | help="Weight decay rate.",
338 | )
339 | # add arguments for training batch size
340 | validator.add_argument(
341 | "batch_size_train",
342 | type=int,
343 | help="Training batch size.",
344 | )
345 | # add arguments for evaluation batch size
346 | validator.add_argument(
347 | "batch_size_eval",
348 | type=int,
349 | help="Evaluation batch size, including validation and testing.",
350 | )
351 | # add arguments for number of workers for data loading
352 | validator.add_argument(
353 | "num_workers",
354 | help="Number of workers for data loading.",
355 | )
356 | # add arguments for warm up steps
357 | validator.add_argument(
358 | "warmup_steps",
359 | type=int,
360 | help="Number of warmup steps. Required if a warmup schedule is used.",
361 | )
362 | # add arguments for random seed
363 | validator.add_argument(
364 | "seed",
365 | type=int,
366 | help="Random seed.",
367 | )
368 | # add arguments for output directory
369 | validator.add_argument(
370 | "output_dir",
371 | type=str,
372 | help="Output directory to save checkpoints and logs.",
373 | )
374 | # add arguments for whether only use evaluation
375 | validator.add_argument(
376 | "evaluate",
377 | help="Whether to only evaluate the model. If true, training will not be performed.",
378 | )
379 | # add arguments for splits used for training, e.g. ["train", "val"]
380 | validator.add_argument(
381 | "train_splits",
382 | type=list,
383 | help="Splits to use for training.",
384 | )
385 | # add arguments for splits used for validation, e.g. ["val"]
386 | validator.add_argument(
387 | "valid_splits",
388 | type=list,
389 | help="Splits to use for validation. If not provided, will skip the validation.",
390 | )
391 | # add arguments for splits used for testing, e.g. ["test"]
392 | validator.add_argument(
393 | "test_splits",
394 | type=list,
395 | help="Splits to use for testing. If not provided, will skip the testing.",
396 | )
397 | # add arguments for accumulating gradient for iterations
398 | validator.add_argument(
399 | "accum_grad_iters",
400 | type=int,
401 | help="Number of iterations to accumulate gradient for.",
402 | )
403 |
404 | # ====== distributed training ======
405 | validator.add_argument(
406 | "device",
407 | type=str,
408 | choices=["cpu", "cuda"],
409 | help="Device to use. Support 'cuda' or 'cpu' as for now.",
410 | )
411 | validator.add_argument(
412 | "world_size",
413 | type=int,
414 | help="Number of processes participating in the job.",
415 | )
416 | validator.add_argument("dist_url", type=str)
417 | validator.add_argument("distributed", type=bool)
418 | # add arguments to opt using distributed sampler during evaluation or not
419 | validator.add_argument(
420 | "use_dist_eval_sampler",
421 | type=bool,
422 | help="Whether to use distributed sampler during evaluation or not.",
423 | )
424 |
425 | # ====== task specific ======
426 | # generation task specific arguments
427 | # add arguments for maximal length of text output
428 | validator.add_argument(
429 | "max_len",
430 | type=int,
431 | help="Maximal length of text output.",
432 | )
433 | # add arguments for minimal length of text output
434 | validator.add_argument(
435 | "min_len",
436 | type=int,
437 | help="Minimal length of text output.",
438 | )
439 | # add arguments number of beams
440 | validator.add_argument(
441 | "num_beams",
442 | type=int,
443 | help="Number of beams used for beam search.",
444 | )
445 |
446 | # vqa task specific arguments
447 | # add arguments for number of answer candidates
448 | validator.add_argument(
449 | "num_ans_candidates",
450 | type=int,
451 | help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452 | )
453 | # add arguments for inference method
454 | validator.add_argument(
455 | "inference_method",
456 | type=str,
457 | choices=["genearte", "rank"],
458 | help="""Inference method to use for question answering. If rank, requires a answer list.""",
459 | )
460 |
461 | # ====== model specific ======
462 | validator.add_argument(
463 | "k_test",
464 | type=int,
465 | help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466 | )
467 |
468 | return validator
469 |
--------------------------------------------------------------------------------