├── .gitignore ├── README.md ├── app.py ├── common_utils.py ├── constaints.py ├── data └── examples.json ├── data_utils.py ├── ds_config_first_stage.json ├── ds_config_second_stage.json ├── flash_attention.py ├── inference.py ├── introduction.png ├── main.py ├── mmbench_evaluation.py ├── mme_evaluation.py ├── model.py ├── process_instruction_data.py ├── process_mim.py ├── requirements.txt ├── scripts ├── run_demo.sh ├── run_first_stage.sh ├── run_first_stage_val.sh ├── run_inference.sh ├── run_mmbench_eval.sh ├── run_mme_eval.sh ├── run_second_stage.sh └── run_second_stage_val.sh ├── stable_diffusion.py ├── templates └── index.html └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextBind: Multi-turn Interleaved Multimodal Instruction-following 2 | 3 | ![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg) 4 | ![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg) 5 | ![Model Weight License](https://img.shields.io/badge/Model_Weight%20License-CC%20By%20NC%204.0-red.svg) 6 | 7 | 8 |

9 | 🌐 Project Page • 🤗 Online Demo • 📃 Paper • ⏬ Data • 🤖 Model 10 |

11 | 12 | 13 | **** 14 | 15 | 16 | ## Content: 17 | * 1. Introduction 18 | * 2. Build Our Demo Locally 19 | * 2.1. Environment Installation 20 | * 2.2. Prepare Vsion Model 21 | * 2.3. Prepare TextBind Weights 22 | * 2.4. Running Demo 23 | * 3. Train Your Own Models Using Our TextBind Recipe 24 | * 3.1. Data Preparation 25 | * 3.2. Prepare BLIP-2 Q-Former 26 | * 3.3. Training Configurations 27 | * 3.4. Training TextBind 28 | * Usage and License Notices 29 | * Citation 30 | 31 | **** 32 | 33 | 34 | 35 | ### 1. Introduction: [Back to Top] 36 | 37 |

38 | 39 |

40 | 41 | Large language models with instruction-following abilities have revolutionized the field of artificial intelligence. These models show exceptional generalizability to tackle various real-world tasks through their natural language interfaces. However, their performance heavily relies on high-quality exemplar data, which is often difficult to obtain. This challenge is further exacerbated when it comes to multimodal instruction following. We introduce TextBind, an almost annotation-free framework for empowering larger language models with the multi-turn interleaved multimodal instruction-following capabilities. Our approach requires only image-caption pairs and generates multi-turn multimodal instruction-response conversations from a language model. 42 | 43 | **** 44 | 45 | 46 | 47 | ### 2. Build Our Demo Locally: [Back to Top] 48 | 49 | 50 | 51 | #### 2.1. Install Environment: 52 | Install the Pytorch package with the correct cuda version, for example 53 | ``` 54 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 55 | ``` 56 | 57 | Then install the required environment, please run 58 | ``` 59 | pip install -r requirements.txt 60 | ``` 61 | 62 | 63 | 64 | #### 2.2. Prepare Vsion Model: 65 | Follow BLIP-2, we use EVA-CLIP as the vision model, you can run the following commands to prepare: 66 | ``` 67 | import torch 68 | from transformers import Blip2ForConditionalGeneration 69 | model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl") 70 | vision_model = model.vision_model 71 | vision_model.save_pretrained("checkpoint/blip2_vision_model") 72 | ``` 73 | 74 | 75 | 76 | #### 2.3. Prepare TextBind Weights: 77 | 78 | |**Base Language Model**|**Huggingface Weights Address**|**Maximum Sequence Length**| 79 | |:-------------:|:-------------:|:-------------:| 80 | |Llama-2-7b-chat-hf|[SihengLi/TextBind](https://huggingface.co/SihengLi/TextBind)|768| 81 | 82 | Then please put the downloaded checkpoints under the [./checkpoint/](./checkpoint/) directory 83 | 84 | 85 | 86 | #### 2.4. Running Demo: 87 | Please set the checkpoints in scripts/run_demo.sh as 88 | ``` 89 | CHECKPOINT=./checkpoint/second_stage_model.pt 90 | VISION_MODEL=./checkpoint/blip2_vision_model 91 | LANGUAGE_MODEL=meta-llama/Llama-2-7b-chat-hf 92 | PROCESSOR=Salesforce/blip2-flan-t5-xxl 93 | SD_BASE=stabilityai/stable-diffusion-xl-base-1.0 94 | SD_REFINER=stabilityai/stable-diffusion-xl-refiner-1.0 95 | ``` 96 | Then you can run the demo locally as 97 | ```bash 98 | bash scripts/run_demo.sh 99 | ``` 100 | **** 101 | 102 | 103 | 104 | ### 3. Train Your Own Models Using Our TextBind Recipe: [Back to Top] 105 | 106 | **Prerequisites:** Before training the model, making sure the environment is properly installed and the vision model has been prepared. You can refer to [Here] for more information. 107 | 108 | 109 | 110 | #### 3.1. Data Preparation: 111 | 112 | **Declaimer:** To ensure the reproducibility of our results, we have released our training dataset. The dataset must be used for research purpose only. 113 | 114 | |**Training Stage**|**Dataset Address**| 115 | |:-------------:|:-------------:| 116 | |Multimodal Alignment|[CC3M+CC12M+SBU](https://github.com/Vision-CAIR/MiniGPT-4/blob/main/dataset/README_1_STAGE.md)| 117 | |Multimodal Instruction Following|[TextBind](https://drive.google.com/drive/folders/1-SkzQRInSfrVyZeB0EZJzpCPXXwHb27W?usp=sharing)| 118 | 119 | After downloading, put the downloaded file under the [./data/](./data/) directory. 120 | 121 | For our textbind, you need to download the images manually using the url_list provided in the downloaded file and rename the download image according to the image_list. 122 | 123 | > **** The data directory should look like: 124 | 125 | . 126 | └── ./data/ 127 | └── /cc_sbu/ 128 | └── /cc_sbu_dataset/ 129 | └── {00000..01254}.tar 130 | └── /textbind/ 131 | ├── train.json 132 | └── /images/ 133 | ├── 490272.png 134 | ├── 862235.png 135 | └── ... 136 | 137 | 138 | 139 | 140 | 141 | #### 3.2. Prepare BLIP-2 Q-Former: 142 | BLIP-2 Q-Former is utilized for the initialization of our Q-Former, run: 143 | ``` 144 | import torch 145 | from transformers import Blip2ForConditionalGeneration 146 | model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl") 147 | 148 | state_dict = model.state_dict() 149 | state_dict = {key: value for key, value in state_dict.items() if key.split(".")[0] in ["query_tokens", "qformer"]} 150 | torch.save(state_dict, "checkpoint/blip2_qformer.pt") 151 | ``` 152 | 153 | 154 | 155 | 156 | 157 | #### 3.3 Training Configurations: 158 | 159 | The table below show the training hyperparameters used in our experiments. The hyperparameters are selected based on the constrain of our computational resources, i.e. 8 x A100 (40G) GPUs. 160 | 161 | |**Training Stage**|**Language Model**|**Epoch**|**Batch Size**|**Learning Rate**|**Training Modules**| 162 | |:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:| 163 | |Multimodal Alignment|Llama-2-7b-chat-hf|2|256|1e-4|Q-Former, Linear| 164 | |Multimodal Instruction Following|Llama-2-7b-chat-hf|3|64|1e-5|QFormer, Linear, LLM| 165 | 166 | 167 | 168 | 169 | 170 | 171 | #### 3.4. Training TextBind: 172 | For the multimodal alignment stage, please set the paths in scripts/run_first_stage.sh as 173 | ``` 174 | TRAIN_DATA_PATH=${your_first_stage_data_path} 175 | CHECKPOINT=./checkpoint/blip2_qformer.pt 176 | VISION_MODEL=./checkpoint/blip2_vision_model 177 | LANGUAGE_MODEL=meta-llama/Llama-2-7b-chat-hf 178 | PROCESSOR=Salesforce/blip2-flan-t5-xxl 179 | ``` 180 | 181 | then run the following commands: 182 | ``` 183 | bash scripts/run_first_stage.sh 184 | ``` 185 | 186 | For the multimodel instruction tuning stage, please set the paths in scripts/run_second_stage.sh as 187 | ``` 188 | TRAIN_DATA_PATH=${your_second_stage_data_path} 189 | CHECKPOINT=${your_first_stage_model_path} 190 | VISION_MODEL=./checkpoint/blip2_vision_model 191 | LANGUAGE_MODEL=meta-llama/Llama-2-7b-chat-hf 192 | PROCESSOR=Salesforce/blip2-flan-t5-xxl 193 | ``` 194 | 195 | then run the following commands: 196 | ``` 197 | bash scripts/run_second_stage.sh 198 | ``` 199 | 200 | **** 201 | 202 | 203 | 204 | ### Usage and License Notices: 205 | 206 | TextBind is intended and licensed for research use only. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. The delta weights are also CC BY NC 4.0 (allowing only non-commercial use). 207 | 208 | 209 | **** 210 | 211 | 212 | 213 | ### Citation: 214 | 215 | If you found TextBind useful in your research or applications, please kindly cite using the following BibTeX: 216 | ``` 217 | @article{li2023textbind, 218 | title={TextBind: Multi-turn Interleaved Multimodal Instruction-following in the Wild}, 219 | author={Li, Huayang and Li, Siheng and Cai, Deng and Wang, Longyue and Liu, Lemao and Watanabe, Taro and Yang, Yujiu and Shi, Shuming}, 220 | year={2023} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import base64 4 | import hashlib 5 | import torch 6 | from bs4 import BeautifulSoup 7 | from flask import Flask, render_template, request, jsonify, send_from_directory, Blueprint 8 | from main import parse_args 9 | from inference import MIMPipeline 10 | import logging 11 | from common_utils import FileUtils 12 | import random 13 | random.seed(10086) 14 | 15 | logging.basicConfig( 16 | level=logging.INFO, 17 | format="%(asctime)s [%(levelname)s] %(message)s", 18 | handlers=[ 19 | logging.StreamHandler() 20 | ], 21 | ) 22 | 23 | def check_url_prefix(url_prefix): 24 | url_prefix = url_prefix.strip() 25 | if url_prefix: 26 | if url_prefix.startswith("/"): 27 | return url_prefix 28 | else: 29 | return "/" + url_prefix 30 | else: 31 | return url_prefix 32 | 33 | PUNC = set([".", ",", "!"]) 34 | 35 | app = Flask(__name__) 36 | args = parse_args() 37 | url_prefix = check_url_prefix(args.url_prefix) 38 | image_remote_path="{}/images/".format(url_prefix) 39 | image_local_dir = os.path.abspath(os.path.dirname(__file__)) + image_remote_path 40 | FileUtils.check_dirs(image_local_dir) 41 | logging.info("Remote image path: {}".format(image_remote_path)) 42 | logging.info("Local image dir: {}".format(image_local_dir)) 43 | api = Blueprint('api', __name__, url_prefix=url_prefix if url_prefix else None) 44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | logging.info("Loading model for demo...") 46 | agent = MIMPipeline(args, device) 47 | engines = agent.engines 48 | examples = FileUtils.load_file(args.demo_example_path) if args.demo_example_path else [] 49 | 50 | def run_inference(data, selection): 51 | global agent 52 | return agent.run(data, selection) 53 | 54 | def process_user_input(user_input): 55 | soup = BeautifulSoup(user_input, 'html.parser') 56 | images = soup.find_all('img') 57 | 58 | for img in images: 59 | if not img['src'].startswith(image_remote_path): 60 | sha1 = hashlib.sha1() 61 | img_data = img['src'].split(',')[1] 62 | img_bytes = base64.b64decode(img_data) 63 | sha1.update(img_bytes) 64 | img_ext = img['src'].split(',')[0].split('/')[1].split(';')[0] 65 | img_filename = "{}.{}".format(sha1.hexdigest(), img_ext) 66 | image_local_path = os.path.join(image_local_dir, img_filename) 67 | if not os.path.exists(image_local_path): 68 | with open(image_local_path, 'wb') as f: 69 | f.write(img_bytes) 70 | user_input = user_input.replace(img['src'], os.path.join(image_remote_path, img_filename)) 71 | return user_input 72 | 73 | 74 | def parse_chat_history(chat_history_html): 75 | # Parse the chat history using BeautifulSoup 76 | soup = BeautifulSoup(chat_history_html, 'html.parser') 77 | messages = soup.find_all(class_='message') 78 | chat_history = [] 79 | for message in messages: 80 | role = 'user' if 'user-message' in message['class'] else 'assistant' 81 | cur_message = {"role": role, "data": []} 82 | 83 | if role == 'user': 84 | if message: 85 | user_input = process_user_input(str(message)) 86 | for element in BeautifulSoup(user_input, 'html.parser').recursiveChildGenerator(): 87 | if element.name == 'img': 88 | message_data = {'type': 'image', 'value': element['src']} 89 | cur_message['data'].append(message_data) 90 | elif element.name is None: 91 | message_data = {'type': 'text', 'value': element.string.strip()} 92 | cur_message['data'].append(message_data) 93 | else: 94 | for child in message.children: 95 | if child.name == 'img': 96 | message_data = { 97 | 'type': 'image', 98 | 'value': child['src'], 99 | } 100 | cur_message['data'].append(message_data) 101 | elif child.name == 'p': 102 | message_data = { 103 | 'type': 'text', 104 | 'value': child.get_text(strip=True), 105 | } 106 | cur_message['data'].append(message_data) 107 | elif child.name == 'div': 108 | for element in BeautifulSoup(str(child), 'html.parser').recursiveChildGenerator(): 109 | if element.name == 'img': 110 | message_data = {'type': 'image', 'value': element['src']} 111 | cur_message['data'].append(message_data) 112 | elif element.name is None: 113 | message_data = {'type': 'text', 'value': element.string.strip()} 114 | cur_message['data'].append(message_data) 115 | else: 116 | continue 117 | 118 | if cur_message['data']: 119 | chat_history.append(cur_message) 120 | return chat_history 121 | 122 | 123 | @api.route('/') 124 | def index(): 125 | return render_template('index.html', url_prefix=url_prefix) 126 | 127 | 128 | @api.route('/random_conversation', methods=['GET']) 129 | def random_dialogue(): 130 | # Implement a function to fetch a random dialogue history from the server 131 | if examples: 132 | ex = random.choice(examples) 133 | for item in ex['conversation']: 134 | if item['image_list'] and not item['image_list'][0].startswith(image_remote_path): 135 | item['image_list'] = ["{}/{}".format(image_remote_path, m) for m in item['image_list']] 136 | else: 137 | ex = [{'type': 'text', 'value': "No examples on server"}] 138 | logging.info("Example: \n{}".format(json.dumps(ex))) 139 | return jsonify(ex) 140 | 141 | 142 | @api.route('/images/') 143 | def serve_image(filename): 144 | return send_from_directory(image_local_dir, filename) 145 | 146 | 147 | def prepare_model_input(chat_history): 148 | data = {"conversation": []} 149 | for msg in chat_history: 150 | mesg_data, image_list = [], [] 151 | for it in msg['data']: 152 | if it['type'] == "text": 153 | mesg_data.append(it['value']) 154 | elif it['type'] == "image": 155 | mesg_data.append("") 156 | image_list.append(it['value'].split("/")[-1]) 157 | # mesg_data = " ".join(mesg_data) 158 | mesg_data = [it for it in mesg_data if it] 159 | mesg_data_str = mesg_data[0] 160 | for m in mesg_data[1:]: 161 | if m[0] in PUNC: 162 | mesg_data_str += m 163 | else: 164 | mesg_data_str += " " + m 165 | data["conversation"].append( 166 | {'role': "user" if msg['role'] == "user" else "assistant", "content": mesg_data_str, "image_list": image_list, "caption_list": []} 167 | ) 168 | return data 169 | 170 | 171 | def count_images_and_words(data): 172 | n_words = 0 173 | n_images = 0 174 | for turn in data['conversation']: 175 | n_images += len(turn['image_list']) 176 | n_words += len(turn['content'].split()) 177 | return n_words, n_images 178 | 179 | 180 | def split_model_output(gen_text, use_image_id): 181 | if not use_image_id: 182 | gen_text_splits = gen_text.split("") 183 | return gen_text_splits 184 | else: 185 | gen_text_splits, image_order = [], [] 186 | gen_text_words = gen_text.split() 187 | cache = [] 188 | for wi, w in enumerate(gen_text_words): 189 | if w.startswith("") 204 | if "" not in gen_text: 205 | gen_text_splits, image_order = split_model_output(gen_text, True) 206 | gen_imgs = [gen_imgs[i] for i in image_order] 207 | else: 208 | gen_text_splits = split_model_output(gen_text, False) 209 | n_splits = len(gen_text_splits) 210 | assert len(gen_imgs) == (n_splits-1) 211 | for i in range(n_splits): 212 | response.append( 213 | {"type": "text", "value": gen_text_splits[i].replace("", "").strip()} 214 | ) 215 | if i != (n_splits-1): 216 | response.append( 217 | {"type": "image", "value": "{}/{}".format(image_remote_path, gen_imgs[i])} 218 | ) 219 | return response 220 | 221 | @api.route('/engine-list', methods=['GET']) 222 | def engine_list(): 223 | global engines 224 | if not engines: 225 | el = [ 226 | {'id': '---', 'name': '---'}, 227 | ] 228 | else: 229 | el = engines 230 | return jsonify(el) 231 | 232 | @api.route('/chat', methods=['POST']) 233 | def chat(): 234 | data = request.get_json() 235 | user_input = data.get('user_input', '') 236 | nlp_engine = data.get('nlp_engine', 's2') 237 | # logging.info("Selected Engine: {}".format(nlp_engine)) 238 | chat_history_html = data.get('chat_history', '') 239 | chat_history = parse_chat_history(chat_history_html) if chat_history_html else [] 240 | 241 | # try: 242 | if user_input: 243 | # Process the user_input with your NLP engine here and get the response 244 | model_input = prepare_model_input(chat_history) 245 | logging.info("Model Input: \n{}".format(json.dumps(model_input))) 246 | n_words, n_images = count_images_and_words(model_input) 247 | logging.info("{} words and {} images in conversation".format(n_words, n_images)) 248 | if n_words >= args.safe_word_num or n_images >= args.safe_image_num: 249 | response_items = [{'type': 'text', 'value': "I'm sorry that I may not be able to continue this conversation, due to my limited GPU memory. The admin set the safe number of words and images in conversation to {} and {}, respectively. This strategy is to avoid the core dump of the GPU. If you want to experience longer conversation, please run our model and code on your more powerful GPUs!".format(args.safe_word_num, args.safe_image_num)}] 250 | else: 251 | model_output = run_inference(model_input, nlp_engine) 252 | logging.info("Model Response: \n{}".format(json.dumps(model_output))) 253 | response_items = parse_model_output(model_output) 254 | else: 255 | response_items = [{'type': 'text', 'value': "Please input some text or images."}] 256 | # except Exception as e: 257 | # logging.info("Error message:\ne") 258 | # response_items = [{'type': 'text', 'value': "Our server met some errors... Contact us to report the bug."}] 259 | 260 | return jsonify({'response': response_items}) 261 | 262 | 263 | app.register_blueprint(api) 264 | 265 | 266 | if __name__ == '__main__': 267 | app.run(debug=False, host='0.0.0.0', port=args.port) 268 | -------------------------------------------------------------------------------- /common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import yaml 5 | import os.path as path 6 | import tarfile 7 | from glob import glob 8 | from io import BytesIO 9 | import csv 10 | import multiprocessing as mp 11 | from PIL import Image, ImageOps 12 | import re 13 | import math 14 | import logging 15 | import constaints as C 16 | 17 | logger = logging.getLogger() 18 | logger.setLevel(logging.INFO) 19 | handler = logging.StreamHandler() 20 | handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 21 | handler.setLevel(logging.INFO) 22 | # Add auto-flushing to the StreamHandler 23 | handler.terminator = '\n' 24 | handler.flush = lambda: handler.stream.flush() 25 | logger.addHandler(handler) 26 | 27 | 28 | 29 | class FileType: 30 | PT = "pt" 31 | TXT = "txt" 32 | JSON = "json" 33 | TSV = "tsv" 34 | TAR = "tar" 35 | YAML = "yaml" 36 | CSV = "csv" 37 | ALL = ["pt", "txt", "json", "tsv", "tar", "yaml", "csv"] 38 | 39 | 40 | class FileExtensionType: 41 | ADD = "add" 42 | CHANGE = "change" 43 | 44 | 45 | LANGUAGE_LIST = ["en", "de", "zh", "vi", "fr"] 46 | 47 | 48 | class PrepUtils: 49 | 50 | @staticmethod 51 | def extract_ngrams(sent, n, return_index=False): 52 | if isinstance(sent, str): 53 | sent = sent.split() 54 | ngrams = [tuple(sent[i:i + n]) for i in range(len(sent) - n + 1)] 55 | if not return_index: 56 | return ngrams 57 | else: 58 | ids = list(range(len(sent))) 59 | ngram_ids = [tuple(ids[i:i + n]) for i in range(len(ids) - n + 1)] 60 | return ngrams, ngram_ids 61 | 62 | @staticmethod 63 | def resize_and_pad(image, size=(512, 512)): 64 | aspect_ratio = float(image.width) / float(image.height) 65 | 66 | if aspect_ratio > 1: 67 | new_width = size[0] 68 | new_height = int(size[0] / aspect_ratio) 69 | else: 70 | new_height = size[1] 71 | new_width = int(size[1] * aspect_ratio) 72 | 73 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 74 | padded_image = ImageOps.expand(resized_image, ( 75 | (size[0] - new_width) // 2, 76 | (size[1] - new_height) // 2, 77 | (size[0] - new_width + 1) // 2, 78 | (size[1] - new_height + 1) // 2), fill='white') 79 | 80 | return padded_image 81 | 82 | @staticmethod 83 | def gather_image_data(corpus): 84 | img_dict = dict() 85 | for data in corpus: 86 | for idx, url, cap in zip(data["image_idx"], data["url"], data['image']): 87 | cap = PrepUtils.clean_tag(cap) 88 | if idx not in img_dict: 89 | img_dict[idx] = (url, cap) 90 | ret = [{"image": "{}.png".format(idx), "url": v[0], "caption": v[1]} for idx, v in img_dict.items()] 91 | return ret 92 | 93 | @staticmethod 94 | def check_image_file(image_path): 95 | if not FileUtils.exists(image_path): 96 | return False 97 | else: 98 | try: 99 | _ = Image.open(image_path) 100 | return True 101 | except Exception: 102 | logger.info("Find damaged image {}".format(image_path)) 103 | return False 104 | 105 | @staticmethod 106 | def check_image_list(image_dir, image_path_list): 107 | for image_path in image_path_list: 108 | if not PrepUtils.check_image_file("{}/{}".format(image_dir, image_path)): 109 | return False 110 | return True 111 | 112 | @staticmethod 113 | def edit_distance(s1, s2): 114 | m, n = len(s1), len(s2) 115 | dp = [[0 for _ in range(n + 1)] for _ in range(m + 1)] 116 | 117 | for i in range(m + 1): 118 | for j in range(n + 1): 119 | if i == 0: 120 | dp[i][j] = j 121 | elif j == 0: 122 | dp[i][j] = i 123 | elif s1[i - 1] == s2[j - 1]: 124 | dp[i][j] = dp[i - 1][j - 1] 125 | else: 126 | dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) 127 | 128 | return dp[m][n] 129 | 130 | @staticmethod 131 | def extract_text(text, x): 132 | pattern = f".*?" 133 | match = re.search(pattern, text) 134 | if match: 135 | original_text = match.group(0)[len(f""):-len(f"")].strip() 136 | else: 137 | original_text = None 138 | return original_text 139 | 140 | @staticmethod 141 | def extract_idx(text): 142 | image_idx_list = [] 143 | matches = re.findall(r'.*?', text) 144 | for match in matches: 145 | image_idx = re.findall(r"", match)[0] 146 | image_idx_list.append(int(image_idx)) 147 | return image_idx_list 148 | 149 | @staticmethod 150 | def sub_image_tag(text): 151 | return re.sub(r'.*?', "", text) 152 | 153 | @staticmethod 154 | def split_turn(dialogue): 155 | start_index = dialogue.index(C.HUMAN) 156 | sents = dialogue[start_index:].strip().split("\n") 157 | turns = [] 158 | for sent in sents: 159 | sent = sent.strip() 160 | if sent.startswith(C.HUMAN): 161 | turns.append(sent) 162 | elif sent.startswith(C.ASSISTANT): 163 | turns.append(sent) 164 | else: 165 | if sent: 166 | turns[-1] += "\n" + sent 167 | return turns 168 | 169 | @staticmethod 170 | def is_valid_turn(turns): 171 | for i, t in enumerate(turns): 172 | if C.HUMAN in t and C.ASSISTANT in t: 173 | return False 174 | if t.startswith(C.HUMAN) and (i % 2) == 0: 175 | continue 176 | if t.startswith(C.ASSISTANT) and (i % 2) == 1: 177 | continue 178 | return False 179 | return True 180 | 181 | @staticmethod 182 | def has_repeated_images(ex): 183 | img_tags = ["".format(i) for i in range(len(ex['image']))] 184 | response = ex['response'] 185 | for it in img_tags: 186 | n = len(StringUtils.find_all_indices(response, it)) 187 | if n > 1: 188 | return True 189 | return False 190 | 191 | @staticmethod 192 | def has_unseen_image(ex): 193 | n_image = len(ex['image']) 194 | pattern = r"" 195 | img_ids = [int(it) for it in re.findall(pattern, ex['response'])] 196 | for i in img_ids: 197 | if i >= n_image: 198 | return True 199 | return False 200 | 201 | @staticmethod 202 | def has_image(ex): 203 | pattern = r'.*?' 204 | return bool(re.search(pattern, ex['response'])) 205 | 206 | @staticmethod 207 | def remove_non_paired_img_tag(text): 208 | img_tags = re.findall(r'', text) 209 | # Iterate through the tags and remove them if there is no paired closing tag 210 | for tag in img_tags: 211 | closing_tag = tag.replace('<', '', '', re.sub('', '', s)).strip() 220 | 221 | 222 | class StringUtils: 223 | @staticmethod 224 | def get_digit_num(n): 225 | return math.ceil(math.log10(n + 1)) 226 | 227 | @staticmethod 228 | def format_number(n, dn=None): 229 | if dn is None: 230 | dn = StringUtils.get_digit_num(n) + 3 231 | return "{:0{}}".format(n, dn) 232 | 233 | @staticmethod 234 | def camel_to_snake(name): 235 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 236 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 237 | 238 | @staticmethod 239 | def find_all_indices(text, substring): 240 | indices = [] 241 | index = -1 242 | while True: 243 | try: 244 | index = text.find(substring, index + 1) 245 | except ValueError: 246 | pass 247 | if index == -1: 248 | break 249 | indices.append(index) 250 | return indices 251 | 252 | 253 | class MPUtils: 254 | @staticmethod 255 | def prepare_shards(data, nproc): 256 | if (len(data) % nproc) == 0: 257 | ss = len(data) // nproc 258 | else: 259 | ss = len(data) // nproc + 1 260 | 261 | shards = [data[i*ss:(i+1)*ss] for i in range(nproc)] 262 | return shards 263 | 264 | 265 | @staticmethod 266 | def mp_func(target_func, args_list): 267 | nporc = len(args_list) 268 | processes = [] 269 | for i in range(nporc): 270 | proc = mp.Process(target=target_func, args=args_list[i]) 271 | proc.start() 272 | processes.append(proc) 273 | logging.info("Start process {}".format(i)) 274 | 275 | logging.info("Waiting for the finish of all processes") 276 | for proc in processes: 277 | proc.join() 278 | 279 | 280 | class FileUtils: 281 | @staticmethod 282 | def exists(file_path): 283 | return path.exists(file_path) 284 | 285 | @staticmethod 286 | def rename(source_fname, target_fname): 287 | if os.path.exists(source_fname): 288 | # Rename the file 289 | os.rename(source_fname, target_fname) 290 | logging.info("File renamed from {} to {}".format(source_fname, target_fname)) 291 | else: 292 | logging("The file {} does not exist".format(source_fname)) 293 | 294 | @staticmethod 295 | def is_dir(file_path): 296 | return path.isdir(file_path) 297 | 298 | @staticmethod 299 | def get_last_path(path): 300 | if not path: 301 | return path 302 | parent, last_path = os.path.split(path) 303 | if last_path: 304 | return last_path 305 | else: 306 | return FileUtils.get_last_path(parent) 307 | 308 | @staticmethod 309 | def get_dir(path): 310 | return os.path.dirname(path) 311 | 312 | @staticmethod 313 | def check_dirs(dir_path): 314 | if path.exists(dir_path): 315 | logging.info("{} already exists".format(dir_path)) 316 | else: 317 | logging.info("Making new directory {}".format(dir_path)) 318 | os.makedirs(dir_path) 319 | 320 | @staticmethod 321 | def check_basename(fpath): 322 | bname = os.path.basename(fpath) 323 | parts = bname.split(".") 324 | if len(parts) <= 1: 325 | return bname 326 | elif parts[-1] in LANGUAGE_LIST or parts[-1] in FileType.ALL: 327 | return ".".join(parts[:-1]) 328 | else: 329 | return bname 330 | 331 | @staticmethod 332 | def check_file_type(fpath): 333 | parts = fpath.split(".") 334 | ext = "" 335 | if parts: 336 | ext = parts[-1] 337 | return ext 338 | 339 | @staticmethod 340 | def data_iterator(file_pattern, file_type=None, shard_size=0): 341 | fpath_list = sorted(list(glob(file_pattern))) 342 | if fpath_list: 343 | logging.info("Files will be loaded in the following order:\n{}".format( 344 | "\n".join(fpath_list) 345 | )) 346 | else: 347 | logging.warning("No file found given this pattern: {}".format(file_pattern)) 348 | shard_data = [] 349 | for fpath in fpath_list: 350 | logging.info("Start to process {}".format(fpath)) 351 | loaded_data = FileUtils.load_file(fpath, file_type) 352 | if shard_size > 0: 353 | shard_data += loaded_data 354 | while len(shard_data) >= shard_size: 355 | yield shard_data[:shard_size] 356 | shard_data = shard_data[shard_size:] 357 | else: 358 | for d in loaded_data: 359 | yield d 360 | if shard_size > 0 and shard_data: 361 | yield shard_data 362 | 363 | @staticmethod 364 | def load_from_disk(fpath, file_tyle=None): 365 | return FileUtils.load_file(fpath, file_tyle) 366 | 367 | @staticmethod 368 | def load_file(fpath, file_type=None): 369 | if file_type is None: 370 | file_type = FileUtils.check_file_type(fpath) 371 | if file_type == FileType.TXT: 372 | data = [] 373 | with open(fpath, 'r') as fin: 374 | for line in fin: 375 | data.append(line.strip()) 376 | elif file_type == FileType.PT: 377 | data = torch.load(fpath) 378 | elif file_type == FileType.JSON: 379 | with open(fpath, 'r') as fin: 380 | data = json.load(fin) 381 | elif file_type == FileType.TSV: 382 | data = [] 383 | with open(fpath, 'r') as fin: 384 | for line in fin: 385 | data.append(line.strip().split("\t")) 386 | elif file_type == FileType.CSV: 387 | data = dict() 388 | with open(fpath) as fin: 389 | reader = csv.DictReader(fin) 390 | col_names = reader.fieldnames 391 | data = [col_names] 392 | for row in reader: 393 | data.append([row[cn] for cn in col_names]) 394 | elif file_type == FileType.TAR: 395 | from PIL import Image 396 | # NOTE: this is only for the caption & img data 397 | data = [] 398 | with tarfile.open(fpath, "r") as tar: 399 | txt_files = sorted([f.name for f in tar.getmembers() if f.name.endswith('.txt')]) 400 | jpg_files = [FileUtils.handle_file_extension(it, 'jpg', 'change', True) for it in txt_files] 401 | for tf, jf in zip(txt_files, jpg_files): 402 | txt_obj = tar.extractfile(tf) 403 | jpg_obj = tar.extractfile(jf) 404 | txt_data = txt_obj.read().decode("utf-8") 405 | jpg_data = Image.open(BytesIO(jpg_obj.read())) 406 | data.append((txt_data, jpg_data)) 407 | elif file_type == FileType.YAML: 408 | with open(fpath, 'r') as file: 409 | data = yaml.load(file, Loader=yaml.FullLoader) 410 | else: 411 | logging.warning("Unknown loading file type: {}".format(file_type)) 412 | if file_type in LANGUAGE_LIST: 413 | data = [] 414 | logging.info("Treat file with language suffix {} by txt".format(file_type)) 415 | with open(fpath, 'r') as fin: 416 | for line in fin: 417 | data.append(line.strip()) 418 | else: 419 | data = torch.load(fpath) 420 | logging.info("Loaded file from {}".format(fpath)) 421 | return data 422 | 423 | @staticmethod 424 | def save_file(data, fpath, file_type=None): 425 | FileUtils.save_to_disk(data, fpath, file_type) 426 | 427 | @staticmethod 428 | def save_to_disk(data, fpath, file_type=None): 429 | if file_type is None: 430 | file_type = FileUtils.check_file_type(fpath) 431 | 432 | if file_type == FileType.TXT: 433 | with open(fpath, 'w') as fout: 434 | for line in data: 435 | fout.write("{}\n".format(line.strip())) 436 | elif file_type == FileType.PT: 437 | torch.save(data, fpath) 438 | elif file_type == FileType.JSON: 439 | with open(fpath, 'w') as fout: 440 | json.dump(data, fout, indent="\t") 441 | elif file_type == FileType.TSV: 442 | with open(fpath, 'w') as fout: 443 | for it in data: 444 | fout.write("{}\n".format("\t".join(it).strip())) 445 | elif file_type == FileType.YAML: 446 | with open(fpath, 'w') as fout: 447 | yaml.dump(data, fout) 448 | elif file_type == FileType.CSV: 449 | with open(fpath, 'w') as fout: 450 | writer = csv.DictWriter(fout) 451 | for row in data: 452 | writer.writerow(row) 453 | else: 454 | logging.warning("Unknown saving file type: {}".format(file_type)) 455 | if file_type in LANGUAGE_LIST: 456 | logging.info("Treat file with language suffix {} by txt".format(file_type)) 457 | with open(fpath, 'w') as fout: 458 | for line in data: 459 | fout.write("{}\n".format(line.strip())) 460 | else: 461 | torch.save(data, fpath) 462 | logging.info("Save file to {}".format(fpath)) 463 | 464 | @staticmethod 465 | def handle_file_extension(file_path, new_extension, type=FileExtensionType.ADD, only_return_basename=False): 466 | from pathlib import Path 467 | # Ensure the new extension starts with a dot 468 | if not new_extension.startswith("."): 469 | new_extension = f".{new_extension}" 470 | file = Path(file_path) 471 | if type == FileExtensionType.CHANGE: 472 | new_file_name = f"{file.parent}/{file.stem}{new_extension}" 473 | elif type == FileExtensionType.ADD: 474 | new_file_name = f"{file.parent}/{file.stem}{new_extension}{file.suffix}" 475 | if only_return_basename: 476 | return os.path.basename(new_file_name) 477 | else: 478 | return new_file_name -------------------------------------------------------------------------------- /constaints.py: -------------------------------------------------------------------------------- 1 | HEADERS = { 2 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.82 Safari/537.36', 3 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9', 4 | 'Accept-Language': 'en-US,en;q=0.9', 5 | 'Accept-Encoding': 'gzip, deflate, br', 6 | 'Connection': 'keep-alive', 7 | 'Upgrade-Insecure-Requests': '1', 8 | } 9 | ASSISTANT = "Assistant:" 10 | HUMAN = "Human:" -------------------------------------------------------------------------------- /data/examples.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import sys 4 | import json 5 | import torch 6 | import random 7 | import logging 8 | import argparse 9 | import functools 10 | import transformers 11 | import numpy as np 12 | from copy import deepcopy 13 | from PIL import Image 14 | from io import BytesIO 15 | from functools import partial 16 | from transformers import LlamaTokenizer 17 | from torch.utils.data import Dataset 18 | from typing import Dict, Optional, Sequence, List, Tuple, Callable 19 | import webdataset as wds 20 | from transformers import Blip2Processor 21 | 22 | ImageCaptionTemplates = [ 23 | " Provide a short and precise description of the image displayed.", 24 | "Give a brief and accurate depiction of the picture shown. ", 25 | " Present a concise and clear summary of the image seen.", 26 | "Offer a short and straightforward representation of the image provided. ", 27 | " Share a brief, yet informative account of the picture presented.", 28 | "Deliver a concise and comprehensible explanation of the image shown. ", 29 | " Express a succinct and clear narrative of the picture displayed.", 30 | "Convey a brief and unambiguous description of the image presented. ", 31 | " Present a short and coherent account of the picture shown.", 32 | "Provide a compact and lucid representation of the image displayed. ", 33 | " Give a brief and distinct explanation of the picture provided.", 34 | "Share a concise and easy-to-understand description of the image shown. ", 35 | " Offer a clear and to-the-point narrative of the picture presented.", 36 | "Express a brief and well-defined interpretation of the image provided. ", 37 | " Deliver a concise and intelligible account of the image displayed.", 38 | "Convey a short and easily grasped summary of the picture shown. ", 39 | " Present a succinct and clear-cut representation of the image presented.", 40 | "Provide a brief and sharp explanation of the picture provided. ", 41 | " Give a concise and articulate description of the image shown.", 42 | "Share a short and comprehensible narrative of the picture displayed. " 43 | ] 44 | 45 | ImageGenerationTemplates = [ 46 | [ 47 | "Now, I need you to creat an image for me, the main content is: {image_caption}, Thanks!", 48 | "My pleasure! Here you are ." 49 | ], 50 | [ 51 | "Could you create an image for me with this content: {image_caption}? Thanks a lot!", 52 | "Here's the image you requested: .", 53 | ], 54 | [ 55 | "I'd appreciate if you have an image with this description: {image_caption}.", 56 | "Happy to help! Here's your image: .", 57 | ], 58 | [ 59 | "Can you please create an image with the following caption: {image_caption}?", 60 | "Sure thing! Here's the image you asked for: .", 61 | ], 62 | [ 63 | "I was wondering if you could make an image based on this: {image_caption}.", 64 | "Absolutely! Here's the image you're looking for: .", 65 | ], 66 | [ 67 | "I need an image that represents this idea: {image_caption}. Can you help?", 68 | "Of course! Here's the image that fits your description: .", 69 | ], 70 | [ 71 | "Please create an image for me with this concept: {image_caption}.", 72 | "No problem! Here's the image you requested: .", 73 | ], 74 | [ 75 | "It would be great if you could make an image with this theme: {image_caption}.", 76 | "I'm happy to help! Here's the image based on your theme: .", 77 | ], 78 | [ 79 | "I'd love to see an image that captures this: {image_caption}.", 80 | "I've got you covered! Check out this image: .", 81 | ], 82 | [ 83 | "Can you come up with an image that has this content: {image_caption}?", 84 | "Sure! Here's an image with the content you described: .", 85 | ], 86 | [ 87 | "Please provide me with an image that showcases this: {image_caption}.", 88 | "Here's an image that showcases your request: .", 89 | ], 90 | [ 91 | "I'm looking for an image that depicts this: {image_caption}. Can you help?", 92 | "Certainly! Here's an image that depicts your request: .", 93 | ], 94 | [ 95 | "Could you generate an image based on this idea: {image_caption}?", 96 | "Here's an image generated based on your idea: .", 97 | ], 98 | [ 99 | "I'd be grateful if you could create an image with this subject: {image_caption}.", 100 | "I'm happy to assist! Here's the image with the subject you mentioned: .", 101 | ], 102 | [ 103 | "Can you design an image for me that includes this: {image_caption}?", 104 | "I'd be happy to! Here's the image that includes your request: .", 105 | ], 106 | [ 107 | "Please make an image for me that captures this essence: {image_caption}.", 108 | "Here's an image that captures the essence you described: .", 109 | ], 110 | [ 111 | "I need an image that conveys this message: {image_caption}. Can you create one?", 112 | "Sure! Here's an image that conveys the message you described: .", 113 | ], 114 | [ 115 | "I'd like to see an image that represents this concept: {image_caption}.", 116 | "Here's an image that represents the concept you mentioned: .", 117 | ], 118 | [ 119 | "Can you please come up with an image that illustrates this: {image_caption}?", 120 | "Here's an image that illustrates your request: .", 121 | ], 122 | [ 123 | "I'm looking for an image that embodies this idea: {image_caption}. Can you make one?", 124 | "Of course! Here's an image that embodies the idea you described: .", 125 | ], 126 | ] 127 | 128 | 129 | B_INST, E_INST = "[INST]", "[/INST]" 130 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 131 | 132 | SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] 133 | 134 | def preprocess_llama(tokenizer, dialog, training=True): 135 | input_ids = [] 136 | labels = [] 137 | input_images = [] 138 | output_image_list = [] 139 | output_caption_list = [] 140 | 141 | unsafe = any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) 142 | if dialog[0]["role"] == "system": 143 | dialog = [ 144 | { 145 | "role": dialog[1]["role"], 146 | "content": B_SYS 147 | + dialog[0]["content"] 148 | + E_SYS 149 | + dialog[1]["content"], 150 | "image_list": dialog[0]['image_list'] + dialog[1]['image_list'], 151 | } 152 | ] + dialog[2:] 153 | assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( 154 | [msg["role"] == "assistant" for msg in dialog[1::2]] 155 | ), ( 156 | "model only supports 'system', 'user' and 'assistant' roles, " 157 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" 158 | ) 159 | 160 | for prompt, answer in zip(dialog[::2], dialog[1::2]): 161 | prompt_ids = tokenizer.encode( 162 | f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", 163 | add_special_tokens=False 164 | ) 165 | answer_ids = tokenizer.encode( 166 | f"{(answer['content']).strip()} {tokenizer.eos_token}", 167 | add_special_tokens=False 168 | ) 169 | input_ids = input_ids + prompt_ids + answer_ids 170 | labels = labels + [-100] * len(prompt_ids) + answer_ids 171 | input_images = input_images + prompt['image_list'] + answer['image_list'] 172 | output_image_list = output_image_list + answer['image_list'] 173 | output_caption_list = output_caption_list + answer["caption_list"] 174 | 175 | if not training: 176 | assert ( 177 | dialog[-1]["role"] == "user" 178 | ), f"Last message must be from user, got {dialog[-1]['role']}" 179 | prompt_ids = tokenizer.encode( 180 | f"{tokenizer.bos_token}{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", 181 | add_special_tokens=False 182 | ) 183 | input_ids = input_ids + prompt_ids 184 | labels = labels + [-100] * len(prompt_ids) 185 | input_images = input_images + dialog[-1]['image_list'] 186 | 187 | return input_ids, labels, input_images, output_image_list, output_caption_list 188 | 189 | def preprocess_mim(data, tokenizer, image_processor, args, training): 190 | 191 | # setting placeholder for image tokens 192 | # when generating , it means that we need to generate an image 193 | # num_query_tokens is the number of image tokens 194 | image_placeholder = "".join([""] * args.num_query_tokens) 195 | 196 | for turn in data["conversation"]: 197 | turn['content'] = turn['content'].replace("", image_placeholder) 198 | input_ids, labels, input_images, output_image_list, output_caption_list = preprocess_llama(tokenizer, data["conversation"], training) 199 | input_image_index = [idx for idx, value in enumerate(input_ids) if value == args.image_token_id] 200 | 201 | # image processing for both input image and output image; 202 | # for mim dataset, we need to load the image from disk 203 | if len(input_images) > 0 and isinstance(input_images[0], str): 204 | input_images = [ 205 | Image.open( 206 | os.path.join(data["image_dir"] if "image_dir" in data else args.image_dir, image) 207 | ).convert("RGB") 208 | for image in input_images 209 | ] 210 | 211 | input_images = [ 212 | image_processor( 213 | images=image, 214 | return_tensors='pt' 215 | )['pixel_values'][0] 216 | for image in input_images 217 | ] 218 | 219 | if not training: 220 | return { 221 | "input_ids": input_ids, 222 | "input_images": input_images if len(input_images) > 0 else None, 223 | "input_image_index": input_image_index if len(input_image_index) > 0 else None, 224 | } 225 | 226 | idx = 0 227 | seqlen = len(input_ids) 228 | attention_mask = [[1] * seqlen for _ in range(seqlen)] 229 | position_ids = list(range(seqlen)) 230 | cur_image_index = 0 231 | while idx < seqlen: 232 | # caption generation start 233 | if input_ids[idx] == args.image_token_id and labels[idx] != -100: 234 | 235 | caption = output_caption_list[cur_image_index] 236 | caption_ids = [args.caption_start_id] + tokenizer.encode(caption, add_special_tokens=False) + [args.caption_end_id] 237 | cur_image_index += 1 238 | 239 | input_ids.extend(caption_ids) 240 | labels[idx] = args.caption_start_id 241 | labels[idx+1:idx+args.num_query_tokens] = [-100] * (args.num_query_tokens-1) 242 | labels.extend([-100 if idx==0 else token_id for idx, token_id in enumerate(caption_ids)]) 243 | attention_mask.extend([[1] * idx + [0] * (seqlen-idx)] * len(caption_ids)) 244 | position_ids.extend(list(range(idx, idx+len(caption_ids)))) 245 | 246 | idx += args.num_query_tokens 247 | else: 248 | idx += 1 249 | 250 | if len(input_ids) > args.max_input_length: 251 | 252 | # do not truncate in the image tokens 253 | end_position = args.max_input_length 254 | while input_ids[end_position] == args.image_token_id: 255 | end_position += 1 256 | 257 | input_ids = input_ids[:end_position] 258 | labels = labels[:end_position] 259 | attention_mask = [mask[:end_position] for mask in attention_mask[:end_position]] 260 | position_ids = position_ids[:end_position] 261 | 262 | input_images = input_images[:(input_ids.count(args.image_token_id) // args.num_query_tokens)] 263 | input_image_index = [idx for idx, value in enumerate(input_ids) if value == args.image_token_id] 264 | 265 | assert len(input_ids) == len(labels) == len(attention_mask) == len(position_ids) 266 | 267 | return { 268 | "input_ids": input_ids, 269 | "labels": labels, 270 | "attention_mask": attention_mask, 271 | "position_ids": position_ids, 272 | "input_images": input_images, 273 | "input_image_index": input_image_index, 274 | } 275 | 276 | def collate_fn(batch: List[Dict[str, torch.Tensor]], args) -> Dict[str, torch.Tensor]: 277 | 278 | input_ids = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(x["input_ids"]) for x in batch], batch_first=True, padding_value=0) 279 | labels = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(x["labels"]) for x in batch], batch_first=True, padding_value=-100) 280 | position_ids = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(x["position_ids"]) for x in batch], batch_first=True, padding_value=0) 281 | 282 | bsz, seqlen = input_ids.shape 283 | 284 | attention_mask = torch.ones(bsz, seqlen, seqlen, dtype=torch.int64) 285 | for i, x in enumerate(batch): 286 | for j, mask in enumerate(x["attention_mask"]): 287 | attention_mask[i, j, : len(mask)] = torch.LongTensor(mask) 288 | 289 | # since we really do not care the last token 290 | input_image_index = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(x["input_image_index"]) for x in batch], batch_first=True, padding_value=seqlen-1) 291 | if input_image_index.shape[-1] == 0: 292 | input_image_index = torch.nn.utils.rnn.pad_sequence([torch.LongTensor([seqlen-1] * args.num_query_tokens) for x in batch], batch_first=True, padding_value=seqlen-1) 293 | 294 | max_input_images = max(1, max( len(x["input_images"]) for x in batch)) 295 | input_images = torch.zeros(bsz, max_input_images, 3, 224, 224) 296 | for i, x in enumerate(batch): 297 | for j, image in enumerate(x["input_images"]): 298 | input_images[i, j] = image 299 | 300 | return { 301 | "input_ids": input_ids, 302 | "labels": labels, 303 | "attention_mask": attention_mask, 304 | "position_ids": position_ids, 305 | "input_images": input_images, 306 | "input_image_index": input_image_index, 307 | } 308 | 309 | 310 | class MIMDataset(Dataset): 311 | def __init__(self, args: argparse.Namespace, 312 | data_path: str, 313 | tokenizer: transformers.PreTrainedTokenizer, 314 | image_processor: transformers.CLIPImageProcessor): 315 | super(MIMDataset, self).__init__() 316 | list_data_dict = json.load(open(data_path, "r")) 317 | 318 | self.args = args 319 | self.tokenizer = tokenizer 320 | self.image_processor = image_processor 321 | self.list_data_dict = list_data_dict 322 | 323 | def __len__(self): 324 | return len(self.list_data_dict) 325 | 326 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 327 | data = deepcopy(self.list_data_dict[i]) 328 | return preprocess_mim(data, self.tokenizer, self.image_processor, self.args, True) 329 | 330 | 331 | def webdataset_map(example, args, image_processor, tokenizer): 332 | 333 | image = Image.open(BytesIO(example["image"])).convert("RGB") 334 | caption = example["caption"].decode("utf-8") 335 | key = json.loads(example["meta"])["key"] 336 | 337 | template = random.choice(ImageCaptionTemplates) 338 | 339 | data = { 340 | "conversation": [ 341 | { 342 | "role": "user", 343 | "content": template, 344 | "image_list": [image], 345 | "caption_list": [caption], 346 | }, 347 | { 348 | "role": "assistant", 349 | "content": caption, 350 | "image_list": [], 351 | "caption_list": [], 352 | } 353 | ] 354 | } 355 | 356 | return preprocess_mim(data, tokenizer, image_processor, args, True) 357 | 358 | def load_mim_dataset(args: argparse.Namespace, 359 | tokenizer: transformers.PreTrainedTokenizer, 360 | image_processor: transformers.CLIPImageProcessor, 361 | ) -> Tuple[torch.utils.data.Dataset, Callable]: 362 | 363 | dataset = MIMDataset(args, (args.train_data_path if args.train else args.val_data_path), tokenizer, image_processor) 364 | 365 | return dataset, collate_fn 366 | 367 | def load_pair_dataset(args: argparse.Namespace, 368 | tokenizer: transformers.PreTrainedTokenizer, 369 | image_processor: transformers.CLIPImageProcessor, 370 | ) -> Tuple[torch.utils.data.Dataset, Callable]: 371 | """Load dataset.""" 372 | 373 | if args.train: 374 | dataset = wds.WebDataset(args.train_data_path, resampled=True) \ 375 | .shuffle(1000) \ 376 | .rename(image="jpg;png", meta="json", caption="txt") \ 377 | .map(functools.partial(webdataset_map, args=args, tokenizer=tokenizer, image_processor=image_processor)) \ 378 | .with_epoch(args.with_epoch // (args.world_size * args.with_num_works) ) 379 | else: 380 | dataset = wds.WebDataset(args.val_data_path) \ 381 | .rename(image="jpg;png", meta="json", caption="txt") \ 382 | .map(functools.partial(webdataset_map, args=args, tokenizer=tokenizer, image_processor=image_processor)) 383 | 384 | return dataset, collate_fn 385 | 386 | if __name__ == "__main__": 387 | 388 | class Dummy: 389 | def __call__(self, images, return_tensors): 390 | return {'pixel_values': [torch.zeros(3, 224, 224)]} 391 | image_processor = Dummy() 392 | 393 | args = argparse.Namespace(num_query_tokens=32, max_num_images=5, language_model="../../CKPT/meta-llama/Llama-2-7b-chat-hf", train=True) 394 | tokenizer = LlamaTokenizer.from_pretrained(args.language_model) 395 | add_tokens = ["", "", ""] 396 | tokenizer.add_special_tokens(({"additional_special_tokens": add_tokens})) 397 | args.image_token_id = tokenizer.convert_tokens_to_ids("") 398 | args.caption_start_id = tokenizer.convert_tokens_to_ids("") 399 | args.caption_end_id = tokenizer.convert_tokens_to_ids("") 400 | args.num_new_tokens = len(add_tokens) 401 | # check mim dataset 402 | args.train_data_path = "./data/multi_instruct_processed_1000_54281.json" 403 | args.num_prompt_tokens = 8 404 | args.max_input_length = 768 405 | dataset, collate_fn = load_mim_dataset(args, tokenizer, image_processor) 406 | 407 | # data = dataset[1] 408 | # print (data["input_ids"]) 409 | # print (tokenizer.decode(data['input_ids'])) 410 | # print () 411 | # print (tokenizer.decode( [x if x >0 else tokenizer.unk_token_id for x in data['labels'] ]) ) 412 | 413 | # print (data["position_ids"]) 414 | # exit() 415 | 416 | corpus = json.load(open(args.train_data_path, "r")) 417 | print("Number of data:", len(corpus)) 418 | data = dataset[0] 419 | 420 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=8, collate_fn=partial(collate_fn, args=args), shuffle=True) 421 | print("Number of batch:", len(dataloader)) 422 | length = [] 423 | total_cnt = 0 424 | large_cnt = 0 425 | for idx, batch in enumerate(dataloader): 426 | length.append(batch["input_ids"].shape[-1]) 427 | print("Average length: ", np.mean(length)) 428 | 429 | total_cnt += 1 430 | if length[-1] > 600: 431 | large_cnt += 1 432 | print(total_cnt) 433 | print(large_cnt) 434 | 435 | # args.train_data_path = "./data/minigpt4/cc_sbu/cc_sbu_dataset/{00000..01254}.tar" 436 | # args.with_epoch = 1000000 437 | # args.sd_latents_dir = "./data/sd2_prompt_embeds" 438 | # args.max_input_length = 512 439 | # args.num_prompt_tokens = 8 440 | # dataset, collate_fn = load_pair_dataset(args, tokenizer, image_processor) 441 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=4, collate_fn=partial(collate_fn, args=args)) 442 | # total_num = 0 443 | # large_than_256 = 0 444 | # for batch in dataloader: 445 | # total_num += 1 446 | # if batch["input_ids"].shape[1] > 256: 447 | # large_than_256 += 1 448 | # print("Total: {}, Large than 256: {}".format(total_num, large_than_256)) -------------------------------------------------------------------------------- /ds_config_first_stage.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 16, 3 | "gradient_accumulation_steps": 2, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "optimizer": { 7 | "type": "Adam", 8 | "params": { 9 | "lr": 1e-4, 10 | "betas": [0.9, 0.95], 11 | "eps": 1e-8, 12 | "weight_decay": 0.001 13 | } 14 | }, 15 | "zero_optimization": { 16 | "stage": 2, 17 | "allgather_partitions": true, 18 | "allgather_bucket_size": 7e8, 19 | "overlap_comm": true, 20 | "reduce_scatter": true, 21 | "reduce_bucket_size": 7e8, 22 | "contiguous_gradients": true 23 | }, 24 | "fp16": { 25 | "enabled": false 26 | }, 27 | "bf16": { 28 | "enabled": true 29 | }, 30 | "activation_checkpointing": { 31 | "partition_activations": true, 32 | "cpu_checkpointing": true, 33 | "contiguous_memory_optimization": false, 34 | "number_checkpoints": null, 35 | "synchronize_checkpoint_boundary": false, 36 | "profile": false 37 | } 38 | } -------------------------------------------------------------------------------- /ds_config_second_stage.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 8, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "optimizer": { 7 | "type": "Adam", 8 | "params": { 9 | "lr": 1e-5, 10 | "betas": [0.9, 0.95], 11 | "eps": 1e-8, 12 | "weight_decay": 0.001 13 | } 14 | }, 15 | "zero_optimization": { 16 | "stage": 2, 17 | "offload_optimizer": { 18 | "device": "cpu", 19 | "pin_memory": true, 20 | "buffer_count": 4, 21 | "fast_init": false 22 | }, 23 | "allgather_partitions": true, 24 | "allgather_bucket_size": 5e7, 25 | "overlap_comm": true, 26 | "reduce_scatter": true, 27 | "reduce_bucket_size": 5e7, 28 | "contiguous_gradients": true 29 | }, 30 | "fp16": { 31 | "enabled": false 32 | }, 33 | "bf16": { 34 | "enabled": true 35 | }, 36 | "activation_checkpointing": { 37 | "partition_activations": true, 38 | "cpu_checkpointing": true, 39 | "contiguous_memory_optimization": false, 40 | "number_checkpoints": null, 41 | "synchronize_checkpoint_boundary": false, 42 | "profile": false 43 | } 44 | } -------------------------------------------------------------------------------- /flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import math 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | from einops import rearrange 10 | from flash_attn.flash_attn_interface import ( # pip3 install "flash-attn>=2.0" 11 | flash_attn_varlen_qkvpacked_func, 12 | flash_attn_qkvpacked_func, 13 | flash_attn_func 14 | ) 15 | from flash_attn.bert_padding import unpad_input, pad_input 16 | 17 | def LlamaAttention_forward( 18 | self, 19 | hidden_states: torch.Tensor, 20 | attention_mask: Optional[torch.Tensor] = None, 21 | position_ids: Optional[torch.LongTensor] = None, 22 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 23 | output_attentions: bool = False, 24 | use_cache: bool = False, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | bsz, q_len, _ = hidden_states.size() 27 | 28 | if self.pretraining_tp > 1: 29 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp 30 | query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) 31 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 32 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 33 | 34 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] 35 | query_states = torch.cat(query_states, dim=-1) 36 | 37 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] 38 | key_states = torch.cat(key_states, dim=-1) 39 | 40 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] 41 | value_states = torch.cat(value_states, dim=-1) 42 | 43 | else: 44 | query_states = self.q_proj(hidden_states) 45 | key_states = self.k_proj(hidden_states) 46 | value_states = self.v_proj(hidden_states) 47 | 48 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 49 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 50 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 51 | 52 | 53 | kv_seq_len = key_states.shape[-2] 54 | if past_key_value is not None: 55 | kv_seq_len += past_key_value[0].shape[-2] 56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 57 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 58 | 59 | if past_key_value is not None: 60 | # reuse k, v, self_attention 61 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 62 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 63 | 64 | past_key_value = (key_states, value_states) if use_cache else None 65 | 66 | # repeat k/v heads if n_kv_heads < n_heads 67 | key_states = repeat_kv(key_states, self.num_key_value_groups) 68 | value_states = repeat_kv(value_states, self.num_key_value_groups) 69 | 70 | # we need to set causal to True, if q_len == kv_seq_len, which means we are training. Note that causal is also True (because q_len == kv_seq_len) for the first step of generation and it is okay. 71 | causal = q_len == kv_seq_len 72 | 73 | if causal: 74 | # transform the data into the format required by flash attention 75 | qkv = torch.stack( 76 | [query_states, key_states, value_states], dim=2 77 | ) # [bsz, nh, 3, q_len, hd] 78 | 79 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 80 | 81 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 82 | # the attention_mask should be the same as the key_padding_mask 83 | key_padding_mask = attention_mask 84 | nheads = qkv.shape[-2] 85 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 86 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 87 | x_unpad = rearrange( 88 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 89 | ) 90 | output_unpad = flash_attn_varlen_qkvpacked_func( 91 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 92 | ) 93 | attn_output = pad_input( 94 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 95 | ) 96 | else: 97 | attn_output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), dropout_p=0.0, softmax_scale=None, causal=False) 98 | attn_output = rearrange(attn_output, "b s h d -> b s (h d)") 99 | 100 | 101 | if self.pretraining_tp > 1: 102 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) 103 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) 104 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) 105 | else: 106 | attn_output = self.o_proj(attn_output) 107 | 108 | if not output_attentions: 109 | attn_weights = None 110 | 111 | return attn_output, attn_weights, past_key_value 112 | 113 | 114 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 115 | # requires the attention mask to be the same as the key_padding_mask 116 | def _prepare_decoder_attention_mask( 117 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 118 | ): 119 | # [bsz, seq_len] 120 | return attention_mask 121 | 122 | 123 | def replace_llama_attn_with_flash_attn(): 124 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 125 | _prepare_decoder_attention_mask 126 | ) 127 | transformers.models.llama.modeling_llama.LlamaAttention.forward = LlamaAttention_forward 128 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import copy 5 | import json 6 | import uuid 7 | import torch 8 | from PIL import Image 9 | from diffusers import StableDiffusionPipeline, DiffusionPipeline 10 | from transformers import LlamaTokenizer, Blip2Processor 11 | 12 | from model import MIMModel 13 | from data_utils import preprocess_mim 14 | from utils import parse_args, build_model_and_processor 15 | 16 | def postprocesss(s, args): 17 | 18 | s = s.replace("" * args.num_query_tokens, "") 19 | s = s.replace(" ".join([""] * args.num_query_tokens), "") 20 | s = s.replace(" ", " ") 21 | pattern = ".*?" 22 | s = re.sub(pattern, "", s) 23 | return s 24 | 25 | def inference( 26 | data, 27 | args, 28 | model, 29 | tokenizer, 30 | image_processor, 31 | sd_base, 32 | sd_refiner, 33 | generator, 34 | device, 35 | ): 36 | 37 | # process mim input 38 | inputs = preprocess_mim(data, tokenizer, image_processor, args, False) 39 | 40 | input_ids = torch.LongTensor(inputs["input_ids"]).unsqueeze(0).to(device) 41 | input_images = (None if inputs["input_images"] is None else torch.stack(inputs["input_images"], 0).unsqueeze(0).to(device)) 42 | input_image_index = (None if inputs["input_image_index"] is None else torch.LongTensor(inputs["input_image_index"]).unsqueeze(0).to(device)) 43 | 44 | # print("Using none cache generation") 45 | # time1 = time.time() 46 | # outputs = model.none_cache_generation( 47 | # input_ids=input_ids, 48 | # input_images=input_images, 49 | # input_image_index=input_image_index, 50 | # prompt_token_index=prompt_token_index, 51 | # image_start_id=args.image_start_id, 52 | # image_token_id=args.image_token_id, 53 | # tokenizer=tokenizer, 54 | # image_processor=image_processor, 55 | # sd_pipe=sd_pipe, 56 | # generator=generator, 57 | # generate_image=args.generate_image, 58 | # max_output_length=args.max_output_length, 59 | # ) 60 | # time2 = time.time() 61 | # print("Time: ", time2 - time1) 62 | 63 | # time1 = time.time() 64 | # print("Using cache generation") 65 | outputs = model.cache_generation( 66 | input_ids=input_ids, 67 | input_images=input_images, 68 | input_image_index=input_image_index, 69 | caption_start_id=args.caption_start_id, 70 | caption_end_id=args.caption_end_id, 71 | tokenizer=tokenizer, 72 | image_processor=image_processor, 73 | sd_base=sd_base, 74 | sd_refiner=sd_refiner, 75 | generator=generator, 76 | generate_image=args.generate_image, 77 | max_output_length=args.max_output_length, 78 | top_p=args.top_p, 79 | ) 80 | # time2 = time.time() 81 | # print("Time cost: ", time2 - time1) 82 | 83 | total_text = tokenizer.decode(outputs["sequences"][0].tolist(), skip_special_tokens=False) 84 | generation = total_text.split("[/INST]")[-1].strip() 85 | 86 | image_list = [] 87 | for image in outputs["image_list"]: 88 | name = uuid.uuid4() 89 | image.save(os.path.join(args.inference_dir, f"{name}.png")) 90 | image_list.append(f"{name}.png") 91 | 92 | data["conversation"].append( 93 | { 94 | "role": "assistant", 95 | "content": generation, 96 | "image_list": image_list, 97 | "caption_list": outputs["caption_list"], 98 | "tags": "generation" 99 | } 100 | ) 101 | 102 | for turn in data["conversation"]: 103 | turn['content'] = postprocesss(turn['content'], args) 104 | 105 | return data 106 | 107 | class MIMPipeline: 108 | def __init__(self, args, device): 109 | 110 | model, tokenizer, image_processor = build_model_and_processor(args) 111 | 112 | model = model.half().to(device) if args.fp16 else model.to(device) 113 | model.eval() 114 | 115 | if args.model_list: 116 | print("Loading engine list from {}...".format(args.model_list)) 117 | engines = json.load(open(args.model_list)) 118 | 119 | model_list = {} 120 | for m_info in engines: 121 | checkpoint = m_info['path'] 122 | model_id = m_info['id'] 123 | 124 | _model = copy.deepcopy(model) 125 | _model.load_state_dict(torch.load(checkpoint), strict=False) 126 | _model.eval() 127 | 128 | 129 | del _model.language_model 130 | del _model.vision_model 131 | 132 | model_list[model_id] = {} 133 | for module in args.save_modules: 134 | model_list[model_id][module] = getattr(_model, module).half().to(device) if args.fp16 else getattr(_model, module).to(device) 135 | 136 | print(f"Load checkpoint: {checkpoint}") 137 | else: 138 | print("Engine list was not provided...") 139 | model_list = {} 140 | engines = [] 141 | 142 | 143 | sd_base = DiffusionPipeline.from_pretrained( 144 | args.sd_base, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 145 | ).to(device) 146 | 147 | sd_refiner = DiffusionPipeline.from_pretrained( 148 | args.sd_refiner, 149 | text_encoder_2=sd_base.text_encoder_2, 150 | vae=sd_base.vae, 151 | torch_dtype=torch.float16, 152 | use_safetensors=True, 153 | variant="fp16", 154 | ).to(device) 155 | 156 | generator = torch.Generator(device=device).manual_seed(42) 157 | 158 | if not os.path.exists(args.inference_dir): 159 | os.makedirs(args.inference_dir) 160 | 161 | self.args = args 162 | self.device = device 163 | self.model = model 164 | self.model_list = model_list 165 | self.tokenizer = tokenizer 166 | self.image_processor = image_processor 167 | self.sd_base = sd_base 168 | self.sd_refiner = sd_refiner 169 | self.generator = generator 170 | self.engines = engines 171 | 172 | def run(self, data, selection=None): 173 | if selection and len(self.model_list) > 0: 174 | for module in self.args.save_modules: 175 | setattr(self.model, module, self.model_list[selection][module]) 176 | print("Loaded module: %s from model %s" % (module, selection)) 177 | 178 | data = inference(data, self.args, self.model, self.tokenizer, self.image_processor, self.sd_base, self.sd_refiner, self.generator, self.device) 179 | 180 | return data 181 | 182 | def evaluate(args, device): 183 | agent = MIMPipeline(args, device) 184 | 185 | corpus = json.load(open(args.val_data_path, "r")) 186 | 187 | inference_results = [] 188 | for idx, data in enumerate(corpus): 189 | # if len(data["conversation"]) <= 5: 190 | # continue 191 | # data["conversation"] = data["conversation"][:5] 192 | 193 | data["conversation"] = data["conversation"][:-1] 194 | data = agent.run(data, None) 195 | 196 | # save the input image to the inference directory for checking 197 | for turn in data["conversation"][:-1]: 198 | for image_path in turn ["image_list"]: 199 | image = Image.open(os.path.join(data["image_dir"], image_path)).convert("RGB") 200 | image.save(os.path.join(args.inference_dir, image_path)) 201 | 202 | inference_results.append(data) 203 | print(data) 204 | print("=====================================") 205 | json.dump(inference_results, open(os.path.join(args.inference_dir, "inference_results.json"), "w"), indent=4) 206 | 207 | if __name__ == "__main__": 208 | 209 | args = parse_args() 210 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 211 | evaluate(args, device) 212 | -------------------------------------------------------------------------------- /introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SihengLi99/TextBind/c6254f1bcf7f971aeccd875906cf4135a6916df1/introduction.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import math 5 | import torch 6 | import wandb 7 | import shutil 8 | import argparse 9 | import deepspeed 10 | import numpy as np 11 | from tqdm import tqdm 12 | from functools import partial 13 | from torch.utils.data import DataLoader 14 | from transformers import LlamaTokenizer, Blip2Processor, get_cosine_schedule_with_warmup 15 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 16 | 17 | from model import MIMModel 18 | from data_utils import load_mim_dataset, load_pair_dataset 19 | from utils import parse_args, build_model_and_processor 20 | 21 | os.environ["WANDB_API_KEY"] = "85a3c5af1814c40a13d5d9e64783857cf260b506" 22 | os.environ["WANDB_MODE"] = "dryrun" 23 | 24 | def save_checkpoint(args, model_engine, checkpoint_dir): 25 | model_engine.save_checkpoint(checkpoint_dir) 26 | if args.local_rank in [-1, 0]: 27 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) 28 | state_dict = {key: value for key, value in state_dict.items() if key.split(".")[0] in args.save_modules} 29 | 30 | shutil.rmtree(checkpoint_dir) 31 | os.makedirs(checkpoint_dir, exist_ok=True) 32 | torch.save(state_dict, os.path.join(checkpoint_dir, 'pytorch_model.pt')) 33 | torch.distributed.barrier() 34 | 35 | def train(args): 36 | 37 | model, tokenizer, image_processor = build_model_and_processor(args) 38 | 39 | if not args.training_lm: 40 | for param in model.language_model.parameters(): 41 | param.requires_grad = False 42 | 43 | if not args.training_vm: 44 | for param in model.vision_model.parameters(): 45 | param.requires_grad = False 46 | 47 | args.world_size = int(os.environ["WORLD_SIZE"]) 48 | if args.stage == "first": 49 | train_dataset, collate_fn = load_pair_dataset(args=args, tokenizer=tokenizer, image_processor=image_processor) 50 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_micro_batch_size_per_gpu, num_workers=args.with_num_works, collate_fn=partial(collate_fn, args=args), pin_memory=True) 51 | num_training_steps = args.num_epochs * args.with_epoch // ( args.world_size * args.train_micro_batch_size_per_gpu * args.gradient_accumulation_steps) 52 | lr_scheduler = partial(get_cosine_schedule_with_warmup, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps) 53 | model_engine, _, _, _ = deepspeed.initialize(args=args, model=model, lr_scheduler=lr_scheduler) 54 | train_dataloader_len = args.with_epoch // (args.world_size * args.train_micro_batch_size_per_gpu) 55 | else: 56 | train_dataset, collate_fn = load_mim_dataset(args=args, tokenizer=tokenizer, image_processor=image_processor) 57 | num_training_steps = args.num_epochs * len(train_dataset) // (args.world_size * args.train_micro_batch_size_per_gpu * args.gradient_accumulation_steps) 58 | lr_scheduler = partial(get_cosine_schedule_with_warmup, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps) 59 | model_engine, _, train_dataloader, _ = deepspeed.initialize(args=args, model=model, lr_scheduler=lr_scheduler, training_data=train_dataset, collate_fn=partial(collate_fn, args=args)) 60 | train_dataloader_len = len(train_dataloader) 61 | 62 | print("Total training steps: ", num_training_steps) 63 | 64 | # Training loop 65 | model_engine.train() 66 | current_step = 0 67 | total_instance = 0 68 | for epoch in tqdm(range(args.num_epochs), desc='Epoch', unit='epoch'): 69 | step_progress = tqdm(enumerate(train_dataloader), desc='Step', leave=False, unit='step', total=train_dataloader_len) 70 | for step, batch in step_progress: 71 | batch = {key: value.cuda() if torch.is_tensor(value) else value for key, value in batch.items()} 72 | 73 | bsz = batch["input_ids"].shape[0] 74 | total_instance += bsz 75 | 76 | # Compute Loss 77 | loss = model_engine(**batch) 78 | model_engine.backward(loss) 79 | model_engine.step() 80 | 81 | current_step += 1 82 | step_progress.set_description(f'Epoch {epoch} Step {current_step} - Loss: {loss:.4f}') 83 | wandb.log({"loss": loss}) 84 | 85 | if current_step % args.save_per_steps == 0: 86 | checkpoint_dir = os.path.join(args.save_dir, f'checkpoint_step{current_step // args.gradient_accumulation_steps}') 87 | save_checkpoint(args, model_engine, checkpoint_dir) 88 | print (f"{epoch+1} finished: {total_instance} instances") 89 | checkpoint_dir = os.path.join(args.save_dir, f'checkpoint_epoch{epoch+1}_step{current_step // args.gradient_accumulation_steps}') 90 | save_checkpoint(args, model_engine, checkpoint_dir) 91 | 92 | @torch.no_grad() 93 | def validation(args): 94 | 95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 96 | 97 | model, tokenizer, image_processor = build_model_and_processor(args) 98 | model = model.half().to(device) if args.fp16 else model.to(device) 99 | model.eval() 100 | 101 | if args.stage == "first": 102 | val_dataset, collate_fn = load_pair_dataset(args=args, tokenizer=tokenizer, image_processor=image_processor) 103 | val_dataloader = DataLoader(val_dataset, batch_size=args.train_micro_batch_size_per_gpu, num_workers=args.with_num_works, collate_fn=partial(collate_fn, args=args), pin_memory=True, shuffle=False) 104 | val_dataloader_len = args.with_epoch // args.train_micro_batch_size_per_gpu 105 | else: 106 | val_dataset, collate_fn = load_mim_dataset(args=args, tokenizer=tokenizer, image_processor=image_processor) 107 | val_dataloader = DataLoader(val_dataset, batch_size=args.train_micro_batch_size_per_gpu, num_workers=args.with_num_works, collate_fn=partial(collate_fn, args=args), pin_memory=True, shuffle=False) 108 | val_dataloader_len = len(val_dataloader) 109 | 110 | loss = [] 111 | total_instance = 0 112 | step_progress = tqdm(enumerate(val_dataloader), desc='Step', leave=False, unit='step', total=val_dataloader_len) 113 | for step, batch in step_progress: 114 | batch = {key: value.cuda() if torch.is_tensor(value) else value for key, value in batch.items()} 115 | 116 | bsz = batch["input_ids"].shape[0] 117 | total_instance += bsz 118 | 119 | # Compute Loss 120 | _loss = model(**batch) 121 | loss.append(_loss.item()) 122 | 123 | step_progress.set_description(f'Step {step} - Loss: {_loss:.4f}') 124 | 125 | print (f"Validation finished: {total_instance} instances, {len(loss)} batches") 126 | print (f"Validation loss: {np.mean(loss)}") 127 | 128 | def main(): 129 | args = parse_args() 130 | 131 | config = vars(args) 132 | deepspeed_config = json.load(open(args.deepspeed_config)) 133 | config.update(deepspeed_config) 134 | for key, value in deepspeed_config.items(): 135 | setattr(args, key, value) 136 | 137 | args.save_dir = os.path.join("checkpoint", args.project_name) 138 | if args.local_rank in [0, -1] and args.train and not os.path.exists(args.save_dir): 139 | os.makedirs(args.save_dir) 140 | 141 | if args.train: 142 | wandb.init( 143 | project=args.project_name, 144 | group="ddp", 145 | config=config, 146 | dir=args.save_dir 147 | ) 148 | 149 | # from flash_attention import replace_llama_attn_with_flash_attn 150 | # replace_llama_attn_with_flash_attn() 151 | 152 | train(args) 153 | wandb.finish() 154 | else: 155 | validation(args) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /mmbench_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import torch 4 | import base64 5 | import random 6 | import pandas as pd 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset 10 | 11 | from utils import parse_args 12 | from inference import MIMPipeline 13 | 14 | def decode_base64_to_image(base64_string): 15 | image_data = base64.b64decode(base64_string) 16 | image = Image.open(io.BytesIO(image_data)) 17 | return image 18 | 19 | class MMBenchDataset(Dataset): 20 | def __init__(self, 21 | data_file, 22 | sys_prompt='There are several options:'): 23 | self.df = pd.read_csv(data_file, sep='\t') 24 | self.sys_prompt = sys_prompt 25 | 26 | def __len__(self): 27 | return len(self.df) 28 | 29 | def __getitem__(self, idx): 30 | index = self.df.iloc[idx]['index'] 31 | image = self.df.iloc[idx]['image'] 32 | image = decode_base64_to_image(image) 33 | question = self.df.iloc[idx]['question'] 34 | answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None 35 | catetory = self.df.iloc[idx]['category'] 36 | l2_catetory = self.df.iloc[idx]['l2-category'] 37 | 38 | option_candidate = ['A', 'B', 'C', 'D', 'E'] 39 | options = { 40 | cand: self.load_from_df(idx, cand) 41 | for cand in option_candidate 42 | if self.load_from_df(idx, cand) is not None 43 | } 44 | options_prompt = f'{self.sys_prompt}\n' 45 | for key, item in options.items(): 46 | options_prompt += f'{key}. {item}\n' 47 | 48 | hint = self.load_from_df(idx, 'hint') 49 | data = { 50 | 'img': image, 51 | 'question': question, 52 | 'answer': answer, 53 | 'options': options_prompt, 54 | 'category': catetory, 55 | 'l2-category': l2_catetory, 56 | 'options_dict': options, 57 | 'index': index, 58 | 'context': hint, 59 | } 60 | return data 61 | def load_from_df(self, idx, key): 62 | if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): 63 | return self.df.iloc[idx][key] 64 | else: 65 | return None 66 | 67 | def evaluate(args, agent): 68 | 69 | dataset = MMBenchDataset("./data/mmbench_test_20230712.tsv") 70 | results = [] 71 | for data in tqdm(dataset): 72 | 73 | if data['context'] is not None: 74 | prompt = data['context'] + ' ' + data['question'] + ' ' + data['options'] 75 | else: 76 | prompt = data['question'] + ' ' + data['options'] 77 | 78 | inference_data = { 79 | "conversation": [ 80 | { 81 | "role": "user", 82 | "content": f" {prompt}", 83 | "image_list": [data["img"]], 84 | "caption_list": [], 85 | } 86 | ], 87 | } 88 | 89 | inference_data = agent.run(inference_data) 90 | prediction = inference_data["conversation"][-1]["content"].replace("", "").strip() 91 | # print(f"prediction: ", prediction) 92 | 93 | options = data["options"].split("\n") 94 | results.append({ 95 | "question": data["question"], 96 | "A": options[0] if len(options) > 0 else "", 97 | "B": options[1] if len(options) > 1 else "", 98 | "C": options[2] if len(options) > 2 else "", 99 | "D": options[3] if len(options) > 3 else "", 100 | "prediction": prediction, 101 | "category": data["category"], 102 | "l2-category": data["l2-category"], 103 | "index": data["index"], 104 | }) 105 | 106 | if not os.path.exists(args.inference_dir): 107 | os.makedirs(args.inference_dir) 108 | 109 | df = pd.DataFrame(results) 110 | df.to_excel(os.path.join(args.inference_dir, "submission.xlsx"), index=False) 111 | 112 | if __name__ == "__main__": 113 | 114 | args = parse_args() 115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | agent = MIMPipeline(args, device) 117 | 118 | evaluate(args, agent) 119 | -------------------------------------------------------------------------------- /mme_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import torch 5 | import wandb 6 | import argparse 7 | import deepspeed 8 | import numpy as np 9 | from tqdm import tqdm 10 | from PIL import Image 11 | from accelerate import infer_auto_device_map, dispatch_model 12 | from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline 13 | from transformers import AutoTokenizer, CLIPImageProcessor, CLIPModel, CLIPProcessor, LlamaTokenizer, Blip2Processor 14 | 15 | from utils import parse_args 16 | from inference import MIMPipeline 17 | 18 | def load_mme(): 19 | 20 | total_corpus = {} 21 | total_num = 0 22 | for dir_name in os.listdir("./MME_Benchmark_release_version"): 23 | if dir_name in ["readme.txt", "eval_tool"]: 24 | continue 25 | 26 | image_dir = os.path.join("./MME_Benchmark_release_version", dir_name) 27 | if dir_name == "artwork": 28 | image_dir = os.path.join(image_dir, "./images/toy_dataset") 29 | elif os.path.exists(os.path.join(image_dir, "images")): 30 | image_dir = os.path.join(image_dir, "images") 31 | 32 | question_dir = os.path.join("./MME_Benchmark_release_version", dir_name) 33 | if os.path.exists(os.path.join(question_dir, "questions_answers_YN")): 34 | question_dir = os.path.join(question_dir, "questions_answers_YN") 35 | 36 | corpus = [] 37 | for file in os.listdir(question_dir): 38 | if ".txt" in file: 39 | with open(os.path.join(question_dir, file), "r") as f: 40 | questions_answers = [line.split("\t") for line in f.readlines()] 41 | for question, answer in questions_answers: 42 | 43 | corpus.append({ 44 | "image_id": file.split(".txt")[0], 45 | "image_dir": image_dir, 46 | "image": file.replace(".txt", ".jpg") if os.path.exists(os.path.join(image_dir, file.replace(".txt", ".jpg"))) else file.replace(".txt", ".png"), 47 | "question": question.strip(), 48 | "answer": answer.strip() 49 | }) 50 | 51 | print(f"dir_name: {dir_name}, corpus length: {len(corpus)}") 52 | total_corpus[dir_name] = corpus 53 | total_num += len(corpus) 54 | 55 | print(f"total corpus length: {len(total_corpus)}") 56 | print(f"total num: {total_num}") 57 | 58 | return total_corpus 59 | 60 | def evaluate(args, agent): 61 | 62 | corpus = load_mme() 63 | for id1, subject_name in enumerate(corpus): 64 | print(f"subject_name: {subject_name}") 65 | for id2, data in enumerate(tqdm(corpus[subject_name])): 66 | image = data["image"] 67 | text = data["question"] 68 | args.image_dir = data["image_dir"] 69 | 70 | inference_data = { 71 | "conversation": [ 72 | { 73 | "role": "user", 74 | "content": f" {text}", 75 | "image_list": [image], 76 | "caption_list": [], 77 | } 78 | ], 79 | "image_dir": data["image_dir"] 80 | } 81 | 82 | inference_data = agent.run(inference_data) 83 | data["prediction"] = inference_data["conversation"][-1]["content"] 84 | print(f"generation: ", data["prediction"]) 85 | 86 | if not os.path.exists(args.inference_dir): 87 | os.makedirs(args.inference_dir) 88 | 89 | for subject_name in corpus: 90 | 91 | with open(os.path.join("./MME_Benchmark_release_version/eval_tool/Your_Results", f"{subject_name}.txt"), "r") as f: 92 | groundtruth = f.readlines() 93 | 94 | predictions = corpus[subject_name] 95 | with open(os.path.join(args.inference_dir, f"{subject_name}.txt"), "w") as f: 96 | for line in groundtruth: 97 | line = line.strip() 98 | image_name, question, groundtruth_answer = line.split("\t") 99 | 100 | for data in predictions: 101 | if data["image"] == image_name and data["question"] == question: 102 | prediction_answer = data["prediction"] 103 | 104 | f.write(f"{image_name}\t{question}\t{groundtruth_answer}\t{prediction_answer}\n") 105 | 106 | if __name__ == "__main__": 107 | 108 | args = parse_args() 109 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 110 | agent = MIMPipeline(args, device) 111 | 112 | evaluate(args, agent) 113 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from typing import Dict, Optional, Sequence, List, Tuple, Union, Any 7 | from torch.nn import CrossEntropyLoss 8 | from transformers import CLIPVisionModel, Blip2QFormerConfig, Blip2Config, Blip2VisionModel, Blip2QFormerModel, LlamaForCausalLM 9 | from transformers.modeling_outputs import ModelOutput 10 | from transformers.models.llama.modeling_llama import _make_causal_mask 11 | from stable_diffusion import decode_with_sdxl 12 | from einops import rearrange 13 | 14 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 15 | """ 16 | Expands attention_mask from `[bsz, tgt_seq_len, src_seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 17 | """ 18 | bsz, tgt_len, src_len = mask.size() 19 | 20 | expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 21 | 22 | inverted_mask = 1.0 - expanded_mask 23 | 24 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 25 | 26 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 27 | # create causal mask 28 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 29 | combined_attention_mask = None 30 | if input_shape[-1] > 1: 31 | combined_attention_mask = _make_causal_mask( 32 | input_shape, 33 | inputs_embeds.dtype, 34 | device=inputs_embeds.device, 35 | past_key_values_length=past_key_values_length, 36 | ) 37 | 38 | if attention_mask is not None: 39 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 40 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 41 | inputs_embeds.device 42 | ) 43 | combined_attention_mask = ( 44 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 45 | ) 46 | 47 | return combined_attention_mask 48 | 49 | class MIMOutputWithPast(ModelOutput): 50 | 51 | logits: torch.FloatTensor = None 52 | past_key_values: Tuple[Tuple[torch.FloatTensor]] = None 53 | 54 | class MIMModel(nn.Module): 55 | 56 | def __init__(self, args): 57 | super(MIMModel, self).__init__() 58 | 59 | self.args = args 60 | 61 | self.language_model = LlamaForCausalLM.from_pretrained(args.language_model, low_cpu_mem_usage=True) 62 | vision_model_class = Blip2VisionModel if "blip2" in args.vision_model else CLIPVisionModel 63 | self.vision_model = vision_model_class.from_pretrained(args.vision_model, low_cpu_mem_usage=True) 64 | 65 | # vision --> text qformer 66 | qformer_config = Blip2QFormerConfig( 67 | hidden_size=args.qformer_hidden_size, 68 | intermediate_size=args.qformer_intermediate_size, 69 | num_hidden_layers=args.num_qformer_hidden_layers, 70 | num_attention_heads=args.num_qformer_attention_heads, 71 | encoder_hidden_size=self.vision_model.config.hidden_size, 72 | ) 73 | self.query_tokens = nn.Parameter(torch.zeros(1, args.num_query_tokens, args.qformer_hidden_size)) 74 | self.qformer = Blip2QFormerModel(qformer_config) 75 | self.qformer_projection = nn.Linear(qformer_config.hidden_size, self.language_model.config.hidden_size) 76 | 77 | # support multiple caption generation 78 | if args.train or args.compute_loss: 79 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 80 | 81 | def get_token_embeds( 82 | self, 83 | input_ids, 84 | ): 85 | token_embeds = self.language_model.get_input_embeddings()(input_ids) 86 | 87 | return token_embeds 88 | 89 | def get_image_embeds( 90 | self, 91 | input_images, 92 | dtype=None, 93 | ): 94 | # image_embeds: num_input_images x 3 x 224 x 224 95 | num_input_images = input_images.shape[0] 96 | image_embeds = self.vision_model( 97 | pixel_values=input_images if dtype is None else input_images.to(dtype), 98 | ).last_hidden_state 99 | 100 | # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention 101 | query_tokens = self.query_tokens.expand(num_input_images, -1, -1) 102 | image_embeds = self.qformer( 103 | query_embeds=query_tokens, 104 | encoder_hidden_states=image_embeds, 105 | ).last_hidden_state 106 | # [num_input_images, num_query_tokens, language_model_hidden_size] 107 | image_embeds = self.qformer_projection(image_embeds) 108 | 109 | return image_embeds 110 | 111 | def get_inputs_embeds( 112 | self, 113 | input_ids, 114 | input_images=None, 115 | input_image_index=None, 116 | ): 117 | bsz = input_ids.shape[0] 118 | 119 | text_embeds = self.get_token_embeds(input_ids) 120 | 121 | if input_images is None: 122 | inputs_embeds = text_embeds 123 | 124 | else: 125 | # step 1: forward the images through the vision encoder 126 | num_input_image = input_images.shape[1] 127 | input_images = rearrange(input_images, "bsz n a b c -> (bsz n) a b c") 128 | 129 | image_embeds = self.get_image_embeds(input_images, text_embeds.dtype) 130 | 131 | # [bsz*num_input_image, ...] --> [bsz, num_input_image, ...] 132 | image_embeds = rearrange(image_embeds, "(bsz n) a b -> bsz n a b", n=num_input_image) 133 | 134 | # update image embeds 135 | inputs_embeds = torch.scatter( 136 | input=text_embeds, dim=1, 137 | index=input_image_index.unsqueeze(-1).expand(-1, -1, text_embeds.shape[-1]), 138 | src=image_embeds.view(bsz, -1, image_embeds.shape[-1])) 139 | 140 | return inputs_embeds 141 | 142 | def forward( 143 | self, 144 | input_ids=None, 145 | attention_mask=None, 146 | position_ids=None, 147 | input_images=None, 148 | input_image_index=None, 149 | inputs_embeds=None, 150 | labels=None, 151 | past_key_values=None, 152 | use_cache=False, 153 | ): 154 | 155 | if inputs_embeds is None: 156 | inputs_embeds = self.get_inputs_embeds( 157 | input_ids=input_ids, 158 | input_images=input_images, 159 | input_image_index=input_image_index, 160 | ) 161 | 162 | # step 4: forward the input through the language model 163 | outputs = self.language_model( 164 | inputs_embeds=inputs_embeds, 165 | attention_mask=attention_mask, 166 | position_ids=position_ids, 167 | labels=labels, 168 | past_key_values=past_key_values, 169 | output_hidden_states=True, 170 | use_cache=use_cache, 171 | ) 172 | 173 | if not self.args.train and not self.args.compute_loss: 174 | return MIMOutputWithPast( 175 | logits=outputs.logits, 176 | past_key_values=outputs.past_key_values, 177 | ) 178 | 179 | return outputs.loss 180 | 181 | 182 | def prepare_inputs_for_generation( 183 | self, 184 | input_ids=None, 185 | input_images=None, 186 | input_image_index=None, 187 | past_key_values=None, 188 | ): 189 | if past_key_values: 190 | if input_images is not None: 191 | inputs_embeds = self.get_image_embeds(input_images=input_images) 192 | else: 193 | inputs_embeds = self.get_token_embeds(input_ids=input_ids[:, -1:]) 194 | else: 195 | inputs_embeds = self.get_inputs_embeds( 196 | input_ids=input_ids, 197 | input_images=input_images, 198 | input_image_index=input_image_index, 199 | ) 200 | 201 | return { 202 | "inputs_embeds": inputs_embeds, 203 | "past_key_values": past_key_values, 204 | } 205 | 206 | def _update_model_kwargs_for_generation( 207 | self, 208 | outputs: ModelOutput, 209 | model_kwargs: Dict[str, Any], 210 | ) -> Dict[str, Any]: 211 | # update past_key_values 212 | model_kwargs["past_key_values"] = outputs.past_key_values 213 | return model_kwargs 214 | 215 | @torch.no_grad() 216 | def cache_generation( 217 | self, 218 | input_ids, 219 | tokenizer, 220 | image_processor, 221 | input_images=None, 222 | input_image_index=None, 223 | caption_start_id=None, 224 | caption_end_id=None, 225 | sd_base=None, 226 | sd_refiner=None, 227 | generator=None, 228 | generate_image=False, 229 | max_output_length=256, 230 | top_p=None, 231 | temperature=1.0, 232 | **model_kwargs, 233 | ): 234 | 235 | image_list = [] 236 | caption_list = [] 237 | for _ in range(max_output_length): 238 | 239 | model_inputs = self.prepare_inputs_for_generation(input_ids, input_images, input_image_index, **model_kwargs) 240 | outputs = self(**model_inputs, use_cache=True) 241 | next_token_scores = outputs.logits[:, -1, :] 242 | 243 | if top_p is not None: 244 | # top-p sampling 245 | next_token_scores = next_token_scores / temperature 246 | next_token_scores = topp_logits_filter(next_token_scores, top_p) 247 | probs = torch.nn.functional.softmax(next_token_scores, dim=-1) 248 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 249 | else: 250 | next_tokens = next_token_scores.argmax(-1) 251 | 252 | # print("Next token:", tokenizer.decode(next_tokens.item())) 253 | # print("Next token score:", next_token_scores[0, next_tokens.item()].item()) 254 | 255 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 256 | 257 | model_kwargs = self._update_model_kwargs_for_generation( 258 | outputs, model_kwargs 259 | ) 260 | 261 | if next_tokens.item() == tokenizer.eos_token_id: 262 | break 263 | 264 | if next_tokens.item() == caption_start_id and generate_image: 265 | cached_past_key_values = outputs.past_key_values 266 | caption_start_idx = input_ids.shape[-1] 267 | 268 | if next_tokens.item() == caption_end_id and generate_image: 269 | caption = tokenizer.decode(input_ids[0][caption_start_idx:-1]).strip() 270 | image = decode_with_sdxl(caption, sd_base, sd_refiner, generator) 271 | image_list.append(image) 272 | caption_list.append(caption) 273 | 274 | # [1, 3, 224, 224] 275 | input_images = image_processor(images=image, return_tensors='pt')['pixel_values'] 276 | input_images = input_images.to(input_ids.device).to(torch.half if self.args.fp16 else torch.float) 277 | 278 | # update past_key_values 279 | model_kwargs["past_key_values"] = cached_past_key_values 280 | else: 281 | input_images = None 282 | 283 | return { 284 | "sequences": input_ids, 285 | "image_list": image_list, 286 | "caption_list": caption_list 287 | } 288 | 289 | def topp_logits_filter(scores, p): 290 | sorted_logits, sorted_indices = torch.sort(scores, descending=False) 291 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 292 | 293 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 294 | sorted_indices_to_remove = cumulative_probs <= (1 - p) 295 | # Keep at least 1 token 296 | sorted_indices_to_remove[..., -1 :] = 0 297 | 298 | # scatter sorted tensors to original indexing 299 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 300 | scores = scores.masked_fill(indices_to_remove, -float("Inf")) 301 | return scores 302 | -------------------------------------------------------------------------------- /process_instruction_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import copy 5 | from datasets import load_dataset 6 | from io import BytesIO 7 | from base64 import b64decode 8 | from PIL import Image 9 | 10 | 11 | def process_m3it(): 12 | proxies = { 13 | "http": "http://127.0.0.1:8118", 14 | "https": "http://127.0.0.1:8118", 15 | } 16 | ds_name = "coco" # change the dataset name here 17 | dataset = load_dataset("MMInstruction/M3IT", ds_name, cache_dir="/apdcephfs/share_733425/jcykcai/sihengli/Dataset") 18 | 19 | # for train_instance in dataset['train']: 20 | # instruction = train_instance["instruction"] # str 21 | # inputs = train_instance["inputs"] # str 22 | # outputs = train_instance["outputs"] # str 23 | # image_base64_str_list = train_instance["image_base64_str"] # str (base64) 24 | # image_0 = Image.open(BytesIO(b64decode(image_base64_str_list[0]))) 25 | 26 | def process_llava(): 27 | 28 | corpus = json.load(open("./data/llava_instruct_150k.json")) 29 | new_corpus = [] 30 | 31 | for id1, data in enumerate(corpus): 32 | new_data = {"conversation": [], "image_dir": "./data/train2017"} 33 | for id2, turn in enumerate(data["conversations"]): 34 | new_turn = {} 35 | if turn["from"] == "human": 36 | new_turn["role"] = "user" 37 | else: 38 | new_turn["role"] = "assistant" 39 | new_turn["content"] = turn["value"] 40 | if "" in turn["value"]: 41 | new_turn["image_list"] = [data["image"]] 42 | else: 43 | new_turn["image_list"] = [] 44 | 45 | new_turn["caption_list"] = [] 46 | 47 | new_data["conversation"].append(new_turn) 48 | new_corpus.append(new_data) 49 | 50 | print("Number of data: {}".format(len(new_corpus))) 51 | json.dump(new_corpus, open("./data/llava_processed.json", "w"), indent=4) 52 | 53 | 54 | def process_alpaca_gpt4(): 55 | 56 | PROMPT_DICT = { 57 | "prompt_input": ( 58 | "Below is an instruction that describes a task, paired with an input that provides further context. " 59 | "Write a response that appropriately completes the request.\n\n" 60 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 61 | ), 62 | "prompt_no_input": ( 63 | "Below is an instruction that describes a task. " 64 | "Write a response that appropriately completes the request.\n\n" 65 | "### Instruction:\n{instruction}\n\n### Response:" 66 | ), 67 | } 68 | corpus = json.load(open("./data/alpaca_gpt4_data.json")) 69 | new_corpus = [] 70 | 71 | for idx, data in enumerate(corpus): 72 | 73 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 74 | 75 | user_content = prompt_input.format_map(data) if data.get("input", "") != "" else prompt_no_input.format_map(data) 76 | user_turn = {"role": "user", "content": user_content, "image_list": [], "caption_list": []} 77 | 78 | assistant_turn = {"role": "assistant", "content": data["output"], "image_list": [], "caption_list": []} 79 | 80 | new_data = {"conversation": [user_turn, assistant_turn]} 81 | new_corpus.append(new_data) 82 | 83 | print("Number of data: {}".format(len(new_corpus))) 84 | json.dump(new_corpus, open("./data/alpaca_gpt4_processed.json", "w"), indent=4) 85 | 86 | 87 | def process_minigpt4(): 88 | 89 | corpus = json.load(open("./data/cc_sbu_align/filter_cap.json"))["annotations"] 90 | new_corpus = [] 91 | for idx, data in enumerate(corpus): 92 | 93 | user_content = "\nDescribe this image in detail. Give as many details as possible. Say everything you see." 94 | user_turn = {"role": "user", "content": user_content, "image_list": [data["image_id"] + ".jpg"], "caption_list": []} 95 | 96 | assistant_turn = {"role": "assistant", "content": data["caption"], "image_list": [], "caption_list": []} 97 | 98 | new_data = {"conversation": [user_turn, assistant_turn], "image_dir": "./data/minigpt4_images"} 99 | new_corpus.append(new_data) 100 | 101 | print("Number of data: {}".format(len(new_corpus))) 102 | json.dump(new_corpus, open("./data/minigpt4_processed.json", "w"), indent=4) 103 | 104 | 105 | def process_platypus(): 106 | 107 | PROMPT_DICT = { 108 | "prompt_input": ( 109 | "Below is an instruction that describes a task, paired with an input that provides further context. " 110 | "Write a response that appropriately completes the request.\n\n" 111 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 112 | ), 113 | "prompt_no_input": ( 114 | "Below is an instruction that describes a task. " 115 | "Write a response that appropriately completes the request.\n\n" 116 | "### Instruction:\n{instruction}\n\n### Response:" 117 | ), 118 | } 119 | 120 | dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] 121 | new_corpus = [] 122 | 123 | for idx, data in enumerate(dataset): 124 | 125 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 126 | 127 | user_content = prompt_input.format_map(data) if data.get("input", "") != "" else prompt_no_input.format_map(data) 128 | user_turn = {"role": "user", "content": user_content.replace("", ""), "image_list": [], "caption_list": []} 129 | 130 | assistant_turn = {"role": "assistant", "content": data["output"].replace("", ""), "image_list": [], "caption_list": []} 131 | 132 | new_data = {"conversation": [user_turn, assistant_turn]} 133 | new_corpus.append(new_data) 134 | 135 | print("Number of data: {}".format(len(new_corpus))) 136 | json.dump(new_corpus, open("./data/platypus_processed.json", "w"), indent=4) 137 | 138 | 139 | def process_multi_instruct(): 140 | 141 | corpus = [] 142 | per_dataset_num = 5000 143 | for path in [f"./data/train_{per_dataset_num}.jsonl", f"./data/test_{per_dataset_num}.jsonl"]: 144 | with open(path, 'r') as f: 145 | for line in f: 146 | json_obj = json.loads(line) 147 | corpus.append(json_obj) 148 | 149 | new_corpus = [] 150 | for idx, data in enumerate(corpus): 151 | 152 | user_turn = {"role": "user", "content": "\n"+data["prompt"], "image_list": [data["image_path"]], "caption_list": []} 153 | 154 | assistant_turn = {"role": "assistant", "content": data["target"], "image_list": [], "caption_list": []} 155 | 156 | new_data = {"conversation": [user_turn, assistant_turn], "image_dir": "./data"} 157 | new_corpus.append(new_data) 158 | 159 | print("Number of data: {}".format(len(new_corpus))) 160 | json.dump(new_corpus, open(f"./data/multi_instruct_processed_{per_dataset_num}_{len(new_corpus)}.json", "w"), indent=4) 161 | 162 | def process_shikra(): 163 | 164 | corpus = [] 165 | for path in ["./data/GPT4GEN_BoxCoT_test.jsonl", "./data/GPT4GEN_BoxCoT_train.jsonl", "./data/GPT4GEN_RD_BoxCoT_train.jsonl"]: 166 | with open(path, 'r') as f: 167 | for line in f: 168 | json_obj = json.loads(line) 169 | corpus.append(json_obj) 170 | 171 | print(len(corpus)) 172 | 173 | new_corpus = [] 174 | for idx, data in enumerate(corpus): 175 | 176 | user_turn = {"role": "user", "content": "\n"+data["question"].replace("", "").replace("", "").replace(" ", "").replace(" ,", ",").replace(" .", "."), "image_list": [data["img_path"]], "caption_list": []} 177 | 178 | assistant_turn = {"role": "assistant", "content": data["cot_with_ans"].replace("", "").replace("", "").replace(" ", "").replace(" ,", ",").replace(" .", "."), "image_list": [], "caption_list": []} 179 | 180 | new_data = {"conversation": [user_turn, assistant_turn], "image_dir": "./data/flickr30k-images"} 181 | new_corpus.append(new_data) 182 | 183 | print("Number of data: {}".format(len(new_corpus))) 184 | json.dump(new_corpus, open("./data/shikra_processed.json", "w"), indent=4) 185 | 186 | def process_mim(): 187 | 188 | corpus = json.load(open("./data/train.s2.v4.clean.reform.train.json", "r")) 189 | for idx, data in enumerate(corpus): 190 | data["image_dir"] = "./data/mim_images" 191 | data["conversation"] = data["conversation"][:-1] 192 | 193 | for turn in data["conversation"]: 194 | for idx, caption in enumerate(turn["caption_list"]): 195 | turn["caption_list"][idx] = re.sub(r'|<\/img\d+>', '', caption).strip() 196 | 197 | print("Number of data: {}".format(len(corpus))) 198 | json.dump(corpus, open("./data/mim_processed.json", "w"), indent=4) 199 | 200 | 201 | def blender(): 202 | 203 | num_llava = 0 204 | multi_instruct_data = "./data/multi_instruct_processed_2000_105745.json" 205 | 206 | corpus1 = [] 207 | corpus2 = [] 208 | corpus3 = [] 209 | corpus4 = [] 210 | corpus5 = [] 211 | corpus6 = [] 212 | 213 | corpus1 = json.load(open("./data/platypus_processed.json", "r")) 214 | corpus2 = json.load(open("./data/minigpt4_processed.json", "r")) 215 | corpus3 = json.load(open("./data/shikra_processed.json", "r")) 216 | corpus4 = json.load(open("./data/llava_processed.json", "r")) 217 | corpus5 = json.load(open(multi_instruct_data, "r")) 218 | corpus6 = json.load(open("./data/mim_processed.json", "r")) 219 | print("platypus: ", len(corpus1)) 220 | print("minigpt4: ", len(corpus2)) 221 | print("shikra: ", len(corpus3)) 222 | print("llava: ", len(corpus4)) 223 | print("multi_instruct: ", len(corpus5)) 224 | print("mim: ", len(corpus6)) 225 | 226 | final_data = f"./data/platypus_{len(corpus1)}_minigpt4_{len(corpus2)}_shikra_{len(corpus3)}_llava_{len(corpus4)}_multi_instruct_{len(corpus5)}_mim_{len(corpus6)}.json" 227 | 228 | corpus = corpus1 + corpus2 + corpus3 + corpus4 + corpus5 + corpus6 229 | print("final: ", len(corpus)) 230 | json.dump(corpus, open(final_data, "w"), indent=4) 231 | 232 | 233 | if __name__ == "__main__": 234 | 235 | # process_llava() 236 | # process_alpaca_gpt4() 237 | # process_minigpt4() 238 | # process_platypus() 239 | # process_multi_instruct() 240 | # process_shikra() 241 | # process_mim() 242 | 243 | blender() -------------------------------------------------------------------------------- /process_mim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import requests 5 | from PIL import Image, ImageOps 6 | from io import BytesIO 7 | import fire 8 | from common_utils import logger, FileUtils, MPUtils 9 | from common_utils import PrepUtils 10 | import constaints as C 11 | 12 | 13 | def clean_mim_data(data_path, mim_edit_dist=0.1, min_turn_num=3): 14 | cleaned = [] 15 | data = FileUtils.load_file(data_path) 16 | logger.info("{} exemples loaded".format(len(data))) 17 | visited = set() 18 | stats = { 19 | "repeat_image": 0, "unseen_image": 0, "no_image": 0, "repeat_response": 0, 20 | "invalid_turns": 0, "wrong_role": 0, "changed_caption": 0, "too_short": 0 21 | } 22 | for ex in data: 23 | if PrepUtils.has_repeated_images(ex): 24 | stats['repeat_image'] += 1 25 | continue 26 | if PrepUtils.has_unseen_image(ex): 27 | stats['unseen_image'] += 1 28 | continue 29 | if not PrepUtils.has_image(ex): 30 | stats['no_image'] += 1 31 | continue 32 | response, caption_list = ex['response'], ex['image'] 33 | if response in visited: 34 | stats['repeat_response'] += 1 35 | continue 36 | else: 37 | visited.add(response) 38 | try: 39 | turns = PrepUtils.split_turn(response) 40 | except ValueError: 41 | stats['wrong_role'] += 1 42 | continue 43 | if not PrepUtils.is_valid_turn(turns): 44 | stats['invalid_turns'] += 1 45 | continue 46 | if len(turns) < min_turn_num: 47 | stats['too_short'] += 1 48 | continue 49 | turns = [PrepUtils.remove_non_paired_img_tag(t) for t in turns] 50 | new_turns, is_valid = [], True 51 | for sent in turns: 52 | caption_matched, has_caption = True, False 53 | for i, _ in enumerate(caption_list): 54 | gen_cap = PrepUtils.extract_text(sent, i) 55 | if gen_cap is None: 56 | continue 57 | has_caption = True 58 | clean_cap = PrepUtils.clean_tag(caption_list[i]) 59 | ed_score = PrepUtils.edit_distance(gen_cap, clean_cap) / len(clean_cap) 60 | if ed_score > mim_edit_dist: 61 | caption_matched = False 62 | if has_caption and not caption_matched: 63 | is_valid = False 64 | stats['changed_caption'] += 1 65 | break 66 | new_turns.append(sent) 67 | if is_valid: 68 | ex['response'] = "\n\n".join(new_turns) 69 | cleaned.append(ex) 70 | logger.info("{} exemples left after cleaning".format(len(cleaned))) 71 | FileUtils.save_file(cleaned, FileUtils.handle_file_extension(data_path, "clean", "add") , 'json') 72 | stats = ["{}\t{:.2f}%\t{}/{}".format(k, v / len(data) * 100, v, len(data)) for k, v in stats.items()] 73 | logger.info("Statistical results: \n{}".format("\n".join(stats))) 74 | 75 | 76 | def prepare_model_input(data_path, image_dir="./data/mim_images", force_downloading=False, nproc=16): 77 | corpus = FileUtils.load_file(data_path, 'json') 78 | logger.info("Loaded {} instances".format(len(corpus))) 79 | img_data = PrepUtils.gather_image_data(corpus) 80 | faied_image_urls = [] 81 | if force_downloading: 82 | FileUtils.check_dirs(image_dir) 83 | img_data_shards = MPUtils.prepare_shards(img_data, nproc) 84 | args_list = [(img_data_shards[i], image_dir, i) for i in range(nproc)] 85 | MPUtils.mp_func(download_images, args_list) 86 | logger.info("Finished downloading") 87 | for proc_id in range(nproc): 88 | fpath = image_dir + "/failed.proc{}.txt".format(proc_id) 89 | if FileUtils.exists(fpath): 90 | faied_image_urls += FileUtils.load_file(fpath) 91 | faied_image_urls = set(faied_image_urls) 92 | new_corpus, stats = [], {"incomplete_image": 0} 93 | for idx, data in enumerate(corpus): 94 | image_path_list = ["{}.png".format(mi) for mi in data["image_idx"]] 95 | has_failed_images = False 96 | if faied_image_urls: 97 | for u in data["url"]: 98 | if u in faied_image_urls: 99 | has_failed_images = True 100 | break 101 | if has_failed_images: 102 | stats['incomplete_image'] += 1 103 | continue 104 | else: 105 | if not PrepUtils.check_image_list(image_dir, image_path_list): 106 | stats['incomplete_image'] += 1 107 | continue 108 | conversation = [] 109 | for turn in data["response"].split("\n\n"): 110 | image_ids = PrepUtils.extract_idx(turn) 111 | image_list = [image_path_list[j] for j in image_ids] 112 | url_list = [data["url"][j] for j in image_ids] 113 | caption_list = [PrepUtils.clean_tag(data["image"][j]) for j in image_ids] 114 | turn = PrepUtils.sub_image_tag(turn) 115 | if turn.startswith(C.ASSISTANT): 116 | conversation.append({"role": "assistant", "content": turn[len(C.ASSISTANT):].strip(), "image_list":image_list, "caption_list": caption_list, "url_list": url_list}) 117 | elif turn.startswith(C.HUMAN): 118 | conversation.append({"role": "user", "content": turn[len(C.HUMAN):].strip(), "image_list":image_list, "caption_list": caption_list, "url_list": url_list}) 119 | new_corpus.append({"conversation": conversation, 'image_dir': image_dir}) 120 | logger.info("{} instances left after cleaning".format(len(new_corpus))) 121 | FileUtils.save_file(new_corpus, FileUtils.handle_file_extension(data_path, "reform", "add"), 'json') 122 | FileUtils.save_file(img_data, FileUtils.handle_file_extension(data_path, "img-cap", "add"), 'json') 123 | stats = ["{}\t{:.2f}%\t{}/{}".format(k, v / len(corpus) * 100, v, len(corpus)) for k, v in stats.items()] 124 | logger.info("Statistical results: \n{}".format("\n".join(stats))) 125 | 126 | 127 | def get_image_caption(corpus_path, save_path): 128 | logger.info("Start processing...") 129 | corpus = json.load(open(corpus_path, "r")) 130 | logger.info(f"load {len(corpus)} instances") 131 | new_corpus = [] 132 | for data in corpus: 133 | for image_idx, caption in zip(data["image_idx"], data["image"]): 134 | caption = re.findall(r'(.*?)<\/img\d+>', caption)[0] 135 | new_corpus.append({ 136 | "image": f"{image_idx}.png", 137 | "caption": caption.strip() 138 | }) 139 | json.dump(new_corpus, open(save_path, "w"), indent=4) 140 | 141 | 142 | def download_images(image_data, image_dir, proc_id, max_try_num=5, image_size=512): 143 | failed_data = [] 144 | if isinstance(image_data, str): 145 | image_data = FileUtils.load_file(image_data) 146 | FileUtils.check_dirs(image_dir) 147 | logger.info("Proc-{} | Downloading images for shard with {} examples".format(proc_id, len(image_data))) 148 | for idx in range(len(image_data)): 149 | image_basename, url = image_data[idx]['image'], image_data[idx]['url'] 150 | image_path = "{}/{}".format(image_dir, image_basename) 151 | if not PrepUtils.check_image_file(image_path): 152 | logger.info("Proc-{} | Downloading from {} for {}".format(proc_id, url, image_basename)) 153 | try_num = 0 154 | while try_num < max_try_num: 155 | try: 156 | response = requests.get(url, headers=C.HEADERS) 157 | image = Image.open(BytesIO(response.content)).convert("RGB") 158 | image = PrepUtils.resize_and_pad(image, (image_size, image_size)) 159 | image.save(image_path) 160 | break 161 | except: 162 | try_num += 1 163 | if try_num >= max_try_num: 164 | failed_data.append(url) 165 | logger.info("Failed to download {}".format(url)) 166 | FileUtils.save_file(failed_data, image_dir + "/failed.proc{}.txt".format(proc_id)) 167 | 168 | 169 | def split_data(data_path, save_prefix, valid_num=100, test_num=100): 170 | import random 171 | random.seed(10086) 172 | data = FileUtils.load_file(data_path) 173 | ids = list(range(len(data))) 174 | random.shuffle(ids) 175 | data = [data[i] for i in ids] 176 | valid = data[:valid_num] 177 | test = data[valid_num:valid_num+test_num] 178 | train = data[valid_num+test_num:] 179 | FileUtils.save_file(train, save_prefix + ".train.json") 180 | FileUtils.save_file(valid, save_prefix + ".valid.json") 181 | FileUtils.save_file(test, save_prefix + ".test.json") 182 | 183 | 184 | def data_statistics(data_path): 185 | from sacremoses import MosesTokenizer 186 | from collections import Counter 187 | from tqdm import tqdm 188 | tokenizer = MosesTokenizer(lang='en') 189 | mean = lambda x: sum(x) / len(x) 190 | 191 | def compute_div_score(ngram_counters): 192 | div_score_turns = [] 193 | for k, cs in ngram_counters.items(): 194 | div_score = 0 195 | for n in range(2, 5): 196 | total_num = sum(cs[n].values()) 197 | unique_num = len(cs[n]) 198 | div_score += unique_num / total_num 199 | div_score_turns.append(div_score) 200 | return mean(div_score_turns) 201 | 202 | def traditional_statistics(data): 203 | total_ex_num = len(data) 204 | conversation_lens, instruct_lens, response_lens = [], [], [] 205 | conversation_image_nums, instruct_image_nums, response_image_nums = [], [], [] 206 | turn_nums = [] 207 | for ex in data: 208 | conversation = ex['conversation'] 209 | turn_nums.append(len(conversation) / 2) 210 | ci, ct, ri, rt, ii, it = 0, 0, 0, 0, 0, 0 211 | for c in conversation: 212 | if c['role'] == "user": 213 | ci += len(c['image_list']) 214 | ii += len(c['image_list']) 215 | content = tokenizer.tokenize(c['content'], escape=False) 216 | ct += len(content) 217 | it += len(content) 218 | elif c['role'] == "assistant": 219 | ci += len(c['image_list']) 220 | ri += len(c['image_list']) 221 | content = tokenizer.tokenize(c['content'], escape=False) 222 | ct += len(content) 223 | rt += len(content) 224 | else: 225 | raise ValueError(c['role']) 226 | conversation_lens.append(ct) 227 | conversation_image_nums.append(ci) 228 | instruct_lens.append(it) 229 | instruct_image_nums.append(ii) 230 | response_image_nums.append(ri) 231 | response_lens.append(rt) 232 | logger.info("total_ex_num: {}".format(total_ex_num)) 233 | logger.info("turn_nums: {}".format(mean(turn_nums))) 234 | logger.info("conversation_lens: {}".format(mean(conversation_lens))) 235 | logger.info("instruct_lens: {}".format(mean(instruct_lens))) 236 | logger.info("response_lens: {}".format(mean(response_lens))) 237 | logger.info("conversation_image_nums: {}".format(mean(conversation_image_nums))) 238 | logger.info("instruct_image_nums: {}".format(mean(instruct_image_nums))) 239 | logger.info("response_image_nums: {}".format(mean(response_image_nums))) 240 | 241 | def image_diversity(data): 242 | user_image_nums, assist_image_nums = dict(), dict() 243 | for ex in data: 244 | conversation = ex['conversation'] 245 | for cidx, c in enumerate(conversation): 246 | if c['role'] == "user": 247 | cidx = cidx // 2 248 | if cidx in user_image_nums: 249 | user_image_nums[cidx].append(len(c['image_list'])) 250 | else: 251 | user_image_nums[cidx] = [len(c['image_list'])] 252 | elif c['role'] == "assistant": 253 | cidx = cidx // 2 254 | if cidx in assist_image_nums: 255 | assist_image_nums[cidx].append(len(c['image_list'])) 256 | else: 257 | assist_image_nums[cidx] = [len(c['image_list'])] 258 | else: 259 | raise ValueError(c['role']) 260 | for i in range(len(user_image_nums)): 261 | logger.info("turn: {}\tuser_image_num: {}".format(i, mean(user_image_nums[i]))) 262 | for i in range(len(assist_image_nums)): 263 | logger.info("turn: {}\tassistant_image_num: {}".format(i, mean(assist_image_nums[i]))) 264 | 265 | def text_diversity(data): 266 | ngram_counters = dict() 267 | user_ngram_counters = dict() 268 | assitant_ngram_counters = dict() 269 | for ex in tqdm(data): 270 | conversation = ex['conversation'] 271 | for cidx, c in enumerate(conversation): 272 | cidx = cidx // 2 273 | if cidx not in ngram_counters: 274 | ngram_counters[cidx] = {2: Counter(), 3: Counter(), 4: Counter()} 275 | if cidx not in user_ngram_counters: 276 | user_ngram_counters[cidx] = {2: Counter(), 3: Counter(), 4: Counter()} 277 | if cidx not in assitant_ngram_counters: 278 | assitant_ngram_counters[cidx] = {2: Counter(), 3: Counter(), 4: Counter()} 279 | content = tokenizer.tokenize(c['content'], escape=False) 280 | if c['role'] == "user": 281 | for n in range(2, 5): 282 | nragms = PrepUtils.extract_ngrams(content, n=n) 283 | user_ngram_counters[cidx][n].update(nragms) 284 | ngram_counters[cidx][n].update(nragms) 285 | elif c['role'] == "assistant": 286 | for n in range(2, 5): 287 | nragms = PrepUtils.extract_ngrams(content, n=n) 288 | assitant_ngram_counters[cidx][n].update(nragms) 289 | ngram_counters[cidx][n].update(nragms) 290 | logger.info("Turns: {}".format(list(ngram_counters.keys()))) 291 | logger.info("Overall Div score: {}".format(compute_div_score(ngram_counters))) 292 | logger.info("User Div score: {}".format(compute_div_score(user_ngram_counters))) 293 | logger.info("Assistant Div score: {}".format(compute_div_score(assitant_ngram_counters))) 294 | 295 | data = FileUtils.load_file(data_path) 296 | logger.info("----------------------- traditional_statistics -----------------------") 297 | # traditional_statistics(data) 298 | logger.info("----------------------- text_diversity -----------------------") 299 | text_diversity(data) 300 | logger.info("----------------------- image_diversity -----------------------") 301 | # image_diversity(data) 302 | 303 | 304 | def analyze_human_annotation(data_path="./annotation.csv", has_header=True): 305 | from collections import Counter 306 | ann = FileUtils.load_file(data_path) 307 | if has_header: 308 | ann = ann[1:] 309 | counters = {"quality": Counter(), "character": Counter(), "error": Counter()} 310 | n = 0 311 | for row in ann: 312 | quality = row[2] 313 | if quality: 314 | n += 1 315 | counters['quality'].update([quality]) 316 | if quality == "Poor": 317 | counters['error'].update([it.strip() for it in row[4].split(',')]) 318 | else: 319 | counters['character'].update([it.strip() for it in row[3].split(',')]) 320 | for k, v in counters.items(): 321 | for label, freq in v.most_common(): 322 | logger.info("{} | {}: {}".format(k, label, freq / n)) 323 | 324 | 325 | if __name__ == "__main__": 326 | fire.Fire({ 327 | "prepare_model_input": prepare_model_input, 328 | "clean_mim_data": clean_mim_data, 329 | "split_data": split_data, 330 | "download_images": download_images, 331 | "data_statistics": data_statistics, 332 | "analyze_human_annotation": analyze_human_annotation 333 | }) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.22.0 2 | beautifulsoup4==4.12.2 3 | datasets==2.14.5 4 | deepspeed==0.10.0 5 | diffusers==0.20.0 6 | einops==0.6.1 7 | flask==2.3.2 8 | sentencepiece==0.1.99 9 | transformers==4.31.0 10 | wandb==0.15.8 11 | webdataset==0.2.48 -------------------------------------------------------------------------------- /scripts/run_demo.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | URL_PREFIX="" 4 | IMG_DIR="./${URL_PREFIX}/images" 5 | mkdir -p $IMG_DIR 6 | 7 | CHECKPOINT=./checkpoint/second_stage_model.pt 8 | VISION_MODEL=./checkpoint/blip2_vision_model 9 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 10 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 11 | SD_BASE=../../CKPT/stabilityai/stable-diffusion-xl-base-1.0 12 | SD_REFINER=../../CKPT/stabilityai/stable-diffusion-xl-refiner-1.0 13 | 14 | python app.py \ 15 | --fp16 \ 16 | --generate_image \ 17 | --max_output_length 256 \ 18 | --checkpoint $CHECKPOINT \ 19 | --inference_dir $IMG_DIR \ 20 | --vision_model $VISION_MODEL \ 21 | --language_model $LANGUAGE_MODEL \ 22 | --processor $PROCESSOR \ 23 | --sd_base $SD_BASE \ 24 | --sd_refiner $SD_REFINER \ 25 | --num_query_tokens 32 \ 26 | --num_qformer_hidden_layers 12 \ 27 | --num_qformer_attention_heads 12 \ 28 | --qformer_hidden_size 768 \ 29 | --qformer_intermediate_size 3072 \ 30 | --demo_example_path data/examples.json \ 31 | --url_prefix "${URL_PREFIX}" \ 32 | --safe_image_num 4 \ 33 | --safe_word_num 650 \ 34 | 35 | 36 | -------------------------------------------------------------------------------- /scripts/run_first_stage.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG=INFO 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_SL=3 5 | export NCCL_NET_GDR_READ=1 6 | 7 | SAVE_MODULES="query_tokens qformer qformer_projection" 8 | TRAIN_DATA_PATH="../MIM/data/minigpt4/cc_sbu/cc_sbu_dataset/{00000..01254}.tar" 9 | CHECKPOINT=./checkpoint/blip2_vision.pt 10 | VISION_MODEL=./checkpoint/blip2_vision_model 11 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 12 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 13 | 14 | deepspeed \ 15 | --include localhost:0,1,2,3,4,5,6,7 \ 16 | main.py \ 17 | --train \ 18 | --stage first \ 19 | --project_name first_stage_1e-4_256_12_12_768_3072_blip2_cosine8 \ 20 | --deepspeed_config ./ds_config_first_stage.json \ 21 | --with_epoch 10000000 \ 22 | --num_epochs 8 \ 23 | --warmup_steps 2000 \ 24 | --with_num_works 4 \ 25 | --save_modules $SAVE_MODULES \ 26 | --train_data_path $TRAIN_DATA_PATH \ 27 | --checkpoint $CHECKPOINT \ 28 | --vision_model $VISION_MODEL \ 29 | --language_model $LANGUAGE_MODEL \ 30 | --processor $PROCESSOR \ 31 | --max_input_length 256 \ 32 | --num_query_tokens 32 \ 33 | --num_qformer_hidden_layers 12 \ 34 | --num_qformer_attention_heads 12 \ 35 | --qformer_hidden_size 768 \ 36 | --qformer_intermediate_size 3072 -------------------------------------------------------------------------------- /scripts/run_first_stage_val.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG=INFO 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_SL=3 5 | export NCCL_NET_GDR_READ=1 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | VAL_DATA_PATH=../MIM/data/minigpt4/cc_sbu/cc_sbu_dataset/01255.tar 9 | CHECKPOINT=./checkpoint/first_stage_1e-4_256_12_12_768_3072_blip2_cosine2/checkpoint_epoch2_step78128/pytorch_model.pt 10 | VISION_MODEL=./checkpoint/blip2_vision_model 11 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 12 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 13 | 14 | python main.py \ 15 | --use_causal_mask \ 16 | --fp16 \ 17 | --compute_loss \ 18 | --stage first \ 19 | --deepspeed_config ./ds_config_first_stage.json \ 20 | --with_epoch 5000 \ 21 | --with_num_works 4 \ 22 | --val_data_path $VAL_DATA_PATH \ 23 | --checkpoint $CHECKPOINT \ 24 | --vision_model $VISION_MODEL \ 25 | --language_model $LANGUAGE_MODEL \ 26 | --processor $PROCESSOR \ 27 | --max_input_length 768 \ 28 | --num_query_tokens 32 \ 29 | --num_qformer_hidden_layers 12 \ 30 | --num_qformer_attention_heads 12 \ 31 | --qformer_hidden_size 768 \ 32 | --qformer_intermediate_size 3072 -------------------------------------------------------------------------------- /scripts/run_inference.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | 4 | VAL_DATA_PATH=./good_cases/inference.json \ 5 | INFERENCE_DIR=./inference_results 6 | CHECKPOINT=./checkpoint/second_stage_model.pt 7 | VISION_MODEL=./checkpoint/blip2_vision_model 8 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 9 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 10 | SD_BASE=../../CKPT/stabilityai/stable-diffusion-xl-base-1.0 11 | SD_REFINER=../../CKPT/stabilityai/stable-diffusion-xl-refiner-1.0 12 | 13 | python inference.py \ 14 | --fp16 \ 15 | --generate_image \ 16 | --max_output_length 320 \ 17 | --val_data_path $VAL_DATA_PATH \ 18 | --inference_dir $INFERENCE_DIR \ 19 | --checkpoint $CHECKPOINT \ 20 | --vision_model $VISION_MODEL \ 21 | --language_model $LANGUAGE_MODEL \ 22 | --processor $PROCESSOR \ 23 | --sd_base $SD_BASE \ 24 | --sd_refiner $SD_REFINER \ 25 | --num_query_tokens 32 \ 26 | --num_qformer_hidden_layers 12 \ 27 | --num_qformer_attention_heads 12 \ 28 | --qformer_hidden_size 768 \ 29 | --qformer_intermediate_size 3072 30 | -------------------------------------------------------------------------------- /scripts/run_mmbench_eval.sh: -------------------------------------------------------------------------------- 1 | 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | INFERENCE_DIR=./MMBench_Results/second_stage_1e-5_12_12_768_3072_blip2_cosine2_tlm_mim_21k 5 | CHECKPOINT=./checkpoint/second_stage_1e-5_12_12_768_3072_blip2_cosine2_tlm_mim_21k/checkpoint_epoch3_step1014/pytorch_model.pt 6 | VISION_MODEL=./checkpoint/blip2_vision_model 7 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 8 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 9 | SD_BASE=../../CKPT/stabilityai/stable-diffusion-xl-base-1.0 10 | SD_REFINER=../../CKPT/stabilityai/stable-diffusion-xl-refiner-1.0 11 | 12 | python mmbench_evaluation.py \ 13 | --fp16 \ 14 | --max_output_length 256 \ 15 | --inference_dir $INFERENCE_DIR \ 16 | --checkpoint $CHECKPOINT \ 17 | --vision_model $VISION_MODEL \ 18 | --language_model $LANGUAGE_MODEL \ 19 | --processor $PROCESSOR \ 20 | --sd_base $SD_BASE \ 21 | --sd_refiner $SD_REFINER \ 22 | --num_query_tokens 32 \ 23 | --num_qformer_hidden_layers 12 \ 24 | --num_qformer_attention_heads 12 \ 25 | --qformer_hidden_size 768 \ 26 | --qformer_intermediate_size 3072 -------------------------------------------------------------------------------- /scripts/run_mme_eval.sh: -------------------------------------------------------------------------------- 1 | 2 | export CUDA_VISIBLE_DEVICES=4 3 | 4 | INFERENCE_DIR=./MME_Benchmark_release_version/eval_tool/second_stage_1e-4_12_12_768_3072_blip2_cosine2_multi_22k_768 5 | CHECKPOINT=./checkpoint/second_stage_1e-5_12_12_768_3072_blip2_cosine2_tlm_mim21k_768/checkpoint_epoch3_step1014/pytorch_model.pt 6 | VISION_MODEL=./checkpoint/blip2_vision_model 7 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 8 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 9 | SD_BASE=../../CKPT/stabilityai/stable-diffusion-xl-base-1.0 10 | SD_REFINER=../../CKPT/stabilityai/stable-diffusion-xl-refiner-1.0 11 | 12 | python mme_evaluation.py \ 13 | --fp16 \ 14 | --max_output_length 4 \ 15 | --inference_dir $INFERENCE_DIR \ 16 | --checkpoint $CHECKPOINT \ 17 | --vision_model $VISION_MODEL \ 18 | --language_model $LANGUAGE_MODEL \ 19 | --processor $PROCESSOR \ 20 | --sd_base $SD_BASE \ 21 | --sd_refiner $SD_REFINER \ 22 | --num_query_tokens 32 \ 23 | --num_qformer_hidden_layers 12 \ 24 | --num_qformer_attention_heads 12 \ 25 | --qformer_hidden_size 768 \ 26 | --qformer_intermediate_size 3072 -------------------------------------------------------------------------------- /scripts/run_second_stage.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG=INFO 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_SL=3 5 | export NCCL_NET_GDR_READ=1 6 | 7 | TRAIN_DATA_PATH=./data/platypus_24926_minigpt4_3439_shikra_7081_llava_157712_multi_instruct_105745_mim_21629.json 8 | SAVE_MODULES="query_tokens qformer qformer_projection language_model" 9 | CHECKPOINT=./checkpoint/first_stage_1e-4_256_12_12_768_3072_blip2_cosine2/checkpoint_epoch2_step78128/pytorch_model.pt 10 | VISION_MODEL=./checkpoint/blip2_vision_model 11 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 12 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 13 | 14 | deepspeed \ 15 | --include localhost:0,1,2,3,4,5,6,7 \ 16 | main.py \ 17 | --train \ 18 | --stage second \ 19 | --training_lm \ 20 | --deepspeed_config ./ds_config_second_stage.json \ 21 | --project_name second_stage_1e-5_12_12_768_3072_blip2_cosine2_tlm_plat25k_mini3k_shik7k_llav158k_mult106k_mim22k_768 \ 22 | --num_epochs 3 \ 23 | --warmup_steps 100 \ 24 | --train_data_path $TRAIN_DATA_PATH \ 25 | --save_modules $SAVE_MODULES \ 26 | --checkpoint $CHECKPOINT \ 27 | --vision_model $VISION_MODEL \ 28 | --language_model $LANGUAGE_MODEL \ 29 | --processor $PROCESSOR \ 30 | --max_input_length 768 \ 31 | --num_query_tokens 32 \ 32 | --num_qformer_hidden_layers 12 \ 33 | --num_qformer_attention_heads 12 \ 34 | --qformer_hidden_size 768 \ 35 | --qformer_intermediate_size 3072 \ 36 | -------------------------------------------------------------------------------- /scripts/run_second_stage_val.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG=INFO 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_SL=3 5 | export NCCL_NET_GDR_READ=1 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | VAL_DATA_PATH=./data/train.s2.v3.proc.valid.json 9 | CHECKPOINT=./checkpoint/second_stage_1e-5_12_12_768_3072_blip2_cosine2_tlm_mim21k_768/checkpoint_epoch3_step1014/pytorch_model.pt 10 | VISION_MODEL=./checkpoint/blip2_vision_model 11 | LANGUAGE_MODEL=../../CKPT/meta-llama/Llama-2-7b-chat-hf 12 | PROCESSOR=../../CKPT/Salesforce/blip2-flan-t5-xxl 13 | 14 | python main.py \ 15 | --use_causal_mask \ 16 | --fp16 \ 17 | --compute_loss \ 18 | --stage second \ 19 | --deepspeed_config ./ds_config_second_stage.json \ 20 | --with_num_works 4 \ 21 | --val_data_path $VAL_DATA_PATH \ 22 | --checkpoint $CHECKPOINT \ 23 | --vision_model $VISION_MODEL \ 24 | --language_model $LANGUAGE_MODEL \ 25 | --processor $PROCESSOR \ 26 | --sd_base $SD_BASE \ 27 | --sd_refiner $SD_REFINER \ 28 | --num_query_tokens 32 \ 29 | --num_qformer_hidden_layers 12 \ 30 | --num_qformer_attention_heads 12 \ 31 | --qformer_hidden_size 768 \ 32 | --qformer_intermediate_size 3072 33 | 34 | -------------------------------------------------------------------------------- /stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import torch 5 | import random 6 | import argparse 7 | import numpy as np 8 | import webdataset as wds 9 | import torch.distributed as dist 10 | from tqdm import tqdm 11 | from PIL import Image 12 | from io import BytesIO 13 | from transformers import Blip2Processor 14 | from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, DiffusionPipeline, AutoencoderKL, StableDiffusionXLImg2ImgPipeline, DPMSolverMultistepScheduler 15 | from data_utils import ImageCaptionTemplates, ImageGenerationTemplates 16 | 17 | 18 | 19 | def get_sd_latents(sd_pipe, generator, image): 20 | with torch.no_grad(): 21 | latents = sd_pipe(prompt="", image=image, generator=generator, strength=0.0, output_type="latent").images 22 | return latents 23 | 24 | def decode_with_latents(sd_pipe, latents): 25 | with torch.no_grad(): 26 | image = sd_pipe.vae.decode(latents / sd_pipe.vae.config.scaling_factor, return_dict=False)[0] 27 | image = sd_pipe.watermark.apply_watermark(image) 28 | image = sd_pipe.image_processor.postprocess(image, output_type="pil")[0] 29 | return image 30 | 31 | def webdataset_map(example): 32 | 33 | image = Image.open(BytesIO(example["image"])).convert("RGB") 34 | caption = example["caption"].decode("utf-8") 35 | key = json.loads(example["meta"])["key"] 36 | 37 | return { 38 | "image": image, 39 | "caption": caption, 40 | "key": key 41 | } 42 | 43 | def webdataset_collate_fn(batch): 44 | 45 | images = [example["image"] for example in batch] 46 | captions = [example["caption"] for example in batch] 47 | keys = [example["key"] for example in batch] 48 | 49 | return { 50 | "image": images, 51 | "caption": captions, 52 | "key": keys 53 | } 54 | 55 | def generate_val_generation_dataset(): 56 | 57 | urls = "/apdcephfs/share_733425/jcykcai/sihengli/MIM/mim/data/minigpt4/cc_sbu/cc_sbu_dataset/01255.tar" 58 | dataset = wds.WebDataset(urls).rename(image="jpg;png", meta="json", caption="txt").map(webdataset_map) 59 | 60 | image_dir = "./data/val_image_generation_images" 61 | if not os.path.exists(image_dir): 62 | os.makedirs(image_dir) 63 | target_path = "./data/val_image_generation.json" 64 | corpus = [] 65 | for data in dataset: 66 | caption = data["caption"] 67 | image = data["image"] 68 | template = random.choice(ImageGenerationTemplates) 69 | input_text = f"Human: {template[0].format(image_caption=caption)}\n" 70 | output_text = f"Assistant: {template[1]}" 71 | 72 | cur_idx = len(corpus) 73 | image_path = os.path.join(image_dir, f"{cur_idx}.png") 74 | image.save(image_path) 75 | corpus.append({ 76 | "input": input_text, 77 | "output": output_text, 78 | "key": data["key"], 79 | "caption": caption, 80 | "input_image_list": [], 81 | "output_image_list": [image_path] 82 | }) 83 | if len(corpus) == 100: 84 | break 85 | json.dump(corpus, open(target_path, "w"), indent=4) 86 | 87 | def generate_val_caption_dataset(): 88 | 89 | urls = "/apdcephfs/share_733425/jcykcai/sihengli/MIM/mim/data/minigpt4/cc_sbu/cc_sbu_dataset/01255.tar" 90 | dataset = wds.WebDataset(urls).rename(image="jpg;png", meta="json", caption="txt").map(webdataset_map) 91 | 92 | image_dir = "./data/val_image_caption_images" 93 | if not os.path.exists(image_dir): 94 | os.makedirs(image_dir) 95 | target_path = "./data/val_image_caption.json" 96 | corpus = [] 97 | for data in dataset: 98 | caption = data["caption"] 99 | image = data["image"] 100 | template = random.choice(ImageCaptionTemplates) 101 | input_text = f"Human: {template}\n" 102 | output_text = f"Assistant: {caption}" 103 | 104 | cur_idx = len(corpus) 105 | image_path = os.path.join(image_dir, f"{cur_idx}.png") 106 | image.save(image_path) 107 | corpus.append({ 108 | "input": input_text, 109 | "output": output_text, 110 | "key": data["key"], 111 | "caption": caption, 112 | "input_image_list": [image_path], 113 | "output_image_list": [] 114 | }) 115 | if len(corpus) == 100: 116 | break 117 | json.dump(corpus, open(target_path, "w"), indent=4) 118 | 119 | 120 | def generate_sd2_prompt_embeds(): 121 | 122 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 123 | 124 | ckpt = "/apdcephfs/share_733425/jcykcai/sihengli/CKPT/stabilityai/stable-diffusion-2-1" 125 | sd_pipe = StableDiffusionPipeline.from_pretrained(ckpt, use_safetensors=True, variant="fp16") 126 | sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) 127 | sd_pipe = sd_pipe.to("cuda") 128 | 129 | urls = "/apdcephfs/share_733425/jcykcai/sihengli/MIM/MIM/data/minigpt4/cc_sbu/cc_sbu_dataset/{01001..01200}.tar" 130 | dataset = wds.WebDataset(urls).rename(image="jpg;png", meta="json", caption="txt").map(webdataset_map) 131 | 132 | data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=256, collate_fn=webdataset_collate_fn, num_workers=4, pin_memory=True, drop_last=False) 133 | 134 | save_dir = "./data/sd2_prompt_embeds" 135 | if not os.path.exists(save_dir): 136 | os.makedirs(save_dir) 137 | for idx, batch in tqdm(enumerate(data_loader)): 138 | 139 | with torch.no_grad(): 140 | prompt_embeds = sd_pipe._encode_prompt(prompt=batch["caption"], device=device, num_images_per_prompt=1, do_classifier_free_guidance=False) 141 | for id1 in range(prompt_embeds.shape[0]): 142 | np.save(os.path.join(save_dir, "{}.npy".format(batch['key'][id1])), prompt_embeds[id1].cpu().numpy()) 143 | 144 | def generate_sd2_prompt_embeds_mim(): 145 | 146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 147 | 148 | ckpt = "/apdcephfs/share_733425/jcykcai/sihengli/CKPT/stabilityai/stable-diffusion-2-1" 149 | sd_pipe = StableDiffusionPipeline.from_pretrained(ckpt, use_safetensors=True, variant="fp16") 150 | sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) 151 | sd_pipe = sd_pipe.to("cuda") 152 | 153 | save_dir = "./data/sd2_prompt_embeds_mim" 154 | if not os.path.exists(save_dir): 155 | os.makedirs(save_dir) 156 | # corpus = json.load(open("./data/train.s2.v3.img-cap.json"))[25000:] 157 | corpus = json.load(open("./data/stage1_sample.json")) 158 | print(len(corpus)) 159 | for idx, data in tqdm(enumerate(corpus)): 160 | 161 | with torch.no_grad(): 162 | prompt_embeds = sd_pipe._encode_prompt(prompt=data["caption"], device=device, num_images_per_prompt=1, do_classifier_free_guidance=False) 163 | # image = sd_pipe(prompt_embeds=prompt_embeds, generator=generator).images[0] 164 | # print(prompt_embeds.shape) 165 | # image = sd_pipe(prompt=caption, generator=generator).images[0] 166 | np.save(os.path.join(save_dir, "{}.npy".format(data["image"])), prompt_embeds[0].cpu().numpy()) 167 | 168 | def decode_with_sd2_prompt_embeds(sd_pipe, generator, prompt_embeds): 169 | 170 | with torch.no_grad(): 171 | image = sd_pipe(prompt_embeds=prompt_embeds, generator=generator).images[0] 172 | return image 173 | 174 | def sample_stage1_data(): 175 | urls = "/apdcephfs/share_733425/jcykcai/sihengli/MIM/MIM/data/minigpt4/cc_sbu/cc_sbu_dataset/{00000..01200}.tar" 176 | dataset = wds.WebDataset(urls).rename(image="jpg;png", meta="json", caption="txt").map(webdataset_map) 177 | 178 | sample_num = 20000 179 | corpus = [] 180 | if not os.path.exists("./data/stage1_sample"): 181 | os.makedirs("./data/stage1_sample") 182 | for data in dataset: 183 | print(data["key"]) 184 | image = data["image"] 185 | caption = data["caption"] 186 | key = data["key"] 187 | image_name = f"stage1_{key}.png" 188 | image.save(os.path.join("./data/stage1_sample", image_name)) 189 | image.save(os.path.join("./data/mim_images", image_name)) 190 | 191 | template = random.choice(ImageGenerationTemplates) 192 | data = { 193 | "conversation": [ 194 | { 195 | "role": "user", 196 | "content": template[0].format(image_caption=caption), 197 | "image_list": [] 198 | }, 199 | { 200 | "role": "assistant", 201 | "content": template[1], 202 | "image_list": [image_name], 203 | } 204 | ], 205 | "image": image_name, 206 | "caption": caption, 207 | } 208 | corpus.append(data) 209 | if len(corpus) == sample_num: 210 | break 211 | print(len(corpus)) 212 | json.dump(corpus, open("./data/stage1_sample.json", "w"), indent=4) 213 | 214 | def blender_stage1_and_stage2(): 215 | 216 | corpus1 = json.load(open("./data/stage1_sample.json")) 217 | corpus2 = json.load(open("./data/train_mim.json")) 218 | corpus = corpus1 + corpus2 219 | 220 | random.shuffle(corpus) 221 | print(len(corpus)) 222 | json.dump(corpus, open("./data/train_mim_blender.json", "w"), indent=4) 223 | 224 | 225 | def decode_with_sdxl(prompt, base, refiner, generator): 226 | n_steps = 40 227 | high_noise_frac = 0.8 228 | 229 | # run both experts 230 | image = base( 231 | prompt=prompt, 232 | generator=generator, 233 | num_inference_steps=n_steps, 234 | denoising_end=high_noise_frac, 235 | output_type="latent", 236 | ).images 237 | image = refiner( 238 | prompt=prompt, 239 | generator=generator, 240 | num_inference_steps=n_steps, 241 | denoising_start=high_noise_frac, 242 | image=image, 243 | ).images[0] 244 | 245 | return image 246 | 247 | if __name__ == "__main__": 248 | 249 | # generate_sd2_prompt_embeds() 250 | # generate_val_generation_dataset() 251 | # generate_val_caption_dataset() 252 | # generate_sd2_prompt_embeds_mim() 253 | # sample_stage1_data() 254 | # blender_stage1_and_stage2() 255 | 256 | # generate_sd2_prompt_embeds_mim() 257 | 258 | device = "cuda:1" 259 | ckpt = "/apdcephfs/share_733425/jcykcai/sihengli/CKPT/stabilityai/stable-diffusion-xl-base-1.0" 260 | base = DiffusionPipeline.from_pretrained( 261 | ckpt, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 262 | ) 263 | base.to(device) 264 | 265 | ckpt = "/apdcephfs/share_733425/jcykcai/sihengli/CKPT/stabilityai/stable-diffusion-xl-refiner-1.0" 266 | refiner = DiffusionPipeline.from_pretrained( 267 | ckpt, 268 | text_encoder_2=base.text_encoder_2, 269 | vae=base.vae, 270 | torch_dtype=torch.float16, 271 | use_safetensors=True, 272 | variant="fp16", 273 | ) 274 | refiner.to(device) 275 | 276 | generator = torch.Generator(device="cuda").manual_seed(53) 277 | 278 | prompt = "Hi, I'm looking for inspiration on how to relax outdoors while still being productive." 279 | 280 | image = decode_with_sdxl(prompt, base, refiner, generator) 281 | 282 | image.save("./data/relax_outdoors.png") -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | Mim 10 | 182 | 183 | 184 |
185 |

Instructions

186 |

What it is: This is the demo of TextBind. This demo supports interleaved text and images in a multi-turn conversation. It can also generate appropriate images without showing an explicit description.

187 |

How to use:

188 |
    189 |
  1. 190 | Send: Click to send the content in the input box to the model. You can provide interleaved text and images in the input box. 191 |
  2. 192 |
  3. 193 | Upload IMG: Click to upload an image from your local device to the input box. The image will appear at the position of your cursor. 194 |
  4. 195 |
  5. 196 | Example: Click to show a random conversation example. 197 |
  6. 198 |
199 |

Tips: (1) If you want to start a new conversation, please use ctrl+R or (cmd+R) to refresh the webpage. (2) Uploading large images (>1MB) may fail, please be careful about the image size. (3) Our server uses the FIFO strategy to handle user requests. Therefore, the waiting time may be very long when there are many users.

200 |
201 |
202 |
203 |
204 |
205 |
206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 |
215 | 216 |
217 |
218 | 380 | 381 | 382 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import torch 5 | import deepspeed 6 | import transformers 7 | import argparse 8 | import numpy as np 9 | import webdataset as wds 10 | from tqdm import tqdm 11 | from transformers import PreTrainedTokenizer, PreTrainedModel, LlamaTokenizer, Blip2Processor 12 | from typing import List, Dict, Sequence 13 | from PIL import Image 14 | from diffusers import StableDiffusionPipeline 15 | from transformers import Blip2ForConditionalGeneration, Blip2VisionModel 16 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 17 | 18 | from model import MIMModel 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--project_name', type=str, default='mim', help='Wandb project name') 23 | parser.add_argument('--deepspeed_config', type=str, default='deepspeed_config.json', help='DeepSpeed configuration file') 24 | parser.add_argument('--local_rank', type=int, default=-1, help='Local rank for distributed training (-1: not distributed)') 25 | parser.add_argument('--train_data_path', type=str, default='data/train.json', help='Path to training data') 26 | parser.add_argument('--val_data_path', type=str, default='data/val.json', help='Path to validation data') 27 | parser.add_argument('--image_dir', type=str, default=None, help='Path to image directory') 28 | parser.add_argument('--inference_dir', type=str, default=None) 29 | parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoints') 30 | parser.add_argument('--save_modules', type=str, nargs='+', default=['model'], help='State keys to save') 31 | 32 | # training parameters 33 | parser.add_argument('--train', action='store_true', help='Train model') 34 | parser.add_argument('--compute_loss', action='store_true', help='Compute loss') 35 | parser.add_argument('--stage', type=str, help='Training stage') 36 | parser.add_argument('--warmup_steps', type=int, default=1000, help='Warmup steps') 37 | parser.add_argument('--num_epochs', type=int, default=None, help='Number of training epochs') 38 | parser.add_argument('--max_input_length', type=int, default=512, help='Maximum input length') 39 | parser.add_argument('--max_num_images', type=int, default=32, help='Maximum number of images') 40 | parser.add_argument('--with_epoch', type=int, default=0, help='for wds') 41 | parser.add_argument('--with_num_works', type=int, default=1, help="for wds") 42 | parser.add_argument('--save_per_steps', type=int, default=1000000, help='Save model per number of steps') 43 | parser.add_argument('--training_lm', action='store_true', default=False, help='Train language model') 44 | parser.add_argument('--training_vm', action='store_true', default=False, help='Train language model') 45 | 46 | # inference parameters 47 | parser.add_argument('--max_output_length', type=int, default=256, help='Maximum generation length') 48 | parser.add_argument('--generate_image', action='store_true', help='Generate image') 49 | parser.add_argument('--top_p', type=float, default=None, help='Top p') 50 | 51 | # model parameters 52 | parser.add_argument('--fp16', action='store_true', help='Use fp16') 53 | parser.add_argument('--bf16', action='store_true', help='Use bf16') 54 | parser.add_argument('--vision_model', type=str, default='openai/clip-vit-base-patch32', help='Vision model') 55 | parser.add_argument('--language_model', type=str, default='openai/clip-vit-base-patch32', help='Language model') 56 | parser.add_argument('--sd_base', type=str, default="runwayml/stable-diffusion-v1-5", help='Stable Diffusion model') 57 | parser.add_argument('--sd_refiner', type=str, default="runwayml/stable-diffusion-v1-5", help='Stable Diffusion model') 58 | parser.add_argument('--processor', type=str, default='clip', help='Processor') 59 | parser.add_argument('--num_query_tokens', type=int, default=32, help='Number of query tokens') 60 | parser.add_argument('--num_qformer_attention_heads', type=int, default=16, help='Number of query tokens') 61 | parser.add_argument('--num_qformer_hidden_layers', type=int, default=12, help='Number of query tokens') 62 | parser.add_argument('--qformer_hidden_size', type=int, default=1024, help='Number of query tokens') 63 | parser.add_argument('--qformer_intermediate_size', type=int, default=1408, help='Number of query tokens') 64 | 65 | # demo parameters 66 | parser.add_argument('--port', default=8081, help='Port to run the demo') 67 | parser.add_argument('--model_list', type=str, default="", help='path to the info of model list') 68 | parser.add_argument('--demo_example_path', type=str, default="", help='path to the example data') 69 | parser.add_argument('--url_prefix', type=str, default="", help='add prefix to the url') 70 | parser.add_argument('--safe_image_num', type=int, default=16, help='maximum number of images appearing in conversation') 71 | parser.add_argument('--safe_word_num', type=int, default=768, help='maximum number of words appearing in conversationl') 72 | 73 | return parser.parse_args() 74 | 75 | def build_model_and_processor(args): 76 | 77 | tokenizer = LlamaTokenizer.from_pretrained(args.language_model) 78 | add_tokens = ["", "", ""] 79 | tokenizer.add_special_tokens(({"additional_special_tokens": add_tokens})) 80 | args.image_token_id = tokenizer.convert_tokens_to_ids("") 81 | args.caption_start_id = tokenizer.convert_tokens_to_ids("") 82 | args.caption_end_id = tokenizer.convert_tokens_to_ids("") 83 | args.num_new_tokens = len(add_tokens) 84 | image_processor = Blip2Processor.from_pretrained(args.processor) 85 | model = MIMModel(args) 86 | model.language_model.resize_token_embeddings(len(tokenizer)) 87 | 88 | if args.checkpoint: 89 | state_dict = torch.load(os.path.join(args.checkpoint)) 90 | model.load_state_dict(state_dict, strict=False) 91 | print("Loaded checkpoint from %s" % args.checkpoint) 92 | print("Loaded modules: %s" % set([key.split(".")[0] for key in state_dict.keys()])) 93 | 94 | return model, tokenizer, image_processor 95 | 96 | 97 | def smart_tokenizer_and_embedding_resize( 98 | additional_tokens: List[str], 99 | tokenizer: transformers.PreTrainedTokenizer, 100 | model: transformers.PreTrainedModel, 101 | ): 102 | """Resize tokenizer and embedding. 103 | 104 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 105 | """ 106 | num_new_tokens = tokenizer.add_tokens(additional_tokens) 107 | model.resize_token_embeddings(len(tokenizer)) 108 | 109 | if num_new_tokens > 0: 110 | input_embeddings = model.get_input_embeddings().weight.data 111 | output_embeddings = model.get_output_embeddings().weight.data 112 | 113 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 114 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 115 | 116 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 117 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 118 | 119 | 120 | def compute_clip_score( 121 | model: transformers.CLIPModel, 122 | processor: transformers.CLIPProcessor, 123 | image: Image, 124 | caption: str, 125 | ): 126 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 127 | processed_input = processor(text=[caption], images=[image.cpu()], return_tensors="pt", padding=True) 128 | 129 | img_features = model.get_image_features(processed_input["pixel_values"].to(device)) 130 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) 131 | 132 | txt_features = model.get_text_features( 133 | processed_input["input_ids"][:, :77].to(device), processed_input["attention_mask"][:, :77].to(device) 134 | ) 135 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) 136 | 137 | # cosine similarity between feature vectors 138 | score = (img_features * txt_features).sum(axis=-1).item() 139 | 140 | return score 141 | 142 | 143 | if __name__ == "__main__": 144 | 145 | import torch 146 | from transformers import Blip2ForConditionalGeneration 147 | model = Blip2ForConditionalGeneration.from_pretrained("../../CKPT/Salesforce/blip2-flan-t5-xxl") 148 | vision_model = model.vision_model 149 | vision_model.save_pretrained("checkpoint/blip2_vision_model") 150 | 151 | state_dict = model.state_dict() 152 | state_dict = {key: value for key, value in state_dict.items() if key.split(".")[0] in ["query_tokens", "qformer"]} 153 | torch.save(state_dict, "checkpoint/blip2_qformer.pt") 154 | 155 | --------------------------------------------------------------------------------