├── LICENSE ├── README.md ├── environment.yaml ├── eval_configs └── eval.yaml ├── generate_reports.py ├── minigpt4 ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── config.cpython-39.pyc │ │ ├── dist_utils.cpython-311.pyc │ │ ├── dist_utils.cpython-39.pyc │ │ ├── logger.cpython-311.pyc │ │ ├── logger.cpython-39.pyc │ │ ├── optims.cpython-39.pyc │ │ ├── registry.cpython-311.pyc │ │ ├── registry.cpython-38.pyc │ │ ├── registry.cpython-39.pyc │ │ ├── utils.cpython-311.pyc │ │ ├── utils.cpython-38.pyc │ │ └── utils.cpython-39.pyc │ ├── config.py │ ├── dist_utils.py │ ├── gradcam.py │ ├── logger.py │ ├── optims.py │ ├── registry.py │ └── utils.py ├── configs │ ├── datasets │ │ ├── cc_sbu │ │ │ ├── align.yaml │ │ │ └── defaults.yaml │ │ ├── iuxray │ │ │ ├── align.yaml │ │ │ └── generate_then_refine.yaml │ │ ├── laion │ │ │ └── defaults.yaml │ │ └── mimic │ │ │ ├── align.yaml │ │ │ └── generate_then_refine.yaml │ ├── default.yaml │ └── models │ │ ├── minigpt4-7b.yaml │ │ └── minigpt4.yaml ├── conversation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── conversation.cpython-39.pyc │ └── conversation.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── data_utils.cpython-39.pyc │ ├── builders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── base_dataset_builder.cpython-311.pyc │ │ │ ├── base_dataset_builder.cpython-38.pyc │ │ │ ├── base_dataset_builder.cpython-39.pyc │ │ │ ├── image_text_pair_builder.cpython-311.pyc │ │ │ └── image_text_pair_builder.cpython-39.pyc │ │ ├── base_dataset_builder.py │ │ └── image_text_pair_builder.py │ ├── data_utils.py │ └── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base_dataset.cpython-311.pyc │ │ ├── base_dataset.cpython-39.pyc │ │ ├── caption_datasets.cpython-311.pyc │ │ ├── caption_datasets.cpython-39.pyc │ │ ├── cc_sbu_dataset.cpython-311.pyc │ │ ├── cc_sbu_dataset.cpython-39.pyc │ │ ├── dataloader_utils.cpython-39.pyc │ │ ├── iuxray_dataset.cpython-311.pyc │ │ ├── iuxray_dataset.cpython-39.pyc │ │ ├── laion_dataset.cpython-311.pyc │ │ ├── laion_dataset.cpython-39.pyc │ │ ├── mimic_dataset.cpython-311.pyc │ │ └── mimic_dataset.cpython-39.pyc │ │ ├── base_dataset.py │ │ ├── caption_datasets.py │ │ ├── cc_sbu_dataset.py │ │ ├── dataloader_utils.py │ │ ├── iuxray_dataset.py │ │ ├── laion_dataset.py │ │ └── mimic_dataset.py ├── models │ ├── Qformer.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Qformer.cpython-311.pyc │ │ ├── Qformer.cpython-39.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base_model.cpython-311.pyc │ │ ├── base_model.cpython-39.pyc │ │ ├── blip2.cpython-311.pyc │ │ ├── blip2.cpython-39.pyc │ │ ├── eva_vit.cpython-311.pyc │ │ ├── eva_vit.cpython-39.pyc │ │ ├── mini_gpt4.cpython-311.pyc │ │ ├── mini_gpt4.cpython-39.pyc │ │ ├── modeling_llama.cpython-311.pyc │ │ └── modeling_llama.cpython-39.pyc │ ├── base_model.py │ ├── blip2.py │ ├── blip2_outputs.py │ ├── eva_vit.py │ ├── mini_gpt4.py │ └── modeling_llama.py ├── output │ └── minigpt4_stage2_finetune │ │ ├── 20230706044 │ │ ├── checkpoint_0.pth │ │ ├── checkpoint_1.pth │ │ └── log.txt │ │ └── 20230706051 │ │ └── log.txt ├── processors │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base_processor.cpython-311.pyc │ │ ├── base_processor.cpython-39.pyc │ │ ├── blip_processors.cpython-311.pyc │ │ ├── blip_processors.cpython-39.pyc │ │ ├── randaugment.cpython-311.pyc │ │ └── randaugment.cpython-39.pyc │ ├── base_processor.py │ ├── blip_processors.py │ └── randaugment.py ├── runners │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── runner_base.cpython-39.pyc │ └── runner_base.py └── tasks │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_task.cpython-39.pyc │ ├── image_text_pretrain.cpython-39.pyc │ └── mimic_generate_then_refine.cpython-39.pyc │ ├── base_task.py │ ├── image_text_pretrain.py │ └── mimic_generate_then_refine.py ├── prompts ├── stage1-pretraining-prompts.txt ├── stage2-generation-prompts.txt └── stage2-refinement-prompts.txt ├── train.py └── train_configs ├── stage1 ├── config.yaml └── zero.json └── stage2 ├── iuxray ├── config.yaml └── zero.json └── mimic ├── config.yaml └── zero.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yan Song's NLP Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Bootstrapping Large Language Models for Radiology Report Generation

3 | 4 | The official GitHub repository of the AAAI-2024 paper ["Bootstrapping Large Language Models for Radiology Report Generation"](https://ojs.aaai.org/index.php/AAAI/article/view/29826). 5 | 6 | # Reference 7 | If our work is helpful to your research, please cite our paper: 8 | ``` latex 9 | @inproceedings{chang2024bootstrapping, 10 | author = {Chang Liu and 11 | Yuanhe Tian and 12 | Weidong Chen and 13 | Yan Song and 14 | Yongdong Zhang}, 15 | editor = {Michael J. Wooldridge and 16 | Jennifer G. Dy and 17 | Sriraam Natarajan}, 18 | title = {Bootstrapping Large Language Models for Radiology Report Generation}, 19 | booktitle = {AAAI}, 20 | pages = {18635--18643}, 21 | year = {2024}, 22 | } 23 | ``` 24 | 25 | # Getting Started 26 | 1. Before you run the code, you need to create a virtual environment and activate it via the following command: 27 | ```bash 28 | conda env create -f environment.yaml 29 | conda activate venv 30 | ``` 31 | 32 | 2. Once the virtual environment is created, you need to download the LLM model weights following the instruction in [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4). Once the model weights are downloaded, you need to modify some configuration files: 33 | - `minigpt4/models/minigpt4-7b.yaml`: line 16 with the path of Vicuna 7b model weights. 34 | - `minigpt4/models/minigpt4.yaml`: line 16 with the path of Vicuna 13b model weights. 35 | 36 | 3. You need to download the dataset from the official websites of [IU X-Ray](https://openi.nlm.nih.gov/faq#collection) and [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/). Once the datasets are ready, you need to modify some configuration files: 37 | - `minigpt4/configs/datasets/iuxray/align.yaml`: line 5 with the path of pre-training dataset. 38 | - `minigpt4/configs/datasets/iuxray/generate_then_refine.yaml`: line 5 with the path of IU X-Ray dataset, line 6 with the path of public medical corpora. 39 | - `minigpt4/configs/datasets/mimic/align.yaml`: line 5 with the path of pre-training dataset. 40 | - `minigpt4/configs/datasets/mimic/generate_then_refine.yaml`: line 5 with the path of MIMIC-CXR dataset, line 6 with the path of public medical corpora. 41 | 42 | # Training 43 | 1. **Pre-training.** We recommend you to follow the instructions below to pre-train MiniGPT-4 on MIMIC-CXR. 44 | 45 | (1) Modify the configuration files. 46 | - `train_configs/stage1/config.yaml`: line 12 with the path of the linear projection layer of MiniGPT-4, line 59 with the output path. 47 | 48 | (2) Run the following command lines to pre-train MiniGPT-4 on MIMIC-CXR. 49 | ``` 50 | python train.py --cfg-path train_configs/stage1/config.yaml 51 | ``` 52 | 53 | If you need to reduce the memory usage, we recommend you to use the first stage strategy of `ZeRO` optimizer. Run the following command lines to pre-train MiniGPT-4 on MIMIC-CXR with a lower memory usage. 54 | 55 | ``` 56 | deepspeed --nproc-per-gpu NUM_GPUS --master-port MASTER_PORT train.py --cfg-path train_configs/stage1/config.yaml use_zero_optimizer --deepspeed_config train_configs/stage1/zero.json 57 | ``` 58 | 59 | You can download our pre-trained model weights from [here](https://huggingface.co/a-b-c-d-e-g/R2-LLM). 60 | 61 | 2. **Fine-tuning.** We recommend you to follow the instructions below to fine-tune MiniGPT-4 on IU X-Ray and MIMIC-CXR. 62 | 63 | (1) Modify the configuration files. Herein, we take the IU X-Ray configuration as an example. 64 | - `train_configs/stage2/iuxray/config.yaml`: line 11 with the path of the linear projection layer of pre-trained MiniGPT-4 on MIMIC-CXR, line 56 with the output path. 65 | 66 | (2) Run the following command lines to fine-tune MiniGPT-4. 67 | 68 | ``` 69 | python train.py --cfg-path train_configs/stage2/iuxray/config.yaml 70 | ``` 71 | 72 | Our codebase supports `ZeRO` to reduce the memory usage. You can run the following command lines with `ZeRO`. 73 | 74 | ``` 75 | deepspeed --nproc-per-gpu NUM_GPUS --master-port MASTER_PORT train.py --cfg-path train_configs/stage2/iuxray/config.yaml use_zero_optimizer --deepspeed_config train_configs/stage2/iuxray/zero.json 76 | ``` 77 | 78 | You can download our fine-tuned model weights from [here](https://huggingface.co/a-b-c-d-e-g/R2-LLM). 79 | 80 | # Inference 81 | Run the following command lines to generate radiology reports. 82 | 83 | ``` 84 | python generate_reports.py \ 85 | --cfg-path configs/eval_configs/eval.yaml \ 86 | --gpu-id GPU_IDS \ 87 | --image_path IMAGE_PATH \ 88 | --annotations ANNOTATIONS_PATH_OF_IUXRAY_OR_MIMIC \ 89 | --checkpoint PATH_TO_PRETRAINED_MODEL_WEIGHTS \ 90 | ``` 91 | 92 | # Acknowledgement 93 | This GitHub repository is heavily built based on the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) repository. Thanks to the authors for their great work! -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: venv 2 | channels: 3 | - pytorch 4 | - defaults 5 | - anaconda 6 | dependencies: 7 | - python=3.9 8 | - cudatoolkit 9 | - pip 10 | - pip: 11 | - torch==2.0.0 12 | - torchaudio 13 | - torchvision 14 | - huggingface-hub==0.18.0 15 | - matplotlib==3.7.0 16 | - psutil==5.9.4 17 | - iopath 18 | - pyyaml==6.0 19 | - regex==2022.10.31 20 | - tokenizers==0.13.2 21 | - tqdm==4.64.1 22 | - transformers==4.30.0 23 | - timm==0.6.13 24 | - webdataset==0.2.48 25 | - omegaconf==2.3.0 26 | - opencv-python==4.7.0.72 27 | - decord==0.6.0 28 | - peft==0.2.0 29 | - sentence-transformers 30 | - gradio==3.47.1 31 | - accelerate==0.20.3 32 | - bitsandbytes==0.37.0 33 | - scikit-image 34 | - visual-genome 35 | - wandb -------------------------------------------------------------------------------- /eval_configs/eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | model_type: pretrain_vicuna 4 | freeze_vit: True 5 | freeze_qformer: True 6 | max_txt_len: 100 7 | end_sym: "###" 8 | low_resource: True 9 | prompt_path: "/path/to/prompts" 10 | prompt_template: '###Human: {} ###Assistant: ' 11 | ckpt: '/path/to/linear' 12 | 13 | # lora configuartion 14 | use_lora: True 15 | lora_rank: 8 16 | lora_alpha: 32 17 | lora_dropout: 0.1 18 | 19 | datasets: 20 | mimic_generate_then_refine: 21 | vis_processor: 22 | train: 23 | name: "blip2_image_eval" 24 | image_size: 224 25 | text_processor: 26 | train: 27 | name: "blip_caption" 28 | 29 | run: 30 | task: image_text_pretrain 31 | -------------------------------------------------------------------------------- /generate_reports.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import json 5 | import random 6 | from tqdm import tqdm 7 | from PIL import Image 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from transformers import StoppingCriteria, StoppingCriteriaList 13 | 14 | from minigpt4.common.config import Config 15 | from minigpt4.common.dist_utils import get_rank 16 | from minigpt4.common.registry import registry 17 | from minigpt4.conversation.conversation import Chat, CONV_VISION 18 | 19 | # imports modules for registration 20 | from minigpt4.datasets.builders import * 21 | from minigpt4.models import * 22 | from minigpt4.processors import * 23 | from minigpt4.runners import * 24 | from minigpt4.tasks import * 25 | 26 | from peft import LoraConfig, TaskType, get_peft_model, set_peft_model_state_dict 27 | 28 | def clean_reports(report): 29 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 30 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 31 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 32 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 33 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 34 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 35 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 36 | .strip().lower().split('. ') 37 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 38 | .replace('\\', '').replace("'", '').strip().lower()) 39 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 40 | report = ' . '.join(tokens) + ' .' 41 | return report 42 | 43 | class StoppingCriteriaSub(StoppingCriteria): 44 | 45 | def __init__(self, stops=[], encounters=1): 46 | super().__init__() 47 | self.stops = stops 48 | 49 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 50 | for stop in self.stops: 51 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 52 | return True 53 | 54 | return False 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser(description="Demo") 59 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 60 | parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") 61 | parser.add_argument( 62 | "--options", 63 | nargs="+", 64 | help="override some settings in the used config, the key-value pair " 65 | "in xxx=yyy format will be merged into config file (deprecate), " 66 | "change to --cfg-options instead.", 67 | ) 68 | 69 | parser.add_argument('--image_path', default='', type=str, help='path of the input image') 70 | parser.add_argument('--generation_prompts', type=str, default='prompts/stage2-generation-prompts.txt', help='path of the generation prompts for the first stage') 71 | parser.add_argument('--refinement_prompts', type=str, default='prompts/stage2-refinement-prompts.txt', help='path of the refinement prompts for the second stage') 72 | parser.add_argument('--annotations', type=str, default='', help='path of annotation file, to load in the GTs') 73 | parser.add_argument('--checkpoint', required=True, help='checkpoint path') 74 | parser.add_argument('--beam_size', type=int, default=1) 75 | parser.add_argument('--temperature', type=float, default=1.0) 76 | parser.add_argument('--max_txt_len', default=160, type=int) 77 | 78 | args = parser.parse_args() 79 | return args 80 | 81 | 82 | def setup_seeds(config): 83 | seed = config.run_cfg.seed + get_rank() 84 | 85 | random.seed(seed) 86 | np.random.seed(seed) 87 | torch.manual_seed(seed) 88 | 89 | cudnn.benchmark = False 90 | cudnn.deterministic = True 91 | 92 | 93 | # ======================================== 94 | # Model Initialization 95 | # ======================================== 96 | 97 | print('Initializing Chat') 98 | args = parse_args() 99 | cfg = Config(args) 100 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 101 | 102 | model_config = cfg.model_cfg 103 | model_config.device_8bit = args.gpu_id 104 | model_cls = registry.get_model_class(model_config.arch) 105 | model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) 106 | 107 | # load LoRA 108 | peft_config = LoraConfig(inference_mode=False, r=cfg.model_cfg.lora_rank, lora_alpha=cfg.model_cfg.lora_alpha, lora_dropout=cfg.model_cfg.lora_dropout) 109 | peft_model = get_peft_model(model.llama_model, peft_config=peft_config) 110 | # loading normal pytroch checkpoint 111 | if args.checkpoint.endswith('.pth'): 112 | full_state_dict = torch.load(args.checkpoint, map_location='cpu') 113 | # loading ZeRO checkpoint 114 | elif args.checkpoint.endswith('.pt'): 115 | full_state_dict = torch.load(args.checkpoint, map_location='cpu')['module'] 116 | set_peft_model_state_dict(peft_model, full_state_dict) 117 | peft_model = peft_model.to(device) 118 | print('LLaMA checkpoint loaded.') 119 | # load in the linear projection layer 120 | llama_proj_state_dict = {} 121 | for key, value in full_state_dict.items(): 122 | if 'llama_proj' in key: 123 | llama_proj_state_dict[key[18:]] = value 124 | model.llama_proj.load_state_dict(llama_proj_state_dict) 125 | print('Linear projection layer loaded.') 126 | 127 | vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train 128 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 129 | print('Initialization Finished') 130 | 131 | # ======================================== 132 | # Start Testing 133 | # ======================================== 134 | 135 | 136 | # image_paths = [] 137 | # for root, dirs, files in os.walk(args.images): 138 | # for file in files: 139 | # image_paths.append(os.path.join(root, file)) 140 | 141 | # load generation prompts from local path 142 | generation_prompts = [] 143 | with open(args.generation_prompts, 'r') as f: 144 | for line in f.readlines(): 145 | generation_prompts.append(line.strip('\n')) 146 | 147 | # load refinement prompts from local path 148 | refinement_prompts = [] 149 | with open(args.refinement_prompts, 'r') as f: 150 | for line in f.readlines(): 151 | refinement_prompts.append(line.strip('\n')) 152 | 153 | final_record_message = '' 154 | with torch.no_grad(): 155 | # TODO: Start the first stage 156 | # random sample one prompt 157 | prompt = random.choice(generation_prompts) 158 | prompt = '###Human: ' + prompt + '###Assistant: ' 159 | 160 | # encode image 161 | img_list = [] 162 | raw_image = Image.open(args.image_path).convert('RGB') 163 | image = vis_processor(raw_image).unsqueeze(0).to(device) 164 | image_emb, _ = model.encode_img(image) 165 | img_list.append(image_emb) 166 | 167 | # wrap image with prompt 168 | prompt_segs = prompt.split('') 169 | seg_tokens = [ 170 | model.llama_tokenizer( 171 | seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids 172 | # only add bos to the first seg 173 | for i, seg in enumerate(prompt_segs) 174 | ] 175 | seg_embs = [peft_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] 176 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] 177 | mixed_embs = torch.cat(mixed_embs, dim=1) 178 | 179 | # prepare other things before generate 180 | stop_words_ids = [torch.tensor([835]).to(device), torch.tensor([2277, 29937]).to(device)] # '###' can be encoded in two different ways. 181 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 182 | 183 | # generate 184 | outputs = peft_model.base_model.model.generate( 185 | inputs_embeds=mixed_embs, 186 | max_new_tokens=args.max_txt_len, 187 | stopping_criteria=stopping_criteria, 188 | num_beams=args.beam_size, 189 | do_sample=True, 190 | min_length=1, 191 | top_p=0.9, 192 | repetition_penalty=1.0, 193 | length_penalty=1, 194 | temperature=args.temperature,) 195 | 196 | output_token = outputs[0] 197 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it 198 | output_token = output_token[1:] 199 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it 200 | output_token = output_token[1:] 201 | output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False) 202 | output_text = output_text.split('###')[0] # remove the stop sign '###' 203 | output_text = output_text.split('Assistant:')[-1].strip() 204 | generated_text = output_text 205 | 206 | # TODO: Start the second stage 207 | coarse_generated_report = output_token 208 | coarse_report_embeds = peft_model.base_model.model.model.embed_tokens(coarse_generated_report).expand(image_emb.shape[0], -1, -1) 209 | atts_report = torch.ones(coarse_report_embeds.size()[:-1], dtype=torch.long).to(device) 210 | prompt = random.choice(refinement_prompts) 211 | prompt = '###Human: ' + prompt + '###Assistant: ' 212 | 213 | # encode image 214 | img_list = [] 215 | raw_image = Image.open(args.image_path).convert('RGB') 216 | image = vis_processor(raw_image).unsqueeze(0).to(device) 217 | image_emb, _ = model.encode_img(image) 218 | img_list.append(image_emb) 219 | 220 | # the right implementation 221 | p_before, p_after_all = prompt.split('') 222 | p_mid, p_after = p_after_all.split('') 223 | p_before_tokens = model.llama_tokenizer(p_before, return_tensors="pt", add_special_tokens=True).to(device).input_ids 224 | p_mid_tokens = model.llama_tokenizer(p_mid, return_tensors="pt", add_special_tokens=False).to(device).input_ids 225 | p_after_tokens = model.llama_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(device).input_ids 226 | 227 | # embedding 228 | p_before_embeds = peft_model.base_model.model.model.embed_tokens(p_before_tokens) 229 | p_mid_embeds = peft_model.base_model.model.model.embed_tokens(p_mid_tokens) 230 | p_after_embeds = peft_model.base_model.model.model.embed_tokens(p_after_tokens) 231 | mixed_embs = torch.cat([p_before_embeds, img_list[0], p_mid_embeds, coarse_report_embeds, p_after_embeds], dim=1) 232 | mixed_embs = torch.cat([p_mid_embeds, coarse_report_embeds, p_after_embeds], dim=1) 233 | 234 | # prepare other things before generate 235 | stop_words_ids = [torch.tensor([835]).to(device), torch.tensor([2277, 29937]).to(device)] # '###' can be encoded in two different ways. 236 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 237 | 238 | # generate 239 | outputs = peft_model.base_model.model.generate( 240 | inputs_embeds=mixed_embs, 241 | max_new_tokens=args.max_txt_len, 242 | stopping_criteria=stopping_criteria, 243 | num_beams=args.beam_size, 244 | do_sample=True, 245 | min_length=1, 246 | top_p=0.9, 247 | repetition_penalty=1.0, 248 | length_penalty=1, 249 | temperature=args.temperature,) 250 | 251 | output_token = outputs[0] 252 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it 253 | output_token = output_token[1:] 254 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it 255 | output_token = output_token[1:] 256 | output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False) 257 | output_text = output_text.split('###')[0] # remove the stop sign '###' 258 | output_text = output_text.split('Assistant:')[-1].strip() 259 | refined_text = output_text 260 | 261 | print('Generated report:') 262 | print(refined_text) 263 | -------------------------------------------------------------------------------- /minigpt4/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from minigpt4.common.registry import registry 14 | 15 | from minigpt4.datasets.builders import * 16 | from minigpt4.models import * 17 | from minigpt4.processors import * 18 | from minigpt4.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /minigpt4/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /minigpt4/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__init__.py -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/dist_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/dist_utils.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/dist_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/logger.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/logger.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/optims.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/optims.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/registry.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/registry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/registry.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/common/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/common/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | # use zero optimizer for distributed initialization 81 | if args.use_zero_optimizer: 82 | print("Using ZeRO optimizer distributed mode.") 83 | import deepspeed 84 | deepspeed.init_distributed( 85 | dist_backend=args.dist_backend, 86 | init_method=args.dist_url, 87 | rank=args.rank, 88 | timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing, 89 | # config=args.deepspeed_config, 90 | ) 91 | # use pytorch distributed initialization 92 | else: 93 | print("Using PyTorch optimizer distributed mode.") 94 | torch.distributed.init_process_group( 95 | backend=args.dist_backend, 96 | init_method=args.dist_url, 97 | world_size=args.world_size, 98 | rank=args.rank, 99 | timeout=datetime.timedelta( 100 | days=365 101 | ), # allow auto-downloading and de-compressing 102 | ) 103 | torch.distributed.barrier() 104 | setup_for_distributed(args.rank == 0) 105 | 106 | 107 | def get_dist_info(): 108 | if torch.__version__ < "1.0": 109 | initialized = dist._initialized 110 | else: 111 | initialized = dist.is_initialized() 112 | if initialized: 113 | rank = dist.get_rank() 114 | world_size = dist.get_world_size() 115 | else: # non-distributed training 116 | rank = 0 117 | world_size = 1 118 | return rank, world_size 119 | 120 | 121 | def main_process(func): 122 | @functools.wraps(func) 123 | def wrapper(*args, **kwargs): 124 | rank, _ = get_dist_info() 125 | if rank == 0: 126 | return func(*args, **kwargs) 127 | 128 | return wrapper 129 | 130 | 131 | def download_cached_file(url, check_hash=True, progress=False): 132 | """ 133 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 134 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 135 | """ 136 | 137 | def get_cached_file_path(): 138 | # a hack to sync the file path across processes 139 | parts = torch.hub.urlparse(url) 140 | filename = os.path.basename(parts.path) 141 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 142 | 143 | return cached_file 144 | 145 | if is_main_process(): 146 | timm_hub.download_cached_file(url, check_hash, progress) 147 | 148 | if is_dist_avail_and_initialized(): 149 | dist.barrier() 150 | 151 | return get_cached_file_path() 152 | -------------------------------------------------------------------------------- /minigpt4/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /minigpt4/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from minigpt4.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value, 79 | ) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError( 100 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 101 | ) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def global_avg(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = "" 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt="{avg:.4f}") 129 | data_time = SmoothedValue(fmt="{avg:.4f}") 130 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 131 | log_msg = [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append("max mem: {memory:.0f}") 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print( 152 | log_msg.format( 153 | i, 154 | len(iterable), 155 | eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), 158 | data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB, 160 | ) 161 | ) 162 | else: 163 | print( 164 | log_msg.format( 165 | i, 166 | len(iterable), 167 | eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), 170 | data=str(data_time), 171 | ) 172 | ) 173 | i += 1 174 | end = time.time() 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print( 178 | "{} Total time: {} ({:.4f} s / it)".format( 179 | header, total_time_str, total_time / len(iterable) 180 | ) 181 | ) 182 | 183 | 184 | class AttrDict(dict): 185 | def __init__(self, *args, **kwargs): 186 | super(AttrDict, self).__init__(*args, **kwargs) 187 | self.__dict__ = self 188 | 189 | 190 | def setup_logger(): 191 | logging.basicConfig( 192 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 193 | format="%(asctime)s [%(levelname)s] %(message)s", 194 | handlers=[logging.StreamHandler()], 195 | ) 196 | -------------------------------------------------------------------------------- /minigpt4/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from minigpt4.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | iters_per_epoch, 63 | min_lr, 64 | init_lr, 65 | warmup_steps=0, 66 | warmup_start_lr=-1, 67 | **kwargs 68 | ): 69 | self.optimizer = optimizer 70 | 71 | self.max_epoch = max_epoch 72 | self.iters_per_epoch = iters_per_epoch 73 | self.min_lr = min_lr 74 | 75 | self.init_lr = init_lr 76 | self.warmup_steps = warmup_steps 77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 78 | 79 | def step(self, cur_epoch, cur_step): 80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 81 | if total_cur_step < self.warmup_steps: 82 | warmup_lr_schedule( 83 | step=cur_step, 84 | optimizer=self.optimizer, 85 | max_step=self.warmup_steps, 86 | init_lr=self.warmup_start_lr, 87 | max_lr=self.init_lr, 88 | ) 89 | else: 90 | cosine_lr_schedule( 91 | epoch=total_cur_step, 92 | optimizer=self.optimizer, 93 | max_epoch=self.max_epoch * self.iters_per_epoch, 94 | init_lr=self.init_lr, 95 | min_lr=self.min_lr, 96 | ) 97 | 98 | 99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 100 | """Decay the learning rate""" 101 | lr = (init_lr - min_lr) * 0.5 * ( 102 | 1.0 + math.cos(math.pi * epoch / max_epoch) 103 | ) + min_lr 104 | for param_group in optimizer.param_groups: 105 | param_group["lr"] = lr 106 | 107 | 108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 109 | """Warmup the learning rate""" 110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 111 | for param_group in optimizer.param_groups: 112 | param_group["lr"] = lr 113 | 114 | 115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 116 | """Decay the learning rate""" 117 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = lr 120 | -------------------------------------------------------------------------------- /minigpt4/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from minigpt4.common.registry import registry 31 | from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | assert issubclass( 38 | builder_cls, BaseDatasetBuilder 39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | builder_cls 41 | ) 42 | if name in cls.mapping["builder_name_mapping"]: 43 | raise KeyError( 44 | "Name '{}' already registered for {}.".format( 45 | name, cls.mapping["builder_name_mapping"][name] 46 | ) 47 | ) 48 | cls.mapping["builder_name_mapping"][name] = builder_cls 49 | return builder_cls 50 | 51 | return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from minigpt4.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from minigpt4.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from minigpt4.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from minigpt4.models import BaseModel 96 | 97 | assert issubclass( 98 | model_cls, BaseModel 99 | ), "All models must inherit BaseModel class" 100 | if name in cls.mapping["model_name_mapping"]: 101 | raise KeyError( 102 | "Name '{}' already registered for {}.".format( 103 | name, cls.mapping["model_name_mapping"][name] 104 | ) 105 | ) 106 | cls.mapping["model_name_mapping"][name] = model_cls 107 | return model_cls 108 | 109 | return wrap 110 | 111 | @classmethod 112 | def register_processor(cls, name): 113 | r"""Register a processor to registry with key 'name' 114 | 115 | Args: 116 | name: Key with which the task will be registered. 117 | 118 | Usage: 119 | 120 | from minigpt4.common.registry import registry 121 | """ 122 | 123 | def wrap(processor_cls): 124 | from minigpt4.processors import BaseProcessor 125 | 126 | assert issubclass( 127 | processor_cls, BaseProcessor 128 | ), "All processors must inherit BaseProcessor class" 129 | if name in cls.mapping["processor_name_mapping"]: 130 | raise KeyError( 131 | "Name '{}' already registered for {}.".format( 132 | name, cls.mapping["processor_name_mapping"][name] 133 | ) 134 | ) 135 | cls.mapping["processor_name_mapping"][name] = processor_cls 136 | return processor_cls 137 | 138 | return wrap 139 | 140 | @classmethod 141 | def register_lr_scheduler(cls, name): 142 | r"""Register a model to registry with key 'name' 143 | 144 | Args: 145 | name: Key with which the task will be registered. 146 | 147 | Usage: 148 | 149 | from minigpt4.common.registry import registry 150 | """ 151 | 152 | def wrap(lr_sched_cls): 153 | if name in cls.mapping["lr_scheduler_name_mapping"]: 154 | raise KeyError( 155 | "Name '{}' already registered for {}.".format( 156 | name, cls.mapping["lr_scheduler_name_mapping"][name] 157 | ) 158 | ) 159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 160 | return lr_sched_cls 161 | 162 | return wrap 163 | 164 | @classmethod 165 | def register_runner(cls, name): 166 | r"""Register a model to registry with key 'name' 167 | 168 | Args: 169 | name: Key with which the task will be registered. 170 | 171 | Usage: 172 | 173 | from minigpt4.common.registry import registry 174 | """ 175 | 176 | def wrap(runner_cls): 177 | if name in cls.mapping["runner_name_mapping"]: 178 | raise KeyError( 179 | "Name '{}' already registered for {}.".format( 180 | name, cls.mapping["runner_name_mapping"][name] 181 | ) 182 | ) 183 | cls.mapping["runner_name_mapping"][name] = runner_cls 184 | return runner_cls 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def register_path(cls, name, path): 190 | r"""Register a path to registry with key 'name' 191 | 192 | Args: 193 | name: Key with which the path will be registered. 194 | 195 | Usage: 196 | 197 | from minigpt4.common.registry import registry 198 | """ 199 | assert isinstance(path, str), "All path must be str." 200 | if name in cls.mapping["paths"]: 201 | raise KeyError("Name '{}' already registered.".format(name)) 202 | cls.mapping["paths"][name] = path 203 | 204 | @classmethod 205 | def register(cls, name, obj): 206 | r"""Register an item to registry with key 'name' 207 | 208 | Args: 209 | name: Key with which the item will be registered. 210 | 211 | Usage:: 212 | 213 | from minigpt4.common.registry import registry 214 | 215 | registry.register("config", {}) 216 | """ 217 | path = name.split(".") 218 | current = cls.mapping["state"] 219 | 220 | for part in path[:-1]: 221 | if part not in current: 222 | current[part] = {} 223 | current = current[part] 224 | 225 | current[path[-1]] = obj 226 | 227 | # @classmethod 228 | # def get_trainer_class(cls, name): 229 | # return cls.mapping["trainer_name_mapping"].get(name, None) 230 | 231 | @classmethod 232 | def get_builder_class(cls, name): 233 | return cls.mapping["builder_name_mapping"].get(name, None) 234 | 235 | @classmethod 236 | def get_model_class(cls, name): 237 | return cls.mapping["model_name_mapping"].get(name, None) 238 | 239 | @classmethod 240 | def get_task_class(cls, name): 241 | return cls.mapping["task_name_mapping"].get(name, None) 242 | 243 | @classmethod 244 | def get_processor_class(cls, name): 245 | return cls.mapping["processor_name_mapping"].get(name, None) 246 | 247 | @classmethod 248 | def get_lr_scheduler_class(cls, name): 249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 250 | 251 | @classmethod 252 | def get_runner_class(cls, name): 253 | return cls.mapping["runner_name_mapping"].get(name, None) 254 | 255 | @classmethod 256 | def list_runners(cls): 257 | return sorted(cls.mapping["runner_name_mapping"].keys()) 258 | 259 | @classmethod 260 | def list_models(cls): 261 | return sorted(cls.mapping["model_name_mapping"].keys()) 262 | 263 | @classmethod 264 | def list_tasks(cls): 265 | return sorted(cls.mapping["task_name_mapping"].keys()) 266 | 267 | @classmethod 268 | def list_processors(cls): 269 | return sorted(cls.mapping["processor_name_mapping"].keys()) 270 | 271 | @classmethod 272 | def list_lr_schedulers(cls): 273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 274 | 275 | @classmethod 276 | def list_datasets(cls): 277 | return sorted(cls.mapping["builder_name_mapping"].keys()) 278 | 279 | @classmethod 280 | def get_path(cls, name): 281 | return cls.mapping["paths"].get(name, None) 282 | 283 | @classmethod 284 | def get(cls, name, default=None, no_warning=False): 285 | r"""Get an item from registry with key 'name' 286 | 287 | Args: 288 | name (string): Key whose value needs to be retrieved. 289 | default: If passed and key is not in registry, default value will 290 | be returned with a warning. Default: None 291 | no_warning (bool): If passed as True, warning when key doesn't exist 292 | will not be generated. Useful for MMF's 293 | internal operations. Default: False 294 | """ 295 | original_name = name 296 | name = name.split(".") 297 | value = cls.mapping["state"] 298 | for subname in name: 299 | value = value.get(subname, default) 300 | if value is default: 301 | break 302 | 303 | if ( 304 | "writer" in cls.mapping["state"] 305 | and value == default 306 | and no_warning is False 307 | ): 308 | cls.mapping["state"]["writer"].warning( 309 | "Key {} is not present in registry, returning default value " 310 | "of {}".format(original_name, default) 311 | ) 312 | return value 313 | 314 | @classmethod 315 | def unregister(cls, name): 316 | r"""Remove an item from registry with key 'name' 317 | 318 | Args: 319 | name: Key which needs to be removed. 320 | Usage:: 321 | 322 | from mmf.common.registry import registry 323 | 324 | config = registry.unregister("config") 325 | """ 326 | return cls.mapping["state"].pop(name, None) 327 | 328 | 329 | registry = Registry() 330 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_sbu/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu_align: 3 | data_type: images 4 | build_info: 5 | storage: 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_sbu/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/iuxray/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | mimic_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/mimic 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/iuxray/generate_then_refine.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | mimic_generate_then_refine: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/iuxray 6 | unlabeled_annotation_path: /path/to/pubmed 7 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/laion/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | storage: 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/mimic/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | mimic_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/mimic 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/mimic/generate_then_refine.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | mimic_generate_then_refine: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/mimic 6 | unlabeled_annotation_path: /path/to/pubmed 7 | -------------------------------------------------------------------------------- /minigpt4/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with persistent storage 5 | cache_root: "/export/home/.cache/minigpt4" 6 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4-7b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # Vicuna 16 | llama_model: "/path/to/vicuna-7b" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "blip2_image_train" 25 | image_size: 224 26 | eval: 27 | name: "blip2_image_eval" 28 | image_size: 224 29 | text_processor: 30 | train: 31 | name: "blip_caption" 32 | eval: 33 | name: "blip_caption" 34 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # Vicuna 16 | llama_model: "/path/to/vicuna-13b" 17 | 18 | # generation configs 19 | prompt: "" 20 | 21 | preprocess: 22 | vis_processor: 23 | train: 24 | name: "blip2_image_train" 25 | image_size: 224 26 | eval: 27 | name: "blip2_image_eval" 28 | image_size: 224 29 | text_processor: 30 | train: 31 | name: "blip_caption" 32 | eval: 33 | name: "blip_caption" 34 | -------------------------------------------------------------------------------- /minigpt4/conversation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__init__.py -------------------------------------------------------------------------------- /minigpt4/conversation/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/conversation/__pycache__/conversation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/conversation/conversation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from PIL import Image 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 7 | from transformers import StoppingCriteria, StoppingCriteriaList 8 | 9 | import dataclasses 10 | from enum import auto, Enum 11 | from typing import List, Tuple, Any 12 | 13 | from minigpt4.common.registry import registry 14 | 15 | 16 | class SeparatorStyle(Enum): 17 | """Different separator style.""" 18 | SINGLE = auto() 19 | TWO = auto() 20 | 21 | 22 | @dataclasses.dataclass 23 | class Conversation: 24 | """A class that keeps all conversation history.""" 25 | system: str 26 | roles: List[str] 27 | messages: List[List[str]] 28 | offset: int 29 | # system_img: List[Image.Image] = [] 30 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 31 | sep: str = "###" 32 | sep2: str = None 33 | 34 | skip_next: bool = False 35 | conv_id: Any = None 36 | 37 | def get_prompt(self): 38 | if self.sep_style == SeparatorStyle.SINGLE: 39 | ret = self.system + self.sep 40 | for role, message in self.messages: 41 | if message: 42 | ret += role + ": " + message + self.sep 43 | else: 44 | ret += role + ":" 45 | return ret 46 | elif self.sep_style == SeparatorStyle.TWO: 47 | seps = [self.sep, self.sep2] 48 | ret = self.system + seps[0] 49 | for i, (role, message) in enumerate(self.messages): 50 | if message: 51 | ret += role + ": " + message + seps[i % 2] 52 | else: 53 | ret += role + ":" 54 | return ret 55 | else: 56 | raise ValueError(f"Invalid style: {self.sep_style}") 57 | 58 | def append_message(self, role, message): 59 | self.messages.append([role, message]) 60 | 61 | def to_gradio_chatbot(self): 62 | ret = [] 63 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 64 | if i % 2 == 0: 65 | ret.append([msg, None]) 66 | else: 67 | ret[-1][-1] = msg 68 | return ret 69 | 70 | def copy(self): 71 | return Conversation( 72 | system=self.system, 73 | # system_img=self.system_img, 74 | roles=self.roles, 75 | messages=[[x, y] for x, y in self.messages], 76 | offset=self.offset, 77 | sep_style=self.sep_style, 78 | sep=self.sep, 79 | sep2=self.sep2, 80 | conv_id=self.conv_id) 81 | 82 | def dict(self): 83 | return { 84 | "system": self.system, 85 | # "system_img": self.system_img, 86 | "roles": self.roles, 87 | "messages": self.messages, 88 | "offset": self.offset, 89 | "sep": self.sep, 90 | "sep2": self.sep2, 91 | "conv_id": self.conv_id, 92 | } 93 | 94 | 95 | class StoppingCriteriaSub(StoppingCriteria): 96 | 97 | def __init__(self, stops=[], encounters=1): 98 | super().__init__() 99 | self.stops = stops 100 | 101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 102 | for stop in self.stops: 103 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 104 | return True 105 | 106 | return False 107 | 108 | 109 | CONV_VISION = Conversation( 110 | system="Give the following image: ImageContent. " 111 | "You will be able to see the image once I provide it to you. Please answer my questions.", 112 | roles=("Human", "Assistant"), 113 | messages=[], 114 | offset=2, 115 | sep_style=SeparatorStyle.SINGLE, 116 | sep="###", 117 | ) 118 | 119 | 120 | 121 | class Chat: 122 | def __init__(self, model, vis_processor, device='cuda:0'): 123 | self.device = device 124 | self.model = model 125 | self.vis_processor = vis_processor 126 | stop_words_ids = [torch.tensor([835]).to(self.device), 127 | torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. 128 | self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 129 | 130 | def ask(self, text, conv): 131 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ 132 | and conv.messages[-1][1][-6:] == '': # last message is image. 133 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) 134 | else: 135 | conv.append_message(conv.roles[0], text) 136 | 137 | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, 138 | repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): 139 | conv.append_message(conv.roles[1], None) 140 | embs = self.get_context_emb(conv, img_list) 141 | 142 | current_max_len = embs.shape[1] + max_new_tokens 143 | if current_max_len - max_length > 0: 144 | print('Warning: The number of tokens in current conversation exceeds the max length. ' 145 | 'The model will not see the contexts outside the range.') 146 | begin_idx = max(0, current_max_len - max_length) 147 | 148 | embs = embs[:, begin_idx:] 149 | 150 | outputs = self.model.llama_model.generate( 151 | inputs_embeds=embs, 152 | max_new_tokens=max_new_tokens, 153 | stopping_criteria=self.stopping_criteria, 154 | num_beams=num_beams, 155 | do_sample=True, 156 | min_length=min_length, 157 | top_p=top_p, 158 | repetition_penalty=repetition_penalty, 159 | length_penalty=length_penalty, 160 | temperature=temperature, 161 | ) 162 | output_token = outputs[0] 163 | if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it 164 | output_token = output_token[1:] 165 | if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it 166 | output_token = output_token[1:] 167 | output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) 168 | output_text = output_text.split('###')[0] # remove the stop sign '###' 169 | output_text = output_text.split('Assistant:')[-1].strip() 170 | conv.messages[-1][1] = output_text 171 | return output_text, output_token.cpu().numpy() 172 | 173 | def upload_img(self, image, conv, img_list): 174 | if isinstance(image, str): # is a image path 175 | raw_image = Image.open(image).convert('RGB') 176 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 177 | elif isinstance(image, Image.Image): 178 | raw_image = image 179 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 180 | elif isinstance(image, torch.Tensor): 181 | if len(image.shape) == 3: 182 | image = image.unsqueeze(0) 183 | image = image.to(self.device) 184 | 185 | image_emb, _ = self.model.encode_img(image) 186 | img_list.append(image_emb) 187 | conv.append_message(conv.roles[0], "") 188 | msg = "Received." 189 | # self.conv.append_message(self.conv.roles[1], msg) 190 | return msg 191 | 192 | def get_context_emb(self, conv, img_list): 193 | prompt = conv.get_prompt() 194 | prompt_segs = prompt.split('') 195 | assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." 196 | seg_tokens = [ 197 | self.model.llama_tokenizer( 198 | seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids 199 | # only add bos to the first seg 200 | for i, seg in enumerate(prompt_segs) 201 | ] 202 | seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] 203 | mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] 204 | mixed_embs = torch.cat(mixed_embs, dim=1) 205 | return mixed_embs 206 | 207 | 208 | -------------------------------------------------------------------------------- /minigpt4/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__init__.py -------------------------------------------------------------------------------- /minigpt4/datasets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config 9 | from minigpt4.datasets.builders.image_text_pair_builder import ( 10 | CCSBUBuilder, 11 | LaionBuilder, 12 | CCSBUAlignBuilder 13 | ) 14 | from minigpt4.common.registry import registry 15 | 16 | __all__ = [ 17 | "CCSBUBuilder", 18 | "LaionBuilder", 19 | "CCSBUAlignBuilder" 20 | ] 21 | 22 | 23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 24 | """ 25 | Example 26 | 27 | >>> dataset = load_dataset("coco_caption", cfg=None) 28 | >>> splits = dataset.keys() 29 | >>> print([len(dataset[split]) for split in splits]) 30 | 31 | """ 32 | if cfg_path is None: 33 | cfg = None 34 | else: 35 | cfg = load_dataset_config(cfg_path) 36 | 37 | try: 38 | builder = registry.get_builder_class(name)(cfg) 39 | except TypeError: 40 | print( 41 | f"Dataset {name} not found. Available datasets:\n" 42 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 43 | ) 44 | exit(1) 45 | 46 | if vis_path is not None: 47 | if data_type is None: 48 | # use default data type in the config 49 | data_type = builder.config.data_type 50 | 51 | assert ( 52 | data_type in builder.config.build_info 53 | ), f"Invalid data_type {data_type} for {name}." 54 | 55 | builder.config.build_info.get(data_type).storage = vis_path 56 | 57 | dataset = builder.build_datasets() 58 | return dataset 59 | 60 | 61 | class DatasetZoo: 62 | def __init__(self) -> None: 63 | self.dataset_zoo = { 64 | k: list(v.DATASET_CONFIG_DICT.keys()) 65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 66 | } 67 | 68 | def get_names(self): 69 | return list(self.dataset_zoo.keys()) 70 | 71 | 72 | dataset_zoo = DatasetZoo() 73 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import minigpt4.common.utils as utils 19 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from minigpt4.common.registry import registry 21 | from minigpt4.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class BaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | 31 | if cfg is None: 32 | # help to create datasets from default config. 33 | self.config = load_dataset_config(self.default_config_path()) 34 | elif isinstance(cfg, str): 35 | self.config = load_dataset_config(cfg) 36 | else: 37 | # when called from task.build_dataset() 38 | self.config = cfg 39 | 40 | self.data_type = self.config.data_type 41 | 42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 44 | 45 | def build_datasets(self): 46 | # download, split, etc... 47 | # only called on 1 GPU/TPU in distributed 48 | 49 | if is_main_process(): 50 | self._download_data() 51 | 52 | if is_dist_avail_and_initialized(): 53 | dist.barrier() 54 | 55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 56 | logging.info("Building datasets...") 57 | datasets = self.build() # dataset['train'/'val'/'test'] 58 | 59 | return datasets 60 | 61 | def build_processors(self): 62 | vis_proc_cfg = self.config.get("vis_processor") 63 | txt_proc_cfg = self.config.get("text_processor") 64 | 65 | if vis_proc_cfg is not None: 66 | vis_train_cfg = vis_proc_cfg.get("train") 67 | vis_eval_cfg = vis_proc_cfg.get("eval") 68 | 69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 71 | 72 | if txt_proc_cfg is not None: 73 | txt_train_cfg = txt_proc_cfg.get("train") 74 | txt_eval_cfg = txt_proc_cfg.get("eval") 75 | 76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 78 | 79 | @staticmethod 80 | def _build_proc_from_cfg(cfg): 81 | return ( 82 | registry.get_processor_class(cfg.name).from_config(cfg) 83 | if cfg is not None 84 | else None 85 | ) 86 | 87 | @classmethod 88 | def default_config_path(cls, type="default"): 89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 90 | 91 | def _download_data(self): 92 | self._download_ann() 93 | self._download_vis() 94 | 95 | def _download_ann(self): 96 | """ 97 | Download annotation files if necessary. 98 | All the vision-language datasets should have annotations of unified format. 99 | 100 | storage_path can be: 101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 103 | 104 | Local annotation paths should be relative. 105 | """ 106 | anns = self.config.build_info.annotations 107 | 108 | splits = anns.keys() 109 | 110 | cache_root = registry.get_path("cache_root") 111 | 112 | for split in splits: 113 | info = anns[split] 114 | 115 | urls, storage_paths = info.get("url", None), info.storage 116 | 117 | if isinstance(urls, str): 118 | urls = [urls] 119 | if isinstance(storage_paths, str): 120 | storage_paths = [storage_paths] 121 | 122 | assert len(urls) == len(storage_paths) 123 | 124 | for url_or_filename, storage_path in zip(urls, storage_paths): 125 | # if storage_path is relative, make it full by prefixing with cache_root. 126 | if not os.path.isabs(storage_path): 127 | storage_path = os.path.join(cache_root, storage_path) 128 | 129 | dirname = os.path.dirname(storage_path) 130 | if not os.path.exists(dirname): 131 | os.makedirs(dirname) 132 | 133 | if os.path.isfile(url_or_filename): 134 | src, dst = url_or_filename, storage_path 135 | if not os.path.exists(dst): 136 | shutil.copyfile(src=src, dst=dst) 137 | else: 138 | logging.info("Using existing file {}.".format(dst)) 139 | else: 140 | if os.path.isdir(storage_path): 141 | # if only dirname is provided, suffix with basename of URL. 142 | raise ValueError( 143 | "Expecting storage_path to be a file path, got directory {}".format( 144 | storage_path 145 | ) 146 | ) 147 | else: 148 | filename = os.path.basename(storage_path) 149 | 150 | download_url(url=url_or_filename, root=dirname, filename=filename) 151 | 152 | def _download_vis(self): 153 | 154 | storage_path = self.config.build_info.get(self.data_type).storage 155 | storage_path = utils.get_cache_path(storage_path) 156 | 157 | if not os.path.exists(storage_path): 158 | warnings.warn( 159 | f""" 160 | The specified path {storage_path} for visual inputs does not exist. 161 | Please provide a correct path to the visual inputs or 162 | refer to datasets/download_scripts/README.md for downloading instructions. 163 | """ 164 | ) 165 | 166 | def build(self): 167 | """ 168 | Create by split datasets inheriting torch.utils.data.Datasets. 169 | 170 | # build() can be dataset-specific. Overwrite to customize. 171 | """ 172 | self.build_processors() 173 | 174 | build_info = self.config.build_info 175 | 176 | ann_info = build_info.annotations 177 | vis_info = build_info.get(self.data_type) 178 | 179 | datasets = dict() 180 | for split in ann_info.keys(): 181 | if split not in ["train", "val", "test"]: 182 | continue 183 | 184 | is_train = split == "train" 185 | 186 | # processors 187 | vis_processor = ( 188 | self.vis_processors["train"] 189 | if is_train 190 | else self.vis_processors["eval"] 191 | ) 192 | text_processor = ( 193 | self.text_processors["train"] 194 | if is_train 195 | else self.text_processors["eval"] 196 | ) 197 | 198 | # annotation path 199 | ann_paths = ann_info.get(split).storage 200 | if isinstance(ann_paths, str): 201 | ann_paths = [ann_paths] 202 | 203 | abs_ann_paths = [] 204 | for ann_path in ann_paths: 205 | if not os.path.isabs(ann_path): 206 | ann_path = utils.get_cache_path(ann_path) 207 | abs_ann_paths.append(ann_path) 208 | ann_paths = abs_ann_paths 209 | 210 | # visual data storage path 211 | vis_path = os.path.join(vis_info.storage, split) 212 | 213 | if not os.path.isabs(vis_path): 214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 215 | vis_path = utils.get_cache_path(vis_path) 216 | 217 | if not os.path.exists(vis_path): 218 | warnings.warn("storage path {} does not exist.".format(vis_path)) 219 | 220 | # create datasets 221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 222 | datasets[split] = dataset_cls( 223 | vis_processor=vis_processor, 224 | text_processor=text_processor, 225 | ann_paths=ann_paths, 226 | vis_root=vis_path, 227 | ) 228 | 229 | return datasets 230 | 231 | 232 | def load_dataset_config(cfg_path): 233 | cfg = OmegaConf.load(cfg_path).datasets 234 | cfg = cfg[list(cfg.keys())[0]] 235 | 236 | return cfg 237 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from minigpt4.common.registry import registry 6 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from minigpt4.datasets.datasets.laion_dataset import LaionDataset 8 | from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset 9 | from minigpt4.datasets.datasets.mimic_dataset import MIMICDataset, MIMICGenerateThenRefineDataset 10 | from minigpt4.datasets.datasets.iuxray_dataset import IUXRayGenerateThenRefineDataset 11 | 12 | 13 | @registry.register_builder("cc_sbu") 14 | class CCSBUBuilder(BaseDatasetBuilder): 15 | train_dataset_cls = CCSBUDataset 16 | 17 | DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} 18 | 19 | def _download_ann(self): 20 | pass 21 | 22 | def _download_vis(self): 23 | pass 24 | 25 | def build(self): 26 | self.build_processors() 27 | 28 | build_info = self.config.build_info 29 | 30 | datasets = dict() 31 | split = "train" 32 | 33 | # create datasets 34 | # [NOTE] return inner_datasets (wds.DataPipeline) 35 | dataset_cls = self.train_dataset_cls 36 | datasets[split] = dataset_cls( 37 | vis_processor=self.vis_processors[split], 38 | text_processor=self.text_processors[split], 39 | location=build_info.storage, 40 | ).inner_dataset 41 | 42 | return datasets 43 | 44 | 45 | @registry.register_builder("laion") 46 | class LaionBuilder(BaseDatasetBuilder): 47 | train_dataset_cls = LaionDataset 48 | 49 | DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} 50 | 51 | def _download_ann(self): 52 | pass 53 | 54 | def _download_vis(self): 55 | pass 56 | 57 | def build(self): 58 | self.build_processors() 59 | 60 | build_info = self.config.build_info 61 | 62 | datasets = dict() 63 | split = "train" 64 | 65 | # create datasets 66 | # [NOTE] return inner_datasets (wds.DataPipeline) 67 | dataset_cls = self.train_dataset_cls 68 | datasets[split] = dataset_cls( 69 | vis_processor=self.vis_processors[split], 70 | text_processor=self.text_processors[split], 71 | location=build_info.storage, 72 | ).inner_dataset 73 | 74 | return datasets 75 | 76 | 77 | @registry.register_builder("cc_sbu_align") 78 | class CCSBUAlignBuilder(BaseDatasetBuilder): 79 | train_dataset_cls = CCSBUAlignDataset 80 | 81 | DATASET_CONFIG_DICT = { 82 | "default": "configs/datasets/cc_sbu/align.yaml", 83 | } 84 | 85 | def build_datasets(self): 86 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 87 | logging.info("Building datasets...") 88 | self.build_processors() 89 | 90 | build_info = self.config.build_info 91 | storage_path = build_info.storage 92 | 93 | datasets = dict() 94 | 95 | if not os.path.exists(storage_path): 96 | warnings.warn("storage path {} does not exist.".format(storage_path)) 97 | 98 | # create datasets 99 | dataset_cls = self.train_dataset_cls 100 | datasets['train'] = dataset_cls( 101 | vis_processor=self.vis_processors["train"], 102 | text_processor=self.text_processors["train"], 103 | ann_paths=[os.path.join(storage_path, 'filter_cap.json')], 104 | vis_root=os.path.join(storage_path, 'image'), 105 | ) 106 | 107 | return datasets 108 | 109 | @registry.register_builder("mimic_align") 110 | class MIMICBuilder(BaseDatasetBuilder): 111 | train_dataset_cls = MIMICDataset 112 | 113 | DATASET_CONFIG_DICT = { 114 | "default": "minigpt4/configs/datasets/mimic/align.yaml", 115 | } 116 | 117 | def build_datasets(self): 118 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 119 | logging.info("Building datasets...") 120 | self.build_processors() 121 | 122 | build_info = self.config.build_info 123 | storage_path = build_info.storage 124 | 125 | datasets = dict() 126 | 127 | if not os.path.exists(storage_path): 128 | warnings.warn("storage path {} does not exist.".format(storage_path)) 129 | 130 | # create datasets 131 | dataset_cls = self.train_dataset_cls 132 | datasets['train'] = dataset_cls( 133 | vis_processor=self.vis_processors["train"], 134 | text_processor=self.text_processors["train"], 135 | ann_path=os.path.join(storage_path, 'annotation.json'), 136 | image_root=os.path.join(storage_path, 'images'), 137 | ) 138 | 139 | return datasets 140 | 141 | @registry.register_builder("mimic_generate_then_refine") 142 | class MIMICGenerateThenRefineBuilder(BaseDatasetBuilder): 143 | train_dataset_cls = MIMICGenerateThenRefineDataset 144 | 145 | DATASET_CONFIG_DICT = { 146 | "default": "minigpt4/configs/datasets/mimic/generate_then_refine.yaml", 147 | } 148 | 149 | def build_datasets(self): 150 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 151 | logging.info("Building datasets...") 152 | self.build_processors() 153 | 154 | build_info = self.config.build_info 155 | storage_path = build_info.storage 156 | unlabeled_annotation_path = build_info.unlabeled_annotation_path 157 | 158 | datasets = dict() 159 | 160 | if not os.path.exists(storage_path): 161 | warnings.warn("storage path {} does not exist.".format(storage_path)) 162 | 163 | # create datasets 164 | dataset_cls = self.train_dataset_cls 165 | datasets['train'] = dataset_cls( 166 | vis_processor=self.vis_processors["train"], 167 | text_processor=self.text_processors["train"], 168 | ann_path=os.path.join(storage_path, 'mimic_anno_with_ref.json'), 169 | image_root=os.path.join(storage_path, 'images'), 170 | unlabeled_ann_path=os.path.join(unlabeled_annotation_path, 'annotation.json'), 171 | ) 172 | 173 | return datasets 174 | 175 | @registry.register_builder("iuxray_generate_then_refine") 176 | class IUXRayGenerateThenRefineBuilder(BaseDatasetBuilder): 177 | train_dataset_cls = IUXRayGenerateThenRefineDataset 178 | 179 | DATASET_CONFIG_DICT = { 180 | "default": "minigpt4/configs/datasets/iuxray/generate_then_refine.yaml", 181 | } 182 | 183 | def build_datasets(self): 184 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 185 | logging.info("Building datasets...") 186 | self.build_processors() 187 | 188 | build_info = self.config.build_info 189 | storage_path = build_info.storage 190 | unlabeled_annotation_path = build_info.unlabeled_annotation_path 191 | 192 | datasets = dict() 193 | 194 | if not os.path.exists(storage_path): 195 | warnings.warn("storage path {} does not exist.".format(storage_path)) 196 | 197 | # create datasets 198 | dataset_cls = self.train_dataset_cls 199 | datasets['train'] = dataset_cls( 200 | vis_processor=self.vis_processors["train"], 201 | text_processor=self.text_processors["train"], 202 | ann_path=os.path.join(storage_path, 'annotation.json'), 203 | image_root=os.path.join(storage_path, 'images'), 204 | unlabeled_ann_path=os.path.join(unlabeled_annotation_path, 'annotation.json'), 205 | ) 206 | 207 | return datasets -------------------------------------------------------------------------------- /minigpt4/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from minigpt4.common.registry import registry 26 | from minigpt4.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | 114 | # reorganize by split 115 | for _, dataset in datasets.items(): 116 | for split_name, dataset_split in dataset.items(): 117 | if split_name not in reorg_datasets: 118 | reorg_datasets[split_name] = [dataset_split] 119 | else: 120 | reorg_datasets[split_name].append(dataset_split) 121 | 122 | return reorg_datasets 123 | 124 | 125 | def concat_datasets(datasets): 126 | """ 127 | Concatenates multiple datasets into a single dataset. 128 | 129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 130 | generic IterableDataset because it requires creating separate samplers. 131 | 132 | Now only supports conctenating training datasets and assuming validation and testing 133 | have only a single dataset. This is because metrics should not be computed on the concatenated 134 | datasets. 135 | 136 | Args: 137 | datasets: dict of torch.utils.data.Dataset objects by split. 138 | 139 | Returns: 140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 141 | "val" and "test" remain the same. 142 | 143 | If the input training datasets contain both map-style and DataPipeline datasets, returns 144 | a tuple, where the first element is a concatenated map-style dataset and the second 145 | element is a chained DataPipeline dataset. 146 | 147 | """ 148 | # concatenate datasets in the same split 149 | for split_name in datasets: 150 | if split_name != "train": 151 | assert ( 152 | len(datasets[split_name]) == 1 153 | ), "Do not support multiple {} datasets.".format(split_name) 154 | datasets[split_name] = datasets[split_name][0] 155 | else: 156 | iterable_datasets, map_datasets = [], [] 157 | for dataset in datasets[split_name]: 158 | if isinstance(dataset, wds.DataPipeline): 159 | logging.info( 160 | "Dataset {} is IterableDataset, can't be concatenated.".format( 161 | dataset 162 | ) 163 | ) 164 | iterable_datasets.append(dataset) 165 | elif isinstance(dataset, IterableDataset): 166 | raise NotImplementedError( 167 | "Do not support concatenation of generic IterableDataset." 168 | ) 169 | else: 170 | map_datasets.append(dataset) 171 | 172 | # if len(iterable_datasets) > 0: 173 | # concatenate map-style datasets and iterable-style datasets separately 174 | if len(iterable_datasets) > 1: 175 | chained_datasets = ( 176 | ChainDataset(iterable_datasets) 177 | ) 178 | elif len(iterable_datasets) == 1: 179 | chained_datasets = iterable_datasets[0] 180 | else: 181 | chained_datasets = None 182 | 183 | concat_datasets = ( 184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 185 | ) 186 | 187 | train_datasets = concat_datasets, chained_datasets 188 | train_datasets = tuple([x for x in train_datasets if x is not None]) 189 | train_datasets = ( 190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 191 | ) 192 | 193 | datasets[split_name] = train_datasets 194 | 195 | return datasets 196 | 197 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__init__.py -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/iuxray_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/datasets/datasets/__pycache__/mimic_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | from collections import OrderedDict 10 | 11 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 12 | from PIL import Image 13 | 14 | 15 | class __DisplMixin: 16 | def displ_item(self, index): 17 | sample, ann = self.__getitem__(index), self.annotation[index] 18 | 19 | return OrderedDict( 20 | { 21 | "file": ann["image"], 22 | "caption": ann["caption"], 23 | "image": sample["image"], 24 | } 25 | ) 26 | 27 | 28 | class CaptionDataset(BaseDataset, __DisplMixin): 29 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 30 | """ 31 | vis_root (string): Root directory of images (e.g. coco/images/) 32 | ann_root (string): directory to store the annotation file 33 | """ 34 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 35 | 36 | self.img_ids = {} 37 | n = 0 38 | for ann in self.annotation: 39 | img_id = ann["image_id"] 40 | if img_id not in self.img_ids.keys(): 41 | self.img_ids[img_id] = n 42 | n += 1 43 | 44 | def __getitem__(self, index): 45 | 46 | # TODO this assumes image input, not general enough 47 | ann = self.annotation[index] 48 | 49 | img_file = '{:0>12}.jpg'.format(ann["image_id"]) 50 | image_path = os.path.join(self.vis_root, img_file) 51 | image = Image.open(image_path).convert("RGB") 52 | 53 | image = self.vis_processor(image) 54 | caption = self.text_processor(ann["caption"]) 55 | 56 | return { 57 | "image": image, 58 | "text_input": caption, 59 | "image_id": self.img_ids[ann["image_id"]], 60 | } 61 | 62 | 63 | class CaptionEvalDataset(BaseDataset, __DisplMixin): 64 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 65 | """ 66 | vis_root (string): Root directory of images (e.g. coco/images/) 67 | ann_root (string): directory to store the annotation file 68 | split (string): val or test 69 | """ 70 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.vis_root, ann["image"]) 77 | image = Image.open(image_path).convert("RGB") 78 | 79 | image = self.vis_processor(image) 80 | 81 | return { 82 | "image": image, 83 | "image_id": ann["image_id"], 84 | "instance_id": ann["instance_id"], 85 | } 86 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/cc_sbu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import webdataset as wds 4 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 5 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 6 | 7 | 8 | class CCSBUDataset(BaseDataset): 9 | def __init__(self, vis_processor, text_processor, location): 10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 11 | 12 | self.inner_dataset = wds.DataPipeline( 13 | wds.ResampledShards(location), 14 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 15 | wds.shuffle(1000, handler=wds.warn_and_continue), 16 | wds.decode("pilrgb", handler=wds.warn_and_continue), 17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 19 | wds.map(self.to_dict, handler=wds.warn_and_continue), 20 | ) 21 | 22 | def to_dict(self, sample): 23 | return { 24 | "image": sample[0], 25 | "text_input": self.text_processor(sample[1]["caption"]), 26 | } 27 | 28 | 29 | class CCSBUAlignDataset(CaptionDataset): 30 | 31 | def __getitem__(self, index): 32 | 33 | # TODO this assumes image input, not general enough 34 | ann = self.annotation[index] 35 | 36 | img_file = '{}.jpg'.format(ann["image_id"]) 37 | image_path = os.path.join(self.vis_root, img_file) 38 | image = Image.open(image_path).convert("RGB") 39 | 40 | image = self.vis_processor(image) 41 | caption = ann["caption"] 42 | 43 | return { 44 | "image": image, 45 | "text_input": caption, 46 | "image_id": self.img_ids[ann["image_id"]], 47 | } -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from minigpt4.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class MultiIterLoader: 16 | """ 17 | A simple wrapper for iterating over multiple iterators. 18 | 19 | Args: 20 | loaders (List[Loader]): List of Iterator loaders. 21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | """ 23 | 24 | def __init__(self, loaders, ratios=None): 25 | # assert all loaders has __next__ method 26 | for loader in loaders: 27 | assert hasattr( 28 | loader, "__next__" 29 | ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | if ratios is None: 32 | ratios = [1.0] * len(loaders) 33 | else: 34 | assert len(ratios) == len(loaders) 35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | self.loaders = loaders 38 | self.ratios = ratios 39 | 40 | def __next__(self): 41 | # random sample from each loader by ratio 42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | return next(self.loaders[loader_idx]) 44 | 45 | 46 | class PrefetchLoader(object): 47 | """ 48 | Modified from https://github.com/ChenRocks/UNITER. 49 | 50 | overlap compute and cuda data transfer 51 | (copied and then modified from nvidia apex) 52 | """ 53 | 54 | def __init__(self, loader): 55 | self.loader = loader 56 | self.stream = torch.cuda.Stream() 57 | 58 | def __iter__(self): 59 | loader_it = iter(self.loader) 60 | self.preload(loader_it) 61 | batch = self.next(loader_it) 62 | while batch is not None: 63 | is_tuple = isinstance(batch, tuple) 64 | if is_tuple: 65 | task, batch = batch 66 | 67 | if is_tuple: 68 | yield task, batch 69 | else: 70 | yield batch 71 | batch = self.next(loader_it) 72 | 73 | def __len__(self): 74 | return len(self.loader) 75 | 76 | def preload(self, it): 77 | try: 78 | self.batch = next(it) 79 | except StopIteration: 80 | self.batch = None 81 | return 82 | # if record_stream() doesn't work, another option is to make sure 83 | # device inputs are created on the main stream. 84 | # self.next_input_gpu = torch.empty_like(self.next_input, 85 | # device='cuda') 86 | # self.next_target_gpu = torch.empty_like(self.next_target, 87 | # device='cuda') 88 | # Need to make sure the memory allocated for next_* is not still in use 89 | # by the main stream at the time we start copying to next_*: 90 | # self.stream.wait_stream(torch.cuda.current_stream()) 91 | with torch.cuda.stream(self.stream): 92 | self.batch = move_to_cuda(self.batch) 93 | # more code for the alternative if record_stream() doesn't work: 94 | # copy_ will record the use of the pinned source tensor in this 95 | # side stream. 96 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 97 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 98 | # self.next_input = self.next_input_gpu 99 | # self.next_target = self.next_target_gpu 100 | 101 | def next(self, it): 102 | torch.cuda.current_stream().wait_stream(self.stream) 103 | batch = self.batch 104 | if batch is not None: 105 | record_cuda_stream(batch) 106 | self.preload(it) 107 | return batch 108 | 109 | def __getattr__(self, name): 110 | method = self.loader.__getattribute__(name) 111 | return method 112 | 113 | 114 | def record_cuda_stream(batch): 115 | if isinstance(batch, torch.Tensor): 116 | batch.record_stream(torch.cuda.current_stream()) 117 | elif isinstance(batch, list) or isinstance(batch, tuple): 118 | for t in batch: 119 | record_cuda_stream(t) 120 | elif isinstance(batch, dict): 121 | for t in batch.values(): 122 | record_cuda_stream(t) 123 | else: 124 | pass 125 | 126 | 127 | class IterLoader: 128 | """ 129 | A wrapper to convert DataLoader as an infinite iterator. 130 | 131 | Modified from: 132 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 133 | """ 134 | 135 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 136 | self._dataloader = dataloader 137 | self.iter_loader = iter(self._dataloader) 138 | self._use_distributed = use_distributed 139 | self._epoch = 0 140 | 141 | @property 142 | def epoch(self) -> int: 143 | return self._epoch 144 | 145 | def __next__(self): 146 | try: 147 | data = next(self.iter_loader) 148 | except StopIteration: 149 | self._epoch += 1 150 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 151 | self._dataloader.sampler.set_epoch(self._epoch) 152 | time.sleep(2) # Prevent possible deadlock during epoch transition 153 | self.iter_loader = iter(self._dataloader) 154 | data = next(self.iter_loader) 155 | 156 | return data 157 | 158 | def __iter__(self): 159 | return self 160 | 161 | def __len__(self): 162 | return len(self._dataloader) 163 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/iuxray_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from PIL import Image 5 | import webdataset as wds 6 | import random 7 | from torch.utils.data import Dataset 8 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 9 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 10 | 11 | 12 | class IUXRayDataset(Dataset): 13 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None): 14 | self.image_root = image_root 15 | self.ann_path = ann_path 16 | 17 | self.vis_processor = vis_processor 18 | self.text_processor = text_processor 19 | 20 | # load annotation file 21 | with open(ann_path, 'r') as f: 22 | self.annotations = json.load(f) 23 | self.train_data = self.annotations['train'] 24 | 25 | def __len__(self): 26 | return len(self.train_data) 27 | 28 | def __getitem__(self, index): 29 | data_sample = self.train_data[index] 30 | image_path = data_sample['image_path'] 31 | 32 | # load image 33 | image_id = data_sample['id'] 34 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB') 35 | image = self.vis_processor(image) 36 | 37 | # load caption 38 | caption = data_sample['report'] 39 | caption = self.clean_reports(caption) 40 | 41 | return {"image": image, 42 | "text_input": caption, 43 | "image_id": image_id} 44 | 45 | def clean_reports(self, report): 46 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 47 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 48 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 49 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 50 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 51 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 52 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 53 | .strip().lower().split('. ') 54 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 55 | .replace('\\', '').replace("'", '').strip().lower()) 56 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 57 | report = ' . '.join(tokens) + ' .' 58 | return report 59 | 60 | class IUXRayGenerateThenRefineDataset(Dataset): 61 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None, unlabeled_ann_path=None, retrieval_size=3): 62 | self.image_root = image_root 63 | self.ann_path = ann_path 64 | self.retrieval_size = retrieval_size 65 | 66 | self.vis_processor = vis_processor 67 | self.text_processor = text_processor 68 | 69 | # load annotation file 70 | with open(ann_path, 'r') as f: 71 | self.annotations = json.load(f) 72 | self.train_data = self.annotations['train'] 73 | 74 | # load unlabeled data 75 | self.unlabeled_data_list = [] 76 | with open(unlabeled_ann_path, 'r') as f: 77 | for line in f.readlines: 78 | self.unlabeled_data_list.append(line.strip('\n')) 79 | 80 | print(f"There are total {len(self.unlabeled_data_list)} unlabeled reports.") 81 | 82 | def __len__(self): 83 | return len(self.train_data) 84 | 85 | def __getitem__(self, index): 86 | data_sample = self.train_data[index] 87 | image_path = data_sample['image_path'] 88 | 89 | # load image 90 | image_id = data_sample['id'] 91 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB') 92 | image = self.vis_processor(image) 93 | 94 | # load caption 95 | caption = data_sample['report'] 96 | caption = self.clean_reports(caption) 97 | 98 | # load reference caption 99 | ref_caption = data_sample['ref_report'] 100 | ref_caption = self.clean_reports(ref_caption) 101 | 102 | # load unlabeled caption 103 | unlabeled_caption = random.sample(self.unlabeled_data_list, self.retrieval_size) 104 | 105 | return {"image": image, 106 | "text_input": caption, 107 | "ref_caption": ref_caption, 108 | "unlabeled_caption": unlabeled_caption, 109 | "image_id": image_id} 110 | 111 | def clean_report_iu_xray(self, report): 112 | report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ 113 | .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ 114 | .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 115 | .strip().lower().split('. ') 116 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). 117 | replace('\\', '').replace("'", '').strip().lower()) 118 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 119 | report = ' . '.join(tokens) + ' .' 120 | return report -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 10 | 11 | 12 | class LaionDataset(BaseDataset): 13 | def __init__(self, vis_processor, text_processor, location): 14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 15 | 16 | self.inner_dataset = wds.DataPipeline( 17 | wds.ResampledShards(location), 18 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 19 | wds.shuffle(1000, handler=wds.warn_and_continue), 20 | wds.decode("pilrgb", handler=wds.warn_and_continue), 21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 23 | wds.map(self.to_dict, handler=wds.warn_and_continue), 24 | ) 25 | 26 | def to_dict(self, sample): 27 | return { 28 | "image": sample[0], 29 | "text_input": self.text_processor(sample[1]["caption"]), 30 | } 31 | 32 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/mimic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from PIL import Image 5 | import webdataset as wds 6 | import random 7 | from torch.utils.data import Dataset 8 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 9 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 10 | 11 | 12 | class MIMICDataset(Dataset): 13 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None): 14 | self.image_root = image_root 15 | self.ann_path = ann_path 16 | 17 | self.vis_processor = vis_processor 18 | self.text_processor = text_processor 19 | 20 | # load annotation file 21 | with open(ann_path, 'r') as f: 22 | self.annotations = json.load(f) 23 | self.train_data = self.annotations['train'] 24 | 25 | def __len__(self): 26 | return len(self.train_data) 27 | 28 | def __getitem__(self, index): 29 | data_sample = self.train_data[index] 30 | image_path = data_sample['image_path'] 31 | 32 | # load image 33 | image_id = data_sample['id'] 34 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB') 35 | image = self.vis_processor(image) 36 | 37 | # load caption 38 | caption = data_sample['report'] 39 | caption = self.clean_reports(caption) 40 | 41 | return {"image": image, 42 | "text_input": caption, 43 | "image_id": image_id} 44 | 45 | def clean_reports(self, report): 46 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 47 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 48 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 49 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 50 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 51 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 52 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 53 | .strip().lower().split('. ') 54 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 55 | .replace('\\', '').replace("'", '').strip().lower()) 56 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 57 | report = ' . '.join(tokens) + ' .' 58 | return report 59 | 60 | class MIMICGenerateThenRefineDataset(Dataset): 61 | def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None, unlabeled_ann_path=None, retrieval_size=3): 62 | self.image_root = image_root 63 | self.ann_path = ann_path 64 | self.retrieval_size = retrieval_size 65 | 66 | self.vis_processor = vis_processor 67 | self.text_processor = text_processor 68 | 69 | # load annotation file 70 | with open(ann_path, 'r') as f: 71 | self.annotations = json.load(f) 72 | self.train_data = self.annotations['train'] 73 | 74 | # load unlabeled data 75 | self.unlabeled_data_list = [] 76 | with open(unlabeled_ann_path, 'r') as f: 77 | for line in f.readlines: 78 | self.unlabeled_data_list.append(line.strip('\n')) 79 | 80 | import random 81 | self.unlabeled_data_list = random.sample(self.unlabeled_data_list, 3000) 82 | 83 | print(f"There are total {len(self.unlabeled_data_list)} unlabeled reports.") 84 | 85 | def __len__(self): 86 | return len(self.train_data) 87 | 88 | def __getitem__(self, index): 89 | data = self.train_data[index] 90 | data_samples = random.sample(self.train_data, self.retrieval_size - 1) 91 | image_path = data['image_path'] 92 | 93 | # load image 94 | image_id = data['id'] 95 | image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB') 96 | image = self.vis_processor(image) 97 | 98 | # load caption 99 | caption = data['report'] 100 | caption = self.clean_reports(caption) 101 | 102 | # load reference caption 103 | all_ref_captions = [] 104 | ref_caption = data['ref_report'] 105 | ref_caption = self.clean_reports(ref_caption) 106 | all_ref_captions.append(ref_caption) 107 | 108 | for data_sample in data_samples: 109 | ref_caption = data_sample['ref_report'] 110 | ref_caption = self.clean_reports(ref_caption) 111 | all_ref_captions.append(ref_caption) 112 | 113 | # load unlabeled caption 114 | unlabeled_caption = random.sample(self.unlabeled_data_list, self.retrieval_size) 115 | 116 | return {"image": image, 117 | "text_input": caption, 118 | "ref_caption": ref_caption, 119 | "unlabeled_caption": unlabeled_caption, 120 | "image_id": image_id} 121 | 122 | def clean_reports(self, report): 123 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 124 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 125 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 126 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 127 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 128 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 129 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 130 | .strip().lower().split('. ') 131 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 132 | .replace('\\', '').replace("'", '').strip().lower()) 133 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 134 | report = ' . '.join(tokens) + ' .' 135 | return report 136 | 137 | -------------------------------------------------------------------------------- /minigpt4/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | from omegaconf import OmegaConf 11 | 12 | from minigpt4.common.registry import registry 13 | from minigpt4.models.base_model import BaseModel 14 | from minigpt4.models.blip2 import Blip2Base 15 | from minigpt4.models.mini_gpt4 import MiniGPT4 16 | from minigpt4.processors.base_processor import BaseProcessor 17 | 18 | 19 | __all__ = [ 20 | "load_model", 21 | "BaseModel", 22 | "Blip2Base", 23 | "MiniGPT4", 24 | ] 25 | 26 | 27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 28 | """ 29 | Load supported models. 30 | 31 | To list all available models and types in registry: 32 | >>> from minigpt4.models import model_zoo 33 | >>> print(model_zoo) 34 | 35 | Args: 36 | name (str): name of the model. 37 | model_type (str): type of the model. 38 | is_eval (bool): whether the model is in eval mode. Default: False. 39 | device (str): device to use. Default: "cpu". 40 | checkpoint (str): path or to checkpoint. Default: None. 41 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 42 | 43 | Returns: 44 | model (torch.nn.Module): model. 45 | """ 46 | 47 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 48 | 49 | if checkpoint is not None: 50 | model.load_checkpoint(checkpoint) 51 | 52 | if is_eval: 53 | model.eval() 54 | 55 | if device == "cpu": 56 | model = model.float() 57 | 58 | return model.to(device) 59 | 60 | 61 | def load_preprocess(config): 62 | """ 63 | Load preprocessor configs and construct preprocessors. 64 | 65 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 66 | 67 | Args: 68 | config (dict): preprocessor configs. 69 | 70 | Returns: 71 | vis_processors (dict): preprocessors for visual inputs. 72 | txt_processors (dict): preprocessors for text inputs. 73 | 74 | Key is "train" or "eval" for processors used in training and evaluation respectively. 75 | """ 76 | 77 | def _build_proc_from_cfg(cfg): 78 | return ( 79 | registry.get_processor_class(cfg.name).from_config(cfg) 80 | if cfg is not None 81 | else BaseProcessor() 82 | ) 83 | 84 | vis_processors = dict() 85 | txt_processors = dict() 86 | 87 | vis_proc_cfg = config.get("vis_processor") 88 | txt_proc_cfg = config.get("text_processor") 89 | 90 | if vis_proc_cfg is not None: 91 | vis_train_cfg = vis_proc_cfg.get("train") 92 | vis_eval_cfg = vis_proc_cfg.get("eval") 93 | else: 94 | vis_train_cfg = None 95 | vis_eval_cfg = None 96 | 97 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 98 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 99 | 100 | if txt_proc_cfg is not None: 101 | txt_train_cfg = txt_proc_cfg.get("train") 102 | txt_eval_cfg = txt_proc_cfg.get("eval") 103 | else: 104 | txt_train_cfg = None 105 | txt_eval_cfg = None 106 | 107 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 108 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 109 | 110 | return vis_processors, txt_processors 111 | 112 | 113 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 114 | """ 115 | Load model and its related preprocessors. 116 | 117 | List all available models and types in registry: 118 | >>> from minigpt4.models import model_zoo 119 | >>> print(model_zoo) 120 | 121 | Args: 122 | name (str): name of the model. 123 | model_type (str): type of the model. 124 | is_eval (bool): whether the model is in eval mode. Default: False. 125 | device (str): device to use. Default: "cpu". 126 | 127 | Returns: 128 | model (torch.nn.Module): model. 129 | vis_processors (dict): preprocessors for visual inputs. 130 | txt_processors (dict): preprocessors for text inputs. 131 | """ 132 | model_cls = registry.get_model_class(name) 133 | 134 | # load model 135 | model = model_cls.from_pretrained(model_type=model_type) 136 | 137 | if is_eval: 138 | model.eval() 139 | 140 | # load preprocess 141 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 142 | if cfg is not None: 143 | preprocess_cfg = cfg.preprocess 144 | 145 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 146 | else: 147 | vis_processors, txt_processors = None, None 148 | logging.info( 149 | f"""No default preprocess for model {name} ({model_type}). 150 | This can happen if the model is not finetuned on downstream datasets, 151 | or it is not intended for direct use without finetuning. 152 | """ 153 | ) 154 | 155 | if device == "cpu" or device == torch.device("cpu"): 156 | model = model.float() 157 | 158 | return model.to(device), vis_processors, txt_processors 159 | 160 | 161 | class ModelZoo: 162 | """ 163 | A utility class to create string representation of available model architectures and types. 164 | 165 | >>> from minigpt4.models import model_zoo 166 | >>> # list all available models 167 | >>> print(model_zoo) 168 | >>> # show total number of models 169 | >>> print(len(model_zoo)) 170 | """ 171 | 172 | def __init__(self) -> None: 173 | self.model_zoo = { 174 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 175 | for k, v in registry.mapping["model_name_mapping"].items() 176 | } 177 | 178 | def __str__(self) -> str: 179 | return ( 180 | "=" * 50 181 | + "\n" 182 | + f"{'Architectures':<30} {'Types'}\n" 183 | + "=" * 50 184 | + "\n" 185 | + "\n".join( 186 | [ 187 | f"{name:<30} {', '.join(types)}" 188 | for name, types in self.model_zoo.items() 189 | ] 190 | ) 191 | ) 192 | 193 | def __iter__(self): 194 | return iter(self.model_zoo.items()) 195 | 196 | def __len__(self): 197 | return sum([len(v) for v in self.model_zoo.values()]) 198 | 199 | 200 | model_zoo = ModelZoo() 201 | -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/Qformer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/Qformer.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/Qformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/Qformer.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/base_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/base_model.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/blip2.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/blip2.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/blip2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/blip2.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/eva_vit.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/eva_vit.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/eva_vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/mini_gpt4.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/mini_gpt4.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/mini_gpt4.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/mini_gpt4.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/modeling_llama.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/modeling_llama.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized 15 | from minigpt4.common.utils import get_abs_path, is_url 16 | from omegaconf import OmegaConf 17 | 18 | 19 | class BaseModel(nn.Module): 20 | """Base class for models.""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @property 26 | def device(self): 27 | return list(self.parameters())[0].device 28 | 29 | def load_checkpoint(self, url_or_filename): 30 | """ 31 | Load from a finetuned checkpoint. 32 | 33 | This should expect no mismatch in the model keys and the checkpoint keys. 34 | """ 35 | 36 | if is_url(url_or_filename): 37 | cached_file = download_cached_file( 38 | url_or_filename, check_hash=False, progress=True 39 | ) 40 | checkpoint = torch.load(cached_file, map_location="cpu") 41 | elif os.path.isfile(url_or_filename): 42 | checkpoint = torch.load(url_or_filename, map_location="cpu") 43 | else: 44 | raise RuntimeError("checkpoint url or path is invalid") 45 | 46 | if "model" in checkpoint.keys(): 47 | state_dict = checkpoint["model"] 48 | else: 49 | state_dict = checkpoint 50 | 51 | msg = self.load_state_dict(state_dict, strict=False) 52 | 53 | logging.info("Missing keys {}".format(msg.missing_keys)) 54 | logging.info("load checkpoint from %s" % url_or_filename) 55 | 56 | return msg 57 | 58 | @classmethod 59 | def from_pretrained(cls, model_type): 60 | """ 61 | Build a pretrained model from default configuration file, specified by model_type. 62 | 63 | Args: 64 | - model_type (str): model type, specifying architecture and checkpoints. 65 | 66 | Returns: 67 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 68 | """ 69 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 70 | model = cls.from_config(model_cfg) 71 | 72 | return model 73 | 74 | @classmethod 75 | def default_config_path(cls, model_type): 76 | assert ( 77 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 78 | ), "Unknown model type {}".format(model_type) 79 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 80 | 81 | def load_checkpoint_from_config(self, cfg, **kwargs): 82 | """ 83 | Load checkpoint as specified in the config file. 84 | 85 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 86 | When loading the pretrained model, each task-specific architecture may define their 87 | own load_from_pretrained() method. 88 | """ 89 | load_finetuned = cfg.get("load_finetuned", True) 90 | if load_finetuned: 91 | finetune_path = cfg.get("finetuned", None) 92 | assert ( 93 | finetune_path is not None 94 | ), "Found load_finetuned is True, but finetune_path is None." 95 | self.load_checkpoint(url_or_filename=finetune_path) 96 | else: 97 | # load pre-trained weights 98 | pretrain_path = cfg.get("pretrained", None) 99 | assert "Found load_finetuned is False, but pretrain_path is None." 100 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 101 | 102 | def before_evaluation(self, **kwargs): 103 | pass 104 | 105 | def show_n_params(self, return_str=True): 106 | tot = 0 107 | for p in self.parameters(): 108 | w = 1 109 | for x in p.shape: 110 | w *= x 111 | tot += w 112 | if return_str: 113 | if tot >= 1e6: 114 | return "{:.1f}M".format(tot / 1e6) 115 | else: 116 | return "{:.1f}K".format(tot / 1e3) 117 | else: 118 | return tot 119 | 120 | 121 | class BaseEncoder(nn.Module): 122 | """ 123 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 124 | """ 125 | 126 | def __init__(self): 127 | super().__init__() 128 | 129 | def forward_features(self, samples, **kwargs): 130 | raise NotImplementedError 131 | 132 | @property 133 | def device(self): 134 | return list(self.parameters())[0].device 135 | 136 | 137 | class SharedQueueMixin: 138 | @torch.no_grad() 139 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 140 | # gather keys before updating queue 141 | image_feats = concat_all_gather(image_feat) 142 | text_feats = concat_all_gather(text_feat) 143 | 144 | batch_size = image_feats.shape[0] 145 | 146 | ptr = int(self.queue_ptr) 147 | assert self.queue_size % batch_size == 0 # for simplicity 148 | 149 | # replace the keys at ptr (dequeue and enqueue) 150 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 151 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 152 | 153 | if idxs is not None: 154 | idxs = concat_all_gather(idxs) 155 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 156 | 157 | ptr = (ptr + batch_size) % self.queue_size # move pointer 158 | self.queue_ptr[0] = ptr 159 | 160 | 161 | class MomentumDistilationMixin: 162 | @torch.no_grad() 163 | def copy_params(self): 164 | for model_pair in self.model_pairs: 165 | for param, param_m in zip( 166 | model_pair[0].parameters(), model_pair[1].parameters() 167 | ): 168 | param_m.data.copy_(param.data) # initialize 169 | param_m.requires_grad = False # not update by gradient 170 | 171 | @torch.no_grad() 172 | def _momentum_update(self): 173 | for model_pair in self.model_pairs: 174 | for param, param_m in zip( 175 | model_pair[0].parameters(), model_pair[1].parameters() 176 | ): 177 | param_m.data = param_m.data * self.momentum + param.data * ( 178 | 1.0 - self.momentum 179 | ) 180 | 181 | 182 | class GatherLayer(torch.autograd.Function): 183 | """ 184 | Gather tensors from all workers with support for backward propagation: 185 | This implementation does not cut the gradients as torch.distributed.all_gather does. 186 | """ 187 | 188 | @staticmethod 189 | def forward(ctx, x): 190 | output = [ 191 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 192 | ] 193 | torch.distributed.all_gather(output, x) 194 | return tuple(output) 195 | 196 | @staticmethod 197 | def backward(ctx, *grads): 198 | all_gradients = torch.stack(grads) 199 | torch.distributed.all_reduce(all_gradients) 200 | return all_gradients[torch.distributed.get_rank()] 201 | 202 | 203 | def all_gather_with_grad(tensors): 204 | """ 205 | Performs all_gather operation on the provided tensors. 206 | Graph remains connected for backward grad computation. 207 | """ 208 | # Queue the gathered tensors 209 | world_size = torch.distributed.get_world_size() 210 | # There is no need for reduction in the single-proc case 211 | if world_size == 1: 212 | return tensors 213 | 214 | # tensor_all = GatherLayer.apply(tensors) 215 | tensor_all = GatherLayer.apply(tensors) 216 | 217 | return torch.cat(tensor_all, dim=0) 218 | 219 | 220 | @torch.no_grad() 221 | def concat_all_gather(tensor): 222 | """ 223 | Performs all_gather operation on the provided tensors. 224 | *** Warning ***: torch.distributed.all_gather has no gradient. 225 | """ 226 | # if use distributed training 227 | if not is_dist_avail_and_initialized(): 228 | return tensor 229 | 230 | tensors_gather = [ 231 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 232 | ] 233 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 234 | 235 | output = torch.cat(tensors_gather, dim=0) 236 | return output 237 | 238 | 239 | def tile(x, dim, n_tile): 240 | init_dim = x.size(dim) 241 | repeat_idx = [1] * x.dim() 242 | repeat_idx[dim] = n_tile 243 | x = x.repeat(*(repeat_idx)) 244 | order_index = torch.LongTensor( 245 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 246 | ) 247 | return torch.index_select(x, dim, order_index.to(x.device)) 248 | -------------------------------------------------------------------------------- /minigpt4/models/blip2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import logging 9 | import os 10 | import time 11 | import datetime 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.distributed as dist 16 | import torch.nn.functional as F 17 | 18 | import minigpt4.common.dist_utils as dist_utils 19 | from minigpt4.common.dist_utils import download_cached_file 20 | from minigpt4.common.utils import is_url 21 | from minigpt4.common.logger import MetricLogger 22 | from minigpt4.models.base_model import BaseModel 23 | from minigpt4.models.Qformer import BertConfig, BertLMHeadModel 24 | from minigpt4.models.eva_vit import create_eva_vit_g 25 | from transformers import BertTokenizer 26 | 27 | 28 | class Blip2Base(BaseModel): 29 | @classmethod 30 | def init_tokenizer(cls): 31 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 32 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 33 | return tokenizer 34 | 35 | def maybe_autocast(self, dtype=torch.float16): 36 | # if on cpu, don't use autocast 37 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 38 | enable_autocast = self.device != torch.device("cpu") 39 | 40 | if enable_autocast: 41 | return torch.cuda.amp.autocast(dtype=dtype) 42 | else: 43 | return contextlib.nullcontext() 44 | 45 | @classmethod 46 | def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 47 | encoder_config = BertConfig.from_pretrained("bert-base-uncased") 48 | encoder_config.encoder_width = vision_width 49 | # insert cross-attention layer every other block 50 | encoder_config.add_cross_attention = True 51 | encoder_config.cross_attention_freq = cross_attention_freq 52 | encoder_config.query_length = num_query_token 53 | Qformer = BertLMHeadModel(config=encoder_config) 54 | query_tokens = nn.Parameter( 55 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 56 | ) 57 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 58 | return Qformer, query_tokens 59 | 60 | @classmethod 61 | def init_vision_encoder( 62 | cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision 63 | ): 64 | assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" 65 | visual_encoder = create_eva_vit_g( 66 | img_size, drop_path_rate, use_grad_checkpoint, precision 67 | ) 68 | 69 | ln_vision = LayerNorm(visual_encoder.num_features) 70 | return visual_encoder, ln_vision 71 | 72 | def load_from_pretrained(self, url_or_filename): 73 | if is_url(url_or_filename): 74 | cached_file = download_cached_file( 75 | url_or_filename, check_hash=False, progress=True 76 | ) 77 | checkpoint = torch.load(cached_file, map_location="cpu") 78 | elif os.path.isfile(url_or_filename): 79 | checkpoint = torch.load(url_or_filename, map_location="cpu") 80 | else: 81 | raise RuntimeError("checkpoint url or path is invalid") 82 | 83 | state_dict = checkpoint["model"] 84 | 85 | msg = self.load_state_dict(state_dict, strict=False) 86 | 87 | # logging.info("Missing keys {}".format(msg.missing_keys)) 88 | logging.info("load checkpoint from %s" % url_or_filename) 89 | 90 | return msg 91 | 92 | 93 | def disabled_train(self, mode=True): 94 | """Overwrite model.train with this function to make sure train/eval mode 95 | does not change anymore.""" 96 | return self 97 | 98 | 99 | class LayerNorm(nn.LayerNorm): 100 | """Subclass torch's LayerNorm to handle fp16.""" 101 | 102 | def forward(self, x: torch.Tensor): 103 | orig_type = x.dtype 104 | ret = super().forward(x.type(torch.float32)) 105 | return ret.type(orig_type) 106 | 107 | 108 | def compute_sim_matrix(model, data_loader, **kwargs): 109 | k_test = kwargs.pop("k_test") 110 | 111 | metric_logger = MetricLogger(delimiter=" ") 112 | header = "Evaluation:" 113 | 114 | logging.info("Computing features for evaluation...") 115 | start_time = time.time() 116 | 117 | texts = data_loader.dataset.text 118 | num_text = len(texts) 119 | text_bs = 256 120 | text_ids = [] 121 | text_embeds = [] 122 | text_atts = [] 123 | for i in range(0, num_text, text_bs): 124 | text = texts[i : min(num_text, i + text_bs)] 125 | text_input = model.tokenizer( 126 | text, 127 | padding="max_length", 128 | truncation=True, 129 | max_length=35, 130 | return_tensors="pt", 131 | ).to(model.device) 132 | text_feat = model.forward_text(text_input) 133 | text_embed = F.normalize(model.text_proj(text_feat)) 134 | text_embeds.append(text_embed) 135 | text_ids.append(text_input.input_ids) 136 | text_atts.append(text_input.attention_mask) 137 | 138 | text_embeds = torch.cat(text_embeds, dim=0) 139 | text_ids = torch.cat(text_ids, dim=0) 140 | text_atts = torch.cat(text_atts, dim=0) 141 | 142 | vit_feats = [] 143 | image_embeds = [] 144 | for samples in data_loader: 145 | image = samples["image"] 146 | 147 | image = image.to(model.device) 148 | image_feat, vit_feat = model.forward_image(image) 149 | image_embed = model.vision_proj(image_feat) 150 | image_embed = F.normalize(image_embed, dim=-1) 151 | 152 | vit_feats.append(vit_feat.cpu()) 153 | image_embeds.append(image_embed) 154 | 155 | vit_feats = torch.cat(vit_feats, dim=0) 156 | image_embeds = torch.cat(image_embeds, dim=0) 157 | 158 | sims_matrix = [] 159 | for image_embed in image_embeds: 160 | sim_q2t = image_embed @ text_embeds.t() 161 | sim_i2t, _ = sim_q2t.max(0) 162 | sims_matrix.append(sim_i2t) 163 | sims_matrix = torch.stack(sims_matrix, dim=0) 164 | 165 | score_matrix_i2t = torch.full( 166 | (len(data_loader.dataset.image), len(texts)), -100.0 167 | ).to(model.device) 168 | 169 | num_tasks = dist_utils.get_world_size() 170 | rank = dist_utils.get_rank() 171 | step = sims_matrix.size(0) // num_tasks + 1 172 | start = rank * step 173 | end = min(sims_matrix.size(0), start + step) 174 | 175 | for i, sims in enumerate( 176 | metric_logger.log_every(sims_matrix[start:end], 50, header) 177 | ): 178 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 179 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 180 | score = model.compute_itm( 181 | image_inputs=image_inputs, 182 | text_ids=text_ids[topk_idx], 183 | text_atts=text_atts[topk_idx], 184 | ).float() 185 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 186 | 187 | sims_matrix = sims_matrix.t() 188 | score_matrix_t2i = torch.full( 189 | (len(texts), len(data_loader.dataset.image)), -100.0 190 | ).to(model.device) 191 | 192 | step = sims_matrix.size(0) // num_tasks + 1 193 | start = rank * step 194 | end = min(sims_matrix.size(0), start + step) 195 | 196 | for i, sims in enumerate( 197 | metric_logger.log_every(sims_matrix[start:end], 50, header) 198 | ): 199 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 200 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 201 | score = model.compute_itm( 202 | image_inputs=image_inputs, 203 | text_ids=text_ids[start + i].repeat(k_test, 1), 204 | text_atts=text_atts[start + i].repeat(k_test, 1), 205 | ).float() 206 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 207 | 208 | if dist_utils.is_dist_avail_and_initialized(): 209 | dist.barrier() 210 | torch.distributed.all_reduce( 211 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 212 | ) 213 | torch.distributed.all_reduce( 214 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 215 | ) 216 | 217 | total_time = time.time() - start_time 218 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 219 | logging.info("Evaluation time {}".format(total_time_str)) 220 | 221 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 222 | -------------------------------------------------------------------------------- /minigpt4/models/blip2_outputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | import torch 12 | from transformers.modeling_outputs import ( 13 | ModelOutput, 14 | BaseModelOutputWithPoolingAndCrossAttentions, 15 | CausalLMOutputWithCrossAttentions, 16 | ) 17 | 18 | 19 | @dataclass 20 | class BlipSimilarity(ModelOutput): 21 | sim_i2t: torch.FloatTensor = None 22 | sim_t2i: torch.FloatTensor = None 23 | 24 | sim_i2t_m: Optional[torch.FloatTensor] = None 25 | sim_t2i_m: Optional[torch.FloatTensor] = None 26 | 27 | sim_i2t_targets: Optional[torch.FloatTensor] = None 28 | sim_t2i_targets: Optional[torch.FloatTensor] = None 29 | 30 | 31 | @dataclass 32 | class BlipIntermediateOutput(ModelOutput): 33 | """ 34 | Data class for intermediate outputs of BLIP models. 35 | 36 | image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). 37 | text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). 38 | 39 | image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). 40 | text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). 41 | 42 | encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. 43 | encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. 44 | 45 | decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. 46 | decoder_labels (torch.LongTensor): labels for the captioning loss. 47 | 48 | itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). 49 | itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) 50 | 51 | """ 52 | 53 | # uni-modal features 54 | image_embeds: torch.FloatTensor = None 55 | text_embeds: Optional[torch.FloatTensor] = None 56 | 57 | image_embeds_m: Optional[torch.FloatTensor] = None 58 | text_embeds_m: Optional[torch.FloatTensor] = None 59 | 60 | # intermediate outputs of multimodal encoder 61 | encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 62 | encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None 63 | 64 | itm_logits: Optional[torch.FloatTensor] = None 65 | itm_labels: Optional[torch.LongTensor] = None 66 | 67 | # intermediate outputs of multimodal decoder 68 | decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None 69 | decoder_labels: Optional[torch.LongTensor] = None 70 | 71 | 72 | @dataclass 73 | class BlipOutput(ModelOutput): 74 | # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. 75 | sims: Optional[BlipSimilarity] = None 76 | 77 | intermediate_output: BlipIntermediateOutput = None 78 | 79 | loss: Optional[torch.FloatTensor] = None 80 | 81 | loss_itc: Optional[torch.FloatTensor] = None 82 | 83 | loss_itm: Optional[torch.FloatTensor] = None 84 | 85 | loss_lm: Optional[torch.FloatTensor] = None 86 | 87 | 88 | @dataclass 89 | class BlipOutputFeatures(ModelOutput): 90 | """ 91 | Data class of features from BlipFeatureExtractor. 92 | 93 | Args: 94 | image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional 95 | image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional 96 | text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional 97 | text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional 98 | 99 | The first embedding or feature is for the [CLS] token. 100 | 101 | Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. 102 | """ 103 | 104 | image_embeds: Optional[torch.FloatTensor] = None 105 | image_embeds_proj: Optional[torch.FloatTensor] = None 106 | 107 | text_embeds: Optional[torch.FloatTensor] = None 108 | text_embeds_proj: Optional[torch.FloatTensor] = None 109 | 110 | multimodal_embeds: Optional[torch.FloatTensor] = None 111 | -------------------------------------------------------------------------------- /minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_0.pth -------------------------------------------------------------------------------- /minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/output/minigpt4_stage2_finetune/20230706044/checkpoint_1.pth -------------------------------------------------------------------------------- /minigpt4/output/minigpt4_stage2_finetune/20230706044/log.txt: -------------------------------------------------------------------------------- 1 | { 2 | "run": { 3 | "task": "image_text_pretrain", 4 | "lr_sched": "linear_warmup_cosine_lr", 5 | "init_lr": 3e-05, 6 | "min_lr": 1e-05, 7 | "warmup_lr": 1e-06, 8 | "weight_decay": 0.05, 9 | "max_epoch": 5, 10 | "iters_per_epoch": 200, 11 | "batch_size_train": 1, 12 | "batch_size_eval": 1, 13 | "num_workers": 4, 14 | "warmup_steps": 200, 15 | "seed": 42, 16 | "output_dir": "output/minigpt4_stage2_finetune", 17 | "amp": true, 18 | "resume_ckpt_path": null, 19 | "evaluate": false, 20 | "train_splits": [ 21 | "train" 22 | ], 23 | "device": "cuda", 24 | "world_size": 1, 25 | "dist_url": "env://", 26 | "distributed": false 27 | }, 28 | "model": { 29 | "arch": "mini_gpt4", 30 | "image_size": 224, 31 | "drop_path_rate": 0, 32 | "use_grad_checkpoint": false, 33 | "vit_precision": "fp16", 34 | "freeze_vit": true, 35 | "freeze_qformer": true, 36 | "num_query_token": 32, 37 | "llama_model": "models/models-13b/vicuna_weights", 38 | "prompt": "", 39 | "model_type": "pretrain_vicuna", 40 | "max_txt_len": 160, 41 | "end_sym": "###", 42 | "prompt_path": "prompts/alignment.txt", 43 | "prompt_template": "###Human: {} ###Assistant: ", 44 | "ckpt": "models/models-13b/minigpt-4/pretrained_minigpt4.pth" 45 | }, 46 | "preprocess": { 47 | "vis_processor": { 48 | "train": { 49 | "name": "blip2_image_train", 50 | "image_size": 224 51 | }, 52 | "eval": { 53 | "name": "blip2_image_eval", 54 | "image_size": 224 55 | } 56 | }, 57 | "text_processor": { 58 | "train": { 59 | "name": "blip_caption" 60 | }, 61 | "eval": { 62 | "name": "blip_caption" 63 | } 64 | } 65 | }, 66 | "datasets": { 67 | "cc_sbu_align": { 68 | "data_type": "images", 69 | "build_info": { 70 | "storage": "/media/ubuntu/data/liuchang/workplace/code/src/MiniGPT-4/data/cc_sbu_align" 71 | }, 72 | "vis_processor": { 73 | "train": { 74 | "name": "blip2_image_train", 75 | "image_size": 224 76 | } 77 | }, 78 | "text_processor": { 79 | "train": { 80 | "name": "blip_caption" 81 | } 82 | } 83 | } 84 | } 85 | } 86 | {"train_lr": "0.000", "train_loss": "0.675"} 87 | {"train_lr": "0.000", "train_loss": "0.656"} 88 | -------------------------------------------------------------------------------- /minigpt4/output/minigpt4_stage2_finetune/20230706051/log.txt: -------------------------------------------------------------------------------- 1 | { 2 | "run": { 3 | "task": "image_text_pretrain", 4 | "lr_sched": "linear_warmup_cosine_lr", 5 | "init_lr": 3e-05, 6 | "min_lr": 1e-05, 7 | "warmup_lr": 1e-06, 8 | "weight_decay": 0.05, 9 | "max_epoch": 5, 10 | "iters_per_epoch": 200, 11 | "batch_size_train": 1, 12 | "batch_size_eval": 1, 13 | "num_workers": 4, 14 | "warmup_steps": 200, 15 | "seed": 42, 16 | "output_dir": "output/minigpt4_stage2_finetune", 17 | "amp": true, 18 | "resume_ckpt_path": null, 19 | "evaluate": false, 20 | "train_splits": [ 21 | "train" 22 | ], 23 | "device": "cuda", 24 | "world_size": 1, 25 | "dist_url": "env://", 26 | "distributed": false 27 | }, 28 | "model": { 29 | "arch": "mini_gpt4", 30 | "image_size": 224, 31 | "drop_path_rate": 0, 32 | "use_grad_checkpoint": false, 33 | "vit_precision": "fp16", 34 | "freeze_vit": true, 35 | "freeze_qformer": true, 36 | "num_query_token": 32, 37 | "llama_model": "models/models-13b/vicuna_weights", 38 | "prompt": "", 39 | "model_type": "pretrain_vicuna", 40 | "max_txt_len": 160, 41 | "end_sym": "###", 42 | "prompt_path": "prompts/alignment.txt", 43 | "prompt_template": "###Human: {} ###Assistant: ", 44 | "ckpt": "models/models-13b/minigpt-4/pretrained_minigpt4.pth" 45 | }, 46 | "preprocess": { 47 | "vis_processor": { 48 | "train": { 49 | "name": "blip2_image_train", 50 | "image_size": 224 51 | }, 52 | "eval": { 53 | "name": "blip2_image_eval", 54 | "image_size": 224 55 | } 56 | }, 57 | "text_processor": { 58 | "train": { 59 | "name": "blip_caption" 60 | }, 61 | "eval": { 62 | "name": "blip_caption" 63 | } 64 | } 65 | }, 66 | "datasets": { 67 | "cc_sbu_align": { 68 | "data_type": "images", 69 | "build_info": { 70 | "storage": "/media/ubuntu/data/liuchang/workplace/code/src/MiniGPT-4/data/cc_sbu_align" 71 | }, 72 | "vis_processor": { 73 | "train": { 74 | "name": "blip2_image_train", 75 | "image_size": 224 76 | } 77 | }, 78 | "text_processor": { 79 | "train": { 80 | "name": "blip_caption" 81 | } 82 | } 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /minigpt4/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.processors.base_processor import BaseProcessor 9 | from minigpt4.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | 15 | from minigpt4.common.registry import registry 16 | 17 | __all__ = [ 18 | "BaseProcessor", 19 | "Blip2ImageTrainProcessor", 20 | "Blip2ImageEvalProcessor", 21 | "BlipCaptionProcessor", 22 | ] 23 | 24 | 25 | def load_processor(name, cfg=None): 26 | """ 27 | Example 28 | 29 | >>> processor = load_processor("alpro_video_train", cfg=None) 30 | """ 31 | processor = registry.get_processor_class(name).from_config(cfg) 32 | 33 | return processor 34 | -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/base_processor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/base_processor.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/base_processor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/blip_processors.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/blip_processors.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/randaugment.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/randaugment.cpython-311.pyc -------------------------------------------------------------------------------- /minigpt4/processors/__pycache__/randaugment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /minigpt4/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.processors.base_processor import BaseProcessor 12 | from minigpt4.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("blip_caption") 29 | class BlipCaptionProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg =None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption 69 | 70 | 71 | @registry.register_processor("blip2_image_train") 72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 74 | super().__init__(mean=mean, std=std) 75 | 76 | self.transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop( 79 | image_size, 80 | scale=(min_scale, max_scale), 81 | interpolation=InterpolationMode.BICUBIC, 82 | ), 83 | transforms.ToTensor(), 84 | self.normalize, 85 | ] 86 | ) 87 | 88 | def __call__(self, item): 89 | return self.transform(item) 90 | 91 | @classmethod 92 | def from_config(cls, cfg=None): 93 | if cfg is None: 94 | cfg = OmegaConf.create() 95 | 96 | image_size = cfg.get("image_size", 224) 97 | 98 | mean = cfg.get("mean", None) 99 | std = cfg.get("std", None) 100 | 101 | min_scale = cfg.get("min_scale", 0.5) 102 | max_scale = cfg.get("max_scale", 1.0) 103 | 104 | return cls( 105 | image_size=image_size, 106 | mean=mean, 107 | std=std, 108 | min_scale=min_scale, 109 | max_scale=max_scale, 110 | ) 111 | 112 | 113 | @registry.register_processor("blip2_image_eval") 114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor): 115 | def __init__(self, image_size=224, mean=None, std=None): 116 | super().__init__(mean=mean, std=std) 117 | 118 | self.transform = transforms.Compose( 119 | [ 120 | transforms.Resize( 121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 122 | ), 123 | transforms.ToTensor(), 124 | self.normalize, 125 | ] 126 | ) 127 | 128 | def __call__(self, item): 129 | return self.transform(item) 130 | 131 | @classmethod 132 | def from_config(cls, cfg=None): 133 | if cfg is None: 134 | cfg = OmegaConf.create() 135 | 136 | image_size = cfg.get("image_size", 224) 137 | 138 | mean = cfg.get("mean", None) 139 | std = cfg.get("std", None) 140 | 141 | return cls(image_size=image_size, mean=mean, std=std) -------------------------------------------------------------------------------- /minigpt4/processors/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | import torch 12 | 13 | 14 | ## aug functions 15 | def identity_func(img): 16 | return img 17 | 18 | 19 | def autocontrast_func(img, cutoff=0): 20 | """ 21 | same output as PIL.ImageOps.autocontrast 22 | """ 23 | n_bins = 256 24 | 25 | def tune_channel(ch): 26 | n = ch.size 27 | cut = cutoff * n // 100 28 | if cut == 0: 29 | high, low = ch.max(), ch.min() 30 | else: 31 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 32 | low = np.argwhere(np.cumsum(hist) > cut) 33 | low = 0 if low.shape[0] == 0 else low[0] 34 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 35 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 36 | if high <= low: 37 | table = np.arange(n_bins) 38 | else: 39 | scale = (n_bins - 1) / (high - low) 40 | offset = -low * scale 41 | table = np.arange(n_bins) * scale + offset 42 | table[table < 0] = 0 43 | table[table > n_bins - 1] = n_bins - 1 44 | table = table.clip(0, 255).astype(np.uint8) 45 | return table[ch] 46 | 47 | channels = [tune_channel(ch) for ch in cv2.split(img)] 48 | out = cv2.merge(channels) 49 | return out 50 | 51 | 52 | def equalize_func(img): 53 | """ 54 | same output as PIL.ImageOps.equalize 55 | PIL's implementation is different from cv2.equalize 56 | """ 57 | n_bins = 256 58 | 59 | def tune_channel(ch): 60 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 61 | non_zero_hist = hist[hist != 0].reshape(-1) 62 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 63 | if step == 0: 64 | return ch 65 | n = np.empty_like(hist) 66 | n[0] = step // 2 67 | n[1:] = hist[:-1] 68 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 69 | return table[ch] 70 | 71 | channels = [tune_channel(ch) for ch in cv2.split(img)] 72 | out = cv2.merge(channels) 73 | return out 74 | 75 | 76 | def rotate_func(img, degree, fill=(0, 0, 0)): 77 | """ 78 | like PIL, rotate by degree, not radians 79 | """ 80 | H, W = img.shape[0], img.shape[1] 81 | center = W / 2, H / 2 82 | M = cv2.getRotationMatrix2D(center, degree, 1) 83 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 84 | return out 85 | 86 | 87 | def solarize_func(img, thresh=128): 88 | """ 89 | same output as PIL.ImageOps.posterize 90 | """ 91 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 92 | table = table.clip(0, 255).astype(np.uint8) 93 | out = table[img] 94 | return out 95 | 96 | 97 | def color_func(img, factor): 98 | """ 99 | same output as PIL.ImageEnhance.Color 100 | """ 101 | ## implementation according to PIL definition, quite slow 102 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 103 | # out = blend(degenerate, img, factor) 104 | # M = ( 105 | # np.eye(3) * factor 106 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 107 | # )[np.newaxis, np.newaxis, :] 108 | M = np.float32( 109 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] 110 | ) * factor + np.float32([[0.114], [0.587], [0.299]]) 111 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 112 | return out 113 | 114 | 115 | def contrast_func(img, factor): 116 | """ 117 | same output as PIL.ImageEnhance.Contrast 118 | """ 119 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 120 | table = ( 121 | np.array([(el - mean) * factor + mean for el in range(256)]) 122 | .clip(0, 255) 123 | .astype(np.uint8) 124 | ) 125 | out = table[img] 126 | return out 127 | 128 | 129 | def brightness_func(img, factor): 130 | """ 131 | same output as PIL.ImageEnhance.Contrast 132 | """ 133 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 134 | out = table[img] 135 | return out 136 | 137 | 138 | def sharpness_func(img, factor): 139 | """ 140 | The differences the this result and PIL are all on the 4 boundaries, the center 141 | areas are same 142 | """ 143 | kernel = np.ones((3, 3), dtype=np.float32) 144 | kernel[1][1] = 5 145 | kernel /= 13 146 | degenerate = cv2.filter2D(img, -1, kernel) 147 | if factor == 0.0: 148 | out = degenerate 149 | elif factor == 1.0: 150 | out = img 151 | else: 152 | out = img.astype(np.float32) 153 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 154 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 155 | out = out.astype(np.uint8) 156 | return out 157 | 158 | 159 | def shear_x_func(img, factor, fill=(0, 0, 0)): 160 | H, W = img.shape[0], img.shape[1] 161 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 162 | out = cv2.warpAffine( 163 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 164 | ).astype(np.uint8) 165 | return out 166 | 167 | 168 | def translate_x_func(img, offset, fill=(0, 0, 0)): 169 | """ 170 | same output as PIL.Image.transform 171 | """ 172 | H, W = img.shape[0], img.shape[1] 173 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 174 | out = cv2.warpAffine( 175 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 176 | ).astype(np.uint8) 177 | return out 178 | 179 | 180 | def translate_y_func(img, offset, fill=(0, 0, 0)): 181 | """ 182 | same output as PIL.Image.transform 183 | """ 184 | H, W = img.shape[0], img.shape[1] 185 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 186 | out = cv2.warpAffine( 187 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 188 | ).astype(np.uint8) 189 | return out 190 | 191 | 192 | def posterize_func(img, bits): 193 | """ 194 | same output as PIL.ImageOps.posterize 195 | """ 196 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 197 | return out 198 | 199 | 200 | def shear_y_func(img, factor, fill=(0, 0, 0)): 201 | H, W = img.shape[0], img.shape[1] 202 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 203 | out = cv2.warpAffine( 204 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 205 | ).astype(np.uint8) 206 | return out 207 | 208 | 209 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 210 | replace = np.array(replace, dtype=np.uint8) 211 | H, W = img.shape[0], img.shape[1] 212 | rh, rw = np.random.random(2) 213 | pad_size = pad_size // 2 214 | ch, cw = int(rh * H), int(rw * W) 215 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 216 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 217 | out = img.copy() 218 | out[x1:x2, y1:y2, :] = replace 219 | return out 220 | 221 | 222 | ### level to args 223 | def enhance_level_to_args(MAX_LEVEL): 224 | def level_to_args(level): 225 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 226 | 227 | return level_to_args 228 | 229 | 230 | def shear_level_to_args(MAX_LEVEL, replace_value): 231 | def level_to_args(level): 232 | level = (level / MAX_LEVEL) * 0.3 233 | if np.random.random() > 0.5: 234 | level = -level 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 241 | def level_to_args(level): 242 | level = (level / MAX_LEVEL) * float(translate_const) 243 | if np.random.random() > 0.5: 244 | level = -level 245 | return (level, replace_value) 246 | 247 | return level_to_args 248 | 249 | 250 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 251 | def level_to_args(level): 252 | level = int((level / MAX_LEVEL) * cutout_const) 253 | return (level, replace_value) 254 | 255 | return level_to_args 256 | 257 | 258 | def solarize_level_to_args(MAX_LEVEL): 259 | def level_to_args(level): 260 | level = int((level / MAX_LEVEL) * 256) 261 | return (level,) 262 | 263 | return level_to_args 264 | 265 | 266 | def none_level_to_args(level): 267 | return () 268 | 269 | 270 | def posterize_level_to_args(MAX_LEVEL): 271 | def level_to_args(level): 272 | level = int((level / MAX_LEVEL) * 4) 273 | return (level,) 274 | 275 | return level_to_args 276 | 277 | 278 | def rotate_level_to_args(MAX_LEVEL, replace_value): 279 | def level_to_args(level): 280 | level = (level / MAX_LEVEL) * 30 281 | if np.random.random() < 0.5: 282 | level = -level 283 | return (level, replace_value) 284 | 285 | return level_to_args 286 | 287 | 288 | func_dict = { 289 | "Identity": identity_func, 290 | "AutoContrast": autocontrast_func, 291 | "Equalize": equalize_func, 292 | "Rotate": rotate_func, 293 | "Solarize": solarize_func, 294 | "Color": color_func, 295 | "Contrast": contrast_func, 296 | "Brightness": brightness_func, 297 | "Sharpness": sharpness_func, 298 | "ShearX": shear_x_func, 299 | "TranslateX": translate_x_func, 300 | "TranslateY": translate_y_func, 301 | "Posterize": posterize_func, 302 | "ShearY": shear_y_func, 303 | } 304 | 305 | translate_const = 10 306 | MAX_LEVEL = 10 307 | replace_value = (128, 128, 128) 308 | arg_dict = { 309 | "Identity": none_level_to_args, 310 | "AutoContrast": none_level_to_args, 311 | "Equalize": none_level_to_args, 312 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), 313 | "Solarize": solarize_level_to_args(MAX_LEVEL), 314 | "Color": enhance_level_to_args(MAX_LEVEL), 315 | "Contrast": enhance_level_to_args(MAX_LEVEL), 316 | "Brightness": enhance_level_to_args(MAX_LEVEL), 317 | "Sharpness": enhance_level_to_args(MAX_LEVEL), 318 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), 319 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 320 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 321 | "Posterize": posterize_level_to_args(MAX_LEVEL), 322 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), 323 | } 324 | 325 | 326 | class RandomAugment(object): 327 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 328 | self.N = N 329 | self.M = M 330 | self.isPIL = isPIL 331 | if augs: 332 | self.augs = augs 333 | else: 334 | self.augs = list(arg_dict.keys()) 335 | 336 | def get_random_ops(self): 337 | sampled_ops = np.random.choice(self.augs, self.N) 338 | return [(op, 0.5, self.M) for op in sampled_ops] 339 | 340 | def __call__(self, img): 341 | if self.isPIL: 342 | img = np.array(img) 343 | ops = self.get_random_ops() 344 | for name, prob, level in ops: 345 | if np.random.random() > prob: 346 | continue 347 | args = arg_dict[name](level) 348 | img = func_dict[name](img, *args) 349 | return img 350 | 351 | 352 | class VideoRandomAugment(object): 353 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 354 | self.N = N 355 | self.M = M 356 | self.p = p 357 | self.tensor_in_tensor_out = tensor_in_tensor_out 358 | if augs: 359 | self.augs = augs 360 | else: 361 | self.augs = list(arg_dict.keys()) 362 | 363 | def get_random_ops(self): 364 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 365 | return [(op, self.M) for op in sampled_ops] 366 | 367 | def __call__(self, frames): 368 | assert ( 369 | frames.shape[-1] == 3 370 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." 371 | 372 | if self.tensor_in_tensor_out: 373 | frames = frames.numpy().astype(np.uint8) 374 | 375 | num_frames = frames.shape[0] 376 | 377 | ops = num_frames * [self.get_random_ops()] 378 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 379 | 380 | frames = torch.stack( 381 | list(map(self._aug, frames, ops, apply_or_not)), dim=0 382 | ).float() 383 | 384 | return frames 385 | 386 | def _aug(self, img, ops, apply_or_not): 387 | for i, (name, level) in enumerate(ops): 388 | if not apply_or_not[i]: 389 | continue 390 | args = arg_dict[name](level) 391 | img = func_dict[name](img, *args) 392 | return torch.from_numpy(img) 393 | 394 | 395 | if __name__ == "__main__": 396 | a = RandomAugment() 397 | img = np.random.randn(32, 32, 3) 398 | a(img) 399 | -------------------------------------------------------------------------------- /minigpt4/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.runners.runner_base import RunnerBase 9 | 10 | __all__ = ["RunnerBase"] 11 | -------------------------------------------------------------------------------- /minigpt4/runners/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/runners/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/runners/__pycache__/runner_base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask 11 | from minigpt4.tasks.mimic_generate_then_refine import MIMICGenerateThenRefine 12 | 13 | 14 | def setup_task(cfg): 15 | assert "task" in cfg.run_cfg, "Task name must be provided." 16 | 17 | task_name = cfg.run_cfg.task 18 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 19 | assert task is not None, "Task {} not properly registered.".format(task_name) 20 | 21 | return task 22 | 23 | 24 | __all__ = [ 25 | "BaseTask", 26 | "ImageTextPretrainTask", 27 | "MIMICGenerateThenRefine", 28 | ] 29 | -------------------------------------------------------------------------------- /minigpt4/tasks/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/tasks/__pycache__/base_task.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/tasks/__pycache__/mimic_generate_then_refine.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2-LLM/2167adf5bf3a204ad85e7065aad74fa1ebb9640e/minigpt4/tasks/__pycache__/mimic_generate_then_refine.cpython-39.pyc -------------------------------------------------------------------------------- /minigpt4/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import deepspeed 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 16 | from minigpt4.common.logger import MetricLogger, SmoothedValue 17 | from minigpt4.common.registry import registry 18 | from minigpt4.datasets.data_utils import prepare_sample 19 | 20 | 21 | class BaseTask: 22 | def __init__(self, **kwargs): 23 | super().__init__() 24 | 25 | self.inst_id_key = "instance_id" 26 | 27 | @classmethod 28 | def setup_task(cls, **kwargs): 29 | return cls() 30 | 31 | def build_model(self, cfg): 32 | model_config = cfg.model_cfg 33 | 34 | model_cls = registry.get_model_class(model_config.arch) 35 | return model_cls.from_config(model_config) 36 | 37 | def build_datasets(self, cfg): 38 | """ 39 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. 40 | Download dataset and annotations automatically if not exist. 41 | 42 | Args: 43 | cfg (common.config.Config): _description_ 44 | 45 | Returns: 46 | dict: Dictionary of torch.utils.data.Dataset objects by split. 47 | """ 48 | 49 | datasets = dict() 50 | 51 | datasets_config = cfg.datasets_cfg 52 | 53 | assert len(datasets_config) > 0, "At least one dataset has to be specified." 54 | 55 | for name in datasets_config: 56 | dataset_config = datasets_config[name] 57 | 58 | builder = registry.get_builder_class(name)(dataset_config) 59 | dataset = builder.build_datasets() 60 | 61 | dataset['train'].name = name 62 | if 'sample_ratio' in dataset_config: 63 | dataset['train'].sample_ratio = dataset_config.sample_ratio 64 | 65 | datasets[name] = dataset 66 | 67 | return datasets 68 | 69 | def train_step(self, model, samples): 70 | loss = model(samples)["loss"] 71 | return loss 72 | 73 | def valid_step(self, model, samples): 74 | raise NotImplementedError 75 | 76 | def before_evaluation(self, model, dataset, **kwargs): 77 | model.before_evaluation(dataset=dataset, task_type=type(self)) 78 | 79 | def after_evaluation(self, **kwargs): 80 | pass 81 | 82 | def inference_step(self): 83 | raise NotImplementedError 84 | 85 | def evaluation(self, model, data_loader, cuda_enabled=True): 86 | metric_logger = MetricLogger(delimiter=" ") 87 | header = "Evaluation" 88 | # TODO make it configurable 89 | print_freq = 10 90 | 91 | results = [] 92 | 93 | for samples in metric_logger.log_every(data_loader, print_freq, header): 94 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 95 | 96 | eval_output = self.valid_step(model=model, samples=samples) 97 | results.extend(eval_output) 98 | 99 | if is_dist_avail_and_initialized(): 100 | dist.barrier() 101 | 102 | return results 103 | 104 | def train_epoch( 105 | self, 106 | epoch, 107 | model, 108 | data_loader, 109 | optimizer, 110 | lr_scheduler, 111 | scaler=None, 112 | cuda_enabled=False, 113 | log_freq=50, 114 | accum_grad_iters=1, 115 | use_zero_optimizer=False, 116 | ): 117 | return self._train_inner_loop( 118 | epoch=epoch, 119 | iters_per_epoch=lr_scheduler.iters_per_epoch, 120 | model=model, 121 | data_loader=data_loader, 122 | optimizer=optimizer, 123 | scaler=scaler, 124 | lr_scheduler=lr_scheduler, 125 | log_freq=log_freq, 126 | cuda_enabled=cuda_enabled, 127 | accum_grad_iters=accum_grad_iters, 128 | use_zero_optimizer=use_zero_optimizer, 129 | ) 130 | 131 | def train_iters( 132 | self, 133 | epoch, 134 | start_iters, 135 | iters_per_inner_epoch, 136 | model, 137 | data_loader, 138 | optimizer, 139 | lr_scheduler, 140 | scaler=None, 141 | cuda_enabled=False, 142 | log_freq=50, 143 | accum_grad_iters=1, 144 | ): 145 | return self._train_inner_loop( 146 | epoch=epoch, 147 | start_iters=start_iters, 148 | iters_per_epoch=iters_per_inner_epoch, 149 | model=model, 150 | data_loader=data_loader, 151 | optimizer=optimizer, 152 | scaler=scaler, 153 | lr_scheduler=lr_scheduler, 154 | log_freq=log_freq, 155 | cuda_enabled=cuda_enabled, 156 | accum_grad_iters=accum_grad_iters, 157 | ) 158 | 159 | def _train_inner_loop( 160 | self, 161 | epoch, 162 | iters_per_epoch, 163 | model, 164 | data_loader, 165 | optimizer, 166 | lr_scheduler, 167 | scaler=None, 168 | start_iters=None, 169 | log_freq=50, 170 | cuda_enabled=False, 171 | accum_grad_iters=1, 172 | use_zero_optimizer=False, 173 | ): 174 | """ 175 | An inner training loop compatible with both epoch-based and iter-based training. 176 | 177 | When using epoch-based, training stops after one epoch; when using iter-based, 178 | training stops after #iters_per_epoch iterations. 179 | """ 180 | use_amp = scaler is not None 181 | 182 | if not hasattr(data_loader, "__next__"): 183 | # convert to iterator if not already 184 | data_loader = iter(data_loader) 185 | 186 | metric_logger = MetricLogger(delimiter=" ") 187 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 188 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 189 | 190 | # if iter-based runner, schedule lr based on inner epoch. 191 | logging.info( 192 | "Start training epoch {}, {} iters per inner epoch.".format( 193 | epoch, iters_per_epoch 194 | ) 195 | ) 196 | header = "Train: data epoch: [{}]".format(epoch) 197 | if start_iters is None: 198 | # epoch-based runner 199 | inner_epoch = epoch 200 | else: 201 | # In iter-based runner, we schedule the learning rate based on iterations. 202 | inner_epoch = start_iters // iters_per_epoch 203 | header = header + "; inner epoch [{}]".format(inner_epoch) 204 | 205 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 206 | # if using iter-based runner, we stop after iters_per_epoch iterations. 207 | if i >= iters_per_epoch: 208 | break 209 | 210 | samples = next(data_loader) 211 | 212 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 213 | samples.update( 214 | { 215 | "epoch": inner_epoch, 216 | "num_iters_per_epoch": iters_per_epoch, 217 | "iters": i, 218 | } 219 | ) 220 | 221 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 222 | 223 | with torch.cuda.amp.autocast(enabled=use_amp): 224 | loss = self.train_step(model=model, samples=samples) 225 | 226 | # after_train_step() 227 | if use_amp: 228 | scaler.scale(loss).backward() 229 | else: 230 | loss.backward() 231 | 232 | # update gradients every accum_grad_iters iterations 233 | if (i + 1) % accum_grad_iters == 0: 234 | if use_amp: 235 | scaler.step(optimizer) 236 | scaler.update() 237 | else: 238 | optimizer.step() 239 | optimizer.zero_grad() 240 | 241 | metric_logger.update(loss=loss.item()) 242 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 243 | 244 | # after train_epoch() 245 | # gather the stats from all processes 246 | metric_logger.synchronize_between_processes() 247 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 248 | return { 249 | k: "{:.3f}".format(meter.global_avg) 250 | for k, meter in metric_logger.meters.items() 251 | } 252 | 253 | @staticmethod 254 | def save_result(result, result_dir, filename, remove_duplicate=""): 255 | import json 256 | 257 | result_file = os.path.join( 258 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 259 | ) 260 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 261 | 262 | json.dump(result, open(result_file, "w")) 263 | 264 | if is_dist_avail_and_initialized(): 265 | dist.barrier() 266 | 267 | if is_main_process(): 268 | logging.warning("rank %d starts merging results." % get_rank()) 269 | # combine results from all processes 270 | result = [] 271 | 272 | for rank in range(get_world_size()): 273 | result_file = os.path.join( 274 | result_dir, "%s_rank%d.json" % (filename, rank) 275 | ) 276 | res = json.load(open(result_file, "r")) 277 | result += res 278 | 279 | if remove_duplicate: 280 | result_new = [] 281 | id_list = [] 282 | for res in result: 283 | if res[remove_duplicate] not in id_list: 284 | id_list.append(res[remove_duplicate]) 285 | result_new.append(res) 286 | result = result_new 287 | 288 | json.dump(result, open(final_result_file, "w")) 289 | print("result file saved to %s" % final_result_file) 290 | 291 | return final_result_file 292 | -------------------------------------------------------------------------------- /minigpt4/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /minigpt4/tasks/mimic_generate_then_refine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import logging 8 | import torch 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.tasks.base_task import BaseTask 12 | from minigpt4.common.logger import MetricLogger, SmoothedValue 13 | from minigpt4.datasets.data_utils import prepare_sample 14 | 15 | 16 | @registry.register_task("mimic_generate_then_refine") 17 | class MIMICGenerateThenRefine(BaseTask): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def train_step(self, model, samples): 22 | loss = model(samples)["loss"] 23 | return loss 24 | 25 | def _train_inner_loop( 26 | self, 27 | epoch, 28 | iters_per_epoch, 29 | model, 30 | data_loader, 31 | optimizer, 32 | lr_scheduler, 33 | scaler=None, 34 | start_iters=None, 35 | log_freq=50, 36 | cuda_enabled=False, 37 | accum_grad_iters=1, 38 | use_zero_optimizer=False, 39 | ): 40 | """ 41 | An inner training loop compatible with both epoch-based and iter-based training. 42 | 43 | When using epoch-based, training stops after one epoch; when using iter-based, 44 | training stops after #iters_per_epoch iterations. 45 | """ 46 | use_amp = scaler is not None 47 | 48 | if not hasattr(data_loader, "__next__"): 49 | # convert to iterator if not already 50 | data_loader = iter(data_loader) 51 | 52 | metric_logger = MetricLogger(delimiter=" ") 53 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 54 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 55 | 56 | # if iter-based runner, schedule lr based on inner epoch. 57 | logging.info( 58 | "Start training epoch {}, {} iters per inner epoch.".format( 59 | epoch, iters_per_epoch 60 | ) 61 | ) 62 | header = "Train: data epoch: [{}]".format(epoch) 63 | if start_iters is None: 64 | # epoch-based runner 65 | inner_epoch = epoch 66 | else: 67 | # In iter-based runner, we schedule the learning rate based on iterations. 68 | inner_epoch = start_iters // iters_per_epoch 69 | header = header + "; inner epoch [{}]".format(inner_epoch) 70 | 71 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 72 | # if using iter-based runner, we stop after iters_per_epoch iterations. 73 | if i >= iters_per_epoch: 74 | break 75 | 76 | samples = next(data_loader) 77 | 78 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 79 | samples.update( 80 | { 81 | "epoch": inner_epoch, 82 | "num_iters_per_epoch": iters_per_epoch, 83 | "iters": i, 84 | } 85 | ) 86 | 87 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 88 | 89 | with torch.cuda.amp.autocast(enabled=use_amp): 90 | loss = self.train_step(model=model, samples=samples) 91 | 92 | # after_train_step() 93 | if use_zero_optimizer: 94 | model.backward(loss) 95 | else: 96 | if use_amp: 97 | scaler.scale(loss).backward() 98 | else: 99 | loss.backward() 100 | 101 | 102 | 103 | # update gradients every accum_grad_iters iterations 104 | if (i + 1) % accum_grad_iters == 0: 105 | if use_zero_optimizer: 106 | model.step() 107 | else: 108 | if use_amp: 109 | scaler.step(optimizer) 110 | scaler.update() 111 | else: 112 | optimizer.step() 113 | optimizer.zero_grad() 114 | 115 | metric_logger.update(loss=loss.item()) 116 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 117 | 118 | # after train_epoch() 119 | # gather the stats from all processes 120 | metric_logger.synchronize_between_processes() 121 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 122 | return { 123 | k: "{:.3f}".format(meter.global_avg) 124 | for k, meter in metric_logger.meters.items() 125 | } 126 | 127 | def evaluation(self, model, data_loader, cuda_enabled=True): 128 | pass 129 | -------------------------------------------------------------------------------- /prompts/stage1-pretraining-prompts.txt: -------------------------------------------------------------------------------- 1 | Describe this image in detail. 2 | Take a look at this image and describe what you notice. 3 | Please provide a detailed description of the picture. 4 | Could you describe the contents of this image for me? -------------------------------------------------------------------------------- /prompts/stage2-generation-prompts.txt: -------------------------------------------------------------------------------- 1 | You are a AI radiologist assistant. Your goal is to describe the syndromes reflected in the the radiograph in details. The description should be reasonable and should not be made up. Describe as informative as possible. -------------------------------------------------------------------------------- /prompts/stage2-refinement-prompts.txt: -------------------------------------------------------------------------------- 1 | ### Human: Rewrite the sentences in the report according to the image. Delete irrelevant descriptions in the report. Supply missing descriptions in the report. Write it as informative as possible. Keep the writing style unchanged. ### Assistant 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import argparse 9 | import os 10 | import random 11 | import shutil 12 | 13 | import numpy as np 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | 17 | import minigpt4.tasks as tasks 18 | from minigpt4.common.config import Config 19 | from minigpt4.common.dist_utils import get_rank, init_distributed_mode 20 | from minigpt4.common.logger import setup_logger 21 | from minigpt4.common.optims import ( 22 | LinearWarmupCosineLRScheduler, 23 | LinearWarmupStepLRScheduler, 24 | ) 25 | from minigpt4.common.registry import registry 26 | from minigpt4.common.utils import now 27 | 28 | # imports modules for registration 29 | from minigpt4.datasets.builders import * 30 | from minigpt4.models import * 31 | from minigpt4.processors import * 32 | from minigpt4.runners import * 33 | from minigpt4.tasks import * 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description="Training") 38 | 39 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 40 | parser.add_argument( 41 | "--options", 42 | nargs="+", 43 | help="override some settings in the used config, the key-value pair " 44 | "in xxx=yyy format will be merged into config file (deprecate), " 45 | "change to --cfg-options instead.", 46 | ) 47 | 48 | # TODO: deepspeed configurations 49 | parser.add_argument('--use_zero_optimizer', action='store_true', help='use ZeRO optimizer to save GPU memory') 50 | parser.add_argument('--local_rank', default=0, type=int, help='local rank') 51 | parser.add_argument('--deepspeed_config', type=str, default='train_configs/zero_configs/stage1.json', help='path to deepspeed configuration file') 52 | parser.add_argument('--train_batch_size', type=int, default=1, help='training batch size') 53 | parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=1, help='batch size per GPU') 54 | 55 | args = parser.parse_args() 56 | # if 'LOCAL_RANK' not in os.environ: 57 | # os.environ['LOCAL_RANK'] = str(args.local_rank) 58 | 59 | return args 60 | 61 | 62 | def setup_seeds(config): 63 | seed = config.run_cfg.seed + get_rank() 64 | 65 | random.seed(seed) 66 | np.random.seed(seed) 67 | torch.manual_seed(seed) 68 | 69 | cudnn.benchmark = False 70 | cudnn.deterministic = True 71 | 72 | 73 | def get_runner_class(cfg): 74 | """ 75 | Get runner class from config. Default to epoch-based runner. 76 | """ 77 | runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) 78 | 79 | return runner_cls 80 | 81 | 82 | def main(): 83 | # allow auto-dl completes on main process without timeout when using NCCL backend. 84 | # os.environ["NCCL_BLOCKING_WAIT"] = "1" 85 | 86 | # set before init_distributed_mode() to ensure the same job_id shared across all ranks. 87 | job_id = now() 88 | 89 | cfg = Config(parse_args()) 90 | 91 | init_distributed_mode(cfg.run_cfg) 92 | 93 | setup_seeds(cfg) 94 | 95 | # set after init_distributed_mode() to only log on master. 96 | setup_logger() 97 | 98 | cfg.pretty_print() 99 | 100 | task = tasks.setup_task(cfg) 101 | datasets = task.build_datasets(cfg) 102 | model = task.build_model(cfg) 103 | 104 | # TODO: define arguments, required by deepspeed 105 | args = parse_args() 106 | args.train_batch_size = cfg.run_cfg.batch_size_train 107 | args.train_micro_batch_size_per_gpu = args.train_batch_size // cfg.run_cfg.world_size 108 | 109 | 110 | runner = get_runner_class(cfg)( 111 | cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets, cmd_args=args, 112 | ) 113 | runner.train() 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /train_configs/stage1/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | model_type: pretrain_vicuna 4 | freeze_vit: True 5 | freeze_qformer: True 6 | freeze_llama: False 7 | max_txt_len: 160 8 | end_sym: "###" 9 | generation_prompt_path: "prompts/stage1-pretraining-prompts.txt" 10 | refinement_prompt_path: "prompts/stage1-pretraining-prompts.txt" 11 | prompt_template: '###Human: {} ###Assistant: ' 12 | ckpt: '/path/to/linear/layer' 13 | is_pretraining: True 14 | 15 | use_contrastive_loss: False 16 | use_refinement_loss: False 17 | triplet_margin: 0.5 18 | triplet_weight: 1.0 19 | refinement_loss_weight: 1.0 20 | 21 | # lora configuartion 22 | use_lora: True # use lora for vicuna 23 | use_lora_vit_qformer: False # use lora for vision backbone 24 | lora_rank: 8 25 | lora_alpha: 32 26 | lora_dropout: 0.1 27 | 28 | # ZeRO optimizer configuration 29 | use_zero_optimizer: True 30 | deepspeed_config: "train_configs/stage1/zero.json" 31 | 32 | datasets: 33 | mimic_generate_then_refine: 34 | vis_processor: 35 | train: 36 | name: "blip2_image_train" 37 | image_size: 224 38 | text_processor: 39 | train: 40 | name: "blip_caption" 41 | 42 | run: 43 | task: mimic_generate_then_refine 44 | # optimizer 45 | lr_sched: "linear_warmup_cosine_lr" 46 | init_lr: 3e-5 47 | min_lr: 1e-5 48 | warmup_lr: 1e-6 49 | 50 | weight_decay: 0.05 51 | max_epoch: 10 52 | iters_per_epoch: 5000 # 200 53 | batch_size_train: 1 # total batch size, not per GPU 54 | batch_size_eval: 1 55 | num_workers: 4 56 | warmup_steps: 200 57 | 58 | seed: 42 59 | output_dir: "/path/to/output/dir" 60 | 61 | amp: True 62 | resume_ckpt_path: null 63 | 64 | evaluate: False 65 | train_splits: ["train"] 66 | 67 | device: "cuda" 68 | world_size: 2 69 | dist_url: "env://" 70 | distributed: True 71 | 72 | # ZeRO optimizer configuration 73 | use_zero_optimizer: True 74 | deepspeed_config: "train_configs/stage1/zero.json" -------------------------------------------------------------------------------- /train_configs/stage1/zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "reduce_bucket_size": 5e8 5 | }, 6 | "train_batch_size": 24 7 | } -------------------------------------------------------------------------------- /train_configs/stage2/iuxray/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | model_type: pretrain_vicuna 4 | freeze_vit: True 5 | freeze_qformer: True 6 | max_txt_len: 100 7 | end_sym: "###" 8 | generation_prompt_path: "prompts/stage-2-generation-prompts.txt" 9 | refinement_prompt_path: "prompts/stage-2-refinement-prompts.txt" 10 | prompt_template: '###Human: {} ###Assistant: ' 11 | ckpt: '/path/to/linear' 12 | 13 | use_contrastive_loss: True 14 | use_refinement_loss: True 15 | triplet_margin: 0.5 16 | triplet_weight: 1.0 17 | refinement_loss_weight: 1.0 18 | 19 | # lora configuartion 20 | use_lora: True 21 | lora_rank: 32 22 | lora_alpha: 32 23 | lora_dropout: 0.1 24 | 25 | # ZeRO optimizer configuration 26 | use_zero_optimizer: True 27 | deepspeed_config: "train_configs/stage2/zero.json" 28 | 29 | datasets: 30 | mimic_generate_then_refine: 31 | vis_processor: 32 | train: 33 | name: "blip2_image_train" 34 | image_size: 224 35 | text_processor: 36 | train: 37 | name: "blip_caption" 38 | 39 | run: 40 | task: mimic_generate_then_refine 41 | # optimizer 42 | lr_sched: "linear_warmup_cosine_lr" 43 | init_lr: 3e-5 44 | min_lr: 1e-5 45 | warmup_lr: 1e-6 46 | 47 | weight_decay: 0.05 48 | max_epoch: 10 49 | iters_per_epoch: 1000 # 200 50 | batch_size_train: 1 # total batch size, not per GPU 51 | batch_size_eval: 1 52 | num_workers: 4 53 | warmup_steps: 200 54 | 55 | seed: 42 56 | output_dir: "/path/to/output" 57 | 58 | amp: True 59 | resume_ckpt_path: null 60 | 61 | evaluate: False 62 | train_splits: ["train"] 63 | 64 | device: "cuda" 65 | world_size: 2 66 | dist_url: "env://" 67 | distributed: True 68 | 69 | # ZeRO optimizer configuration 70 | use_zero_optimizer: True 71 | deepspeed_config: "train_configs/stage2/zero.json" -------------------------------------------------------------------------------- /train_configs/stage2/iuxray/zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "reduce_bucket_size": 5e8 5 | }, 6 | "train_batch_size": 12, 7 | "scheduler": { 8 | "type": "WarmupLR", 9 | "params": { 10 | "warmup_min_lr": 1e-6, 11 | "warmup_max_lr": 1e-5, 12 | "warmup_num_steps": 200 13 | } 14 | } 15 | } -------------------------------------------------------------------------------- /train_configs/stage2/mimic/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4 3 | model_type: pretrain_vicuna 4 | freeze_vit: True 5 | freeze_qformer: True 6 | max_txt_len: 100 7 | end_sym: "###" 8 | generation_prompt_path: "/path/to/generation/prompts" 9 | refinement_prompt_path: "/path/to/refinement/prompts" 10 | prompt_template: '###Human: {} ###Assistant: ' 11 | ckpt: '/path/to/linear' 12 | 13 | use_contrastive_loss: True 14 | use_refinement_loss: True 15 | triplet_margin: 0.5 16 | triplet_weight: 1.0 17 | refinement_loss_weight: 1.0 18 | 19 | # lora configuartion 20 | use_lora: True 21 | lora_rank: 32 22 | lora_alpha: 32 23 | lora_dropout: 0.1 24 | 25 | # ZeRO optimizer configuration 26 | use_zero_optimizer: True 27 | deepspeed_config: "train_configs/stage2/zero.json" 28 | 29 | datasets: 30 | mimic_generate_then_refine: 31 | vis_processor: 32 | train: 33 | name: "blip2_image_train" 34 | image_size: 224 35 | text_processor: 36 | train: 37 | name: "blip_caption" 38 | 39 | run: 40 | task: mimic_generate_then_refine 41 | # optimizer 42 | lr_sched: "linear_warmup_cosine_lr" 43 | init_lr: 3e-5 44 | min_lr: 1e-5 45 | warmup_lr: 1e-6 46 | 47 | weight_decay: 0.05 48 | max_epoch: 10 49 | iters_per_epoch: 1000 # 200 50 | batch_size_train: 1 # total batch size, not per GPU 51 | batch_size_eval: 1 52 | num_workers: 4 53 | warmup_steps: 200 54 | 55 | seed: 42 56 | output_dir: "/path/to/output" 57 | 58 | amp: True 59 | resume_ckpt_path: null 60 | 61 | evaluate: False 62 | train_splits: ["train"] 63 | 64 | device: "cuda" 65 | world_size: 2 66 | dist_url: "env://" 67 | distributed: True 68 | 69 | # ZeRO optimizer configuration 70 | use_zero_optimizer: True 71 | deepspeed_config: "train_configs/stage2/zero.json" -------------------------------------------------------------------------------- /train_configs/stage2/mimic/zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "reduce_bucket_size": 5e8 5 | }, 6 | "train_batch_size": 12, 7 | "scheduler": { 8 | "type": "WarmupLR", 9 | "params": { 10 | "warmup_min_lr": 1e-6, 11 | "warmup_max_lr": 1e-5, 12 | "warmup_num_steps": 200 13 | } 14 | } 15 | } --------------------------------------------------------------------------------