├── LICENSE
├── README.md
├── environment.yaml
├── eval_configs
└── eval.yaml
├── generate_reports.py
├── minigpt4
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-311.pyc
│ ├── __init__.cpython-38.pyc
│ └── __init__.cpython-39.pyc
├── common
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── config.cpython-39.pyc
│ │ ├── dist_utils.cpython-311.pyc
│ │ ├── dist_utils.cpython-39.pyc
│ │ ├── logger.cpython-311.pyc
│ │ ├── logger.cpython-39.pyc
│ │ ├── optims.cpython-39.pyc
│ │ ├── registry.cpython-311.pyc
│ │ ├── registry.cpython-38.pyc
│ │ ├── registry.cpython-39.pyc
│ │ ├── utils.cpython-311.pyc
│ │ ├── utils.cpython-38.pyc
│ │ └── utils.cpython-39.pyc
│ ├── config.py
│ ├── dist_utils.py
│ ├── gradcam.py
│ ├── logger.py
│ ├── optims.py
│ ├── registry.py
│ └── utils.py
├── configs
│ ├── datasets
│ │ ├── cc_sbu
│ │ │ ├── align.yaml
│ │ │ └── defaults.yaml
│ │ ├── iuxray
│ │ │ ├── align.yaml
│ │ │ └── generate_then_refine.yaml
│ │ ├── laion
│ │ │ └── defaults.yaml
│ │ └── mimic
│ │ │ ├── align.yaml
│ │ │ └── generate_then_refine.yaml
│ ├── default.yaml
│ └── models
│ │ ├── minigpt4-7b.yaml
│ │ └── minigpt4.yaml
├── conversation
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ └── conversation.cpython-39.pyc
│ └── conversation.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ └── data_utils.cpython-39.pyc
│ ├── builders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── base_dataset_builder.cpython-311.pyc
│ │ │ ├── base_dataset_builder.cpython-38.pyc
│ │ │ ├── base_dataset_builder.cpython-39.pyc
│ │ │ ├── image_text_pair_builder.cpython-311.pyc
│ │ │ └── image_text_pair_builder.cpython-39.pyc
│ │ ├── base_dataset_builder.py
│ │ └── image_text_pair_builder.py
│ ├── data_utils.py
│ └── datasets
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base_dataset.cpython-311.pyc
│ │ ├── base_dataset.cpython-39.pyc
│ │ ├── caption_datasets.cpython-311.pyc
│ │ ├── caption_datasets.cpython-39.pyc
│ │ ├── cc_sbu_dataset.cpython-311.pyc
│ │ ├── cc_sbu_dataset.cpython-39.pyc
│ │ ├── dataloader_utils.cpython-39.pyc
│ │ ├── iuxray_dataset.cpython-311.pyc
│ │ ├── iuxray_dataset.cpython-39.pyc
│ │ ├── laion_dataset.cpython-311.pyc
│ │ ├── laion_dataset.cpython-39.pyc
│ │ ├── mimic_dataset.cpython-311.pyc
│ │ └── mimic_dataset.cpython-39.pyc
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── dataloader_utils.py
│ │ ├── iuxray_dataset.py
│ │ ├── laion_dataset.py
│ │ └── mimic_dataset.py
├── models
│ ├── Qformer.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── Qformer.cpython-311.pyc
│ │ ├── Qformer.cpython-39.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base_model.cpython-311.pyc
│ │ ├── base_model.cpython-39.pyc
│ │ ├── blip2.cpython-311.pyc
│ │ ├── blip2.cpython-39.pyc
│ │ ├── eva_vit.cpython-311.pyc
│ │ ├── eva_vit.cpython-39.pyc
│ │ ├── mini_gpt4.cpython-311.pyc
│ │ ├── mini_gpt4.cpython-39.pyc
│ │ ├── modeling_llama.cpython-311.pyc
│ │ └── modeling_llama.cpython-39.pyc
│ ├── base_model.py
│ ├── blip2.py
│ ├── blip2_outputs.py
│ ├── eva_vit.py
│ ├── mini_gpt4.py
│ └── modeling_llama.py
├── output
│ └── minigpt4_stage2_finetune
│ │ ├── 20230706044
│ │ ├── checkpoint_0.pth
│ │ ├── checkpoint_1.pth
│ │ └── log.txt
│ │ └── 20230706051
│ │ └── log.txt
├── processors
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base_processor.cpython-311.pyc
│ │ ├── base_processor.cpython-39.pyc
│ │ ├── blip_processors.cpython-311.pyc
│ │ ├── blip_processors.cpython-39.pyc
│ │ ├── randaugment.cpython-311.pyc
│ │ └── randaugment.cpython-39.pyc
│ ├── base_processor.py
│ ├── blip_processors.py
│ └── randaugment.py
├── runners
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ └── runner_base.cpython-39.pyc
│ └── runner_base.py
└── tasks
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── base_task.cpython-39.pyc
│ ├── image_text_pretrain.cpython-39.pyc
│ └── mimic_generate_then_refine.cpython-39.pyc
│ ├── base_task.py
│ ├── image_text_pretrain.py
│ └── mimic_generate_then_refine.py
├── prompts
├── stage1-pretraining-prompts.txt
├── stage2-generation-prompts.txt
└── stage2-refinement-prompts.txt
├── train.py
└── train_configs
├── stage1
├── config.yaml
└── zero.json
└── stage2
├── iuxray
├── config.yaml
└── zero.json
└── mimic
├── config.yaml
└── zero.json
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yan Song's NLP Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Bootstrapping Large Language Models for Radiology Report Generation
3 |
4 | The official GitHub repository of the AAAI-2024 paper ["Bootstrapping Large Language Models for Radiology Report Generation"](https://ojs.aaai.org/index.php/AAAI/article/view/29826).
5 |
6 | # Reference
7 | If our work is helpful to your research, please cite our paper:
8 | ``` latex
9 | @inproceedings{chang2024bootstrapping,
10 | author = {Chang Liu and
11 | Yuanhe Tian and
12 | Weidong Chen and
13 | Yan Song and
14 | Yongdong Zhang},
15 | editor = {Michael J. Wooldridge and
16 | Jennifer G. Dy and
17 | Sriraam Natarajan},
18 | title = {Bootstrapping Large Language Models for Radiology Report Generation},
19 | booktitle = {AAAI},
20 | pages = {18635--18643},
21 | year = {2024},
22 | }
23 | ```
24 |
25 | # Getting Started
26 | 1. Before you run the code, you need to create a virtual environment and activate it via the following command:
27 | ```bash
28 | conda env create -f environment.yaml
29 | conda activate venv
30 | ```
31 |
32 | 2. Once the virtual environment is created, you need to download the LLM model weights following the instruction in [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4). Once the model weights are downloaded, you need to modify some configuration files:
33 | - `minigpt4/models/minigpt4-7b.yaml`: line 16 with the path of Vicuna 7b model weights.
34 | - `minigpt4/models/minigpt4.yaml`: line 16 with the path of Vicuna 13b model weights.
35 |
36 | 3. You need to download the dataset from the official websites of [IU X-Ray](https://openi.nlm.nih.gov/faq#collection) and [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/). Once the datasets are ready, you need to modify some configuration files:
37 | - `minigpt4/configs/datasets/iuxray/align.yaml`: line 5 with the path of pre-training dataset.
38 | - `minigpt4/configs/datasets/iuxray/generate_then_refine.yaml`: line 5 with the path of IU X-Ray dataset, line 6 with the path of public medical corpora.
39 | - `minigpt4/configs/datasets/mimic/align.yaml`: line 5 with the path of pre-training dataset.
40 | - `minigpt4/configs/datasets/mimic/generate_then_refine.yaml`: line 5 with the path of MIMIC-CXR dataset, line 6 with the path of public medical corpora.
41 |
42 | # Training
43 | 1. **Pre-training.** We recommend you to follow the instructions below to pre-train MiniGPT-4 on MIMIC-CXR.
44 |
45 | (1) Modify the configuration files.
46 | - `train_configs/stage1/config.yaml`: line 12 with the path of the linear projection layer of MiniGPT-4, line 59 with the output path.
47 |
48 | (2) Run the following command lines to pre-train MiniGPT-4 on MIMIC-CXR.
49 | ```
50 | python train.py --cfg-path train_configs/stage1/config.yaml
51 | ```
52 |
53 | If you need to reduce the memory usage, we recommend you to use the first stage strategy of `ZeRO` optimizer. Run the following command lines to pre-train MiniGPT-4 on MIMIC-CXR with a lower memory usage.
54 |
55 | ```
56 | deepspeed --nproc-per-gpu NUM_GPUS --master-port MASTER_PORT train.py --cfg-path train_configs/stage1/config.yaml use_zero_optimizer --deepspeed_config train_configs/stage1/zero.json
57 | ```
58 |
59 | You can download our pre-trained model weights from [here](https://huggingface.co/a-b-c-d-e-g/R2-LLM).
60 |
61 | 2. **Fine-tuning.** We recommend you to follow the instructions below to fine-tune MiniGPT-4 on IU X-Ray and MIMIC-CXR.
62 |
63 | (1) Modify the configuration files. Herein, we take the IU X-Ray configuration as an example.
64 | - `train_configs/stage2/iuxray/config.yaml`: line 11 with the path of the linear projection layer of pre-trained MiniGPT-4 on MIMIC-CXR, line 56 with the output path.
65 |
66 | (2) Run the following command lines to fine-tune MiniGPT-4.
67 |
68 | ```
69 | python train.py --cfg-path train_configs/stage2/iuxray/config.yaml
70 | ```
71 |
72 | Our codebase supports `ZeRO` to reduce the memory usage. You can run the following command lines with `ZeRO`.
73 |
74 | ```
75 | deepspeed --nproc-per-gpu NUM_GPUS --master-port MASTER_PORT train.py --cfg-path train_configs/stage2/iuxray/config.yaml use_zero_optimizer --deepspeed_config train_configs/stage2/iuxray/zero.json
76 | ```
77 |
78 | You can download our fine-tuned model weights from [here](https://huggingface.co/a-b-c-d-e-g/R2-LLM).
79 |
80 | # Inference
81 | Run the following command lines to generate radiology reports.
82 |
83 | ```
84 | python generate_reports.py \
85 | --cfg-path configs/eval_configs/eval.yaml \
86 | --gpu-id GPU_IDS \
87 | --image_path IMAGE_PATH \
88 | --annotations ANNOTATIONS_PATH_OF_IUXRAY_OR_MIMIC \
89 | --checkpoint PATH_TO_PRETRAINED_MODEL_WEIGHTS \
90 | ```
91 |
92 | # Acknowledgement
93 | This GitHub repository is heavily built based on the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) repository. Thanks to the authors for their great work!
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: venv
2 | channels:
3 | - pytorch
4 | - defaults
5 | - anaconda
6 | dependencies:
7 | - python=3.9
8 | - cudatoolkit
9 | - pip
10 | - pip:
11 | - torch==2.0.0
12 | - torchaudio
13 | - torchvision
14 | - huggingface-hub==0.18.0
15 | - matplotlib==3.7.0
16 | - psutil==5.9.4
17 | - iopath
18 | - pyyaml==6.0
19 | - regex==2022.10.31
20 | - tokenizers==0.13.2
21 | - tqdm==4.64.1
22 | - transformers==4.30.0
23 | - timm==0.6.13
24 | - webdataset==0.2.48
25 | - omegaconf==2.3.0
26 | - opencv-python==4.7.0.72
27 | - decord==0.6.0
28 | - peft==0.2.0
29 | - sentence-transformers
30 | - gradio==3.47.1
31 | - accelerate==0.20.3
32 | - bitsandbytes==0.37.0
33 | - scikit-image
34 | - visual-genome
35 | - wandb
--------------------------------------------------------------------------------
/eval_configs/eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 | model_type: pretrain_vicuna
4 | freeze_vit: True
5 | freeze_qformer: True
6 | max_txt_len: 100
7 | end_sym: "###"
8 | low_resource: True
9 | prompt_path: "/path/to/prompts"
10 | prompt_template: '###Human: {} ###Assistant: '
11 | ckpt: '/path/to/linear'
12 |
13 | # lora configuartion
14 | use_lora: True
15 | lora_rank: 8
16 | lora_alpha: 32
17 | lora_dropout: 0.1
18 |
19 | datasets:
20 | mimic_generate_then_refine:
21 | vis_processor:
22 | train:
23 | name: "blip2_image_eval"
24 | image_size: 224
25 | text_processor:
26 | train:
27 | name: "blip_caption"
28 |
29 | run:
30 | task: image_text_pretrain
31 |
--------------------------------------------------------------------------------
/generate_reports.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import json
5 | import random
6 | from tqdm import tqdm
7 | from PIL import Image
8 |
9 | import numpy as np
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from transformers import StoppingCriteria, StoppingCriteriaList
13 |
14 | from minigpt4.common.config import Config
15 | from minigpt4.common.dist_utils import get_rank
16 | from minigpt4.common.registry import registry
17 | from minigpt4.conversation.conversation import Chat, CONV_VISION
18 |
19 | # imports modules for registration
20 | from minigpt4.datasets.builders import *
21 | from minigpt4.models import *
22 | from minigpt4.processors import *
23 | from minigpt4.runners import *
24 | from minigpt4.tasks import *
25 |
26 | from peft import LoraConfig, TaskType, get_peft_model, set_peft_model_state_dict
27 |
28 | def clean_reports(report):
29 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
30 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
31 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
32 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
33 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
34 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
35 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
36 | .strip().lower().split('. ')
37 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
38 | .replace('\\', '').replace("'", '').strip().lower())
39 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
40 | report = ' . '.join(tokens) + ' .'
41 | return report
42 |
43 | class StoppingCriteriaSub(StoppingCriteria):
44 |
45 | def __init__(self, stops=[], encounters=1):
46 | super().__init__()
47 | self.stops = stops
48 |
49 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
50 | for stop in self.stops:
51 | if torch.all((stop == input_ids[0][-len(stop):])).item():
52 | return True
53 |
54 | return False
55 |
56 |
57 | def parse_args():
58 | parser = argparse.ArgumentParser(description="Demo")
59 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
60 | parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
61 | parser.add_argument(
62 | "--options",
63 | nargs="+",
64 | help="override some settings in the used config, the key-value pair "
65 | "in xxx=yyy format will be merged into config file (deprecate), "
66 | "change to --cfg-options instead.",
67 | )
68 |
69 | parser.add_argument('--image_path', default='', type=str, help='path of the input image')
70 | parser.add_argument('--generation_prompts', type=str, default='prompts/stage2-generation-prompts.txt', help='path of the generation prompts for the first stage')
71 | parser.add_argument('--refinement_prompts', type=str, default='prompts/stage2-refinement-prompts.txt', help='path of the refinement prompts for the second stage')
72 | parser.add_argument('--annotations', type=str, default='', help='path of annotation file, to load in the GTs')
73 | parser.add_argument('--checkpoint', required=True, help='checkpoint path')
74 | parser.add_argument('--beam_size', type=int, default=1)
75 | parser.add_argument('--temperature', type=float, default=1.0)
76 | parser.add_argument('--max_txt_len', default=160, type=int)
77 |
78 | args = parser.parse_args()
79 | return args
80 |
81 |
82 | def setup_seeds(config):
83 | seed = config.run_cfg.seed + get_rank()
84 |
85 | random.seed(seed)
86 | np.random.seed(seed)
87 | torch.manual_seed(seed)
88 |
89 | cudnn.benchmark = False
90 | cudnn.deterministic = True
91 |
92 |
93 | # ========================================
94 | # Model Initialization
95 | # ========================================
96 |
97 | print('Initializing Chat')
98 | args = parse_args()
99 | cfg = Config(args)
100 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
101 |
102 | model_config = cfg.model_cfg
103 | model_config.device_8bit = args.gpu_id
104 | model_cls = registry.get_model_class(model_config.arch)
105 | model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
106 |
107 | # load LoRA
108 | peft_config = LoraConfig(inference_mode=False, r=cfg.model_cfg.lora_rank, lora_alpha=cfg.model_cfg.lora_alpha, lora_dropout=cfg.model_cfg.lora_dropout)
109 | peft_model = get_peft_model(model.llama_model, peft_config=peft_config)
110 | # loading normal pytroch checkpoint
111 | if args.checkpoint.endswith('.pth'):
112 | full_state_dict = torch.load(args.checkpoint, map_location='cpu')
113 | # loading ZeRO checkpoint
114 | elif args.checkpoint.endswith('.pt'):
115 | full_state_dict = torch.load(args.checkpoint, map_location='cpu')['module']
116 | set_peft_model_state_dict(peft_model, full_state_dict)
117 | peft_model = peft_model.to(device)
118 | print('LLaMA checkpoint loaded.')
119 | # load in the linear projection layer
120 | llama_proj_state_dict = {}
121 | for key, value in full_state_dict.items():
122 | if 'llama_proj' in key:
123 | llama_proj_state_dict[key[18:]] = value
124 | model.llama_proj.load_state_dict(llama_proj_state_dict)
125 | print('Linear projection layer loaded.')
126 |
127 | vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
128 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
129 | print('Initialization Finished')
130 |
131 | # ========================================
132 | # Start Testing
133 | # ========================================
134 |
135 |
136 | # image_paths = []
137 | # for root, dirs, files in os.walk(args.images):
138 | # for file in files:
139 | # image_paths.append(os.path.join(root, file))
140 |
141 | # load generation prompts from local path
142 | generation_prompts = []
143 | with open(args.generation_prompts, 'r') as f:
144 | for line in f.readlines():
145 | generation_prompts.append(line.strip('\n'))
146 |
147 | # load refinement prompts from local path
148 | refinement_prompts = []
149 | with open(args.refinement_prompts, 'r') as f:
150 | for line in f.readlines():
151 | refinement_prompts.append(line.strip('\n'))
152 |
153 | final_record_message = ''
154 | with torch.no_grad():
155 | # TODO: Start the first stage
156 | # random sample one prompt
157 | prompt = random.choice(generation_prompts)
158 | prompt = '###Human: ' + prompt + '###Assistant: '
159 |
160 | # encode image
161 | img_list = []
162 | raw_image = Image.open(args.image_path).convert('RGB')
163 | image = vis_processor(raw_image).unsqueeze(0).to(device)
164 | image_emb, _ = model.encode_img(image)
165 | img_list.append(image_emb)
166 |
167 | # wrap image with prompt
168 | prompt_segs = prompt.split('')
169 | seg_tokens = [
170 | model.llama_tokenizer(
171 | seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
172 | # only add bos to the first seg
173 | for i, seg in enumerate(prompt_segs)
174 | ]
175 | seg_embs = [peft_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
176 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
177 | mixed_embs = torch.cat(mixed_embs, dim=1)
178 |
179 | # prepare other things before generate
180 | stop_words_ids = [torch.tensor([835]).to(device), torch.tensor([2277, 29937]).to(device)] # '###' can be encoded in two different ways.
181 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
182 |
183 | # generate
184 | outputs = peft_model.base_model.model.generate(
185 | inputs_embeds=mixed_embs,
186 | max_new_tokens=args.max_txt_len,
187 | stopping_criteria=stopping_criteria,
188 | num_beams=args.beam_size,
189 | do_sample=True,
190 | min_length=1,
191 | top_p=0.9,
192 | repetition_penalty=1.0,
193 | length_penalty=1,
194 | temperature=args.temperature,)
195 |
196 | output_token = outputs[0]
197 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
198 | output_token = output_token[1:]
199 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
200 | output_token = output_token[1:]
201 | output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
202 | output_text = output_text.split('###')[0] # remove the stop sign '###'
203 | output_text = output_text.split('Assistant:')[-1].strip()
204 | generated_text = output_text
205 |
206 | # TODO: Start the second stage
207 | coarse_generated_report = output_token
208 | coarse_report_embeds = peft_model.base_model.model.model.embed_tokens(coarse_generated_report).expand(image_emb.shape[0], -1, -1)
209 | atts_report = torch.ones(coarse_report_embeds.size()[:-1], dtype=torch.long).to(device)
210 | prompt = random.choice(refinement_prompts)
211 | prompt = '###Human: ' + prompt + '###Assistant: '
212 |
213 | # encode image
214 | img_list = []
215 | raw_image = Image.open(args.image_path).convert('RGB')
216 | image = vis_processor(raw_image).unsqueeze(0).to(device)
217 | image_emb, _ = model.encode_img(image)
218 | img_list.append(image_emb)
219 |
220 | # the right implementation
221 | p_before, p_after_all = prompt.split('')
222 | p_mid, p_after = p_after_all.split('')
223 | p_before_tokens = model.llama_tokenizer(p_before, return_tensors="pt", add_special_tokens=True).to(device).input_ids
224 | p_mid_tokens = model.llama_tokenizer(p_mid, return_tensors="pt", add_special_tokens=False).to(device).input_ids
225 | p_after_tokens = model.llama_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(device).input_ids
226 |
227 | # embedding
228 | p_before_embeds = peft_model.base_model.model.model.embed_tokens(p_before_tokens)
229 | p_mid_embeds = peft_model.base_model.model.model.embed_tokens(p_mid_tokens)
230 | p_after_embeds = peft_model.base_model.model.model.embed_tokens(p_after_tokens)
231 | mixed_embs = torch.cat([p_before_embeds, img_list[0], p_mid_embeds, coarse_report_embeds, p_after_embeds], dim=1)
232 | mixed_embs = torch.cat([p_mid_embeds, coarse_report_embeds, p_after_embeds], dim=1)
233 |
234 | # prepare other things before generate
235 | stop_words_ids = [torch.tensor([835]).to(device), torch.tensor([2277, 29937]).to(device)] # '###' can be encoded in two different ways.
236 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
237 |
238 | # generate
239 | outputs = peft_model.base_model.model.generate(
240 | inputs_embeds=mixed_embs,
241 | max_new_tokens=args.max_txt_len,
242 | stopping_criteria=stopping_criteria,
243 | num_beams=args.beam_size,
244 | do_sample=True,
245 | min_length=1,
246 | top_p=0.9,
247 | repetition_penalty=1.0,
248 | length_penalty=1,
249 | temperature=args.temperature,)
250 |
251 | output_token = outputs[0]
252 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
253 | output_token = output_token[1:]
254 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
255 | output_token = output_token[1:]
256 | output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
257 | output_text = output_text.split('###')[0] # remove the stop sign '###'
258 | output_text = output_text.split('Assistant:')[-1].strip()
259 | refined_text = output_text
260 |
261 | print('Generated report:')
262 | print(refined_text)
263 |
--------------------------------------------------------------------------------
/minigpt4/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | import sys
10 |
11 | from omegaconf import OmegaConf
12 |
13 | from minigpt4.common.registry import registry
14 |
15 | from minigpt4.datasets.builders import *
16 | from minigpt4.models import *
17 | from minigpt4.processors import *
18 | from minigpt4.tasks import *
19 |
20 |
21 | root_dir = os.path.dirname(os.path.abspath(__file__))
22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23 |
24 | registry.register_path("library_root", root_dir)
25 | repo_root = os.path.join(root_dir, "..")
26 | registry.register_path("repo_root", repo_root)
27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28 | registry.register_path("cache_root", cache_root)
29 |
30 | registry.register("MAX_INT", sys.maxsize)
31 | registry.register("SPLIT_NAMES", ["train", "val", "test"])
32 |
--------------------------------------------------------------------------------
/minigpt4/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/minigpt4/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__init__.py
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/config.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/dist_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/dist_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/logger.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/logger.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/optims.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/optims.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/registry.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/registry.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/common/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import functools
10 | import os
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import timm.models.hub as timm_hub
15 |
16 |
17 | def setup_for_distributed(is_master):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | import builtins as __builtin__
22 |
23 | builtin_print = __builtin__.print
24 |
25 | def print(*args, **kwargs):
26 | force = kwargs.pop("force", False)
27 | if is_master or force:
28 | builtin_print(*args, **kwargs)
29 |
30 | __builtin__.print = print
31 |
32 |
33 | def is_dist_avail_and_initialized():
34 | if not dist.is_available():
35 | return False
36 | if not dist.is_initialized():
37 | return False
38 | return True
39 |
40 |
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 |
47 | def get_rank():
48 | if not is_dist_avail_and_initialized():
49 | return 0
50 | return dist.get_rank()
51 |
52 |
53 | def is_main_process():
54 | return get_rank() == 0
55 |
56 |
57 | def init_distributed_mode(args):
58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59 | args.rank = int(os.environ["RANK"])
60 | args.world_size = int(os.environ["WORLD_SIZE"])
61 | args.gpu = int(os.environ["LOCAL_RANK"])
62 | elif "SLURM_PROCID" in os.environ:
63 | args.rank = int(os.environ["SLURM_PROCID"])
64 | args.gpu = args.rank % torch.cuda.device_count()
65 | else:
66 | print("Not using distributed mode")
67 | args.distributed = False
68 | return
69 |
70 | args.distributed = True
71 |
72 | torch.cuda.set_device(args.gpu)
73 | args.dist_backend = "nccl"
74 | print(
75 | "| distributed init (rank {}, world {}): {}".format(
76 | args.rank, args.world_size, args.dist_url
77 | ),
78 | flush=True,
79 | )
80 | # use zero optimizer for distributed initialization
81 | if args.use_zero_optimizer:
82 | print("Using ZeRO optimizer distributed mode.")
83 | import deepspeed
84 | deepspeed.init_distributed(
85 | dist_backend=args.dist_backend,
86 | init_method=args.dist_url,
87 | rank=args.rank,
88 | timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing,
89 | # config=args.deepspeed_config,
90 | )
91 | # use pytorch distributed initialization
92 | else:
93 | print("Using PyTorch optimizer distributed mode.")
94 | torch.distributed.init_process_group(
95 | backend=args.dist_backend,
96 | init_method=args.dist_url,
97 | world_size=args.world_size,
98 | rank=args.rank,
99 | timeout=datetime.timedelta(
100 | days=365
101 | ), # allow auto-downloading and de-compressing
102 | )
103 | torch.distributed.barrier()
104 | setup_for_distributed(args.rank == 0)
105 |
106 |
107 | def get_dist_info():
108 | if torch.__version__ < "1.0":
109 | initialized = dist._initialized
110 | else:
111 | initialized = dist.is_initialized()
112 | if initialized:
113 | rank = dist.get_rank()
114 | world_size = dist.get_world_size()
115 | else: # non-distributed training
116 | rank = 0
117 | world_size = 1
118 | return rank, world_size
119 |
120 |
121 | def main_process(func):
122 | @functools.wraps(func)
123 | def wrapper(*args, **kwargs):
124 | rank, _ = get_dist_info()
125 | if rank == 0:
126 | return func(*args, **kwargs)
127 |
128 | return wrapper
129 |
130 |
131 | def download_cached_file(url, check_hash=True, progress=False):
132 | """
133 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
134 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
135 | """
136 |
137 | def get_cached_file_path():
138 | # a hack to sync the file path across processes
139 | parts = torch.hub.urlparse(url)
140 | filename = os.path.basename(parts.path)
141 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
142 |
143 | return cached_file
144 |
145 | if is_main_process():
146 | timm_hub.download_cached_file(url, check_hash, progress)
147 |
148 | if is_dist_avail_and_initialized():
149 | dist.barrier()
150 |
151 | return get_cached_file_path()
152 |
--------------------------------------------------------------------------------
/minigpt4/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/minigpt4/common/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import datetime
9 | import logging
10 | import time
11 | from collections import defaultdict, deque
12 |
13 | import torch
14 | import torch.distributed as dist
15 |
16 | from minigpt4.common import dist_utils
17 |
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not dist_utils.is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value,
79 | )
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError(
100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101 | )
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def global_avg(self):
110 | loss_str = []
111 | for name, meter in self.meters.items():
112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113 | return self.delimiter.join(loss_str)
114 |
115 | def synchronize_between_processes(self):
116 | for meter in self.meters.values():
117 | meter.synchronize_between_processes()
118 |
119 | def add_meter(self, name, meter):
120 | self.meters[name] = meter
121 |
122 | def log_every(self, iterable, print_freq, header=None):
123 | i = 0
124 | if not header:
125 | header = ""
126 | start_time = time.time()
127 | end = time.time()
128 | iter_time = SmoothedValue(fmt="{avg:.4f}")
129 | data_time = SmoothedValue(fmt="{avg:.4f}")
130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131 | log_msg = [
132 | header,
133 | "[{0" + space_fmt + "}/{1}]",
134 | "eta: {eta}",
135 | "{meters}",
136 | "time: {time}",
137 | "data: {data}",
138 | ]
139 | if torch.cuda.is_available():
140 | log_msg.append("max mem: {memory:.0f}")
141 | log_msg = self.delimiter.join(log_msg)
142 | MB = 1024.0 * 1024.0
143 | for obj in iterable:
144 | data_time.update(time.time() - end)
145 | yield obj
146 | iter_time.update(time.time() - end)
147 | if i % print_freq == 0 or i == len(iterable) - 1:
148 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150 | if torch.cuda.is_available():
151 | print(
152 | log_msg.format(
153 | i,
154 | len(iterable),
155 | eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time),
158 | data=str(data_time),
159 | memory=torch.cuda.max_memory_allocated() / MB,
160 | )
161 | )
162 | else:
163 | print(
164 | log_msg.format(
165 | i,
166 | len(iterable),
167 | eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time),
170 | data=str(data_time),
171 | )
172 | )
173 | i += 1
174 | end = time.time()
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print(
178 | "{} Total time: {} ({:.4f} s / it)".format(
179 | header, total_time_str, total_time / len(iterable)
180 | )
181 | )
182 |
183 |
184 | class AttrDict(dict):
185 | def __init__(self, *args, **kwargs):
186 | super(AttrDict, self).__init__(*args, **kwargs)
187 | self.__dict__ = self
188 |
189 |
190 | def setup_logger():
191 | logging.basicConfig(
192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193 | format="%(asctime)s [%(levelname)s] %(message)s",
194 | handlers=[logging.StreamHandler()],
195 | )
196 |
--------------------------------------------------------------------------------
/minigpt4/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from minigpt4.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | iters_per_epoch,
63 | min_lr,
64 | init_lr,
65 | warmup_steps=0,
66 | warmup_start_lr=-1,
67 | **kwargs
68 | ):
69 | self.optimizer = optimizer
70 |
71 | self.max_epoch = max_epoch
72 | self.iters_per_epoch = iters_per_epoch
73 | self.min_lr = min_lr
74 |
75 | self.init_lr = init_lr
76 | self.warmup_steps = warmup_steps
77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78 |
79 | def step(self, cur_epoch, cur_step):
80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81 | if total_cur_step < self.warmup_steps:
82 | warmup_lr_schedule(
83 | step=cur_step,
84 | optimizer=self.optimizer,
85 | max_step=self.warmup_steps,
86 | init_lr=self.warmup_start_lr,
87 | max_lr=self.init_lr,
88 | )
89 | else:
90 | cosine_lr_schedule(
91 | epoch=total_cur_step,
92 | optimizer=self.optimizer,
93 | max_epoch=self.max_epoch * self.iters_per_epoch,
94 | init_lr=self.init_lr,
95 | min_lr=self.min_lr,
96 | )
97 |
98 |
99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100 | """Decay the learning rate"""
101 | lr = (init_lr - min_lr) * 0.5 * (
102 | 1.0 + math.cos(math.pi * epoch / max_epoch)
103 | ) + min_lr
104 | for param_group in optimizer.param_groups:
105 | param_group["lr"] = lr
106 |
107 |
108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109 | """Warmup the learning rate"""
110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111 | for param_group in optimizer.param_groups:
112 | param_group["lr"] = lr
113 |
114 |
115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116 | """Decay the learning rate"""
117 | lr = max(min_lr, init_lr * (decay_rate**epoch))
118 | for param_group in optimizer.param_groups:
119 | param_group["lr"] = lr
120 |
--------------------------------------------------------------------------------
/minigpt4/common/registry.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 |
9 | class Registry:
10 | mapping = {
11 | "builder_name_mapping": {},
12 | "task_name_mapping": {},
13 | "processor_name_mapping": {},
14 | "model_name_mapping": {},
15 | "lr_scheduler_name_mapping": {},
16 | "runner_name_mapping": {},
17 | "state": {},
18 | "paths": {},
19 | }
20 |
21 | @classmethod
22 | def register_builder(cls, name):
23 | r"""Register a dataset builder to registry with key 'name'
24 |
25 | Args:
26 | name: Key with which the builder will be registered.
27 |
28 | Usage:
29 |
30 | from minigpt4.common.registry import registry
31 | from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32 | """
33 |
34 | def wrap(builder_cls):
35 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36 |
37 | assert issubclass(
38 | builder_cls, BaseDatasetBuilder
39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40 | builder_cls
41 | )
42 | if name in cls.mapping["builder_name_mapping"]:
43 | raise KeyError(
44 | "Name '{}' already registered for {}.".format(
45 | name, cls.mapping["builder_name_mapping"][name]
46 | )
47 | )
48 | cls.mapping["builder_name_mapping"][name] = builder_cls
49 | return builder_cls
50 |
51 | return wrap
52 |
53 | @classmethod
54 | def register_task(cls, name):
55 | r"""Register a task to registry with key 'name'
56 |
57 | Args:
58 | name: Key with which the task will be registered.
59 |
60 | Usage:
61 |
62 | from minigpt4.common.registry import registry
63 | """
64 |
65 | def wrap(task_cls):
66 | from minigpt4.tasks.base_task import BaseTask
67 |
68 | assert issubclass(
69 | task_cls, BaseTask
70 | ), "All tasks must inherit BaseTask class"
71 | if name in cls.mapping["task_name_mapping"]:
72 | raise KeyError(
73 | "Name '{}' already registered for {}.".format(
74 | name, cls.mapping["task_name_mapping"][name]
75 | )
76 | )
77 | cls.mapping["task_name_mapping"][name] = task_cls
78 | return task_cls
79 |
80 | return wrap
81 |
82 | @classmethod
83 | def register_model(cls, name):
84 | r"""Register a task to registry with key 'name'
85 |
86 | Args:
87 | name: Key with which the task will be registered.
88 |
89 | Usage:
90 |
91 | from minigpt4.common.registry import registry
92 | """
93 |
94 | def wrap(model_cls):
95 | from minigpt4.models import BaseModel
96 |
97 | assert issubclass(
98 | model_cls, BaseModel
99 | ), "All models must inherit BaseModel class"
100 | if name in cls.mapping["model_name_mapping"]:
101 | raise KeyError(
102 | "Name '{}' already registered for {}.".format(
103 | name, cls.mapping["model_name_mapping"][name]
104 | )
105 | )
106 | cls.mapping["model_name_mapping"][name] = model_cls
107 | return model_cls
108 |
109 | return wrap
110 |
111 | @classmethod
112 | def register_processor(cls, name):
113 | r"""Register a processor to registry with key 'name'
114 |
115 | Args:
116 | name: Key with which the task will be registered.
117 |
118 | Usage:
119 |
120 | from minigpt4.common.registry import registry
121 | """
122 |
123 | def wrap(processor_cls):
124 | from minigpt4.processors import BaseProcessor
125 |
126 | assert issubclass(
127 | processor_cls, BaseProcessor
128 | ), "All processors must inherit BaseProcessor class"
129 | if name in cls.mapping["processor_name_mapping"]:
130 | raise KeyError(
131 | "Name '{}' already registered for {}.".format(
132 | name, cls.mapping["processor_name_mapping"][name]
133 | )
134 | )
135 | cls.mapping["processor_name_mapping"][name] = processor_cls
136 | return processor_cls
137 |
138 | return wrap
139 |
140 | @classmethod
141 | def register_lr_scheduler(cls, name):
142 | r"""Register a model to registry with key 'name'
143 |
144 | Args:
145 | name: Key with which the task will be registered.
146 |
147 | Usage:
148 |
149 | from minigpt4.common.registry import registry
150 | """
151 |
152 | def wrap(lr_sched_cls):
153 | if name in cls.mapping["lr_scheduler_name_mapping"]:
154 | raise KeyError(
155 | "Name '{}' already registered for {}.".format(
156 | name, cls.mapping["lr_scheduler_name_mapping"][name]
157 | )
158 | )
159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160 | return lr_sched_cls
161 |
162 | return wrap
163 |
164 | @classmethod
165 | def register_runner(cls, name):
166 | r"""Register a model to registry with key 'name'
167 |
168 | Args:
169 | name: Key with which the task will be registered.
170 |
171 | Usage:
172 |
173 | from minigpt4.common.registry import registry
174 | """
175 |
176 | def wrap(runner_cls):
177 | if name in cls.mapping["runner_name_mapping"]:
178 | raise KeyError(
179 | "Name '{}' already registered for {}.".format(
180 | name, cls.mapping["runner_name_mapping"][name]
181 | )
182 | )
183 | cls.mapping["runner_name_mapping"][name] = runner_cls
184 | return runner_cls
185 |
186 | return wrap
187 |
188 | @classmethod
189 | def register_path(cls, name, path):
190 | r"""Register a path to registry with key 'name'
191 |
192 | Args:
193 | name: Key with which the path will be registered.
194 |
195 | Usage:
196 |
197 | from minigpt4.common.registry import registry
198 | """
199 | assert isinstance(path, str), "All path must be str."
200 | if name in cls.mapping["paths"]:
201 | raise KeyError("Name '{}' already registered.".format(name))
202 | cls.mapping["paths"][name] = path
203 |
204 | @classmethod
205 | def register(cls, name, obj):
206 | r"""Register an item to registry with key 'name'
207 |
208 | Args:
209 | name: Key with which the item will be registered.
210 |
211 | Usage::
212 |
213 | from minigpt4.common.registry import registry
214 |
215 | registry.register("config", {})
216 | """
217 | path = name.split(".")
218 | current = cls.mapping["state"]
219 |
220 | for part in path[:-1]:
221 | if part not in current:
222 | current[part] = {}
223 | current = current[part]
224 |
225 | current[path[-1]] = obj
226 |
227 | # @classmethod
228 | # def get_trainer_class(cls, name):
229 | # return cls.mapping["trainer_name_mapping"].get(name, None)
230 |
231 | @classmethod
232 | def get_builder_class(cls, name):
233 | return cls.mapping["builder_name_mapping"].get(name, None)
234 |
235 | @classmethod
236 | def get_model_class(cls, name):
237 | return cls.mapping["model_name_mapping"].get(name, None)
238 |
239 | @classmethod
240 | def get_task_class(cls, name):
241 | return cls.mapping["task_name_mapping"].get(name, None)
242 |
243 | @classmethod
244 | def get_processor_class(cls, name):
245 | return cls.mapping["processor_name_mapping"].get(name, None)
246 |
247 | @classmethod
248 | def get_lr_scheduler_class(cls, name):
249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250 |
251 | @classmethod
252 | def get_runner_class(cls, name):
253 | return cls.mapping["runner_name_mapping"].get(name, None)
254 |
255 | @classmethod
256 | def list_runners(cls):
257 | return sorted(cls.mapping["runner_name_mapping"].keys())
258 |
259 | @classmethod
260 | def list_models(cls):
261 | return sorted(cls.mapping["model_name_mapping"].keys())
262 |
263 | @classmethod
264 | def list_tasks(cls):
265 | return sorted(cls.mapping["task_name_mapping"].keys())
266 |
267 | @classmethod
268 | def list_processors(cls):
269 | return sorted(cls.mapping["processor_name_mapping"].keys())
270 |
271 | @classmethod
272 | def list_lr_schedulers(cls):
273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274 |
275 | @classmethod
276 | def list_datasets(cls):
277 | return sorted(cls.mapping["builder_name_mapping"].keys())
278 |
279 | @classmethod
280 | def get_path(cls, name):
281 | return cls.mapping["paths"].get(name, None)
282 |
283 | @classmethod
284 | def get(cls, name, default=None, no_warning=False):
285 | r"""Get an item from registry with key 'name'
286 |
287 | Args:
288 | name (string): Key whose value needs to be retrieved.
289 | default: If passed and key is not in registry, default value will
290 | be returned with a warning. Default: None
291 | no_warning (bool): If passed as True, warning when key doesn't exist
292 | will not be generated. Useful for MMF's
293 | internal operations. Default: False
294 | """
295 | original_name = name
296 | name = name.split(".")
297 | value = cls.mapping["state"]
298 | for subname in name:
299 | value = value.get(subname, default)
300 | if value is default:
301 | break
302 |
303 | if (
304 | "writer" in cls.mapping["state"]
305 | and value == default
306 | and no_warning is False
307 | ):
308 | cls.mapping["state"]["writer"].warning(
309 | "Key {} is not present in registry, returning default value "
310 | "of {}".format(original_name, default)
311 | )
312 | return value
313 |
314 | @classmethod
315 | def unregister(cls, name):
316 | r"""Remove an item from registry with key 'name'
317 |
318 | Args:
319 | name: Key which needs to be removed.
320 | Usage::
321 |
322 | from mmf.common.registry import registry
323 |
324 | config = registry.unregister("config")
325 | """
326 | return cls.mapping["state"].pop(name, None)
327 |
328 |
329 | registry = Registry()
330 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_sbu/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu_align:
3 | data_type: images
4 | build_info:
5 | storage:
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_sbu/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu:
3 | data_type: images
4 | build_info:
5 | storage:
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/iuxray/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | mimic_align:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/mimic
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/iuxray/generate_then_refine.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | mimic_generate_then_refine:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/iuxray
6 | unlabeled_annotation_path: /path/to/pubmed
7 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/laion/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | laion:
3 | data_type: images
4 | build_info:
5 | storage:
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/mimic/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | mimic_align:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/mimic
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/mimic/generate_then_refine.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | mimic_generate_then_refine:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/mimic
6 | unlabeled_annotation_path: /path/to/pubmed
7 |
--------------------------------------------------------------------------------
/minigpt4/configs/default.yaml:
--------------------------------------------------------------------------------
1 | env:
2 | # For default users
3 | # cache_root: "cache"
4 | # For internal use with persistent storage
5 | cache_root: "/export/home/.cache/minigpt4"
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt4-7b.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "/path/to/vicuna-7b"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "blip2_image_train"
25 | image_size: 224
26 | eval:
27 | name: "blip2_image_eval"
28 | image_size: 224
29 | text_processor:
30 | train:
31 | name: "blip_caption"
32 | eval:
33 | name: "blip_caption"
34 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # Vicuna
16 | llama_model: "/path/to/vicuna-13b"
17 |
18 | # generation configs
19 | prompt: ""
20 |
21 | preprocess:
22 | vis_processor:
23 | train:
24 | name: "blip2_image_train"
25 | image_size: 224
26 | eval:
27 | name: "blip2_image_eval"
28 | image_size: 224
29 | text_processor:
30 | train:
31 | name: "blip_caption"
32 | eval:
33 | name: "blip_caption"
34 |
--------------------------------------------------------------------------------
/minigpt4/conversation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__init__.py
--------------------------------------------------------------------------------
/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/conversation/conversation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from PIL import Image
4 |
5 | import torch
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7 | from transformers import StoppingCriteria, StoppingCriteriaList
8 |
9 | import dataclasses
10 | from enum import auto, Enum
11 | from typing import List, Tuple, Any
12 |
13 | from minigpt4.common.registry import registry
14 |
15 |
16 | class SeparatorStyle(Enum):
17 | """Different separator style."""
18 | SINGLE = auto()
19 | TWO = auto()
20 |
21 |
22 | @dataclasses.dataclass
23 | class Conversation:
24 | """A class that keeps all conversation history."""
25 | system: str
26 | roles: List[str]
27 | messages: List[List[str]]
28 | offset: int
29 | # system_img: List[Image.Image] = []
30 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31 | sep: str = "###"
32 | sep2: str = None
33 |
34 | skip_next: bool = False
35 | conv_id: Any = None
36 |
37 | def get_prompt(self):
38 | if self.sep_style == SeparatorStyle.SINGLE:
39 | ret = self.system + self.sep
40 | for role, message in self.messages:
41 | if message:
42 | ret += role + ": " + message + self.sep
43 | else:
44 | ret += role + ":"
45 | return ret
46 | elif self.sep_style == SeparatorStyle.TWO:
47 | seps = [self.sep, self.sep2]
48 | ret = self.system + seps[0]
49 | for i, (role, message) in enumerate(self.messages):
50 | if message:
51 | ret += role + ": " + message + seps[i % 2]
52 | else:
53 | ret += role + ":"
54 | return ret
55 | else:
56 | raise ValueError(f"Invalid style: {self.sep_style}")
57 |
58 | def append_message(self, role, message):
59 | self.messages.append([role, message])
60 |
61 | def to_gradio_chatbot(self):
62 | ret = []
63 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
64 | if i % 2 == 0:
65 | ret.append([msg, None])
66 | else:
67 | ret[-1][-1] = msg
68 | return ret
69 |
70 | def copy(self):
71 | return Conversation(
72 | system=self.system,
73 | # system_img=self.system_img,
74 | roles=self.roles,
75 | messages=[[x, y] for x, y in self.messages],
76 | offset=self.offset,
77 | sep_style=self.sep_style,
78 | sep=self.sep,
79 | sep2=self.sep2,
80 | conv_id=self.conv_id)
81 |
82 | def dict(self):
83 | return {
84 | "system": self.system,
85 | # "system_img": self.system_img,
86 | "roles": self.roles,
87 | "messages": self.messages,
88 | "offset": self.offset,
89 | "sep": self.sep,
90 | "sep2": self.sep2,
91 | "conv_id": self.conv_id,
92 | }
93 |
94 |
95 | class StoppingCriteriaSub(StoppingCriteria):
96 |
97 | def __init__(self, stops=[], encounters=1):
98 | super().__init__()
99 | self.stops = stops
100 |
101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102 | for stop in self.stops:
103 | if torch.all((stop == input_ids[0][-len(stop):])).item():
104 | return True
105 |
106 | return False
107 |
108 |
109 | CONV_VISION = Conversation(
110 | system="Give the following image:
ImageContent. "
111 | "You will be able to see the image once I provide it to you. Please answer my questions.",
112 | roles=("Human", "Assistant"),
113 | messages=[],
114 | offset=2,
115 | sep_style=SeparatorStyle.SINGLE,
116 | sep="###",
117 | )
118 |
119 |
120 |
121 | class Chat:
122 | def __init__(self, model, vis_processor, device='cuda:0'):
123 | self.device = device
124 | self.model = model
125 | self.vis_processor = vis_processor
126 | stop_words_ids = [torch.tensor([835]).to(self.device),
127 | torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128 | self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129 |
130 | def ask(self, text, conv):
131 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132 | and conv.messages[-1][1][-6:] == '': # last message is image.
133 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134 | else:
135 | conv.append_message(conv.roles[0], text)
136 |
137 | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
138 | repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
139 | conv.append_message(conv.roles[1], None)
140 | embs = self.get_context_emb(conv, img_list)
141 |
142 | current_max_len = embs.shape[1] + max_new_tokens
143 | if current_max_len - max_length > 0:
144 | print('Warning: The number of tokens in current conversation exceeds the max length. '
145 | 'The model will not see the contexts outside the range.')
146 | begin_idx = max(0, current_max_len - max_length)
147 |
148 | embs = embs[:, begin_idx:]
149 |
150 | outputs = self.model.llama_model.generate(
151 | inputs_embeds=embs,
152 | max_new_tokens=max_new_tokens,
153 | stopping_criteria=self.stopping_criteria,
154 | num_beams=num_beams,
155 | do_sample=True,
156 | min_length=min_length,
157 | top_p=top_p,
158 | repetition_penalty=repetition_penalty,
159 | length_penalty=length_penalty,
160 | temperature=temperature,
161 | )
162 | output_token = outputs[0]
163 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
164 | output_token = output_token[1:]
165 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
166 | output_token = output_token[1:]
167 | output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
168 | output_text = output_text.split('###')[0] # remove the stop sign '###'
169 | output_text = output_text.split('Assistant:')[-1].strip()
170 | conv.messages[-1][1] = output_text
171 | return output_text, output_token.cpu().numpy()
172 |
173 | def upload_img(self, image, conv, img_list):
174 | if isinstance(image, str): # is a image path
175 | raw_image = Image.open(image).convert('RGB')
176 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
177 | elif isinstance(image, Image.Image):
178 | raw_image = image
179 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
180 | elif isinstance(image, torch.Tensor):
181 | if len(image.shape) == 3:
182 | image = image.unsqueeze(0)
183 | image = image.to(self.device)
184 |
185 | image_emb, _ = self.model.encode_img(image)
186 | img_list.append(image_emb)
187 | conv.append_message(conv.roles[0], "
")
188 | msg = "Received."
189 | # self.conv.append_message(self.conv.roles[1], msg)
190 | return msg
191 |
192 | def get_context_emb(self, conv, img_list):
193 | prompt = conv.get_prompt()
194 | prompt_segs = prompt.split('')
195 | assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
196 | seg_tokens = [
197 | self.model.llama_tokenizer(
198 | seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
199 | # only add bos to the first seg
200 | for i, seg in enumerate(prompt_segs)
201 | ]
202 | seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
203 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
204 | mixed_embs = torch.cat(mixed_embs, dim=1)
205 | return mixed_embs
206 |
207 |
208 |
--------------------------------------------------------------------------------
/minigpt4/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__init__.py
--------------------------------------------------------------------------------
/minigpt4/datasets/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9 | from minigpt4.datasets.builders.image_text_pair_builder import (
10 | CCSBUBuilder,
11 | LaionBuilder,
12 | CCSBUAlignBuilder
13 | )
14 | from minigpt4.common.registry import registry
15 |
16 | __all__ = [
17 | "CCSBUBuilder",
18 | "LaionBuilder",
19 | "CCSBUAlignBuilder"
20 | ]
21 |
22 |
23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24 | """
25 | Example
26 |
27 | >>> dataset = load_dataset("coco_caption", cfg=None)
28 | >>> splits = dataset.keys()
29 | >>> print([len(dataset[split]) for split in splits])
30 |
31 | """
32 | if cfg_path is None:
33 | cfg = None
34 | else:
35 | cfg = load_dataset_config(cfg_path)
36 |
37 | try:
38 | builder = registry.get_builder_class(name)(cfg)
39 | except TypeError:
40 | print(
41 | f"Dataset {name} not found. Available datasets:\n"
42 | + ", ".join([str(k) for k in dataset_zoo.get_names()])
43 | )
44 | exit(1)
45 |
46 | if vis_path is not None:
47 | if data_type is None:
48 | # use default data type in the config
49 | data_type = builder.config.data_type
50 |
51 | assert (
52 | data_type in builder.config.build_info
53 | ), f"Invalid data_type {data_type} for {name}."
54 |
55 | builder.config.build_info.get(data_type).storage = vis_path
56 |
57 | dataset = builder.build_datasets()
58 | return dataset
59 |
60 |
61 | class DatasetZoo:
62 | def __init__(self) -> None:
63 | self.dataset_zoo = {
64 | k: list(v.DATASET_CONFIG_DICT.keys())
65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66 | }
67 |
68 | def get_names(self):
69 | return list(self.dataset_zoo.keys())
70 |
71 |
72 | dataset_zoo = DatasetZoo()
73 |
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/base_dataset_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is from
3 | Copyright (c) 2022, salesforce.com, inc.
4 | All rights reserved.
5 | SPDX-License-Identifier: BSD-3-Clause
6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | """
8 |
9 | import logging
10 | import os
11 | import shutil
12 | import warnings
13 |
14 | from omegaconf import OmegaConf
15 | import torch.distributed as dist
16 | from torchvision.datasets.utils import download_url
17 |
18 | import minigpt4.common.utils as utils
19 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20 | from minigpt4.common.registry import registry
21 | from minigpt4.processors.base_processor import BaseProcessor
22 |
23 |
24 |
25 | class BaseDatasetBuilder:
26 | train_dataset_cls, eval_dataset_cls = None, None
27 |
28 | def __init__(self, cfg=None):
29 | super().__init__()
30 |
31 | if cfg is None:
32 | # help to create datasets from default config.
33 | self.config = load_dataset_config(self.default_config_path())
34 | elif isinstance(cfg, str):
35 | self.config = load_dataset_config(cfg)
36 | else:
37 | # when called from task.build_dataset()
38 | self.config = cfg
39 |
40 | self.data_type = self.config.data_type
41 |
42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44 |
45 | def build_datasets(self):
46 | # download, split, etc...
47 | # only called on 1 GPU/TPU in distributed
48 |
49 | if is_main_process():
50 | self._download_data()
51 |
52 | if is_dist_avail_and_initialized():
53 | dist.barrier()
54 |
55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56 | logging.info("Building datasets...")
57 | datasets = self.build() # dataset['train'/'val'/'test']
58 |
59 | return datasets
60 |
61 | def build_processors(self):
62 | vis_proc_cfg = self.config.get("vis_processor")
63 | txt_proc_cfg = self.config.get("text_processor")
64 |
65 | if vis_proc_cfg is not None:
66 | vis_train_cfg = vis_proc_cfg.get("train")
67 | vis_eval_cfg = vis_proc_cfg.get("eval")
68 |
69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71 |
72 | if txt_proc_cfg is not None:
73 | txt_train_cfg = txt_proc_cfg.get("train")
74 | txt_eval_cfg = txt_proc_cfg.get("eval")
75 |
76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78 |
79 | @staticmethod
80 | def _build_proc_from_cfg(cfg):
81 | return (
82 | registry.get_processor_class(cfg.name).from_config(cfg)
83 | if cfg is not None
84 | else None
85 | )
86 |
87 | @classmethod
88 | def default_config_path(cls, type="default"):
89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90 |
91 | def _download_data(self):
92 | self._download_ann()
93 | self._download_vis()
94 |
95 | def _download_ann(self):
96 | """
97 | Download annotation files if necessary.
98 | All the vision-language datasets should have annotations of unified format.
99 |
100 | storage_path can be:
101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103 |
104 | Local annotation paths should be relative.
105 | """
106 | anns = self.config.build_info.annotations
107 |
108 | splits = anns.keys()
109 |
110 | cache_root = registry.get_path("cache_root")
111 |
112 | for split in splits:
113 | info = anns[split]
114 |
115 | urls, storage_paths = info.get("url", None), info.storage
116 |
117 | if isinstance(urls, str):
118 | urls = [urls]
119 | if isinstance(storage_paths, str):
120 | storage_paths = [storage_paths]
121 |
122 | assert len(urls) == len(storage_paths)
123 |
124 | for url_or_filename, storage_path in zip(urls, storage_paths):
125 | # if storage_path is relative, make it full by prefixing with cache_root.
126 | if not os.path.isabs(storage_path):
127 | storage_path = os.path.join(cache_root, storage_path)
128 |
129 | dirname = os.path.dirname(storage_path)
130 | if not os.path.exists(dirname):
131 | os.makedirs(dirname)
132 |
133 | if os.path.isfile(url_or_filename):
134 | src, dst = url_or_filename, storage_path
135 | if not os.path.exists(dst):
136 | shutil.copyfile(src=src, dst=dst)
137 | else:
138 | logging.info("Using existing file {}.".format(dst))
139 | else:
140 | if os.path.isdir(storage_path):
141 | # if only dirname is provided, suffix with basename of URL.
142 | raise ValueError(
143 | "Expecting storage_path to be a file path, got directory {}".format(
144 | storage_path
145 | )
146 | )
147 | else:
148 | filename = os.path.basename(storage_path)
149 |
150 | download_url(url=url_or_filename, root=dirname, filename=filename)
151 |
152 | def _download_vis(self):
153 |
154 | storage_path = self.config.build_info.get(self.data_type).storage
155 | storage_path = utils.get_cache_path(storage_path)
156 |
157 | if not os.path.exists(storage_path):
158 | warnings.warn(
159 | f"""
160 | The specified path {storage_path} for visual inputs does not exist.
161 | Please provide a correct path to the visual inputs or
162 | refer to datasets/download_scripts/README.md for downloading instructions.
163 | """
164 | )
165 |
166 | def build(self):
167 | """
168 | Create by split datasets inheriting torch.utils.data.Datasets.
169 |
170 | # build() can be dataset-specific. Overwrite to customize.
171 | """
172 | self.build_processors()
173 |
174 | build_info = self.config.build_info
175 |
176 | ann_info = build_info.annotations
177 | vis_info = build_info.get(self.data_type)
178 |
179 | datasets = dict()
180 | for split in ann_info.keys():
181 | if split not in ["train", "val", "test"]:
182 | continue
183 |
184 | is_train = split == "train"
185 |
186 | # processors
187 | vis_processor = (
188 | self.vis_processors["train"]
189 | if is_train
190 | else self.vis_processors["eval"]
191 | )
192 | text_processor = (
193 | self.text_processors["train"]
194 | if is_train
195 | else self.text_processors["eval"]
196 | )
197 |
198 | # annotation path
199 | ann_paths = ann_info.get(split).storage
200 | if isinstance(ann_paths, str):
201 | ann_paths = [ann_paths]
202 |
203 | abs_ann_paths = []
204 | for ann_path in ann_paths:
205 | if not os.path.isabs(ann_path):
206 | ann_path = utils.get_cache_path(ann_path)
207 | abs_ann_paths.append(ann_path)
208 | ann_paths = abs_ann_paths
209 |
210 | # visual data storage path
211 | vis_path = os.path.join(vis_info.storage, split)
212 |
213 | if not os.path.isabs(vis_path):
214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215 | vis_path = utils.get_cache_path(vis_path)
216 |
217 | if not os.path.exists(vis_path):
218 | warnings.warn("storage path {} does not exist.".format(vis_path))
219 |
220 | # create datasets
221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222 | datasets[split] = dataset_cls(
223 | vis_processor=vis_processor,
224 | text_processor=text_processor,
225 | ann_paths=ann_paths,
226 | vis_root=vis_path,
227 | )
228 |
229 | return datasets
230 |
231 |
232 | def load_dataset_config(cfg_path):
233 | cfg = OmegaConf.load(cfg_path).datasets
234 | cfg = cfg[list(cfg.keys())[0]]
235 |
236 | return cfg
237 |
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/image_text_pair_builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 |
5 | from minigpt4.common.registry import registry
6 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7 | from minigpt4.datasets.datasets.laion_dataset import LaionDataset
8 | from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9 | from minigpt4.datasets.datasets.mimic_dataset import MIMICDataset, MIMICGenerateThenRefineDataset
10 | from minigpt4.datasets.datasets.iuxray_dataset import IUXRayGenerateThenRefineDataset
11 |
12 |
13 | @registry.register_builder("cc_sbu")
14 | class CCSBUBuilder(BaseDatasetBuilder):
15 | train_dataset_cls = CCSBUDataset
16 |
17 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
18 |
19 | def _download_ann(self):
20 | pass
21 |
22 | def _download_vis(self):
23 | pass
24 |
25 | def build(self):
26 | self.build_processors()
27 |
28 | build_info = self.config.build_info
29 |
30 | datasets = dict()
31 | split = "train"
32 |
33 | # create datasets
34 | # [NOTE] return inner_datasets (wds.DataPipeline)
35 | dataset_cls = self.train_dataset_cls
36 | datasets[split] = dataset_cls(
37 | vis_processor=self.vis_processors[split],
38 | text_processor=self.text_processors[split],
39 | location=build_info.storage,
40 | ).inner_dataset
41 |
42 | return datasets
43 |
44 |
45 | @registry.register_builder("laion")
46 | class LaionBuilder(BaseDatasetBuilder):
47 | train_dataset_cls = LaionDataset
48 |
49 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
50 |
51 | def _download_ann(self):
52 | pass
53 |
54 | def _download_vis(self):
55 | pass
56 |
57 | def build(self):
58 | self.build_processors()
59 |
60 | build_info = self.config.build_info
61 |
62 | datasets = dict()
63 | split = "train"
64 |
65 | # create datasets
66 | # [NOTE] return inner_datasets (wds.DataPipeline)
67 | dataset_cls = self.train_dataset_cls
68 | datasets[split] = dataset_cls(
69 | vis_processor=self.vis_processors[split],
70 | text_processor=self.text_processors[split],
71 | location=build_info.storage,
72 | ).inner_dataset
73 |
74 | return datasets
75 |
76 |
77 | @registry.register_builder("cc_sbu_align")
78 | class CCSBUAlignBuilder(BaseDatasetBuilder):
79 | train_dataset_cls = CCSBUAlignDataset
80 |
81 | DATASET_CONFIG_DICT = {
82 | "default": "configs/datasets/cc_sbu/align.yaml",
83 | }
84 |
85 | def build_datasets(self):
86 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
87 | logging.info("Building datasets...")
88 | self.build_processors()
89 |
90 | build_info = self.config.build_info
91 | storage_path = build_info.storage
92 |
93 | datasets = dict()
94 |
95 | if not os.path.exists(storage_path):
96 | warnings.warn("storage path {} does not exist.".format(storage_path))
97 |
98 | # create datasets
99 | dataset_cls = self.train_dataset_cls
100 | datasets['train'] = dataset_cls(
101 | vis_processor=self.vis_processors["train"],
102 | text_processor=self.text_processors["train"],
103 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
104 | vis_root=os.path.join(storage_path, 'image'),
105 | )
106 |
107 | return datasets
108 |
109 | @registry.register_builder("mimic_align")
110 | class MIMICBuilder(BaseDatasetBuilder):
111 | train_dataset_cls = MIMICDataset
112 |
113 | DATASET_CONFIG_DICT = {
114 | "default": "minigpt4/configs/datasets/mimic/align.yaml",
115 | }
116 |
117 | def build_datasets(self):
118 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
119 | logging.info("Building datasets...")
120 | self.build_processors()
121 |
122 | build_info = self.config.build_info
123 | storage_path = build_info.storage
124 |
125 | datasets = dict()
126 |
127 | if not os.path.exists(storage_path):
128 | warnings.warn("storage path {} does not exist.".format(storage_path))
129 |
130 | # create datasets
131 | dataset_cls = self.train_dataset_cls
132 | datasets['train'] = dataset_cls(
133 | vis_processor=self.vis_processors["train"],
134 | text_processor=self.text_processors["train"],
135 | ann_path=os.path.join(storage_path, 'annotation.json'),
136 | image_root=os.path.join(storage_path, 'images'),
137 | )
138 |
139 | return datasets
140 |
141 | @registry.register_builder("mimic_generate_then_refine")
142 | class MIMICGenerateThenRefineBuilder(BaseDatasetBuilder):
143 | train_dataset_cls = MIMICGenerateThenRefineDataset
144 |
145 | DATASET_CONFIG_DICT = {
146 | "default": "minigpt4/configs/datasets/mimic/generate_then_refine.yaml",
147 | }
148 |
149 | def build_datasets(self):
150 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
151 | logging.info("Building datasets...")
152 | self.build_processors()
153 |
154 | build_info = self.config.build_info
155 | storage_path = build_info.storage
156 | unlabeled_annotation_path = build_info.unlabeled_annotation_path
157 |
158 | datasets = dict()
159 |
160 | if not os.path.exists(storage_path):
161 | warnings.warn("storage path {} does not exist.".format(storage_path))
162 |
163 | # create datasets
164 | dataset_cls = self.train_dataset_cls
165 | datasets['train'] = dataset_cls(
166 | vis_processor=self.vis_processors["train"],
167 | text_processor=self.text_processors["train"],
168 | ann_path=os.path.join(storage_path, 'mimic_anno_with_ref.json'),
169 | image_root=os.path.join(storage_path, 'images'),
170 | unlabeled_ann_path=os.path.join(unlabeled_annotation_path, 'annotation.json'),
171 | )
172 |
173 | return datasets
174 |
175 | @registry.register_builder("iuxray_generate_then_refine")
176 | class IUXRayGenerateThenRefineBuilder(BaseDatasetBuilder):
177 | train_dataset_cls = IUXRayGenerateThenRefineDataset
178 |
179 | DATASET_CONFIG_DICT = {
180 | "default": "minigpt4/configs/datasets/iuxray/generate_then_refine.yaml",
181 | }
182 |
183 | def build_datasets(self):
184 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
185 | logging.info("Building datasets...")
186 | self.build_processors()
187 |
188 | build_info = self.config.build_info
189 | storage_path = build_info.storage
190 | unlabeled_annotation_path = build_info.unlabeled_annotation_path
191 |
192 | datasets = dict()
193 |
194 | if not os.path.exists(storage_path):
195 | warnings.warn("storage path {} does not exist.".format(storage_path))
196 |
197 | # create datasets
198 | dataset_cls = self.train_dataset_cls
199 | datasets['train'] = dataset_cls(
200 | vis_processor=self.vis_processors["train"],
201 | text_processor=self.text_processors["train"],
202 | ann_path=os.path.join(storage_path, 'annotation.json'),
203 | image_root=os.path.join(storage_path, 'images'),
204 | unlabeled_ann_path=os.path.join(unlabeled_annotation_path, 'annotation.json'),
205 | )
206 |
207 | return datasets
--------------------------------------------------------------------------------
/minigpt4/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import gzip
9 | import logging
10 | import os
11 | import random as rnd
12 | import tarfile
13 | import zipfile
14 | import random
15 | from typing import List
16 | from tqdm import tqdm
17 |
18 | import decord
19 | from decord import VideoReader
20 | import webdataset as wds
21 | import numpy as np
22 | import torch
23 | from torch.utils.data.dataset import IterableDataset
24 |
25 | from minigpt4.common.registry import registry
26 | from minigpt4.datasets.datasets.base_dataset import ConcatDataset
27 |
28 |
29 | decord.bridge.set_bridge("torch")
30 | MAX_INT = registry.get("MAX_INT")
31 |
32 |
33 | class ChainDataset(wds.DataPipeline):
34 | r"""Dataset for chaining multiple :class:`DataPipeline` s.
35 |
36 | This class is useful to assemble different existing dataset streams. The
37 | chaining operation is done on-the-fly, so concatenating large-scale
38 | datasets with this class will be efficient.
39 |
40 | Args:
41 | datasets (iterable of IterableDataset): datasets to be chained together
42 | """
43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44 | super().__init__()
45 | self.datasets = datasets
46 | self.prob = []
47 | self.names = []
48 | for dataset in self.datasets:
49 | if hasattr(dataset, 'name'):
50 | self.names.append(dataset.name)
51 | else:
52 | self.names.append('Unknown')
53 | if hasattr(dataset, 'sample_ratio'):
54 | self.prob.append(dataset.sample_ratio)
55 | else:
56 | self.prob.append(1)
57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58 |
59 | def __iter__(self):
60 | datastreams = [iter(dataset) for dataset in self.datasets]
61 | while True:
62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63 | yield next(select_datastream)
64 |
65 |
66 | def apply_to_sample(f, sample):
67 | if len(sample) == 0:
68 | return {}
69 |
70 | def _apply(x):
71 | if torch.is_tensor(x):
72 | return f(x)
73 | elif isinstance(x, dict):
74 | return {key: _apply(value) for key, value in x.items()}
75 | elif isinstance(x, list):
76 | return [_apply(x) for x in x]
77 | else:
78 | return x
79 |
80 | return _apply(sample)
81 |
82 |
83 | def move_to_cuda(sample):
84 | def _move_to_cuda(tensor):
85 | return tensor.cuda()
86 |
87 | return apply_to_sample(_move_to_cuda, sample)
88 |
89 |
90 | def prepare_sample(samples, cuda_enabled=True):
91 | if cuda_enabled:
92 | samples = move_to_cuda(samples)
93 |
94 | # TODO fp16 support
95 |
96 | return samples
97 |
98 |
99 | def reorg_datasets_by_split(datasets):
100 | """
101 | Organizes datasets by split.
102 |
103 | Args:
104 | datasets: dict of torch.utils.data.Dataset objects by name.
105 |
106 | Returns:
107 | Dict of datasets by split {split_name: List[Datasets]}.
108 | """
109 | # if len(datasets) == 1:
110 | # return datasets[list(datasets.keys())[0]]
111 | # else:
112 | reorg_datasets = dict()
113 |
114 | # reorganize by split
115 | for _, dataset in datasets.items():
116 | for split_name, dataset_split in dataset.items():
117 | if split_name not in reorg_datasets:
118 | reorg_datasets[split_name] = [dataset_split]
119 | else:
120 | reorg_datasets[split_name].append(dataset_split)
121 |
122 | return reorg_datasets
123 |
124 |
125 | def concat_datasets(datasets):
126 | """
127 | Concatenates multiple datasets into a single dataset.
128 |
129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130 | generic IterableDataset because it requires creating separate samplers.
131 |
132 | Now only supports conctenating training datasets and assuming validation and testing
133 | have only a single dataset. This is because metrics should not be computed on the concatenated
134 | datasets.
135 |
136 | Args:
137 | datasets: dict of torch.utils.data.Dataset objects by split.
138 |
139 | Returns:
140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141 | "val" and "test" remain the same.
142 |
143 | If the input training datasets contain both map-style and DataPipeline datasets, returns
144 | a tuple, where the first element is a concatenated map-style dataset and the second
145 | element is a chained DataPipeline dataset.
146 |
147 | """
148 | # concatenate datasets in the same split
149 | for split_name in datasets:
150 | if split_name != "train":
151 | assert (
152 | len(datasets[split_name]) == 1
153 | ), "Do not support multiple {} datasets.".format(split_name)
154 | datasets[split_name] = datasets[split_name][0]
155 | else:
156 | iterable_datasets, map_datasets = [], []
157 | for dataset in datasets[split_name]:
158 | if isinstance(dataset, wds.DataPipeline):
159 | logging.info(
160 | "Dataset {} is IterableDataset, can't be concatenated.".format(
161 | dataset
162 | )
163 | )
164 | iterable_datasets.append(dataset)
165 | elif isinstance(dataset, IterableDataset):
166 | raise NotImplementedError(
167 | "Do not support concatenation of generic IterableDataset."
168 | )
169 | else:
170 | map_datasets.append(dataset)
171 |
172 | # if len(iterable_datasets) > 0:
173 | # concatenate map-style datasets and iterable-style datasets separately
174 | if len(iterable_datasets) > 1:
175 | chained_datasets = (
176 | ChainDataset(iterable_datasets)
177 | )
178 | elif len(iterable_datasets) == 1:
179 | chained_datasets = iterable_datasets[0]
180 | else:
181 | chained_datasets = None
182 |
183 | concat_datasets = (
184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185 | )
186 |
187 | train_datasets = concat_datasets, chained_datasets
188 | train_datasets = tuple([x for x in train_datasets if x is not None])
189 | train_datasets = (
190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets
191 | )
192 |
193 | datasets[split_name] = train_datasets
194 |
195 | return datasets
196 |
197 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__init__.py
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import json
9 | from typing import Iterable
10 |
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 | class BaseDataset(Dataset):
16 | def __init__(
17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18 | ):
19 | """
20 | vis_root (string): Root directory of images (e.g. coco/images/)
21 | ann_root (string): directory to store the annotation file
22 | """
23 | self.vis_root = vis_root
24 |
25 | self.annotation = []
26 | for ann_path in ann_paths:
27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28 |
29 | self.vis_processor = vis_processor
30 | self.text_processor = text_processor
31 |
32 | self._add_instance_ids()
33 |
34 | def __len__(self):
35 | return len(self.annotation)
36 |
37 | def collater(self, samples):
38 | return default_collate(samples)
39 |
40 | def set_processors(self, vis_processor, text_processor):
41 | self.vis_processor = vis_processor
42 | self.text_processor = text_processor
43 |
44 | def _add_instance_ids(self, key="instance_id"):
45 | for idx, ann in enumerate(self.annotation):
46 | ann[key] = str(idx)
47 |
48 |
49 | class ConcatDataset(ConcatDataset):
50 | def __init__(self, datasets: Iterable[Dataset]) -> None:
51 | super().__init__(datasets)
52 |
53 | def collater(self, samples):
54 | # TODO For now only supports datasets with same underlying collater implementations
55 |
56 | all_keys = set()
57 | for s in samples:
58 | all_keys.update(s)
59 |
60 | shared_keys = all_keys
61 | for s in samples:
62 | shared_keys = shared_keys & set(s.keys())
63 |
64 | samples_shared_keys = []
65 | for s in samples:
66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67 |
68 | return self.datasets[0].collater(samples_shared_keys)
69 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/caption_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | from collections import OrderedDict
10 |
11 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
12 | from PIL import Image
13 |
14 |
15 | class __DisplMixin:
16 | def displ_item(self, index):
17 | sample, ann = self.__getitem__(index), self.annotation[index]
18 |
19 | return OrderedDict(
20 | {
21 | "file": ann["image"],
22 | "caption": ann["caption"],
23 | "image": sample["image"],
24 | }
25 | )
26 |
27 |
28 | class CaptionDataset(BaseDataset, __DisplMixin):
29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30 | """
31 | vis_root (string): Root directory of images (e.g. coco/images/)
32 | ann_root (string): directory to store the annotation file
33 | """
34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35 |
36 | self.img_ids = {}
37 | n = 0
38 | for ann in self.annotation:
39 | img_id = ann["image_id"]
40 | if img_id not in self.img_ids.keys():
41 | self.img_ids[img_id] = n
42 | n += 1
43 |
44 | def __getitem__(self, index):
45 |
46 | # TODO this assumes image input, not general enough
47 | ann = self.annotation[index]
48 |
49 | img_file = '{:0>12}.jpg'.format(ann["image_id"])
50 | image_path = os.path.join(self.vis_root, img_file)
51 | image = Image.open(image_path).convert("RGB")
52 |
53 | image = self.vis_processor(image)
54 | caption = self.text_processor(ann["caption"])
55 |
56 | return {
57 | "image": image,
58 | "text_input": caption,
59 | "image_id": self.img_ids[ann["image_id"]],
60 | }
61 |
62 |
63 | class CaptionEvalDataset(BaseDataset, __DisplMixin):
64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65 | """
66 | vis_root (string): Root directory of images (e.g. coco/images/)
67 | ann_root (string): directory to store the annotation file
68 | split (string): val or test
69 | """
70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71 |
72 | def __getitem__(self, index):
73 |
74 | ann = self.annotation[index]
75 |
76 | image_path = os.path.join(self.vis_root, ann["image"])
77 | image = Image.open(image_path).convert("RGB")
78 |
79 | image = self.vis_processor(image)
80 |
81 | return {
82 | "image": image,
83 | "image_id": ann["image_id"],
84 | "instance_id": ann["instance_id"],
85 | }
86 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/cc_sbu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import webdataset as wds
4 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
5 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
6 |
7 |
8 | class CCSBUDataset(BaseDataset):
9 | def __init__(self, vis_processor, text_processor, location):
10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11 |
12 | self.inner_dataset = wds.DataPipeline(
13 | wds.ResampledShards(location),
14 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
15 | wds.shuffle(1000, handler=wds.warn_and_continue),
16 | wds.decode("pilrgb", handler=wds.warn_and_continue),
17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19 | wds.map(self.to_dict, handler=wds.warn_and_continue),
20 | )
21 |
22 | def to_dict(self, sample):
23 | return {
24 | "image": sample[0],
25 | "text_input": self.text_processor(sample[1]["caption"]),
26 | }
27 |
28 |
29 | class CCSBUAlignDataset(CaptionDataset):
30 |
31 | def __getitem__(self, index):
32 |
33 | # TODO this assumes image input, not general enough
34 | ann = self.annotation[index]
35 |
36 | img_file = '{}.jpg'.format(ann["image_id"])
37 | image_path = os.path.join(self.vis_root, img_file)
38 | image = Image.open(image_path).convert("RGB")
39 |
40 | image = self.vis_processor(image)
41 | caption = ann["caption"]
42 |
43 | return {
44 | "image": image,
45 | "text_input": caption,
46 | "image_id": self.img_ids[ann["image_id"]],
47 | }
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import time
9 | import random
10 | import torch
11 | from minigpt4.datasets.data_utils import move_to_cuda
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class MultiIterLoader:
16 | """
17 | A simple wrapper for iterating over multiple iterators.
18 |
19 | Args:
20 | loaders (List[Loader]): List of Iterator loaders.
21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22 | """
23 |
24 | def __init__(self, loaders, ratios=None):
25 | # assert all loaders has __next__ method
26 | for loader in loaders:
27 | assert hasattr(
28 | loader, "__next__"
29 | ), "Loader {} has no __next__ method.".format(loader)
30 |
31 | if ratios is None:
32 | ratios = [1.0] * len(loaders)
33 | else:
34 | assert len(ratios) == len(loaders)
35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36 |
37 | self.loaders = loaders
38 | self.ratios = ratios
39 |
40 | def __next__(self):
41 | # random sample from each loader by ratio
42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43 | return next(self.loaders[loader_idx])
44 |
45 |
46 | class PrefetchLoader(object):
47 | """
48 | Modified from https://github.com/ChenRocks/UNITER.
49 |
50 | overlap compute and cuda data transfer
51 | (copied and then modified from nvidia apex)
52 | """
53 |
54 | def __init__(self, loader):
55 | self.loader = loader
56 | self.stream = torch.cuda.Stream()
57 |
58 | def __iter__(self):
59 | loader_it = iter(self.loader)
60 | self.preload(loader_it)
61 | batch = self.next(loader_it)
62 | while batch is not None:
63 | is_tuple = isinstance(batch, tuple)
64 | if is_tuple:
65 | task, batch = batch
66 |
67 | if is_tuple:
68 | yield task, batch
69 | else:
70 | yield batch
71 | batch = self.next(loader_it)
72 |
73 | def __len__(self):
74 | return len(self.loader)
75 |
76 | def preload(self, it):
77 | try:
78 | self.batch = next(it)
79 | except StopIteration:
80 | self.batch = None
81 | return
82 | # if record_stream() doesn't work, another option is to make sure
83 | # device inputs are created on the main stream.
84 | # self.next_input_gpu = torch.empty_like(self.next_input,
85 | # device='cuda')
86 | # self.next_target_gpu = torch.empty_like(self.next_target,
87 | # device='cuda')
88 | # Need to make sure the memory allocated for next_* is not still in use
89 | # by the main stream at the time we start copying to next_*:
90 | # self.stream.wait_stream(torch.cuda.current_stream())
91 | with torch.cuda.stream(self.stream):
92 | self.batch = move_to_cuda(self.batch)
93 | # more code for the alternative if record_stream() doesn't work:
94 | # copy_ will record the use of the pinned source tensor in this
95 | # side stream.
96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98 | # self.next_input = self.next_input_gpu
99 | # self.next_target = self.next_target_gpu
100 |
101 | def next(self, it):
102 | torch.cuda.current_stream().wait_stream(self.stream)
103 | batch = self.batch
104 | if batch is not None:
105 | record_cuda_stream(batch)
106 | self.preload(it)
107 | return batch
108 |
109 | def __getattr__(self, name):
110 | method = self.loader.__getattribute__(name)
111 | return method
112 |
113 |
114 | def record_cuda_stream(batch):
115 | if isinstance(batch, torch.Tensor):
116 | batch.record_stream(torch.cuda.current_stream())
117 | elif isinstance(batch, list) or isinstance(batch, tuple):
118 | for t in batch:
119 | record_cuda_stream(t)
120 | elif isinstance(batch, dict):
121 | for t in batch.values():
122 | record_cuda_stream(t)
123 | else:
124 | pass
125 |
126 |
127 | class IterLoader:
128 | """
129 | A wrapper to convert DataLoader as an infinite iterator.
130 |
131 | Modified from:
132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133 | """
134 |
135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136 | self._dataloader = dataloader
137 | self.iter_loader = iter(self._dataloader)
138 | self._use_distributed = use_distributed
139 | self._epoch = 0
140 |
141 | @property
142 | def epoch(self) -> int:
143 | return self._epoch
144 |
145 | def __next__(self):
146 | try:
147 | data = next(self.iter_loader)
148 | except StopIteration:
149 | self._epoch += 1
150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151 | self._dataloader.sampler.set_epoch(self._epoch)
152 | time.sleep(2) # Prevent possible deadlock during epoch transition
153 | self.iter_loader = iter(self._dataloader)
154 | data = next(self.iter_loader)
155 |
156 | return data
157 |
158 | def __iter__(self):
159 | return self
160 |
161 | def __len__(self):
162 | return len(self._dataloader)
163 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/iuxray_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | from PIL import Image
5 | import webdataset as wds
6 | import random
7 | from torch.utils.data import Dataset
8 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
9 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
10 |
11 |
12 | class IUXRayDataset(Dataset):
13 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None):
14 | self.image_root = image_root
15 | self.ann_path = ann_path
16 |
17 | self.vis_processor = vis_processor
18 | self.text_processor = text_processor
19 |
20 | # load annotation file
21 | with open(ann_path, 'r') as f:
22 | self.annotations = json.load(f)
23 | self.train_data = self.annotations['train']
24 |
25 | def __len__(self):
26 | return len(self.train_data)
27 |
28 | def __getitem__(self, index):
29 | data_sample = self.train_data[index]
30 | image_path = data_sample['image_path']
31 |
32 | # load image
33 | image_id = data_sample['id']
34 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
35 | image = self.vis_processor(image)
36 |
37 | # load caption
38 | caption = data_sample['report']
39 | caption = self.clean_reports(caption)
40 |
41 | return {"image": image,
42 | "text_input": caption,
43 | "image_id": image_id}
44 |
45 | def clean_reports(self, report):
46 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
47 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
48 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
49 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
50 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
51 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
52 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
53 | .strip().lower().split('. ')
54 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
55 | .replace('\\', '').replace("'", '').strip().lower())
56 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
57 | report = ' . '.join(tokens) + ' .'
58 | return report
59 |
60 | class IUXRayGenerateThenRefineDataset(Dataset):
61 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None, unlabeled_ann_path=None, retrieval_size=3):
62 | self.image_root = image_root
63 | self.ann_path = ann_path
64 | self.retrieval_size = retrieval_size
65 |
66 | self.vis_processor = vis_processor
67 | self.text_processor = text_processor
68 |
69 | # load annotation file
70 | with open(ann_path, 'r') as f:
71 | self.annotations = json.load(f)
72 | self.train_data = self.annotations['train']
73 |
74 | # load unlabeled data
75 | self.unlabeled_data_list = []
76 | with open(unlabeled_ann_path, 'r') as f:
77 | for line in f.readlines:
78 | self.unlabeled_data_list.append(line.strip('\n'))
79 |
80 | print(f"There are total {len(self.unlabeled_data_list)} unlabeled reports.")
81 |
82 | def __len__(self):
83 | return len(self.train_data)
84 |
85 | def __getitem__(self, index):
86 | data_sample = self.train_data[index]
87 | image_path = data_sample['image_path']
88 |
89 | # load image
90 | image_id = data_sample['id']
91 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
92 | image = self.vis_processor(image)
93 |
94 | # load caption
95 | caption = data_sample['report']
96 | caption = self.clean_reports(caption)
97 |
98 | # load reference caption
99 | ref_caption = data_sample['ref_report']
100 | ref_caption = self.clean_reports(ref_caption)
101 |
102 | # load unlabeled caption
103 | unlabeled_caption = random.sample(self.unlabeled_data_list, self.retrieval_size)
104 |
105 | return {"image": image,
106 | "text_input": caption,
107 | "ref_caption": ref_caption,
108 | "unlabeled_caption": unlabeled_caption,
109 | "image_id": image_id}
110 |
111 | def clean_report_iu_xray(self, report):
112 | report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
113 | .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
114 | .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
115 | .strip().lower().split('. ')
116 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
117 | replace('\\', '').replace("'", '').strip().lower())
118 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
119 | report = ' . '.join(tokens) + ' .'
120 | return report
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import webdataset as wds
9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
10 |
11 |
12 | class LaionDataset(BaseDataset):
13 | def __init__(self, vis_processor, text_processor, location):
14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15 |
16 | self.inner_dataset = wds.DataPipeline(
17 | wds.ResampledShards(location),
18 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
19 | wds.shuffle(1000, handler=wds.warn_and_continue),
20 | wds.decode("pilrgb", handler=wds.warn_and_continue),
21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23 | wds.map(self.to_dict, handler=wds.warn_and_continue),
24 | )
25 |
26 | def to_dict(self, sample):
27 | return {
28 | "image": sample[0],
29 | "text_input": self.text_processor(sample[1]["caption"]),
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/mimic_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | from PIL import Image
5 | import webdataset as wds
6 | import random
7 | from torch.utils.data import Dataset
8 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
9 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
10 |
11 |
12 | class MIMICDataset(Dataset):
13 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None):
14 | self.image_root = image_root
15 | self.ann_path = ann_path
16 |
17 | self.vis_processor = vis_processor
18 | self.text_processor = text_processor
19 |
20 | # load annotation file
21 | with open(ann_path, 'r') as f:
22 | self.annotations = json.load(f)
23 | self.train_data = self.annotations['train']
24 |
25 | def __len__(self):
26 | return len(self.train_data)
27 |
28 | def __getitem__(self, index):
29 | data_sample = self.train_data[index]
30 | image_path = data_sample['image_path']
31 |
32 | # load image
33 | image_id = data_sample['id']
34 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
35 | image = self.vis_processor(image)
36 |
37 | # load caption
38 | caption = data_sample['report']
39 | caption = self.clean_reports(caption)
40 |
41 | return {"image": image,
42 | "text_input": caption,
43 | "image_id": image_id}
44 |
45 | def clean_reports(self, report):
46 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
47 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
48 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
49 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
50 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
51 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
52 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
53 | .strip().lower().split('. ')
54 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
55 | .replace('\\', '').replace("'", '').strip().lower())
56 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
57 | report = ' . '.join(tokens) + ' .'
58 | return report
59 |
60 | class MIMICGenerateThenRefineDataset(Dataset):
61 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None, unlabeled_ann_path=None, retrieval_size=3):
62 | self.image_root = image_root
63 | self.ann_path = ann_path
64 | self.retrieval_size = retrieval_size
65 |
66 | self.vis_processor = vis_processor
67 | self.text_processor = text_processor
68 |
69 | # load annotation file
70 | with open(ann_path, 'r') as f:
71 | self.annotations = json.load(f)
72 | self.train_data = self.annotations['train']
73 |
74 | # load unlabeled data
75 | self.unlabeled_data_list = []
76 | with open(unlabeled_ann_path, 'r') as f:
77 | for line in f.readlines:
78 | self.unlabeled_data_list.append(line.strip('\n'))
79 |
80 | import random
81 | self.unlabeled_data_list = random.sample(self.unlabeled_data_list, 3000)
82 |
83 | print(f"There are total {len(self.unlabeled_data_list)} unlabeled reports.")
84 |
85 | def __len__(self):
86 | return len(self.train_data)
87 |
88 | def __getitem__(self, index):
89 | data = self.train_data[index]
90 | data_samples = random.sample(self.train_data, self.retrieval_size - 1)
91 | image_path = data['image_path']
92 |
93 | # load image
94 | image_id = data['id']
95 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
96 | image = self.vis_processor(image)
97 |
98 | # load caption
99 | caption = data['report']
100 | caption = self.clean_reports(caption)
101 |
102 | # load reference caption
103 | all_ref_captions = []
104 | ref_caption = data['ref_report']
105 | ref_caption = self.clean_reports(ref_caption)
106 | all_ref_captions.append(ref_caption)
107 |
108 | for data_sample in data_samples:
109 | ref_caption = data_sample['ref_report']
110 | ref_caption = self.clean_reports(ref_caption)
111 | all_ref_captions.append(ref_caption)
112 |
113 | # load unlabeled caption
114 | unlabeled_caption = random.sample(self.unlabeled_data_list, self.retrieval_size)
115 |
116 | return {"image": image,
117 | "text_input": caption,
118 | "ref_caption": ref_caption,
119 | "unlabeled_caption": unlabeled_caption,
120 | "image_id": image_id}
121 |
122 | def clean_reports(self, report):
123 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
124 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
125 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
126 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
127 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
128 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
129 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
130 | .strip().lower().split('. ')
131 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
132 | .replace('\\', '').replace("'", '').strip().lower())
133 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
134 | report = ' . '.join(tokens) + ' .'
135 | return report
136 |
137 |
--------------------------------------------------------------------------------
/minigpt4/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import torch
10 | from omegaconf import OmegaConf
11 |
12 | from minigpt4.common.registry import registry
13 | from minigpt4.models.base_model import BaseModel
14 | from minigpt4.models.blip2 import Blip2Base
15 | from minigpt4.models.mini_gpt4 import MiniGPT4
16 | from minigpt4.processors.base_processor import BaseProcessor
17 |
18 |
19 | __all__ = [
20 | "load_model",
21 | "BaseModel",
22 | "Blip2Base",
23 | "MiniGPT4",
24 | ]
25 |
26 |
27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28 | """
29 | Load supported models.
30 |
31 | To list all available models and types in registry:
32 | >>> from minigpt4.models import model_zoo
33 | >>> print(model_zoo)
34 |
35 | Args:
36 | name (str): name of the model.
37 | model_type (str): type of the model.
38 | is_eval (bool): whether the model is in eval mode. Default: False.
39 | device (str): device to use. Default: "cpu".
40 | checkpoint (str): path or to checkpoint. Default: None.
41 | Note that expecting the checkpoint to have the same keys in state_dict as the model.
42 |
43 | Returns:
44 | model (torch.nn.Module): model.
45 | """
46 |
47 | model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48 |
49 | if checkpoint is not None:
50 | model.load_checkpoint(checkpoint)
51 |
52 | if is_eval:
53 | model.eval()
54 |
55 | if device == "cpu":
56 | model = model.float()
57 |
58 | return model.to(device)
59 |
60 |
61 | def load_preprocess(config):
62 | """
63 | Load preprocessor configs and construct preprocessors.
64 |
65 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66 |
67 | Args:
68 | config (dict): preprocessor configs.
69 |
70 | Returns:
71 | vis_processors (dict): preprocessors for visual inputs.
72 | txt_processors (dict): preprocessors for text inputs.
73 |
74 | Key is "train" or "eval" for processors used in training and evaluation respectively.
75 | """
76 |
77 | def _build_proc_from_cfg(cfg):
78 | return (
79 | registry.get_processor_class(cfg.name).from_config(cfg)
80 | if cfg is not None
81 | else BaseProcessor()
82 | )
83 |
84 | vis_processors = dict()
85 | txt_processors = dict()
86 |
87 | vis_proc_cfg = config.get("vis_processor")
88 | txt_proc_cfg = config.get("text_processor")
89 |
90 | if vis_proc_cfg is not None:
91 | vis_train_cfg = vis_proc_cfg.get("train")
92 | vis_eval_cfg = vis_proc_cfg.get("eval")
93 | else:
94 | vis_train_cfg = None
95 | vis_eval_cfg = None
96 |
97 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99 |
100 | if txt_proc_cfg is not None:
101 | txt_train_cfg = txt_proc_cfg.get("train")
102 | txt_eval_cfg = txt_proc_cfg.get("eval")
103 | else:
104 | txt_train_cfg = None
105 | txt_eval_cfg = None
106 |
107 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109 |
110 | return vis_processors, txt_processors
111 |
112 |
113 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114 | """
115 | Load model and its related preprocessors.
116 |
117 | List all available models and types in registry:
118 | >>> from minigpt4.models import model_zoo
119 | >>> print(model_zoo)
120 |
121 | Args:
122 | name (str): name of the model.
123 | model_type (str): type of the model.
124 | is_eval (bool): whether the model is in eval mode. Default: False.
125 | device (str): device to use. Default: "cpu".
126 |
127 | Returns:
128 | model (torch.nn.Module): model.
129 | vis_processors (dict): preprocessors for visual inputs.
130 | txt_processors (dict): preprocessors for text inputs.
131 | """
132 | model_cls = registry.get_model_class(name)
133 |
134 | # load model
135 | model = model_cls.from_pretrained(model_type=model_type)
136 |
137 | if is_eval:
138 | model.eval()
139 |
140 | # load preprocess
141 | cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142 | if cfg is not None:
143 | preprocess_cfg = cfg.preprocess
144 |
145 | vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146 | else:
147 | vis_processors, txt_processors = None, None
148 | logging.info(
149 | f"""No default preprocess for model {name} ({model_type}).
150 | This can happen if the model is not finetuned on downstream datasets,
151 | or it is not intended for direct use without finetuning.
152 | """
153 | )
154 |
155 | if device == "cpu" or device == torch.device("cpu"):
156 | model = model.float()
157 |
158 | return model.to(device), vis_processors, txt_processors
159 |
160 |
161 | class ModelZoo:
162 | """
163 | A utility class to create string representation of available model architectures and types.
164 |
165 | >>> from minigpt4.models import model_zoo
166 | >>> # list all available models
167 | >>> print(model_zoo)
168 | >>> # show total number of models
169 | >>> print(len(model_zoo))
170 | """
171 |
172 | def __init__(self) -> None:
173 | self.model_zoo = {
174 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175 | for k, v in registry.mapping["model_name_mapping"].items()
176 | }
177 |
178 | def __str__(self) -> str:
179 | return (
180 | "=" * 50
181 | + "\n"
182 | + f"{'Architectures':<30} {'Types'}\n"
183 | + "=" * 50
184 | + "\n"
185 | + "\n".join(
186 | [
187 | f"{name:<30} {', '.join(types)}"
188 | for name, types in self.model_zoo.items()
189 | ]
190 | )
191 | )
192 |
193 | def __iter__(self):
194 | return iter(self.model_zoo.items())
195 |
196 | def __len__(self):
197 | return sum([len(v) for v in self.model_zoo.values()])
198 |
199 |
200 | model_zoo = ModelZoo()
201 |
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/Qformer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/Qformer.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/Qformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/Qformer.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/base_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/base_model.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/base_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/base_model.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/blip2.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/blip2.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/blip2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/blip2.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/eva_vit.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/eva_vit.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/mini_gpt4.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/mini_gpt4.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/mini_gpt4.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/mini_gpt4.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/modeling_llama.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/modeling_llama.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/models/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15 | from minigpt4.common.utils import get_abs_path, is_url
16 | from omegaconf import OmegaConf
17 |
18 |
19 | class BaseModel(nn.Module):
20 | """Base class for models."""
21 |
22 | def __init__(self):
23 | super().__init__()
24 |
25 | @property
26 | def device(self):
27 | return list(self.parameters())[0].device
28 |
29 | def load_checkpoint(self, url_or_filename):
30 | """
31 | Load from a finetuned checkpoint.
32 |
33 | This should expect no mismatch in the model keys and the checkpoint keys.
34 | """
35 |
36 | if is_url(url_or_filename):
37 | cached_file = download_cached_file(
38 | url_or_filename, check_hash=False, progress=True
39 | )
40 | checkpoint = torch.load(cached_file, map_location="cpu")
41 | elif os.path.isfile(url_or_filename):
42 | checkpoint = torch.load(url_or_filename, map_location="cpu")
43 | else:
44 | raise RuntimeError("checkpoint url or path is invalid")
45 |
46 | if "model" in checkpoint.keys():
47 | state_dict = checkpoint["model"]
48 | else:
49 | state_dict = checkpoint
50 |
51 | msg = self.load_state_dict(state_dict, strict=False)
52 |
53 | logging.info("Missing keys {}".format(msg.missing_keys))
54 | logging.info("load checkpoint from %s" % url_or_filename)
55 |
56 | return msg
57 |
58 | @classmethod
59 | def from_pretrained(cls, model_type):
60 | """
61 | Build a pretrained model from default configuration file, specified by model_type.
62 |
63 | Args:
64 | - model_type (str): model type, specifying architecture and checkpoints.
65 |
66 | Returns:
67 | - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68 | """
69 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70 | model = cls.from_config(model_cfg)
71 |
72 | return model
73 |
74 | @classmethod
75 | def default_config_path(cls, model_type):
76 | assert (
77 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78 | ), "Unknown model type {}".format(model_type)
79 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80 |
81 | def load_checkpoint_from_config(self, cfg, **kwargs):
82 | """
83 | Load checkpoint as specified in the config file.
84 |
85 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86 | When loading the pretrained model, each task-specific architecture may define their
87 | own load_from_pretrained() method.
88 | """
89 | load_finetuned = cfg.get("load_finetuned", True)
90 | if load_finetuned:
91 | finetune_path = cfg.get("finetuned", None)
92 | assert (
93 | finetune_path is not None
94 | ), "Found load_finetuned is True, but finetune_path is None."
95 | self.load_checkpoint(url_or_filename=finetune_path)
96 | else:
97 | # load pre-trained weights
98 | pretrain_path = cfg.get("pretrained", None)
99 | assert "Found load_finetuned is False, but pretrain_path is None."
100 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101 |
102 | def before_evaluation(self, **kwargs):
103 | pass
104 |
105 | def show_n_params(self, return_str=True):
106 | tot = 0
107 | for p in self.parameters():
108 | w = 1
109 | for x in p.shape:
110 | w *= x
111 | tot += w
112 | if return_str:
113 | if tot >= 1e6:
114 | return "{:.1f}M".format(tot / 1e6)
115 | else:
116 | return "{:.1f}K".format(tot / 1e3)
117 | else:
118 | return tot
119 |
120 |
121 | class BaseEncoder(nn.Module):
122 | """
123 | Base class for primitive encoders, such as ViT, TimeSformer, etc.
124 | """
125 |
126 | def __init__(self):
127 | super().__init__()
128 |
129 | def forward_features(self, samples, **kwargs):
130 | raise NotImplementedError
131 |
132 | @property
133 | def device(self):
134 | return list(self.parameters())[0].device
135 |
136 |
137 | class SharedQueueMixin:
138 | @torch.no_grad()
139 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140 | # gather keys before updating queue
141 | image_feats = concat_all_gather(image_feat)
142 | text_feats = concat_all_gather(text_feat)
143 |
144 | batch_size = image_feats.shape[0]
145 |
146 | ptr = int(self.queue_ptr)
147 | assert self.queue_size % batch_size == 0 # for simplicity
148 |
149 | # replace the keys at ptr (dequeue and enqueue)
150 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152 |
153 | if idxs is not None:
154 | idxs = concat_all_gather(idxs)
155 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156 |
157 | ptr = (ptr + batch_size) % self.queue_size # move pointer
158 | self.queue_ptr[0] = ptr
159 |
160 |
161 | class MomentumDistilationMixin:
162 | @torch.no_grad()
163 | def copy_params(self):
164 | for model_pair in self.model_pairs:
165 | for param, param_m in zip(
166 | model_pair[0].parameters(), model_pair[1].parameters()
167 | ):
168 | param_m.data.copy_(param.data) # initialize
169 | param_m.requires_grad = False # not update by gradient
170 |
171 | @torch.no_grad()
172 | def _momentum_update(self):
173 | for model_pair in self.model_pairs:
174 | for param, param_m in zip(
175 | model_pair[0].parameters(), model_pair[1].parameters()
176 | ):
177 | param_m.data = param_m.data * self.momentum + param.data * (
178 | 1.0 - self.momentum
179 | )
180 |
181 |
182 | class GatherLayer(torch.autograd.Function):
183 | """
184 | Gather tensors from all workers with support for backward propagation:
185 | This implementation does not cut the gradients as torch.distributed.all_gather does.
186 | """
187 |
188 | @staticmethod
189 | def forward(ctx, x):
190 | output = [
191 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192 | ]
193 | torch.distributed.all_gather(output, x)
194 | return tuple(output)
195 |
196 | @staticmethod
197 | def backward(ctx, *grads):
198 | all_gradients = torch.stack(grads)
199 | torch.distributed.all_reduce(all_gradients)
200 | return all_gradients[torch.distributed.get_rank()]
201 |
202 |
203 | def all_gather_with_grad(tensors):
204 | """
205 | Performs all_gather operation on the provided tensors.
206 | Graph remains connected for backward grad computation.
207 | """
208 | # Queue the gathered tensors
209 | world_size = torch.distributed.get_world_size()
210 | # There is no need for reduction in the single-proc case
211 | if world_size == 1:
212 | return tensors
213 |
214 | # tensor_all = GatherLayer.apply(tensors)
215 | tensor_all = GatherLayer.apply(tensors)
216 |
217 | return torch.cat(tensor_all, dim=0)
218 |
219 |
220 | @torch.no_grad()
221 | def concat_all_gather(tensor):
222 | """
223 | Performs all_gather operation on the provided tensors.
224 | *** Warning ***: torch.distributed.all_gather has no gradient.
225 | """
226 | # if use distributed training
227 | if not is_dist_avail_and_initialized():
228 | return tensor
229 |
230 | tensors_gather = [
231 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232 | ]
233 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234 |
235 | output = torch.cat(tensors_gather, dim=0)
236 | return output
237 |
238 |
239 | def tile(x, dim, n_tile):
240 | init_dim = x.size(dim)
241 | repeat_idx = [1] * x.dim()
242 | repeat_idx[dim] = n_tile
243 | x = x.repeat(*(repeat_idx))
244 | order_index = torch.LongTensor(
245 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246 | )
247 | return torch.index_select(x, dim, order_index.to(x.device))
248 |
--------------------------------------------------------------------------------
/minigpt4/models/blip2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2023, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import contextlib
8 | import logging
9 | import os
10 | import time
11 | import datetime
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.distributed as dist
16 | import torch.nn.functional as F
17 |
18 | import minigpt4.common.dist_utils as dist_utils
19 | from minigpt4.common.dist_utils import download_cached_file
20 | from minigpt4.common.utils import is_url
21 | from minigpt4.common.logger import MetricLogger
22 | from minigpt4.models.base_model import BaseModel
23 | from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
24 | from minigpt4.models.eva_vit import create_eva_vit_g
25 | from transformers import BertTokenizer
26 |
27 |
28 | class Blip2Base(BaseModel):
29 | @classmethod
30 | def init_tokenizer(cls):
31 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32 | tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33 | return tokenizer
34 |
35 | def maybe_autocast(self, dtype=torch.float16):
36 | # if on cpu, don't use autocast
37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38 | enable_autocast = self.device != torch.device("cpu")
39 |
40 | if enable_autocast:
41 | return torch.cuda.amp.autocast(dtype=dtype)
42 | else:
43 | return contextlib.nullcontext()
44 |
45 | @classmethod
46 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48 | encoder_config.encoder_width = vision_width
49 | # insert cross-attention layer every other block
50 | encoder_config.add_cross_attention = True
51 | encoder_config.cross_attention_freq = cross_attention_freq
52 | encoder_config.query_length = num_query_token
53 | Qformer = BertLMHeadModel(config=encoder_config)
54 | query_tokens = nn.Parameter(
55 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
56 | )
57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58 | return Qformer, query_tokens
59 |
60 | @classmethod
61 | def init_vision_encoder(
62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63 | ):
64 | assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
65 | visual_encoder = create_eva_vit_g(
66 | img_size, drop_path_rate, use_grad_checkpoint, precision
67 | )
68 |
69 | ln_vision = LayerNorm(visual_encoder.num_features)
70 | return visual_encoder, ln_vision
71 |
72 | def load_from_pretrained(self, url_or_filename):
73 | if is_url(url_or_filename):
74 | cached_file = download_cached_file(
75 | url_or_filename, check_hash=False, progress=True
76 | )
77 | checkpoint = torch.load(cached_file, map_location="cpu")
78 | elif os.path.isfile(url_or_filename):
79 | checkpoint = torch.load(url_or_filename, map_location="cpu")
80 | else:
81 | raise RuntimeError("checkpoint url or path is invalid")
82 |
83 | state_dict = checkpoint["model"]
84 |
85 | msg = self.load_state_dict(state_dict, strict=False)
86 |
87 | # logging.info("Missing keys {}".format(msg.missing_keys))
88 | logging.info("load checkpoint from %s" % url_or_filename)
89 |
90 | return msg
91 |
92 |
93 | def disabled_train(self, mode=True):
94 | """Overwrite model.train with this function to make sure train/eval mode
95 | does not change anymore."""
96 | return self
97 |
98 |
99 | class LayerNorm(nn.LayerNorm):
100 | """Subclass torch's LayerNorm to handle fp16."""
101 |
102 | def forward(self, x: torch.Tensor):
103 | orig_type = x.dtype
104 | ret = super().forward(x.type(torch.float32))
105 | return ret.type(orig_type)
106 |
107 |
108 | def compute_sim_matrix(model, data_loader, **kwargs):
109 | k_test = kwargs.pop("k_test")
110 |
111 | metric_logger = MetricLogger(delimiter=" ")
112 | header = "Evaluation:"
113 |
114 | logging.info("Computing features for evaluation...")
115 | start_time = time.time()
116 |
117 | texts = data_loader.dataset.text
118 | num_text = len(texts)
119 | text_bs = 256
120 | text_ids = []
121 | text_embeds = []
122 | text_atts = []
123 | for i in range(0, num_text, text_bs):
124 | text = texts[i : min(num_text, i + text_bs)]
125 | text_input = model.tokenizer(
126 | text,
127 | padding="max_length",
128 | truncation=True,
129 | max_length=35,
130 | return_tensors="pt",
131 | ).to(model.device)
132 | text_feat = model.forward_text(text_input)
133 | text_embed = F.normalize(model.text_proj(text_feat))
134 | text_embeds.append(text_embed)
135 | text_ids.append(text_input.input_ids)
136 | text_atts.append(text_input.attention_mask)
137 |
138 | text_embeds = torch.cat(text_embeds, dim=0)
139 | text_ids = torch.cat(text_ids, dim=0)
140 | text_atts = torch.cat(text_atts, dim=0)
141 |
142 | vit_feats = []
143 | image_embeds = []
144 | for samples in data_loader:
145 | image = samples["image"]
146 |
147 | image = image.to(model.device)
148 | image_feat, vit_feat = model.forward_image(image)
149 | image_embed = model.vision_proj(image_feat)
150 | image_embed = F.normalize(image_embed, dim=-1)
151 |
152 | vit_feats.append(vit_feat.cpu())
153 | image_embeds.append(image_embed)
154 |
155 | vit_feats = torch.cat(vit_feats, dim=0)
156 | image_embeds = torch.cat(image_embeds, dim=0)
157 |
158 | sims_matrix = []
159 | for image_embed in image_embeds:
160 | sim_q2t = image_embed @ text_embeds.t()
161 | sim_i2t, _ = sim_q2t.max(0)
162 | sims_matrix.append(sim_i2t)
163 | sims_matrix = torch.stack(sims_matrix, dim=0)
164 |
165 | score_matrix_i2t = torch.full(
166 | (len(data_loader.dataset.image), len(texts)), -100.0
167 | ).to(model.device)
168 |
169 | num_tasks = dist_utils.get_world_size()
170 | rank = dist_utils.get_rank()
171 | step = sims_matrix.size(0) // num_tasks + 1
172 | start = rank * step
173 | end = min(sims_matrix.size(0), start + step)
174 |
175 | for i, sims in enumerate(
176 | metric_logger.log_every(sims_matrix[start:end], 50, header)
177 | ):
178 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180 | score = model.compute_itm(
181 | image_inputs=image_inputs,
182 | text_ids=text_ids[topk_idx],
183 | text_atts=text_atts[topk_idx],
184 | ).float()
185 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186 |
187 | sims_matrix = sims_matrix.t()
188 | score_matrix_t2i = torch.full(
189 | (len(texts), len(data_loader.dataset.image)), -100.0
190 | ).to(model.device)
191 |
192 | step = sims_matrix.size(0) // num_tasks + 1
193 | start = rank * step
194 | end = min(sims_matrix.size(0), start + step)
195 |
196 | for i, sims in enumerate(
197 | metric_logger.log_every(sims_matrix[start:end], 50, header)
198 | ):
199 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201 | score = model.compute_itm(
202 | image_inputs=image_inputs,
203 | text_ids=text_ids[start + i].repeat(k_test, 1),
204 | text_atts=text_atts[start + i].repeat(k_test, 1),
205 | ).float()
206 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207 |
208 | if dist_utils.is_dist_avail_and_initialized():
209 | dist.barrier()
210 | torch.distributed.all_reduce(
211 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212 | )
213 | torch.distributed.all_reduce(
214 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215 | )
216 |
217 | total_time = time.time() - start_time
218 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219 | logging.info("Evaluation time {}".format(total_time_str))
220 |
221 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
222 |
--------------------------------------------------------------------------------
/minigpt4/models/blip2_outputs.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Optional
10 |
11 | import torch
12 | from transformers.modeling_outputs import (
13 | ModelOutput,
14 | BaseModelOutputWithPoolingAndCrossAttentions,
15 | CausalLMOutputWithCrossAttentions,
16 | )
17 |
18 |
19 | @dataclass
20 | class BlipSimilarity(ModelOutput):
21 | sim_i2t: torch.FloatTensor = None
22 | sim_t2i: torch.FloatTensor = None
23 |
24 | sim_i2t_m: Optional[torch.FloatTensor] = None
25 | sim_t2i_m: Optional[torch.FloatTensor] = None
26 |
27 | sim_i2t_targets: Optional[torch.FloatTensor] = None
28 | sim_t2i_targets: Optional[torch.FloatTensor] = None
29 |
30 |
31 | @dataclass
32 | class BlipIntermediateOutput(ModelOutput):
33 | """
34 | Data class for intermediate outputs of BLIP models.
35 |
36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38 |
39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41 |
42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44 |
45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46 | decoder_labels (torch.LongTensor): labels for the captioning loss.
47 |
48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50 |
51 | """
52 |
53 | # uni-modal features
54 | image_embeds: torch.FloatTensor = None
55 | text_embeds: Optional[torch.FloatTensor] = None
56 |
57 | image_embeds_m: Optional[torch.FloatTensor] = None
58 | text_embeds_m: Optional[torch.FloatTensor] = None
59 |
60 | # intermediate outputs of multimodal encoder
61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63 |
64 | itm_logits: Optional[torch.FloatTensor] = None
65 | itm_labels: Optional[torch.LongTensor] = None
66 |
67 | # intermediate outputs of multimodal decoder
68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69 | decoder_labels: Optional[torch.LongTensor] = None
70 |
71 |
72 | @dataclass
73 | class BlipOutput(ModelOutput):
74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75 | sims: Optional[BlipSimilarity] = None
76 |
77 | intermediate_output: BlipIntermediateOutput = None
78 |
79 | loss: Optional[torch.FloatTensor] = None
80 |
81 | loss_itc: Optional[torch.FloatTensor] = None
82 |
83 | loss_itm: Optional[torch.FloatTensor] = None
84 |
85 | loss_lm: Optional[torch.FloatTensor] = None
86 |
87 |
88 | @dataclass
89 | class BlipOutputFeatures(ModelOutput):
90 | """
91 | Data class of features from BlipFeatureExtractor.
92 |
93 | Args:
94 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98 |
99 | The first embedding or feature is for the [CLS] token.
100 |
101 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102 | """
103 |
104 | image_embeds: Optional[torch.FloatTensor] = None
105 | image_embeds_proj: Optional[torch.FloatTensor] = None
106 |
107 | text_embeds: Optional[torch.FloatTensor] = None
108 | text_embeds_proj: Optional[torch.FloatTensor] = None
109 |
110 | multimodal_embeds: Optional[torch.FloatTensor] = None
111 |
--------------------------------------------------------------------------------
/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_0.pth
--------------------------------------------------------------------------------
/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_1.pth
--------------------------------------------------------------------------------
/minigpt4/output/minigpt4_stage2_finetune/20230706044/log.txt:
--------------------------------------------------------------------------------
1 | {
2 | "run": {
3 | "task": "image_text_pretrain",
4 | "lr_sched": "linear_warmup_cosine_lr",
5 | "init_lr": 3e-05,
6 | "min_lr": 1e-05,
7 | "warmup_lr": 1e-06,
8 | "weight_decay": 0.05,
9 | "max_epoch": 5,
10 | "iters_per_epoch": 200,
11 | "batch_size_train": 1,
12 | "batch_size_eval": 1,
13 | "num_workers": 4,
14 | "warmup_steps": 200,
15 | "seed": 42,
16 | "output_dir": "output/minigpt4_stage2_finetune",
17 | "amp": true,
18 | "resume_ckpt_path": null,
19 | "evaluate": false,
20 | "train_splits": [
21 | "train"
22 | ],
23 | "device": "cuda",
24 | "world_size": 1,
25 | "dist_url": "env://",
26 | "distributed": false
27 | },
28 | "model": {
29 | "arch": "mini_gpt4",
30 | "image_size": 224,
31 | "drop_path_rate": 0,
32 | "use_grad_checkpoint": false,
33 | "vit_precision": "fp16",
34 | "freeze_vit": true,
35 | "freeze_qformer": true,
36 | "num_query_token": 32,
37 | "llama_model": "models/models-13b/vicuna_weights",
38 | "prompt": "",
39 | "model_type": "pretrain_vicuna",
40 | "max_txt_len": 160,
41 | "end_sym": "###",
42 | "prompt_path": "prompts/alignment.txt",
43 | "prompt_template": "###Human: {} ###Assistant: ",
44 | "ckpt": "models/models-13b/minigpt-4/pretrained_minigpt4.pth"
45 | },
46 | "preprocess": {
47 | "vis_processor": {
48 | "train": {
49 | "name": "blip2_image_train",
50 | "image_size": 224
51 | },
52 | "eval": {
53 | "name": "blip2_image_eval",
54 | "image_size": 224
55 | }
56 | },
57 | "text_processor": {
58 | "train": {
59 | "name": "blip_caption"
60 | },
61 | "eval": {
62 | "name": "blip_caption"
63 | }
64 | }
65 | },
66 | "datasets": {
67 | "cc_sbu_align": {
68 | "data_type": "images",
69 | "build_info": {
70 | "storage": "/media/ubuntu/data/liuchang/workplace/code/src/MiniGPT-4/data/cc_sbu_align"
71 | },
72 | "vis_processor": {
73 | "train": {
74 | "name": "blip2_image_train",
75 | "image_size": 224
76 | }
77 | },
78 | "text_processor": {
79 | "train": {
80 | "name": "blip_caption"
81 | }
82 | }
83 | }
84 | }
85 | }
86 | {"train_lr": "0.000", "train_loss": "0.675"}
87 | {"train_lr": "0.000", "train_loss": "0.656"}
88 |
--------------------------------------------------------------------------------
/minigpt4/output/minigpt4_stage2_finetune/20230706051/log.txt:
--------------------------------------------------------------------------------
1 | {
2 | "run": {
3 | "task": "image_text_pretrain",
4 | "lr_sched": "linear_warmup_cosine_lr",
5 | "init_lr": 3e-05,
6 | "min_lr": 1e-05,
7 | "warmup_lr": 1e-06,
8 | "weight_decay": 0.05,
9 | "max_epoch": 5,
10 | "iters_per_epoch": 200,
11 | "batch_size_train": 1,
12 | "batch_size_eval": 1,
13 | "num_workers": 4,
14 | "warmup_steps": 200,
15 | "seed": 42,
16 | "output_dir": "output/minigpt4_stage2_finetune",
17 | "amp": true,
18 | "resume_ckpt_path": null,
19 | "evaluate": false,
20 | "train_splits": [
21 | "train"
22 | ],
23 | "device": "cuda",
24 | "world_size": 1,
25 | "dist_url": "env://",
26 | "distributed": false
27 | },
28 | "model": {
29 | "arch": "mini_gpt4",
30 | "image_size": 224,
31 | "drop_path_rate": 0,
32 | "use_grad_checkpoint": false,
33 | "vit_precision": "fp16",
34 | "freeze_vit": true,
35 | "freeze_qformer": true,
36 | "num_query_token": 32,
37 | "llama_model": "models/models-13b/vicuna_weights",
38 | "prompt": "",
39 | "model_type": "pretrain_vicuna",
40 | "max_txt_len": 160,
41 | "end_sym": "###",
42 | "prompt_path": "prompts/alignment.txt",
43 | "prompt_template": "###Human: {} ###Assistant: ",
44 | "ckpt": "models/models-13b/minigpt-4/pretrained_minigpt4.pth"
45 | },
46 | "preprocess": {
47 | "vis_processor": {
48 | "train": {
49 | "name": "blip2_image_train",
50 | "image_size": 224
51 | },
52 | "eval": {
53 | "name": "blip2_image_eval",
54 | "image_size": 224
55 | }
56 | },
57 | "text_processor": {
58 | "train": {
59 | "name": "blip_caption"
60 | },
61 | "eval": {
62 | "name": "blip_caption"
63 | }
64 | }
65 | },
66 | "datasets": {
67 | "cc_sbu_align": {
68 | "data_type": "images",
69 | "build_info": {
70 | "storage": "/media/ubuntu/data/liuchang/workplace/code/src/MiniGPT-4/data/cc_sbu_align"
71 | },
72 | "vis_processor": {
73 | "train": {
74 | "name": "blip2_image_train",
75 | "image_size": 224
76 | }
77 | },
78 | "text_processor": {
79 | "train": {
80 | "name": "blip_caption"
81 | }
82 | }
83 | }
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/minigpt4/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.processors.base_processor import BaseProcessor
9 | from minigpt4.processors.blip_processors import (
10 | Blip2ImageTrainProcessor,
11 | Blip2ImageEvalProcessor,
12 | BlipCaptionProcessor,
13 | )
14 |
15 | from minigpt4.common.registry import registry
16 |
17 | __all__ = [
18 | "BaseProcessor",
19 | "Blip2ImageTrainProcessor",
20 | "Blip2ImageEvalProcessor",
21 | "BlipCaptionProcessor",
22 | ]
23 |
24 |
25 | def load_processor(name, cfg=None):
26 | """
27 | Example
28 |
29 | >>> processor = load_processor("alpro_video_train", cfg=None)
30 | """
31 | processor = registry.get_processor_class(name).from_config(cfg)
32 |
33 | return processor
34 |
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/base_processor.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/base_processor.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/blip_processors.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/blip_processors.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/randaugment.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/randaugment.cpython-311.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/minigpt4/processors/blip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import re
9 |
10 | from minigpt4.common.registry import registry
11 | from minigpt4.processors.base_processor import BaseProcessor
12 | from minigpt4.processors.randaugment import RandomAugment
13 | from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 |
17 |
18 | class BlipImageBaseProcessor(BaseProcessor):
19 | def __init__(self, mean=None, std=None):
20 | if mean is None:
21 | mean = (0.48145466, 0.4578275, 0.40821073)
22 | if std is None:
23 | std = (0.26862954, 0.26130258, 0.27577711)
24 |
25 | self.normalize = transforms.Normalize(mean, std)
26 |
27 |
28 | @registry.register_processor("blip_caption")
29 | class BlipCaptionProcessor(BaseProcessor):
30 | def __init__(self, prompt="", max_words=50):
31 | self.prompt = prompt
32 | self.max_words = max_words
33 |
34 | def __call__(self, caption):
35 | caption = self.prompt + self.pre_caption(caption)
36 |
37 | return caption
38 |
39 | @classmethod
40 | def from_config(cls, cfg =None):
41 | if cfg is None:
42 | cfg = OmegaConf.create()
43 |
44 | prompt = cfg.get("prompt", "")
45 | max_words = cfg.get("max_words", 50)
46 |
47 | return cls(prompt=prompt, max_words=max_words)
48 |
49 | def pre_caption(self, caption):
50 | caption = re.sub(
51 | r"([.!\"()*#:;~])",
52 | " ",
53 | caption.lower(),
54 | )
55 | caption = re.sub(
56 | r"\s{2,}",
57 | " ",
58 | caption,
59 | )
60 | caption = caption.rstrip("\n")
61 | caption = caption.strip(" ")
62 |
63 | # truncate caption
64 | caption_words = caption.split(" ")
65 | if len(caption_words) > self.max_words:
66 | caption = " ".join(caption_words[: self.max_words])
67 |
68 | return caption
69 |
70 |
71 | @registry.register_processor("blip2_image_train")
72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74 | super().__init__(mean=mean, std=std)
75 |
76 | self.transform = transforms.Compose(
77 | [
78 | transforms.RandomResizedCrop(
79 | image_size,
80 | scale=(min_scale, max_scale),
81 | interpolation=InterpolationMode.BICUBIC,
82 | ),
83 | transforms.ToTensor(),
84 | self.normalize,
85 | ]
86 | )
87 |
88 | def __call__(self, item):
89 | return self.transform(item)
90 |
91 | @classmethod
92 | def from_config(cls, cfg=None):
93 | if cfg is None:
94 | cfg = OmegaConf.create()
95 |
96 | image_size = cfg.get("image_size", 224)
97 |
98 | mean = cfg.get("mean", None)
99 | std = cfg.get("std", None)
100 |
101 | min_scale = cfg.get("min_scale", 0.5)
102 | max_scale = cfg.get("max_scale", 1.0)
103 |
104 | return cls(
105 | image_size=image_size,
106 | mean=mean,
107 | std=std,
108 | min_scale=min_scale,
109 | max_scale=max_scale,
110 | )
111 |
112 |
113 | @registry.register_processor("blip2_image_eval")
114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115 | def __init__(self, image_size=224, mean=None, std=None):
116 | super().__init__(mean=mean, std=std)
117 |
118 | self.transform = transforms.Compose(
119 | [
120 | transforms.Resize(
121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122 | ),
123 | transforms.ToTensor(),
124 | self.normalize,
125 | ]
126 | )
127 |
128 | def __call__(self, item):
129 | return self.transform(item)
130 |
131 | @classmethod
132 | def from_config(cls, cfg=None):
133 | if cfg is None:
134 | cfg = OmegaConf.create()
135 |
136 | image_size = cfg.get("image_size", 224)
137 |
138 | mean = cfg.get("mean", None)
139 | std = cfg.get("std", None)
140 |
141 | return cls(image_size=image_size, mean=mean, std=std)
--------------------------------------------------------------------------------
/minigpt4/processors/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 |
11 | import torch
12 |
13 |
14 | ## aug functions
15 | def identity_func(img):
16 | return img
17 |
18 |
19 | def autocontrast_func(img, cutoff=0):
20 | """
21 | same output as PIL.ImageOps.autocontrast
22 | """
23 | n_bins = 256
24 |
25 | def tune_channel(ch):
26 | n = ch.size
27 | cut = cutoff * n // 100
28 | if cut == 0:
29 | high, low = ch.max(), ch.min()
30 | else:
31 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32 | low = np.argwhere(np.cumsum(hist) > cut)
33 | low = 0 if low.shape[0] == 0 else low[0]
34 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36 | if high <= low:
37 | table = np.arange(n_bins)
38 | else:
39 | scale = (n_bins - 1) / (high - low)
40 | offset = -low * scale
41 | table = np.arange(n_bins) * scale + offset
42 | table[table < 0] = 0
43 | table[table > n_bins - 1] = n_bins - 1
44 | table = table.clip(0, 255).astype(np.uint8)
45 | return table[ch]
46 |
47 | channels = [tune_channel(ch) for ch in cv2.split(img)]
48 | out = cv2.merge(channels)
49 | return out
50 |
51 |
52 | def equalize_func(img):
53 | """
54 | same output as PIL.ImageOps.equalize
55 | PIL's implementation is different from cv2.equalize
56 | """
57 | n_bins = 256
58 |
59 | def tune_channel(ch):
60 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61 | non_zero_hist = hist[hist != 0].reshape(-1)
62 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63 | if step == 0:
64 | return ch
65 | n = np.empty_like(hist)
66 | n[0] = step // 2
67 | n[1:] = hist[:-1]
68 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69 | return table[ch]
70 |
71 | channels = [tune_channel(ch) for ch in cv2.split(img)]
72 | out = cv2.merge(channels)
73 | return out
74 |
75 |
76 | def rotate_func(img, degree, fill=(0, 0, 0)):
77 | """
78 | like PIL, rotate by degree, not radians
79 | """
80 | H, W = img.shape[0], img.shape[1]
81 | center = W / 2, H / 2
82 | M = cv2.getRotationMatrix2D(center, degree, 1)
83 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84 | return out
85 |
86 |
87 | def solarize_func(img, thresh=128):
88 | """
89 | same output as PIL.ImageOps.posterize
90 | """
91 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
92 | table = table.clip(0, 255).astype(np.uint8)
93 | out = table[img]
94 | return out
95 |
96 |
97 | def color_func(img, factor):
98 | """
99 | same output as PIL.ImageEnhance.Color
100 | """
101 | ## implementation according to PIL definition, quite slow
102 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103 | # out = blend(degenerate, img, factor)
104 | # M = (
105 | # np.eye(3) * factor
106 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107 | # )[np.newaxis, np.newaxis, :]
108 | M = np.float32(
109 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110 | ) * factor + np.float32([[0.114], [0.587], [0.299]])
111 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112 | return out
113 |
114 |
115 | def contrast_func(img, factor):
116 | """
117 | same output as PIL.ImageEnhance.Contrast
118 | """
119 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120 | table = (
121 | np.array([(el - mean) * factor + mean for el in range(256)])
122 | .clip(0, 255)
123 | .astype(np.uint8)
124 | )
125 | out = table[img]
126 | return out
127 |
128 |
129 | def brightness_func(img, factor):
130 | """
131 | same output as PIL.ImageEnhance.Contrast
132 | """
133 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134 | out = table[img]
135 | return out
136 |
137 |
138 | def sharpness_func(img, factor):
139 | """
140 | The differences the this result and PIL are all on the 4 boundaries, the center
141 | areas are same
142 | """
143 | kernel = np.ones((3, 3), dtype=np.float32)
144 | kernel[1][1] = 5
145 | kernel /= 13
146 | degenerate = cv2.filter2D(img, -1, kernel)
147 | if factor == 0.0:
148 | out = degenerate
149 | elif factor == 1.0:
150 | out = img
151 | else:
152 | out = img.astype(np.float32)
153 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155 | out = out.astype(np.uint8)
156 | return out
157 |
158 |
159 | def shear_x_func(img, factor, fill=(0, 0, 0)):
160 | H, W = img.shape[0], img.shape[1]
161 | M = np.float32([[1, factor, 0], [0, 1, 0]])
162 | out = cv2.warpAffine(
163 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164 | ).astype(np.uint8)
165 | return out
166 |
167 |
168 | def translate_x_func(img, offset, fill=(0, 0, 0)):
169 | """
170 | same output as PIL.Image.transform
171 | """
172 | H, W = img.shape[0], img.shape[1]
173 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
174 | out = cv2.warpAffine(
175 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176 | ).astype(np.uint8)
177 | return out
178 |
179 |
180 | def translate_y_func(img, offset, fill=(0, 0, 0)):
181 | """
182 | same output as PIL.Image.transform
183 | """
184 | H, W = img.shape[0], img.shape[1]
185 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
186 | out = cv2.warpAffine(
187 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188 | ).astype(np.uint8)
189 | return out
190 |
191 |
192 | def posterize_func(img, bits):
193 | """
194 | same output as PIL.ImageOps.posterize
195 | """
196 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197 | return out
198 |
199 |
200 | def shear_y_func(img, factor, fill=(0, 0, 0)):
201 | H, W = img.shape[0], img.shape[1]
202 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
203 | out = cv2.warpAffine(
204 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205 | ).astype(np.uint8)
206 | return out
207 |
208 |
209 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
210 | replace = np.array(replace, dtype=np.uint8)
211 | H, W = img.shape[0], img.shape[1]
212 | rh, rw = np.random.random(2)
213 | pad_size = pad_size // 2
214 | ch, cw = int(rh * H), int(rw * W)
215 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217 | out = img.copy()
218 | out[x1:x2, y1:y2, :] = replace
219 | return out
220 |
221 |
222 | ### level to args
223 | def enhance_level_to_args(MAX_LEVEL):
224 | def level_to_args(level):
225 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226 |
227 | return level_to_args
228 |
229 |
230 | def shear_level_to_args(MAX_LEVEL, replace_value):
231 | def level_to_args(level):
232 | level = (level / MAX_LEVEL) * 0.3
233 | if np.random.random() > 0.5:
234 | level = -level
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241 | def level_to_args(level):
242 | level = (level / MAX_LEVEL) * float(translate_const)
243 | if np.random.random() > 0.5:
244 | level = -level
245 | return (level, replace_value)
246 |
247 | return level_to_args
248 |
249 |
250 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251 | def level_to_args(level):
252 | level = int((level / MAX_LEVEL) * cutout_const)
253 | return (level, replace_value)
254 |
255 | return level_to_args
256 |
257 |
258 | def solarize_level_to_args(MAX_LEVEL):
259 | def level_to_args(level):
260 | level = int((level / MAX_LEVEL) * 256)
261 | return (level,)
262 |
263 | return level_to_args
264 |
265 |
266 | def none_level_to_args(level):
267 | return ()
268 |
269 |
270 | def posterize_level_to_args(MAX_LEVEL):
271 | def level_to_args(level):
272 | level = int((level / MAX_LEVEL) * 4)
273 | return (level,)
274 |
275 | return level_to_args
276 |
277 |
278 | def rotate_level_to_args(MAX_LEVEL, replace_value):
279 | def level_to_args(level):
280 | level = (level / MAX_LEVEL) * 30
281 | if np.random.random() < 0.5:
282 | level = -level
283 | return (level, replace_value)
284 |
285 | return level_to_args
286 |
287 |
288 | func_dict = {
289 | "Identity": identity_func,
290 | "AutoContrast": autocontrast_func,
291 | "Equalize": equalize_func,
292 | "Rotate": rotate_func,
293 | "Solarize": solarize_func,
294 | "Color": color_func,
295 | "Contrast": contrast_func,
296 | "Brightness": brightness_func,
297 | "Sharpness": sharpness_func,
298 | "ShearX": shear_x_func,
299 | "TranslateX": translate_x_func,
300 | "TranslateY": translate_y_func,
301 | "Posterize": posterize_func,
302 | "ShearY": shear_y_func,
303 | }
304 |
305 | translate_const = 10
306 | MAX_LEVEL = 10
307 | replace_value = (128, 128, 128)
308 | arg_dict = {
309 | "Identity": none_level_to_args,
310 | "AutoContrast": none_level_to_args,
311 | "Equalize": none_level_to_args,
312 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313 | "Solarize": solarize_level_to_args(MAX_LEVEL),
314 | "Color": enhance_level_to_args(MAX_LEVEL),
315 | "Contrast": enhance_level_to_args(MAX_LEVEL),
316 | "Brightness": enhance_level_to_args(MAX_LEVEL),
317 | "Sharpness": enhance_level_to_args(MAX_LEVEL),
318 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321 | "Posterize": posterize_level_to_args(MAX_LEVEL),
322 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323 | }
324 |
325 |
326 | class RandomAugment(object):
327 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328 | self.N = N
329 | self.M = M
330 | self.isPIL = isPIL
331 | if augs:
332 | self.augs = augs
333 | else:
334 | self.augs = list(arg_dict.keys())
335 |
336 | def get_random_ops(self):
337 | sampled_ops = np.random.choice(self.augs, self.N)
338 | return [(op, 0.5, self.M) for op in sampled_ops]
339 |
340 | def __call__(self, img):
341 | if self.isPIL:
342 | img = np.array(img)
343 | ops = self.get_random_ops()
344 | for name, prob, level in ops:
345 | if np.random.random() > prob:
346 | continue
347 | args = arg_dict[name](level)
348 | img = func_dict[name](img, *args)
349 | return img
350 |
351 |
352 | class VideoRandomAugment(object):
353 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354 | self.N = N
355 | self.M = M
356 | self.p = p
357 | self.tensor_in_tensor_out = tensor_in_tensor_out
358 | if augs:
359 | self.augs = augs
360 | else:
361 | self.augs = list(arg_dict.keys())
362 |
363 | def get_random_ops(self):
364 | sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365 | return [(op, self.M) for op in sampled_ops]
366 |
367 | def __call__(self, frames):
368 | assert (
369 | frames.shape[-1] == 3
370 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371 |
372 | if self.tensor_in_tensor_out:
373 | frames = frames.numpy().astype(np.uint8)
374 |
375 | num_frames = frames.shape[0]
376 |
377 | ops = num_frames * [self.get_random_ops()]
378 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379 |
380 | frames = torch.stack(
381 | list(map(self._aug, frames, ops, apply_or_not)), dim=0
382 | ).float()
383 |
384 | return frames
385 |
386 | def _aug(self, img, ops, apply_or_not):
387 | for i, (name, level) in enumerate(ops):
388 | if not apply_or_not[i]:
389 | continue
390 | args = arg_dict[name](level)
391 | img = func_dict[name](img, *args)
392 | return torch.from_numpy(img)
393 |
394 |
395 | if __name__ == "__main__":
396 | a = RandomAugment()
397 | img = np.random.randn(32, 32, 3)
398 | a(img)
399 |
--------------------------------------------------------------------------------
/minigpt4/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.runners.runner_base import RunnerBase
9 |
10 | __all__ = ["RunnerBase"]
11 |
--------------------------------------------------------------------------------
/minigpt4/runners/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/runners/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
11 | from minigpt4.tasks.mimic_generate_then_refine import MIMICGenerateThenRefine
12 |
13 |
14 | def setup_task(cfg):
15 | assert "task" in cfg.run_cfg, "Task name must be provided."
16 |
17 | task_name = cfg.run_cfg.task
18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg)
19 | assert task is not None, "Task {} not properly registered.".format(task_name)
20 |
21 | return task
22 |
23 |
24 | __all__ = [
25 | "BaseTask",
26 | "ImageTextPretrainTask",
27 | "MIMICGenerateThenRefine",
28 | ]
29 |
--------------------------------------------------------------------------------
/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/tasks/__pycache__/mimic_generate_then_refine.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/mimic_generate_then_refine.cpython-39.pyc
--------------------------------------------------------------------------------
/minigpt4/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import os
10 |
11 | import deepspeed
12 |
13 | import torch
14 | import torch.distributed as dist
15 | from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
16 | from minigpt4.common.logger import MetricLogger, SmoothedValue
17 | from minigpt4.common.registry import registry
18 | from minigpt4.datasets.data_utils import prepare_sample
19 |
20 |
21 | class BaseTask:
22 | def __init__(self, **kwargs):
23 | super().__init__()
24 |
25 | self.inst_id_key = "instance_id"
26 |
27 | @classmethod
28 | def setup_task(cls, **kwargs):
29 | return cls()
30 |
31 | def build_model(self, cfg):
32 | model_config = cfg.model_cfg
33 |
34 | model_cls = registry.get_model_class(model_config.arch)
35 | return model_cls.from_config(model_config)
36 |
37 | def build_datasets(self, cfg):
38 | """
39 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
40 | Download dataset and annotations automatically if not exist.
41 |
42 | Args:
43 | cfg (common.config.Config): _description_
44 |
45 | Returns:
46 | dict: Dictionary of torch.utils.data.Dataset objects by split.
47 | """
48 |
49 | datasets = dict()
50 |
51 | datasets_config = cfg.datasets_cfg
52 |
53 | assert len(datasets_config) > 0, "At least one dataset has to be specified."
54 |
55 | for name in datasets_config:
56 | dataset_config = datasets_config[name]
57 |
58 | builder = registry.get_builder_class(name)(dataset_config)
59 | dataset = builder.build_datasets()
60 |
61 | dataset['train'].name = name
62 | if 'sample_ratio' in dataset_config:
63 | dataset['train'].sample_ratio = dataset_config.sample_ratio
64 |
65 | datasets[name] = dataset
66 |
67 | return datasets
68 |
69 | def train_step(self, model, samples):
70 | loss = model(samples)["loss"]
71 | return loss
72 |
73 | def valid_step(self, model, samples):
74 | raise NotImplementedError
75 |
76 | def before_evaluation(self, model, dataset, **kwargs):
77 | model.before_evaluation(dataset=dataset, task_type=type(self))
78 |
79 | def after_evaluation(self, **kwargs):
80 | pass
81 |
82 | def inference_step(self):
83 | raise NotImplementedError
84 |
85 | def evaluation(self, model, data_loader, cuda_enabled=True):
86 | metric_logger = MetricLogger(delimiter=" ")
87 | header = "Evaluation"
88 | # TODO make it configurable
89 | print_freq = 10
90 |
91 | results = []
92 |
93 | for samples in metric_logger.log_every(data_loader, print_freq, header):
94 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
95 |
96 | eval_output = self.valid_step(model=model, samples=samples)
97 | results.extend(eval_output)
98 |
99 | if is_dist_avail_and_initialized():
100 | dist.barrier()
101 |
102 | return results
103 |
104 | def train_epoch(
105 | self,
106 | epoch,
107 | model,
108 | data_loader,
109 | optimizer,
110 | lr_scheduler,
111 | scaler=None,
112 | cuda_enabled=False,
113 | log_freq=50,
114 | accum_grad_iters=1,
115 | use_zero_optimizer=False,
116 | ):
117 | return self._train_inner_loop(
118 | epoch=epoch,
119 | iters_per_epoch=lr_scheduler.iters_per_epoch,
120 | model=model,
121 | data_loader=data_loader,
122 | optimizer=optimizer,
123 | scaler=scaler,
124 | lr_scheduler=lr_scheduler,
125 | log_freq=log_freq,
126 | cuda_enabled=cuda_enabled,
127 | accum_grad_iters=accum_grad_iters,
128 | use_zero_optimizer=use_zero_optimizer,
129 | )
130 |
131 | def train_iters(
132 | self,
133 | epoch,
134 | start_iters,
135 | iters_per_inner_epoch,
136 | model,
137 | data_loader,
138 | optimizer,
139 | lr_scheduler,
140 | scaler=None,
141 | cuda_enabled=False,
142 | log_freq=50,
143 | accum_grad_iters=1,
144 | ):
145 | return self._train_inner_loop(
146 | epoch=epoch,
147 | start_iters=start_iters,
148 | iters_per_epoch=iters_per_inner_epoch,
149 | model=model,
150 | data_loader=data_loader,
151 | optimizer=optimizer,
152 | scaler=scaler,
153 | lr_scheduler=lr_scheduler,
154 | log_freq=log_freq,
155 | cuda_enabled=cuda_enabled,
156 | accum_grad_iters=accum_grad_iters,
157 | )
158 |
159 | def _train_inner_loop(
160 | self,
161 | epoch,
162 | iters_per_epoch,
163 | model,
164 | data_loader,
165 | optimizer,
166 | lr_scheduler,
167 | scaler=None,
168 | start_iters=None,
169 | log_freq=50,
170 | cuda_enabled=False,
171 | accum_grad_iters=1,
172 | use_zero_optimizer=False,
173 | ):
174 | """
175 | An inner training loop compatible with both epoch-based and iter-based training.
176 |
177 | When using epoch-based, training stops after one epoch; when using iter-based,
178 | training stops after #iters_per_epoch iterations.
179 | """
180 | use_amp = scaler is not None
181 |
182 | if not hasattr(data_loader, "__next__"):
183 | # convert to iterator if not already
184 | data_loader = iter(data_loader)
185 |
186 | metric_logger = MetricLogger(delimiter=" ")
187 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
188 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
189 |
190 | # if iter-based runner, schedule lr based on inner epoch.
191 | logging.info(
192 | "Start training epoch {}, {} iters per inner epoch.".format(
193 | epoch, iters_per_epoch
194 | )
195 | )
196 | header = "Train: data epoch: [{}]".format(epoch)
197 | if start_iters is None:
198 | # epoch-based runner
199 | inner_epoch = epoch
200 | else:
201 | # In iter-based runner, we schedule the learning rate based on iterations.
202 | inner_epoch = start_iters // iters_per_epoch
203 | header = header + "; inner epoch [{}]".format(inner_epoch)
204 |
205 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
206 | # if using iter-based runner, we stop after iters_per_epoch iterations.
207 | if i >= iters_per_epoch:
208 | break
209 |
210 | samples = next(data_loader)
211 |
212 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
213 | samples.update(
214 | {
215 | "epoch": inner_epoch,
216 | "num_iters_per_epoch": iters_per_epoch,
217 | "iters": i,
218 | }
219 | )
220 |
221 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
222 |
223 | with torch.cuda.amp.autocast(enabled=use_amp):
224 | loss = self.train_step(model=model, samples=samples)
225 |
226 | # after_train_step()
227 | if use_amp:
228 | scaler.scale(loss).backward()
229 | else:
230 | loss.backward()
231 |
232 | # update gradients every accum_grad_iters iterations
233 | if (i + 1) % accum_grad_iters == 0:
234 | if use_amp:
235 | scaler.step(optimizer)
236 | scaler.update()
237 | else:
238 | optimizer.step()
239 | optimizer.zero_grad()
240 |
241 | metric_logger.update(loss=loss.item())
242 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
243 |
244 | # after train_epoch()
245 | # gather the stats from all processes
246 | metric_logger.synchronize_between_processes()
247 | logging.info("Averaged stats: " + str(metric_logger.global_avg()))
248 | return {
249 | k: "{:.3f}".format(meter.global_avg)
250 | for k, meter in metric_logger.meters.items()
251 | }
252 |
253 | @staticmethod
254 | def save_result(result, result_dir, filename, remove_duplicate=""):
255 | import json
256 |
257 | result_file = os.path.join(
258 | result_dir, "%s_rank%d.json" % (filename, get_rank())
259 | )
260 | final_result_file = os.path.join(result_dir, "%s.json" % filename)
261 |
262 | json.dump(result, open(result_file, "w"))
263 |
264 | if is_dist_avail_and_initialized():
265 | dist.barrier()
266 |
267 | if is_main_process():
268 | logging.warning("rank %d starts merging results." % get_rank())
269 | # combine results from all processes
270 | result = []
271 |
272 | for rank in range(get_world_size()):
273 | result_file = os.path.join(
274 | result_dir, "%s_rank%d.json" % (filename, rank)
275 | )
276 | res = json.load(open(result_file, "r"))
277 | result += res
278 |
279 | if remove_duplicate:
280 | result_new = []
281 | id_list = []
282 | for res in result:
283 | if res[remove_duplicate] not in id_list:
284 | id_list.append(res[remove_duplicate])
285 | result_new.append(res)
286 | result = result_new
287 |
288 | json.dump(result, open(final_result_file, "w"))
289 | print("result file saved to %s" % final_result_file)
290 |
291 | return final_result_file
292 |
--------------------------------------------------------------------------------
/minigpt4/tasks/image_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("image_text_pretrain")
13 | class ImageTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/minigpt4/tasks/mimic_generate_then_refine.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 | import logging
8 | import torch
9 |
10 | from minigpt4.common.registry import registry
11 | from minigpt4.tasks.base_task import BaseTask
12 | from minigpt4.common.logger import MetricLogger, SmoothedValue
13 | from minigpt4.datasets.data_utils import prepare_sample
14 |
15 |
16 | @registry.register_task("mimic_generate_then_refine")
17 | class MIMICGenerateThenRefine(BaseTask):
18 | def __init__(self):
19 | super().__init__()
20 |
21 | def train_step(self, model, samples):
22 | loss = model(samples)["loss"]
23 | return loss
24 |
25 | def _train_inner_loop(
26 | self,
27 | epoch,
28 | iters_per_epoch,
29 | model,
30 | data_loader,
31 | optimizer,
32 | lr_scheduler,
33 | scaler=None,
34 | start_iters=None,
35 | log_freq=50,
36 | cuda_enabled=False,
37 | accum_grad_iters=1,
38 | use_zero_optimizer=False,
39 | ):
40 | """
41 | An inner training loop compatible with both epoch-based and iter-based training.
42 |
43 | When using epoch-based, training stops after one epoch; when using iter-based,
44 | training stops after #iters_per_epoch iterations.
45 | """
46 | use_amp = scaler is not None
47 |
48 | if not hasattr(data_loader, "__next__"):
49 | # convert to iterator if not already
50 | data_loader = iter(data_loader)
51 |
52 | metric_logger = MetricLogger(delimiter=" ")
53 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
54 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
55 |
56 | # if iter-based runner, schedule lr based on inner epoch.
57 | logging.info(
58 | "Start training epoch {}, {} iters per inner epoch.".format(
59 | epoch, iters_per_epoch
60 | )
61 | )
62 | header = "Train: data epoch: [{}]".format(epoch)
63 | if start_iters is None:
64 | # epoch-based runner
65 | inner_epoch = epoch
66 | else:
67 | # In iter-based runner, we schedule the learning rate based on iterations.
68 | inner_epoch = start_iters // iters_per_epoch
69 | header = header + "; inner epoch [{}]".format(inner_epoch)
70 |
71 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
72 | # if using iter-based runner, we stop after iters_per_epoch iterations.
73 | if i >= iters_per_epoch:
74 | break
75 |
76 | samples = next(data_loader)
77 |
78 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
79 | samples.update(
80 | {
81 | "epoch": inner_epoch,
82 | "num_iters_per_epoch": iters_per_epoch,
83 | "iters": i,
84 | }
85 | )
86 |
87 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
88 |
89 | with torch.cuda.amp.autocast(enabled=use_amp):
90 | loss = self.train_step(model=model, samples=samples)
91 |
92 | # after_train_step()
93 | if use_zero_optimizer:
94 | model.backward(loss)
95 | else:
96 | if use_amp:
97 | scaler.scale(loss).backward()
98 | else:
99 | loss.backward()
100 |
101 |
102 |
103 | # update gradients every accum_grad_iters iterations
104 | if (i + 1) % accum_grad_iters == 0:
105 | if use_zero_optimizer:
106 | model.step()
107 | else:
108 | if use_amp:
109 | scaler.step(optimizer)
110 | scaler.update()
111 | else:
112 | optimizer.step()
113 | optimizer.zero_grad()
114 |
115 | metric_logger.update(loss=loss.item())
116 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
117 |
118 | # after train_epoch()
119 | # gather the stats from all processes
120 | metric_logger.synchronize_between_processes()
121 | logging.info("Averaged stats: " + str(metric_logger.global_avg()))
122 | return {
123 | k: "{:.3f}".format(meter.global_avg)
124 | for k, meter in metric_logger.meters.items()
125 | }
126 |
127 | def evaluation(self, model, data_loader, cuda_enabled=True):
128 | pass
129 |
--------------------------------------------------------------------------------
/prompts/stage1-pretraining-prompts.txt:
--------------------------------------------------------------------------------
1 |
Describe this image in detail.
2 |
Take a look at this image and describe what you notice.
3 |
Please provide a detailed description of the picture.
4 |
Could you describe the contents of this image for me?
--------------------------------------------------------------------------------
/prompts/stage2-generation-prompts.txt:
--------------------------------------------------------------------------------
1 |
You are a AI radiologist assistant. Your goal is to describe the syndromes reflected in the the radiograph in details. The description should be reasonable and should not be made up. Describe as informative as possible.
--------------------------------------------------------------------------------
/prompts/stage2-refinement-prompts.txt:
--------------------------------------------------------------------------------
1 |
### Human: Rewrite the sentences in the report according to the image. Delete irrelevant descriptions in the report. Supply missing descriptions in the report. Write it as informative as possible. Keep the writing style unchanged. ### Assistant
2 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import argparse
9 | import os
10 | import random
11 | import shutil
12 |
13 | import numpy as np
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 |
17 | import minigpt4.tasks as tasks
18 | from minigpt4.common.config import Config
19 | from minigpt4.common.dist_utils import get_rank, init_distributed_mode
20 | from minigpt4.common.logger import setup_logger
21 | from minigpt4.common.optims import (
22 | LinearWarmupCosineLRScheduler,
23 | LinearWarmupStepLRScheduler,
24 | )
25 | from minigpt4.common.registry import registry
26 | from minigpt4.common.utils import now
27 |
28 | # imports modules for registration
29 | from minigpt4.datasets.builders import *
30 | from minigpt4.models import *
31 | from minigpt4.processors import *
32 | from minigpt4.runners import *
33 | from minigpt4.tasks import *
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description="Training")
38 |
39 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
40 | parser.add_argument(
41 | "--options",
42 | nargs="+",
43 | help="override some settings in the used config, the key-value pair "
44 | "in xxx=yyy format will be merged into config file (deprecate), "
45 | "change to --cfg-options instead.",
46 | )
47 |
48 | # TODO: deepspeed configurations
49 | parser.add_argument('--use_zero_optimizer', action='store_true', help='use ZeRO optimizer to save GPU memory')
50 | parser.add_argument('--local_rank', default=0, type=int, help='local rank')
51 | parser.add_argument('--deepspeed_config', type=str, default='train_configs/zero_configs/stage1.json', help='path to deepspeed configuration file')
52 | parser.add_argument('--train_batch_size', type=int, default=1, help='training batch size')
53 | parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=1, help='batch size per GPU')
54 |
55 | args = parser.parse_args()
56 | # if 'LOCAL_RANK' not in os.environ:
57 | # os.environ['LOCAL_RANK'] = str(args.local_rank)
58 |
59 | return args
60 |
61 |
62 | def setup_seeds(config):
63 | seed = config.run_cfg.seed + get_rank()
64 |
65 | random.seed(seed)
66 | np.random.seed(seed)
67 | torch.manual_seed(seed)
68 |
69 | cudnn.benchmark = False
70 | cudnn.deterministic = True
71 |
72 |
73 | def get_runner_class(cfg):
74 | """
75 | Get runner class from config. Default to epoch-based runner.
76 | """
77 | runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
78 |
79 | return runner_cls
80 |
81 |
82 | def main():
83 | # allow auto-dl completes on main process without timeout when using NCCL backend.
84 | # os.environ["NCCL_BLOCKING_WAIT"] = "1"
85 |
86 | # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
87 | job_id = now()
88 |
89 | cfg = Config(parse_args())
90 |
91 | init_distributed_mode(cfg.run_cfg)
92 |
93 | setup_seeds(cfg)
94 |
95 | # set after init_distributed_mode() to only log on master.
96 | setup_logger()
97 |
98 | cfg.pretty_print()
99 |
100 | task = tasks.setup_task(cfg)
101 | datasets = task.build_datasets(cfg)
102 | model = task.build_model(cfg)
103 |
104 | # TODO: define arguments, required by deepspeed
105 | args = parse_args()
106 | args.train_batch_size = cfg.run_cfg.batch_size_train
107 | args.train_micro_batch_size_per_gpu = args.train_batch_size // cfg.run_cfg.world_size
108 |
109 |
110 | runner = get_runner_class(cfg)(
111 | cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets, cmd_args=args,
112 | )
113 | runner.train()
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
118 |
--------------------------------------------------------------------------------
/train_configs/stage1/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 | model_type: pretrain_vicuna
4 | freeze_vit: True
5 | freeze_qformer: True
6 | freeze_llama: False
7 | max_txt_len: 160
8 | end_sym: "###"
9 | generation_prompt_path: "prompts/stage1-pretraining-prompts.txt"
10 | refinement_prompt_path: "prompts/stage1-pretraining-prompts.txt"
11 | prompt_template: '###Human: {} ###Assistant: '
12 | ckpt: '/path/to/linear/layer'
13 | is_pretraining: True
14 |
15 | use_contrastive_loss: False
16 | use_refinement_loss: False
17 | triplet_margin: 0.5
18 | triplet_weight: 1.0
19 | refinement_loss_weight: 1.0
20 |
21 | # lora configuartion
22 | use_lora: True # use lora for vicuna
23 | use_lora_vit_qformer: False # use lora for vision backbone
24 | lora_rank: 8
25 | lora_alpha: 32
26 | lora_dropout: 0.1
27 |
28 | # ZeRO optimizer configuration
29 | use_zero_optimizer: True
30 | deepspeed_config: "train_configs/stage1/zero.json"
31 |
32 | datasets:
33 | mimic_generate_then_refine:
34 | vis_processor:
35 | train:
36 | name: "blip2_image_train"
37 | image_size: 224
38 | text_processor:
39 | train:
40 | name: "blip_caption"
41 |
42 | run:
43 | task: mimic_generate_then_refine
44 | # optimizer
45 | lr_sched: "linear_warmup_cosine_lr"
46 | init_lr: 3e-5
47 | min_lr: 1e-5
48 | warmup_lr: 1e-6
49 |
50 | weight_decay: 0.05
51 | max_epoch: 10
52 | iters_per_epoch: 5000 # 200
53 | batch_size_train: 1 # total batch size, not per GPU
54 | batch_size_eval: 1
55 | num_workers: 4
56 | warmup_steps: 200
57 |
58 | seed: 42
59 | output_dir: "/path/to/output/dir"
60 |
61 | amp: True
62 | resume_ckpt_path: null
63 |
64 | evaluate: False
65 | train_splits: ["train"]
66 |
67 | device: "cuda"
68 | world_size: 2
69 | dist_url: "env://"
70 | distributed: True
71 |
72 | # ZeRO optimizer configuration
73 | use_zero_optimizer: True
74 | deepspeed_config: "train_configs/stage1/zero.json"
--------------------------------------------------------------------------------
/train_configs/stage1/zero.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "reduce_bucket_size": 5e8
5 | },
6 | "train_batch_size": 24
7 | }
--------------------------------------------------------------------------------
/train_configs/stage2/iuxray/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 | model_type: pretrain_vicuna
4 | freeze_vit: True
5 | freeze_qformer: True
6 | max_txt_len: 100
7 | end_sym: "###"
8 | generation_prompt_path: "prompts/stage-2-generation-prompts.txt"
9 | refinement_prompt_path: "prompts/stage-2-refinement-prompts.txt"
10 | prompt_template: '###Human: {} ###Assistant: '
11 | ckpt: '/path/to/linear'
12 |
13 | use_contrastive_loss: True
14 | use_refinement_loss: True
15 | triplet_margin: 0.5
16 | triplet_weight: 1.0
17 | refinement_loss_weight: 1.0
18 |
19 | # lora configuartion
20 | use_lora: True
21 | lora_rank: 32
22 | lora_alpha: 32
23 | lora_dropout: 0.1
24 |
25 | # ZeRO optimizer configuration
26 | use_zero_optimizer: True
27 | deepspeed_config: "train_configs/stage2/zero.json"
28 |
29 | datasets:
30 | mimic_generate_then_refine:
31 | vis_processor:
32 | train:
33 | name: "blip2_image_train"
34 | image_size: 224
35 | text_processor:
36 | train:
37 | name: "blip_caption"
38 |
39 | run:
40 | task: mimic_generate_then_refine
41 | # optimizer
42 | lr_sched: "linear_warmup_cosine_lr"
43 | init_lr: 3e-5
44 | min_lr: 1e-5
45 | warmup_lr: 1e-6
46 |
47 | weight_decay: 0.05
48 | max_epoch: 10
49 | iters_per_epoch: 1000 # 200
50 | batch_size_train: 1 # total batch size, not per GPU
51 | batch_size_eval: 1
52 | num_workers: 4
53 | warmup_steps: 200
54 |
55 | seed: 42
56 | output_dir: "/path/to/output"
57 |
58 | amp: True
59 | resume_ckpt_path: null
60 |
61 | evaluate: False
62 | train_splits: ["train"]
63 |
64 | device: "cuda"
65 | world_size: 2
66 | dist_url: "env://"
67 | distributed: True
68 |
69 | # ZeRO optimizer configuration
70 | use_zero_optimizer: True
71 | deepspeed_config: "train_configs/stage2/zero.json"
--------------------------------------------------------------------------------
/train_configs/stage2/iuxray/zero.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "reduce_bucket_size": 5e8
5 | },
6 | "train_batch_size": 12,
7 | "scheduler": {
8 | "type": "WarmupLR",
9 | "params": {
10 | "warmup_min_lr": 1e-6,
11 | "warmup_max_lr": 1e-5,
12 | "warmup_num_steps": 200
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------
/train_configs/stage2/mimic/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: mini_gpt4
3 | model_type: pretrain_vicuna
4 | freeze_vit: True
5 | freeze_qformer: True
6 | max_txt_len: 100
7 | end_sym: "###"
8 | generation_prompt_path: "/path/to/generation/prompts"
9 | refinement_prompt_path: "/path/to/refinement/prompts"
10 | prompt_template: '###Human: {} ###Assistant: '
11 | ckpt: '/path/to/linear'
12 |
13 | use_contrastive_loss: True
14 | use_refinement_loss: True
15 | triplet_margin: 0.5
16 | triplet_weight: 1.0
17 | refinement_loss_weight: 1.0
18 |
19 | # lora configuartion
20 | use_lora: True
21 | lora_rank: 32
22 | lora_alpha: 32
23 | lora_dropout: 0.1
24 |
25 | # ZeRO optimizer configuration
26 | use_zero_optimizer: True
27 | deepspeed_config: "train_configs/stage2/zero.json"
28 |
29 | datasets:
30 | mimic_generate_then_refine:
31 | vis_processor:
32 | train:
33 | name: "blip2_image_train"
34 | image_size: 224
35 | text_processor:
36 | train:
37 | name: "blip_caption"
38 |
39 | run:
40 | task: mimic_generate_then_refine
41 | # optimizer
42 | lr_sched: "linear_warmup_cosine_lr"
43 | init_lr: 3e-5
44 | min_lr: 1e-5
45 | warmup_lr: 1e-6
46 |
47 | weight_decay: 0.05
48 | max_epoch: 10
49 | iters_per_epoch: 1000 # 200
50 | batch_size_train: 1 # total batch size, not per GPU
51 | batch_size_eval: 1
52 | num_workers: 4
53 | warmup_steps: 200
54 |
55 | seed: 42
56 | output_dir: "/path/to/output"
57 |
58 | amp: True
59 | resume_ckpt_path: null
60 |
61 | evaluate: False
62 | train_splits: ["train"]
63 |
64 | device: "cuda"
65 | world_size: 2
66 | dist_url: "env://"
67 | distributed: True
68 |
69 | # ZeRO optimizer configuration
70 | use_zero_optimizer: True
71 | deepspeed_config: "train_configs/stage2/zero.json"
--------------------------------------------------------------------------------
/train_configs/stage2/mimic/zero.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "reduce_bucket_size": 5e8
5 | },
6 | "train_batch_size": 12,
7 | "scheduler": {
8 | "type": "WarmupLR",
9 | "params": {
10 | "warmup_min_lr": 1e-6,
11 | "warmup_max_lr": 1e-5,
12 | "warmup_num_steps": 200
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------