├── README.md ├── cdchat ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ └── batch_cdchat_vqa.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── cdchat_arch.py │ ├── consolidate.py │ ├── language_model │ │ └── cdchat_llama.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── train │ ├── cdchat_trainer.py │ ├── llama_flash_attn_monkey_patch.py │ ├── train.py │ └── train_mem.py └── utils.py ├── data_files ├── eval_questions_levir_test.json ├── eval_questions_sysu_test.json ├── levir_captions.json └── sysu_test_captions.json ├── images ├── cdchat_annotation.png ├── cdchat_arch.png ├── example_01.png └── example_02.png ├── pyproject.toml └── scripts ├── extract_mm_projector.py ├── finetune_lora.sh ├── merge_lora_weights.py ├── pretrain.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /README.md: -------------------------------------------------------------------------------- 1 | # CDChat: A Large Multimodal Model for Remote Sensing Change Description 2 | 3 | [[Arxiv]](https://arxiv.org/abs/2409.16261) 4 | 5 | 6 | - **Test Captions file of SYSU-CD is uploaded.** 7 | - **Evaluation test questions file of SYSU-CD is uploaded** 8 | - **Train/Val/Test Captions file of LEVIR-CD is uploaded** 9 | - **Evaluation test questions file of LEVIR-CD is uploaded** 10 | 11 | Intruction tuning files will be available soon !!! 12 | 13 | 14 | ### Overview 15 | CDChat is a conversational assistant for RS change description task. We annotate the SYSU-CD dataset to obtain the change text and image pairs for instruction tuning of CDChat. We create change text and image pairs from the two large scale change detection datasets including [SYSU-CD](https://github.com/liumency/SYSU-CD) and [LEVIR-CD](https://github.com/Chen-Yang-Liu/LEVIR-CC-Dataset). 16 | 17 | image 18 | 19 | **Annotation Tool** 20 | 21 | A custom annotation tool was utilized to annotate the SYSY-CD dataset as shown below: 22 | 23 | image 24 | 25 | 26 | ### Installation 27 | - **Clone this repository and navigate to cdchat folder** 28 | ``` 29 | git clone https://github.com/techmn/cdchat.git 30 | cd cdchat 31 | ``` 32 | - **Install Package** 33 | ``` 34 | conda create -n cdchat python=3.10 -y 35 | conda activate cdchat 36 | pip install --upgrade pip # enable PEP 660 support 37 | pip install -e . 38 | ``` 39 | - **Install additional packages for training cases** 40 | ``` 41 | pip install ninja 42 | pip install flash-attn --no-build-isolation 43 | ``` 44 | 45 | ### Train 46 | **Feature Alignment** 47 | 48 | We use the pretrained projector from LLaVA-v1.5 similar to [GeoChat](https://github.com/mbzuai-oryx/GeoChat). 49 | 50 | `--mm_projector_type mlp2x_gelu:` the two-layer MLP vision-language connector 51 | 52 | `--vision_tower openai/clip-vit-large-patch14-336:` CLIP ViT-L/14 336px 53 | 54 | 55 | **Instruction Tuning** 56 | 57 | - Download the cdchat_instruct_file and the image pairs. Place the image pairs folders and cdchat_instruct_file in the same folder. 58 | - Update the `--data_path` and `--image_folder` in the file `finetune_lora.sh` 59 | - Start training !!!! 60 | 61 | **Note** Our Codebase is inspired from [GeoChat](https://github.com/mbzuai-oryx/GeoChat). Please refer to it for detailed instructions on installation and training. 62 | 63 | **Model Weights** 64 | will be available soon !!!! 65 | 66 | ### Evaluation 67 | We evaluate the CDChat on the test sets of LEVIR-CD and SYSU-CD datasets. Below is the command to evaluate the CDChat on the dataset: 68 | 69 | ``` 70 | python cdchat/eval/batch_cdchat_vqa.py \ 71 | --model-path /path/to/model \ 72 | --question-file path/to/json/file \ 73 | --answer-file path/to/output/jsonl/file \ 74 | --image-folder path/to/image/folder/ 75 | ``` 76 | 77 | Here are few examples of the responses from LMMs for change description task: 78 | 79 | image 80 | image 81 | 82 | ### Citation 83 | 84 | ``` 85 | @misc{cdchat_2024, 86 | title={CDChat: A Large Multimodal Model for Remote Sensing Change Description}, 87 | author={Mubashir Noman and Noor Ahsan and Muzammal Naseer and Hisham Cholakkal and Rao Muhammad Anwer and Salman Khan and Fahad Shahbaz Khan}, 88 | year={2024}, 89 | eprint={2409.16261}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.CV}, 92 | url={https://arxiv.org/abs/2409.16261}, 93 | } 94 | ``` 95 | ### Acknowledgements 96 | Our codebase is inspired from the [GeoChat](https://github.com/mbzuai-oryx/GeoChat) repository. We thank them for releasing their valuable codebase. 97 | -------------------------------------------------------------------------------- /cdchat/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CDChatLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /cdchat/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /cdchat/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | from PIL import Image 5 | from threading import Thread 6 | 7 | from cdchat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 8 | from cdchat.utils import disable_torch_init 9 | from cdchat.mm_utils import process_images_demo, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 10 | from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer,TextStreamer 11 | import torch 12 | import dataclasses 13 | from enum import auto, Enum 14 | from typing import List, Tuple, Any 15 | 16 | 17 | class SeparatorStyle(Enum): 18 | """Different separator style.""" 19 | SINGLE = auto() 20 | TWO = auto() 21 | MPT = auto() 22 | PLAIN = auto() 23 | LLAMA_2 = auto() 24 | 25 | 26 | @dataclasses.dataclass 27 | class Conversation: 28 | """A class that keeps all conversation history.""" 29 | system: str 30 | roles: List[str] 31 | messages: List[List[str]] 32 | offset: int 33 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 34 | sep: str = "###" 35 | sep2: str = None 36 | version: str = "Unknown" 37 | 38 | skip_next: bool = False 39 | 40 | def get_prompt(self): 41 | messages = self.messages 42 | if len(messages) > 0 and type(messages[0][1]) is tuple: 43 | messages = self.messages.copy() 44 | init_role, init_msg = messages[0].copy() 45 | init_msg = init_msg[0].replace("", "").strip() 46 | if 'mmtag' in self.version: 47 | messages[0] = (init_role, init_msg) 48 | messages.insert(0, (self.roles[0], "")) 49 | messages.insert(1, (self.roles[1], "Received.")) 50 | else: 51 | messages[0] = (init_role, "\n" + init_msg) 52 | 53 | if self.sep_style == SeparatorStyle.SINGLE: 54 | ret = self.system + self.sep 55 | for role, message in messages: 56 | if message: 57 | if type(message) is tuple: 58 | message, _, _ = message 59 | ret += role + ": " + message + self.sep 60 | else: 61 | ret += role + ":" 62 | elif self.sep_style == SeparatorStyle.TWO: 63 | seps = [self.sep, self.sep2] 64 | ret = self.system + seps[0] 65 | for i, (role, message) in enumerate(messages): 66 | if message: 67 | if type(message) is tuple: 68 | message, _, _ = message 69 | ret += role + ": " + message + seps[i % 2] 70 | else: 71 | ret += role + ":" 72 | elif self.sep_style == SeparatorStyle.MPT: 73 | ret = self.system + self.sep 74 | for role, message in messages: 75 | if message: 76 | if type(message) is tuple: 77 | message, _, _ = message 78 | ret += role + message + self.sep 79 | else: 80 | ret += role 81 | elif self.sep_style == SeparatorStyle.LLAMA_2: 82 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" 83 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 84 | ret = "" 85 | 86 | for i, (role, message) in enumerate(messages): 87 | if i == 0: 88 | assert message, "first message should not be none" 89 | assert role == self.roles[0], "first message should come from user" 90 | if message: 91 | if type(message) is tuple: 92 | message, _, _ = message 93 | if i == 0: message = wrap_sys(self.system) + message 94 | if i % 2 == 0: 95 | message = wrap_inst(message) 96 | ret += self.sep + message 97 | else: 98 | ret += " " + message + " " + self.sep2 99 | else: 100 | ret += "" 101 | ret = ret.lstrip(self.sep) 102 | elif self.sep_style == SeparatorStyle.PLAIN: 103 | seps = [self.sep, self.sep2] 104 | ret = self.system 105 | for i, (role, message) in enumerate(messages): 106 | if message: 107 | if type(message) is tuple: 108 | message, _, _ = message 109 | ret += message + seps[i % 2] 110 | else: 111 | ret += "" 112 | else: 113 | raise ValueError(f"Invalid style: {self.sep_style}") 114 | 115 | return ret 116 | 117 | def append_message(self, role, message): 118 | self.messages.append([role, message]) 119 | 120 | def get_images(self, return_pil=False): 121 | images = [] 122 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 123 | if i % 2 == 0: 124 | if type(msg) is tuple: 125 | import base64 126 | from io import BytesIO 127 | from PIL import Image 128 | msg, image, image_process_mode = msg 129 | if image_process_mode == "Pad": 130 | def expand2square(pil_img, background_color=(122, 116, 104)): 131 | width, height = pil_img.size 132 | if width == height: 133 | return pil_img 134 | elif width > height: 135 | result = Image.new(pil_img.mode, (width, width), background_color) 136 | result.paste(pil_img, (0, (width - height) // 2)) 137 | return result 138 | else: 139 | result = Image.new(pil_img.mode, (height, height), background_color) 140 | result.paste(pil_img, ((height - width) // 2, 0)) 141 | return result 142 | image = expand2square(image) 143 | elif image_process_mode in ["Default", "Crop"]: 144 | pass 145 | elif image_process_mode == "Resize": 146 | image = image.resize((448, 448)) 147 | else: 148 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 149 | max_hw, min_hw = max(image.size), min(image.size) 150 | aspect_ratio = max_hw / min_hw 151 | max_len, min_len = 800, 448 152 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 153 | longest_edge = int(shortest_edge * aspect_ratio) 154 | W, H = image.size 155 | if longest_edge != max(image.size): 156 | if H > W: 157 | H, W = longest_edge, shortest_edge 158 | else: 159 | H, W = shortest_edge, longest_edge 160 | image = image.resize((W, H)) 161 | if return_pil: 162 | images.append(image) 163 | else: 164 | buffered = BytesIO() 165 | image.save(buffered, format="PNG") 166 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 167 | images.append(img_b64_str) 168 | return images 169 | 170 | def to_gradio_chatbot(self): 171 | ret = [] 172 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 173 | if i % 2 == 0: 174 | if type(msg) is tuple: 175 | import base64 176 | from io import BytesIO 177 | msg, image, image_process_mode = msg 178 | max_hw, min_hw = max(image.size), min(image.size) 179 | aspect_ratio = max_hw / min_hw 180 | max_len, min_len = 800, 448 181 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 182 | longest_edge = int(shortest_edge * aspect_ratio) 183 | W, H = image.size 184 | if H > W: 185 | H, W = longest_edge, shortest_edge 186 | else: 187 | H, W = shortest_edge, longest_edge 188 | image = image.resize((W, H)) 189 | buffered = BytesIO() 190 | image.save(buffered, format="JPEG") 191 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 192 | img_str = f'user upload image' 193 | msg = img_str + msg.replace('', '').strip() 194 | ret.append([msg, None]) 195 | else: 196 | ret.append([msg, None]) 197 | else: 198 | ret[-1][-1] = msg 199 | return ret 200 | 201 | def copy(self): 202 | return Conversation( 203 | system=self.system, 204 | roles=self.roles, 205 | messages=[[x, y] for x, y in self.messages], 206 | offset=self.offset, 207 | sep_style=self.sep_style, 208 | sep=self.sep, 209 | sep2=self.sep2, 210 | version=self.version) 211 | 212 | def dict(self): 213 | if len(self.get_images()) > 0: 214 | return { 215 | "system": self.system, 216 | "roles": self.roles, 217 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 218 | "offset": self.offset, 219 | "sep": self.sep, 220 | "sep2": self.sep2, 221 | } 222 | return { 223 | "system": self.system, 224 | "roles": self.roles, 225 | "messages": self.messages, 226 | "offset": self.offset, 227 | "sep": self.sep, 228 | "sep2": self.sep2, 229 | } 230 | 231 | 232 | conv_vicuna_v0 = Conversation( 233 | system="A chat between a curious human and an artificial intelligence assistant. " 234 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 235 | roles=("Human", "Assistant"), 236 | messages=( 237 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 238 | ("Assistant", 239 | "Renewable energy sources are those that can be replenished naturally in a relatively " 240 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 241 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 242 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 243 | "renewable and non-renewable energy sources:\n" 244 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 245 | "energy sources are finite and will eventually run out.\n" 246 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 247 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 248 | "and other negative effects.\n" 249 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 250 | "have lower operational costs than non-renewable sources.\n" 251 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 252 | "locations than non-renewable sources.\n" 253 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 254 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 255 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 256 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 257 | ), 258 | offset=2, 259 | sep_style=SeparatorStyle.SINGLE, 260 | sep="###", 261 | ) 262 | 263 | conv_vicuna_v1 = Conversation( 264 | system="A chat between a curious user and an artificial intelligence assistant. " 265 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 266 | roles=("USER", "ASSISTANT"), 267 | version="v1", 268 | messages=(), 269 | offset=0, 270 | sep_style=SeparatorStyle.TWO, 271 | sep=" ", 272 | sep2="", 273 | ) 274 | 275 | conv_llama_2 = Conversation( 276 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 277 | 278 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 279 | roles=("USER", "ASSISTANT"), 280 | version="llama_v2", 281 | messages=(), 282 | offset=0, 283 | sep_style=SeparatorStyle.LLAMA_2, 284 | sep="", 285 | sep2="", 286 | ) 287 | 288 | conv_llava_llama_2 = Conversation( 289 | system="You are a helpful language and vision assistant. " 290 | "You are able to understand the visual content that the user provides, " 291 | "and assist the user with a variety of tasks using natural language.", 292 | roles=("USER", "ASSISTANT"), 293 | version="llama_v2", 294 | messages=(), 295 | offset=0, 296 | sep_style=SeparatorStyle.LLAMA_2, 297 | sep="", 298 | sep2="", 299 | ) 300 | 301 | conv_mpt = Conversation( 302 | system="""<|im_start|>system 303 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 304 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 305 | version="mpt", 306 | messages=(), 307 | offset=0, 308 | sep_style=SeparatorStyle.MPT, 309 | sep="<|im_end|>", 310 | ) 311 | 312 | conv_llava_plain = Conversation( 313 | system="", 314 | roles=("", ""), 315 | messages=( 316 | ), 317 | offset=0, 318 | sep_style=SeparatorStyle.PLAIN, 319 | sep="\n", 320 | ) 321 | 322 | conv_llava_v0 = Conversation( 323 | system="A chat between a curious human and an artificial intelligence assistant. " 324 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 325 | roles=("Human", "Assistant"), 326 | messages=( 327 | ), 328 | offset=0, 329 | sep_style=SeparatorStyle.SINGLE, 330 | sep="###", 331 | ) 332 | 333 | conv_llava_v0_mmtag = Conversation( 334 | system="A chat between a curious user and an artificial intelligence assistant. " 335 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 336 | "The visual content will be provided with the following format: visual content.", 337 | roles=("Human", "Assistant"), 338 | messages=( 339 | ), 340 | offset=0, 341 | sep_style=SeparatorStyle.SINGLE, 342 | sep="###", 343 | version="v0_mmtag", 344 | ) 345 | 346 | conv_llava_v1 = Conversation( 347 | system="A chat between a curious human and an artificial intelligence assistant. " 348 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 349 | roles=("USER", "ASSISTANT"), 350 | version="v1", 351 | messages=(), 352 | offset=0, 353 | sep_style=SeparatorStyle.TWO, 354 | sep=" ", 355 | sep2="", 356 | ) 357 | 358 | conv_llava_v1_mmtag = Conversation( 359 | system="A chat between a curious user and an artificial intelligence assistant. " 360 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 361 | "The visual content will be provided with the following format: visual content.", 362 | roles=("USER", "ASSISTANT"), 363 | messages=(), 364 | offset=0, 365 | sep_style=SeparatorStyle.TWO, 366 | sep=" ", 367 | sep2="", 368 | version="v1_mmtag", 369 | ) 370 | 371 | default_conversation = conv_vicuna_v0 372 | conv_templates = { 373 | "default": conv_vicuna_v0, 374 | "v0": conv_vicuna_v0, 375 | "v1": conv_vicuna_v1, 376 | "vicuna_v1": conv_vicuna_v1, 377 | "llama_2": conv_llama_2, 378 | 379 | "plain": conv_llava_plain, 380 | "v0_plain": conv_llava_plain, 381 | "llava_v0": conv_llava_v0, 382 | "v0_mmtag": conv_llava_v0_mmtag, 383 | "llava_v1": conv_llava_v1, 384 | "v1_mmtag": conv_llava_v1_mmtag, 385 | "llava_llama_2": conv_llava_llama_2, 386 | 387 | "mpt": conv_mpt, 388 | } 389 | 390 | class Chat: 391 | def __init__(self, model, image_processor,tokenizer, device='cuda:0', stopping_criteria=None): 392 | self.device = device 393 | self.model = model 394 | self.vis_processor = image_processor 395 | self.tokenizer=tokenizer 396 | 397 | # if stopping_criteria is not None: 398 | # self.stopping_criteria = stopping_criteria 399 | # else: 400 | # stop_words_ids = [torch.tensor([2]).to(self.device)] 401 | # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 402 | 403 | def ask(self, text, conv): 404 | # import pdb;pdb.set_trace() 405 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ 406 | and conv.messages[-1][1][-9:] == '\n': # last message is image. 407 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) 408 | else: 409 | conv.append_message(conv.roles[0], text) 410 | 411 | def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, 412 | repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): 413 | conv.append_message(conv.roles[1], None) 414 | prompt = conv.get_prompt() 415 | # prompt='A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \n hello ASSISTANT:' 416 | text_input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device=self.device) 417 | 418 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 419 | keywords = [stop_str] 420 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, text_input_ids) 421 | current_max_len = text_input_ids.shape[1] + max_new_tokens 422 | if current_max_len - max_length > 0: 423 | print('Warning: The number of tokens in current conversation exceeds the max length. ' 424 | 'The model will not see the contexts outside the range.') 425 | begin_idx = max(0, current_max_len - max_length) 426 | embs = text_input_ids[:, begin_idx:] 427 | 428 | generation_kwargs = dict( 429 | input_ids=embs, 430 | images=img_list[0], 431 | max_new_tokens=max_new_tokens, 432 | stopping_criteria=[stopping_criteria], 433 | num_beams=num_beams, 434 | do_sample=True, 435 | min_length=min_length, 436 | top_p=top_p, 437 | use_cache=True, 438 | repetition_penalty=repetition_penalty, 439 | length_penalty=length_penalty, 440 | temperature=float(temperature), 441 | ) 442 | return generation_kwargs 443 | 444 | # def answer(self, conv, img_list, **kargs): 445 | # generation_dict = self.answer_prepare(conv, img_list, **kargs) 446 | # output_token = self.model_generate(**generation_dict)[0] 447 | # output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) 448 | 449 | # output_text = output_text.split('###')[0] # remove the stop sign '###' 450 | # output_text = output_text.split('Assistant:')[-1].strip() 451 | 452 | # conv.messages[-1][1] = output_text 453 | # return output_text, output_token.cpu().numpy() 454 | 455 | def stream_answer(self, conv, img_list, **kargs): 456 | generation_kwargs = self.answer_prepare(conv, img_list, **kargs) 457 | 458 | streamer = TextIteratorStreamer(self.tokenizer,skip_prompt=True, skip_special_tokens=True) 459 | generation_kwargs['streamer'] = streamer 460 | # import pdb;pdb.set_trace() 461 | # output_ids=self.model.generate(*generation_kwargs) 462 | output=self.model_generate(kwargs=generation_kwargs) 463 | # thread = Thread(target=self.model_generate, kwargs=generation_kwargs) 464 | # thread.start() 465 | return streamer 466 | 467 | def model_generate(self, *args, **kwargs): 468 | # for 8 bit and 16 bit compatibility 469 | with torch.inference_mode(): 470 | output = self.model.generate(kwargs['kwargs']['input_ids'], 471 | images=kwargs['kwargs']['images'], 472 | do_sample=False, 473 | temperature=kwargs['kwargs']['temperature'], 474 | max_new_tokens=kwargs['kwargs']['max_new_tokens'], 475 | streamer=kwargs['kwargs']['streamer'], 476 | use_cache=kwargs['kwargs']['use_cache'], 477 | stopping_criteria=kwargs['kwargs']['stopping_criteria']) 478 | # import pdb;pdb.set_trace() 479 | # print(output) 480 | outputs = self.tokenizer.decode(output[0,kwargs['kwargs']['input_ids'].shape[1]:]).strip() 481 | # print(outputs) 482 | return output 483 | 484 | def encode_img(self, img_list): 485 | 486 | image = img_list[0] 487 | img_list.pop(0) 488 | if isinstance(image, str): # is a image path 489 | raw_image = Image.open(image).convert('RGB') 490 | image = process_images_demo([raw_image], self.vis_processor) 491 | # print("raw") 492 | # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 493 | elif isinstance(image, Image.Image): 494 | raw_image = image 495 | image = process_images_demo([raw_image], self.vis_processor ) 496 | image=image.to(device=self.device,dtype=torch.float16) 497 | # print("Image") 498 | # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 499 | elif isinstance(image, torch.Tensor): 500 | if len(image.shape) == 3: 501 | image = image.unsqueeze(0) 502 | image = image.to(self.device) 503 | 504 | # image_emb, _ = self.model.encode_img(image) 505 | img_list.append(image) 506 | 507 | def upload_img(self, image, conv, img_list): 508 | conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN+'\n') 509 | img_list.append(image) 510 | msg = "Received." 511 | 512 | return msg 513 | 514 | 515 | 516 | # if __name__ == "__main__": 517 | # print(default_conversation.get_prompt()) 518 | -------------------------------------------------------------------------------- /cdchat/eval/batch_cdchat_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from cdchat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from cdchat.conversation import conv_templates, SeparatorStyle 10 | from cdchat.model.builder import load_pretrained_model 11 | from cdchat.utils import disable_torch_init 12 | from cdchat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | import numpy as np 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | def eval_model(args): 31 | # Model 32 | disable_torch_init() 33 | model_path = os.path.expanduser(args.model_path) 34 | model_name = get_model_name_from_path(model_path) 35 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, mm_projector_path=args.mm_projector_path) 36 | 37 | #questions=[] 38 | #questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 39 | fid = open(args.question_file, 'r') 40 | questions = json.load(fid) 41 | fid.close() 42 | 43 | 44 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 45 | answers_file = os.path.expanduser(args.answers_file) 46 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 47 | 48 | ans_file = open(answers_file, "w") 49 | 50 | for i in tqdm(range(0,len(questions),args.batch_size)): 51 | input_batch=[] 52 | input_image_batch=[] 53 | count=i 54 | image_folder_A = [] 55 | image_folder_B = [] 56 | image_label_list = [] 57 | batch_end = min(i + args.batch_size, len(questions)) 58 | 59 | 60 | for j in range(i,batch_end): 61 | #image_file=questions[j]['image'] 62 | #qs=questions[j]['text'] 63 | image_file = questions[j]['img_id'] 64 | print(image_file) 65 | qs = questions[j]['question'] 66 | 67 | if model.config.mm_use_im_start_end: 68 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 69 | else: 70 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 71 | 72 | conv = conv_templates[args.conv_mode].copy() 73 | conv.append_message(conv.roles[0], qs) 74 | conv.append_message(conv.roles[1], None) 75 | prompt = conv.get_prompt() 76 | 77 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 78 | input_batch.append(input_ids) 79 | 80 | #image = Image.open(os.path.join(args.image_folder, image_file)) 81 | #image_folder.append(image) 82 | image_A = Image.open(os.path.join(args.image_folder, 'A', image_file)).convert('RGB') 83 | image_B = Image.open(os.path.join(args.image_folder, 'B', image_file)).convert('RGB') 84 | image_folder_label = os.path.join(args.image_folder, 'label') 85 | image_label = Image.open((os.path.join(image_folder_label, image_file)).strip()).convert('L') 86 | image_label = (torch.Tensor(1)*np.array(image_label)/255.).unsqueeze(0) 87 | 88 | image_tensor_A = image_processor.preprocess(image_A,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 89 | image_tensor_B = image_processor.preprocess(image_B,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 90 | 91 | image_tensor_A = image_tensor_A[[2,1,0], :, :] 92 | image_tensor_B = image_tensor_B[[2,1,0], :, :] 93 | 94 | image_folder_A.append(image_tensor_A) 95 | image_folder_B.append(image_tensor_B) 96 | image_label_list.append(image_label) 97 | 98 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 99 | keywords = [stop_str] 100 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 101 | 102 | max_length = max(tensor.size(1) for tensor in input_batch) 103 | 104 | final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch] 105 | final_input_tensors=torch.cat(final_input_list,dim=0) 106 | #image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values'] 107 | image_tensor_batch = {'pre': torch.stack(image_folder_A).half().cuda(), 108 | 'post': torch.stack(image_folder_B).half().cuda(), 109 | 'targets': torch.stack(image_label_list).cuda()} 110 | 111 | with torch.inference_mode(): 112 | output_ids = model.generate( final_input_tensors, images=image_tensor_batch, do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True) 113 | 114 | input_token_len = final_input_tensors.shape[1] 115 | n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item() 116 | if n_diff_input_output > 0: 117 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 118 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True) 119 | for k in range(0,len(final_input_list)): 120 | output = outputs[k].strip() 121 | if output.endswith(stop_str): 122 | output = output[:-len(stop_str)] 123 | output = output.strip() 124 | 125 | ans_id = shortuuid.uuid() 126 | 127 | ans_file.write(json.dumps({ 128 | "question_id": questions[count]["question"], 129 | "image_id": questions[count]["img_id"], 130 | "answer": output, 131 | }) + "\n") 132 | count=count+1 133 | ans_file.flush() 134 | ans_file.close() 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument("--model-path", type=str, default="./cdchat/cdchat/checkpoints/cdchat_lora") 140 | parser.add_argument("--model-base", type=str, default='./cdchat/llava-v1.5-7b') 141 | parser.add_argument("--mm-projector-path", type=str, default='./cdchat/cdchat/checkpoints/pretrain_mm_projector/mm_projector.bin') 142 | parser.add_argument("--image-folder", type=str, default="./dataset/cd_dataset/") 143 | parser.add_argument("--question-file", type=str, default="./dataset/LEVIR-CD-256/questions_levir_test.json") 144 | parser.add_argument("--answers-file", type=str, default="./answer_levir.jsonl") 145 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 146 | parser.add_argument("--num-chunks", type=int, default=1) 147 | parser.add_argument("--chunk-idx", type=int, default=0) 148 | parser.add_argument("--temperature", type=float, default=0.2) 149 | parser.add_argument("--top_p", type=float, default=None) 150 | parser.add_argument("--num_beams", type=int, default=1) 151 | parser.add_argument("--batch_size",type=int, default=1) 152 | args = parser.parse_args() 153 | 154 | eval_model(args) 155 | -------------------------------------------------------------------------------- /cdchat/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from cdchat.constants import IMAGE_TOKEN_INDEX 8 | import numpy as np 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448},return_tensors='pt')['pixel_values'][0] 35 | # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0] 36 | 37 | new_images.append(image) 38 | else: 39 | return image_processor(images, return_tensors='pt')['pixel_values'] 40 | if all(x.shape == new_images[0].shape for x in new_images): 41 | new_images = torch.stack(new_images, dim=0) 42 | return new_images 43 | 44 | def process_images_demo(images, image_processor): 45 | new_images = [] 46 | # image_aspect_ratio = 'pad' 47 | for image in images: 48 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 49 | image = image_processor.preprocess(image,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448},return_tensors='pt')['pixel_values'][0] 50 | # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0] 51 | 52 | new_images.append(image) 53 | 54 | if all(x.shape == new_images[0].shape for x in new_images): 55 | new_images = torch.stack(new_images, dim=0) 56 | return new_images 57 | 58 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 59 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 60 | 61 | def insert_separator(X, sep): 62 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 63 | 64 | input_ids = [] 65 | offset = 0 66 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 67 | offset = 1 68 | input_ids.append(prompt_chunks[0][0]) 69 | 70 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 71 | input_ids.extend(x[offset:]) 72 | 73 | if return_tensors is not None: 74 | if return_tensors == 'pt': 75 | return torch.tensor(input_ids, dtype=torch.long) 76 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 77 | return input_ids 78 | 79 | 80 | def get_model_name_from_path(model_path): 81 | model_path = model_path.strip("/") 82 | model_paths = model_path.split("/") 83 | if model_paths[-1].startswith('checkpoint-'): 84 | return model_paths[-2] + "_" + model_paths[-1] 85 | else: 86 | return model_paths[-1] 87 | 88 | 89 | 90 | 91 | class KeywordsStoppingCriteria(StoppingCriteria): 92 | def __init__(self, keywords, tokenizer, input_ids): 93 | self.keywords = keywords 94 | self.keyword_ids = [] 95 | self.max_keyword_len = 0 96 | for keyword in keywords: 97 | cur_keyword_ids = tokenizer(keyword).input_ids 98 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 99 | cur_keyword_ids = cur_keyword_ids[1:] 100 | if len(cur_keyword_ids) > self.max_keyword_len: 101 | self.max_keyword_len = len(cur_keyword_ids) 102 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 103 | self.tokenizer = tokenizer 104 | self.start_len = input_ids.shape[1] 105 | 106 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 107 | # assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 108 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 109 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 110 | for keyword_id in self.keyword_ids: 111 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 112 | return True 113 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 114 | flag=False 115 | for output in outputs: 116 | 117 | for keyword in self.keywords: 118 | if keyword in output: 119 | flag=True 120 | return flag 121 | return flag -------------------------------------------------------------------------------- /cdchat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.cdchat_llama import CDChatLlamaForCausalLM, CDChatConfig 2 | -------------------------------------------------------------------------------- /cdchat/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from cdchat import CDChatLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /cdchat/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from cdchat.model import * 23 | from cdchat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | from cdchat.model.multimodal_projector.builder import build_vision_projector 25 | 26 | 27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", mm_projector_path=None): 28 | kwargs = {"device_map": device_map} 29 | 30 | if load_8bit: 31 | kwargs['load_in_8bit'] = True 32 | elif load_4bit: 33 | kwargs['load_in_4bit'] = True 34 | kwargs['quantization_config'] = BitsAndBytesConfig( 35 | load_in_4bit=True, 36 | bnb_4bit_compute_dtype=torch.float16, 37 | bnb_4bit_use_double_quant=True, 38 | bnb_4bit_quant_type='nf4' 39 | ) 40 | else: 41 | kwargs['torch_dtype'] = torch.float16 42 | 43 | if 'cdchat' in model_name.lower(): 44 | # Load LLaVA model 45 | if 'lora' in model_name.lower() and model_base is None: 46 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 47 | if 'lora' in model_name.lower() and model_base is not None: 48 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 49 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 50 | print('Loading CDChat from base model...') 51 | model = CDChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 52 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 53 | if model.lm_head.weight.shape[0] != token_num: 54 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 55 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 56 | 57 | print('Loading additional CDChat weights...') 58 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 59 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 60 | else: 61 | # this is probably from HF Hub 62 | from huggingface_hub import hf_hub_download 63 | def load_from_hf(repo_id, filename, subfolder=None): 64 | cache_file = hf_hub_download( 65 | repo_id=repo_id, 66 | filename=filename, 67 | subfolder=subfolder) 68 | return torch.load(cache_file, map_location='cpu') 69 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 70 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 71 | if any(k.startswith('model.model.') for k in non_lora_trainables): 72 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 73 | model.load_state_dict(non_lora_trainables, strict=False) 74 | 75 | 76 | from peft import PeftModel 77 | print('Loading LoRA weights...') 78 | model = PeftModel.from_pretrained(model, model_path) 79 | print('Merging LoRA weights...') 80 | model = model.merge_and_unload() 81 | print('Model is loaded...') 82 | elif model_base is not None: 83 | # this may be mm projector only 84 | print('Loading CDChat from base model...') 85 | 86 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 87 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 88 | model = CDChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 89 | 90 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 91 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 92 | model.load_state_dict(mm_projector_weights, strict=False) 93 | else: 94 | print("Loading CDChat......") 95 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 96 | model = CDChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 97 | else: 98 | # Load language model 99 | if model_base is not None: 100 | # PEFT model 101 | from peft import PeftModel 102 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 103 | model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") 104 | print(f"Loading LoRA weights from {model_path}") 105 | model = PeftModel.from_pretrained(model, model_path) 106 | print(f"Merging weights") 107 | model = model.merge_and_unload() 108 | print('Convert to FP16...') 109 | model.to(torch.float16) 110 | else: 111 | use_fast = False 112 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 113 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 114 | 115 | image_processor = None 116 | 117 | if 'cdchat' in model_name.lower(): 118 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 119 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 120 | if mm_use_im_patch_token: 121 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 122 | if mm_use_im_start_end: 123 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 124 | model.resize_token_embeddings(len(tokenizer)) 125 | 126 | 127 | vision_tower = model.get_vision_tower() 128 | if not vision_tower.is_loaded: 129 | vision_tower.load_model() 130 | vision_tower.to(device=device, dtype=torch.float16) 131 | image_processor = vision_tower.image_processor 132 | 133 | #################################### 134 | model.model.mm_projector = build_vision_projector(model.model.config) 135 | if mm_projector_path is not None: 136 | mm_projector_weights = torch.load(mm_projector_path, map_location='cpu') 137 | 138 | def get_w(weights, keyword): 139 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 140 | 141 | msg = model.model.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False) 142 | print(msg) 143 | 144 | model.model.mm_projector.to(device=device, dtype=torch.float16) 145 | #################################################################### 146 | 147 | if hasattr(model.config, "max_sequence_length"): 148 | context_len = model.config.max_sequence_length 149 | else: 150 | context_len = 2048 151 | 152 | return tokenizer, model, image_processor, context_len 153 | -------------------------------------------------------------------------------- /cdchat/model/cdchat_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from cdchat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | 27 | class CDChatMetaModel: 28 | 29 | def __init__(self, config): 30 | super(CDChatMetaModel, self).__init__(config) 31 | 32 | if hasattr(config, "mm_vision_tower"): 33 | self.vision_tower = build_vision_tower(config, delay_load=True) 34 | self.mm_projector = build_vision_projector(config) 35 | 36 | def get_vision_tower(self): 37 | vision_tower = getattr(self, 'vision_tower', None) 38 | if type(vision_tower) is list: 39 | vision_tower = vision_tower[0] 40 | return vision_tower 41 | 42 | def initialize_vision_modules(self, model_args, fsdp=None): 43 | vision_tower = model_args.vision_tower 44 | mm_vision_select_layer = model_args.mm_vision_select_layer 45 | mm_vision_select_feature = model_args.mm_vision_select_feature 46 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 47 | 48 | self.config.mm_vision_tower = vision_tower 49 | 50 | if self.get_vision_tower() is None: 51 | vision_tower = build_vision_tower(model_args) 52 | 53 | if fsdp is not None and len(fsdp) > 0: 54 | self.vision_tower = [vision_tower] 55 | else: 56 | self.vision_tower = vision_tower 57 | else: 58 | if fsdp is not None and len(fsdp) > 0: 59 | vision_tower = self.vision_tower[0] 60 | else: 61 | vision_tower = self.vision_tower 62 | vision_tower.load_model() 63 | 64 | self.config.use_mm_proj = True 65 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 66 | self.config.mm_hidden_size = vision_tower.embed_dim #vision_tower.hidden_size 67 | self.config.mm_vision_select_layer = mm_vision_select_layer 68 | self.config.mm_vision_select_feature = mm_vision_select_feature 69 | 70 | if getattr(self, 'mm_projector', None) is None: 71 | self.mm_projector = build_vision_projector(self.config) 72 | # print(mm_projector) 73 | 74 | 75 | if pretrain_mm_mlp_adapter is not None: 76 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 77 | 78 | def get_w(weights, keyword): 79 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 80 | 81 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 82 | 83 | 84 | 85 | 86 | class CDChatMetaForCausalLM(ABC): 87 | 88 | @abstractmethod 89 | def get_model(self): 90 | pass 91 | 92 | def get_vision_tower(self): 93 | return self.get_model().get_vision_tower() 94 | 95 | def encode_images(self, images): 96 | #image_features = self.get_model().get_vision_tower()(images) 97 | #image_features = self.get_model().mm_projector(image_features) 98 | 99 | image_features_A = self.get_model().get_vision_tower()(images['pre']) 100 | image_features_B = self.get_model().get_vision_tower()(images['post']) 101 | 102 | image_features = torch.cat([image_features_A, image_features_B], dim=-1) 103 | proj_features = self.get_model().mm_projector(image_features) 104 | 105 | return image_features, proj_features 106 | 107 | def prepare_inputs_labels_for_multimodal( 108 | self, input_ids, attention_mask, past_key_values, labels, images 109 | ): 110 | vision_tower = self.get_vision_tower() 111 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 112 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: 113 | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) 114 | return input_ids, attention_mask, past_key_values, None, labels 115 | 116 | ''' 117 | if type(images) is list or images.ndim == 5: 118 | concat_images = torch.cat([image for image in images], dim=0) 119 | 120 | #image_features = self.encode_images(concat_images) 121 | diff_features, image_features = self.encode_images(images) 122 | 123 | split_sizes = [image.shape[0] for image in images] 124 | image_features = torch.split(image_features, split_sizes, dim=0) 125 | image_features = [x.flatten(0, 1) for x in image_features] 126 | else: 127 | image_features = self.encode_images(images) 128 | ''' 129 | diff_features, image_features = self.encode_images(images) 130 | 131 | new_input_embeds = [] 132 | new_labels = [] if labels is not None else None 133 | cur_image_idx = 0 134 | for batch_idx, cur_input_ids in enumerate(input_ids): 135 | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: 136 | # multimodal LLM, but the current sample is not multimodal 137 | # FIXME: this is a hacky fix, for deepspeed zero3 to work 138 | half_len = cur_input_ids.shape[0] // 2 139 | cur_image_features = image_features[cur_image_idx] 140 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) 141 | cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) 142 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) 143 | new_input_embeds.append(cur_input_embeds) 144 | if labels is not None: 145 | new_labels.append(labels[batch_idx]) 146 | cur_image_idx += 1 147 | continue 148 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 149 | cur_new_input_embeds = [] 150 | if labels is not None: 151 | cur_labels = labels[batch_idx] 152 | cur_new_labels = [] 153 | assert cur_labels.shape == cur_input_ids.shape 154 | while image_token_indices.numel() > 0: 155 | cur_image_features = image_features[cur_image_idx] 156 | image_token_start = image_token_indices[0] 157 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 158 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) 159 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) 160 | cur_new_input_embeds.append(cur_image_features) 161 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) 162 | if labels is not None: 163 | cur_new_labels.append(cur_labels[:image_token_start]) 164 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) 165 | cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) 166 | cur_labels = cur_labels[image_token_start+2:] 167 | else: 168 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) 169 | cur_new_input_embeds.append(cur_image_features) 170 | if labels is not None: 171 | cur_new_labels.append(cur_labels[:image_token_start]) 172 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) 173 | cur_labels = cur_labels[image_token_start+1:] 174 | cur_image_idx += 1 175 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 176 | cur_input_ids = cur_input_ids[image_token_start+2:] 177 | else: 178 | cur_input_ids = cur_input_ids[image_token_start+1:] 179 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 180 | if cur_input_ids.numel() > 0: 181 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 182 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) 183 | else: 184 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) 185 | if labels is not None: 186 | cur_new_labels.append(cur_labels) 187 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] 188 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) 189 | new_input_embeds.append(cur_new_input_embeds) 190 | if labels is not None: 191 | cur_new_labels = torch.cat(cur_new_labels, dim=0) 192 | new_labels.append(cur_new_labels) 193 | 194 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): 195 | max_len = max(x.shape[0] for x in new_input_embeds) 196 | 197 | new_input_embeds_align = [] 198 | for cur_new_embed in new_input_embeds: 199 | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) 200 | new_input_embeds_align.append(cur_new_embed) 201 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0) 202 | 203 | if labels is not None: 204 | new_labels_align = [] 205 | _new_labels = new_labels 206 | for cur_new_label in new_labels: 207 | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) 208 | new_labels_align.append(cur_new_label) 209 | new_labels = torch.stack(new_labels_align, dim=0) 210 | 211 | if attention_mask is not None: 212 | new_attention_mask = [] 213 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): 214 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) 215 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) 216 | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) 217 | new_attention_mask.append(cur_new_attention_mask) 218 | attention_mask = torch.stack(new_attention_mask, dim=0) 219 | assert attention_mask.shape == new_labels.shape 220 | else: 221 | new_input_embeds = torch.stack(new_input_embeds, dim=0) 222 | if labels is not None: 223 | new_labels = torch.stack(new_labels, dim=0) 224 | 225 | if attention_mask is not None: 226 | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) 227 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) 228 | assert attention_mask.shape == new_input_embeds.shape[:2] 229 | 230 | return None, attention_mask, past_key_values, new_input_embeds, new_labels 231 | 232 | def initialize_vision_tokenizer(self, model_args, tokenizer): 233 | if model_args.mm_use_im_patch_token: 234 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 235 | self.resize_token_embeddings(len(tokenizer)) 236 | 237 | if model_args.mm_use_im_start_end: 238 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 239 | self.resize_token_embeddings(len(tokenizer)) 240 | 241 | if num_new_tokens > 0: 242 | input_embeddings = self.get_input_embeddings().weight.data 243 | output_embeddings = self.get_output_embeddings().weight.data 244 | 245 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 246 | dim=0, keepdim=True) 247 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 248 | dim=0, keepdim=True) 249 | 250 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 251 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 252 | 253 | if model_args.tune_mm_mlp_adapter: 254 | for p in self.get_input_embeddings().parameters(): 255 | p.requires_grad = True 256 | for p in self.get_output_embeddings().parameters(): 257 | p.requires_grad = False 258 | 259 | if model_args.pretrain_mm_mlp_adapter: 260 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 261 | print(mm_projector_weights) 262 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 263 | assert num_new_tokens == 2 264 | if input_embeddings.shape == embed_tokens_weight.shape: 265 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 266 | elif embed_tokens_weight.shape[0] == num_new_tokens: 267 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 268 | else: 269 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 270 | elif model_args.mm_use_im_patch_token: 271 | if model_args.tune_mm_mlp_adapter: 272 | for p in self.get_input_embeddings().parameters(): 273 | p.requires_grad = False 274 | for p in self.get_output_embeddings().parameters(): 275 | p.requires_grad = False 276 | -------------------------------------------------------------------------------- /cdchat/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from cdchat.model import * 10 | from cdchat.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /cdchat/model/language_model/cdchat_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..cdchat_arch import CDChatMetaModel, CDChatMetaForCausalLM 28 | 29 | 30 | class CDChatConfig(LlamaConfig): 31 | model_type = "cdchat" 32 | 33 | 34 | class CDChatLlamaModel(CDChatMetaModel, LlamaModel): 35 | config_class = CDChatConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(CDChatLlamaModel, self).__init__(config) 39 | 40 | 41 | class CDChatLlamaForCausalLM(LlamaForCausalLM, CDChatMetaForCausalLM): 42 | config_class = CDChatConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = CDChatLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "images": kwargs.get("images", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("cdchat", CDChatConfig) 140 | AutoModelForCausalLM.register(CDChatConfig, CDChatLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /cdchat/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from cdchat.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /cdchat/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /cdchat/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from PIL import ImageFile 5 | ImageFile.LOAD_TRUNCATED_IMAGES = True 6 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 7 | 8 | 9 | class CLIPVisionTower(nn.Module): 10 | def clip_interpolate_embeddings(self, image_size=448, patch_size= 14): 11 | """This function helps interpolating positional embeddings during checkpoint loading, 12 | especially when you want to apply a pre-trained model on images with different resolution. 13 | 14 | Args: 15 | image_size (int): Image size of the new model. 16 | patch_size (int): Patch size of the new model. 17 | model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. 18 | interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. 19 | reset_heads (bool): If true, not copying the state of heads. Default: False. 20 | 21 | Returns: 22 | OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. 23 | """ 24 | # Shape of pos_embedding is (1, seq_length, hidden_dim) 25 | state_dict = self.vision_tower.vision_model.embeddings.position_embedding.state_dict() 26 | pos_embedding = state_dict['weight'] 27 | pos_embedding = pos_embedding.unsqueeze(0) 28 | n, seq_length, hidden_dim = pos_embedding.shape 29 | if n != 1: 30 | raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") 31 | 32 | new_seq_length = (image_size // patch_size) ** 2 + 1 33 | 34 | # Need to interpolate the weights for the position embedding. 35 | # We do this by reshaping the positions embeddings to a 2d grid, performing 36 | # an interpolation in the (h, w) space and then reshaping back to a 1d grid. 37 | if new_seq_length != seq_length: 38 | # The class token embedding shouldn't be interpolated so we split it up. 39 | seq_length -= 1 40 | new_seq_length -= 1 41 | pos_embedding_token = pos_embedding[:, :1, :] 42 | pos_embedding_img = pos_embedding[:, 1:, :] 43 | 44 | # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) 45 | pos_embedding_img = pos_embedding_img.permute(0, 2, 1) 46 | seq_length_1d = int(math.sqrt(seq_length)) 47 | torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") 48 | 49 | # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) 50 | pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) 51 | new_seq_length_1d = image_size // patch_size 52 | 53 | # Perform interpolation. 54 | # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) 55 | new_pos_embedding_img = nn.functional.interpolate( 56 | pos_embedding_img, 57 | size=new_seq_length_1d, 58 | mode='bicubic', 59 | align_corners=True, 60 | ) 61 | 62 | # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) 63 | new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) 64 | 65 | # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) 66 | new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) 67 | new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)[0] 68 | state_dict['weight'] = new_pos_embedding 69 | self.vision_tower.vision_model.embeddings.position_embedding = nn.Embedding(new_seq_length+1, hidden_dim) 70 | self.vision_tower.vision_model.embeddings.position_embedding.load_state_dict(state_dict) 71 | self.vision_tower.vision_model.embeddings.image_size = image_size 72 | self.vision_tower.vision_model.embeddings.patch_size = patch_size 73 | self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_seq_length+1).expand((1, -1)) 74 | 75 | def __init__(self, vision_tower, args, delay_load=False): 76 | super().__init__() 77 | 78 | self.is_loaded = False 79 | 80 | self.vision_tower_name = vision_tower 81 | self.select_layer = args.mm_vision_select_layer 82 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 83 | 84 | if not delay_load: 85 | self.load_model() 86 | else: 87 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 88 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 89 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 90 | self.vision_tower.requires_grad_(False) 91 | self.clip_interpolate_embeddings(image_size=448, patch_size=14) 92 | 93 | 94 | def load_model(self): 95 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 96 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 97 | self.vision_tower.requires_grad_(False) 98 | self.clip_interpolate_embeddings(image_size=448, patch_size=14) 99 | 100 | self.is_loaded = True 101 | # print(self.is_loaded) 102 | 103 | def feature_select(self, image_forward_outs): 104 | image_features = image_forward_outs.hidden_states[self.select_layer] 105 | if self.select_feature == 'patch': 106 | image_features = image_features[:, 1:] 107 | elif self.select_feature == 'cls_patch': 108 | image_features = image_features 109 | else: 110 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 111 | return image_features 112 | 113 | @torch.no_grad() 114 | def forward(self, images): 115 | if type(images) is list: 116 | image_features = [] 117 | for image in images: 118 | # print(image.shape) 119 | # import pdb; pdb.set_trace() 120 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 121 | 122 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 123 | # print(image_features.shape) 124 | 125 | image_features.append(image_feature) 126 | else: 127 | # print(images.shape) 128 | # import pdb; pdb.set_trace() 129 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 130 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 131 | # print(image_features.shape) 132 | 133 | 134 | return image_features 135 | 136 | @property 137 | def dummy_feature(self): 138 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 139 | 140 | @property 141 | def dtype(self): 142 | return self.vision_tower.dtype 143 | 144 | @property 145 | def device(self): 146 | return self.vision_tower.device 147 | 148 | @property 149 | def config(self): 150 | if self.is_loaded: 151 | return self.vision_tower.config 152 | else: 153 | return self.cfg_only 154 | 155 | @property 156 | def hidden_size(self): 157 | return self.config.hidden_size 158 | 159 | @property 160 | def num_patches(self): 161 | return (self.config.image_size // self.config.patch_size) ** 2 162 | -------------------------------------------------------------------------------- /cdchat/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | print(x) 31 | 32 | return x + self.proj(x) 33 | 34 | 35 | def build_vision_projector(config, delay_load=False, **kwargs): 36 | projector_type = getattr(config, 'mm_projector_type', 'linear') 37 | 38 | if projector_type == 'linear': 39 | return nn.Linear(config.mm_hidden_size*2, config.hidden_size) 40 | 41 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size*2, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | if projector_type == 'identity': 51 | return IdentityMap() 52 | 53 | raise ValueError(f'Unknown projector type: {projector_type}') 54 | -------------------------------------------------------------------------------- /cdchat/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /cdchat/train/cdchat_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.utils.data import Sampler 5 | 6 | from transformers import Trainer 7 | from transformers.trainer import ( 8 | has_length, 9 | ) 10 | from typing import List, Optional 11 | 12 | 13 | def maybe_zero_3(param, ignore_status=False, name=None): 14 | from deepspeed import zero 15 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 16 | if hasattr(param, "ds_id"): 17 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 18 | if not ignore_status: 19 | print(name, 'no ignore status') 20 | with zero.GatheredParameters([param]): 21 | param = param.data.detach().cpu().clone() 22 | else: 23 | param = param.detach().cpu().clone() 24 | return param 25 | 26 | 27 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 28 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 29 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 30 | return to_return 31 | 32 | 33 | def split_to_even_chunks(indices, lengths, num_chunks): 34 | """ 35 | Split a list of indices into `chunks` chunks of roughly equal lengths. 36 | """ 37 | 38 | if len(indices) % num_chunks != 0: 39 | return [indices[i::num_chunks] for i in range(num_chunks)] 40 | 41 | num_indices_per_chunk = len(indices) // num_chunks 42 | 43 | chunks = [[] for _ in range(num_chunks)] 44 | chunks_lengths = [0 for _ in range(num_chunks)] 45 | for index in indices: 46 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 47 | chunks[shortest_chunk].append(index) 48 | chunks_lengths[shortest_chunk] += lengths[index] 49 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 50 | chunks_lengths[shortest_chunk] = float("inf") 51 | 52 | return chunks 53 | 54 | 55 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 56 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 57 | assert all(l != 0 for l in lengths), "Should not have zero length." 58 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 59 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 60 | 61 | assert len(mm_indices) > 0, "Should have at least one multimodal sample." 62 | assert len(lang_indices) > 0, "Should have at least one language sample." 63 | 64 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 65 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 66 | megabatch_size = world_size * batch_size 67 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 68 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 69 | 70 | last_mm = mm_megabatches[-1] 71 | last_lang = lang_megabatches[-1] 72 | additional_batch = last_mm + last_lang 73 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 74 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 75 | megabatches = [megabatches[i] for i in megabatch_indices] 76 | 77 | if len(additional_batch) >= megabatch_size: 78 | megabatches = [additional_batch[:megabatch_size]] + megabatches 79 | additional_batch = additional_batch[megabatch_size:] 80 | 81 | if len(additional_batch) > 0: 82 | megabatches.append(additional_batch) 83 | 84 | return [i for megabatch in megabatches for i in megabatch] 85 | 86 | 87 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 88 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 89 | indices = torch.randperm(len(lengths), generator=generator) 90 | megabatch_size = world_size * batch_size 91 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 92 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 93 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 94 | 95 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 96 | 97 | 98 | class LengthGroupedSampler(Sampler): 99 | r""" 100 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 101 | keeping a bit of randomness. 102 | """ 103 | 104 | def __init__( 105 | self, 106 | batch_size: int, 107 | world_size: int, 108 | lengths: Optional[List[int]] = None, 109 | generator=None, 110 | group_by_modality: bool = False, 111 | ): 112 | if lengths is None: 113 | raise ValueError("Lengths must be provided.") 114 | 115 | self.batch_size = batch_size 116 | self.world_size = world_size 117 | self.lengths = lengths 118 | self.generator = generator 119 | self.group_by_modality = group_by_modality 120 | 121 | def __len__(self): 122 | return len(self.lengths) 123 | 124 | def __iter__(self): 125 | if self.group_by_modality: 126 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 127 | else: 128 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 129 | return iter(indices) 130 | 131 | 132 | class CDChatTrainer(Trainer): 133 | 134 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 135 | if self.train_dataset is None or not has_length(self.train_dataset): 136 | return None 137 | 138 | if self.args.group_by_modality_length: 139 | lengths = self.train_dataset.modality_lengths 140 | return LengthGroupedSampler( 141 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps 142 | self.args.train_batch_size, 143 | world_size=self.args.world_size, 144 | lengths=lengths, 145 | group_by_modality=True, 146 | ) 147 | else: 148 | return super()._get_train_sampler() 149 | 150 | def _save_checkpoint(self, model, trial, metrics=None): 151 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 152 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 153 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 154 | 155 | run_dir = self._get_output_dir(trial=trial) 156 | output_dir = os.path.join(run_dir, checkpoint_folder) 157 | 158 | # Only save Adapter 159 | keys_to_match = ['mm_projector', 'vision_resampler'] 160 | if getattr(self.args, "use_im_start_end", False): 161 | keys_to_match.extend(['embed_tokens', 'embed_in']) 162 | 163 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 164 | 165 | if self.args.local_rank == 0 or self.args.local_rank == -1: 166 | self.model.config.save_pretrained(output_dir) 167 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 168 | else: 169 | super(CDChatTrainer, self)._save_checkpoint(model, trial, metrics) 170 | 171 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 172 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 173 | pass 174 | else: 175 | super(CDChatTrainer, self)._save(output_dir, state_dict) 176 | -------------------------------------------------------------------------------- /cdchat/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /cdchat/train/train.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import copy 19 | from dataclasses import dataclass, field 20 | import json 21 | import logging 22 | import pathlib 23 | from typing import Dict, Optional, Sequence, List 24 | import numpy as np 25 | import torch 26 | 27 | import transformers 28 | 29 | from cdchat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 30 | from torch.utils.data import Dataset 31 | from cdchat.train.cdchat_trainer import CDChatTrainer 32 | 33 | from cdchat import conversation as conversation_lib 34 | from cdchat.model import * 35 | from cdchat.mm_utils import tokenizer_image_token 36 | 37 | from PIL import Image 38 | 39 | 40 | local_rank = None 41 | 42 | 43 | def rank0_print(*args): 44 | if local_rank == 0: 45 | print(*args) 46 | 47 | 48 | @dataclass 49 | class ModelArguments: 50 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 51 | version: Optional[str] = field(default="v0") 52 | freeze_backbone: bool = field(default=False) 53 | tune_mm_mlp_adapter: bool = field(default=False) 54 | vision_tower: Optional[str] = field(default=None) 55 | mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer 56 | pretrain_mm_mlp_adapter: Optional[str] = field(default=None) 57 | mm_projector_type: Optional[str] = field(default='linear') 58 | mm_use_im_start_end: bool = field(default=False) 59 | mm_use_im_patch_token: bool = field(default=True) 60 | mm_vision_select_feature: Optional[str] = field(default="patch") 61 | 62 | 63 | @dataclass 64 | class DataArguments: 65 | data_path: str = field(default=None, 66 | metadata={"help": "Path to the training data."}) 67 | lazy_preprocess: bool = False 68 | is_multimodal: bool = False 69 | image_folder: Optional[str] = field(default=None) 70 | image_aspect_ratio: str = 'square' 71 | image_grid_pinpoints: Optional[str] = field(default=None) 72 | 73 | 74 | @dataclass 75 | class TrainingArguments(transformers.TrainingArguments): 76 | cache_dir: Optional[str] = field(default=None) 77 | optim: str = field(default="adamw_torch") 78 | remove_unused_columns: bool = field(default=False) 79 | freeze_mm_mlp_adapter: bool = field(default=False) 80 | mpt_attn_impl: Optional[str] = field(default="triton") 81 | model_max_length: int = field( 82 | default=512, 83 | metadata={ 84 | "help": 85 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 86 | }, 87 | ) 88 | double_quant: bool = field( 89 | default=True, 90 | metadata={"help": "Compress the quantization statistics through double quantization."} 91 | ) 92 | quant_type: str = field( 93 | default="nf4", 94 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 95 | ) 96 | bits: int = field( 97 | default=16, 98 | metadata={"help": "How many bits to use."} 99 | ) 100 | lora_enable: bool = False 101 | lora_r: int = 64 102 | lora_alpha: int = 16 103 | lora_dropout: float = 0.05 104 | lora_weight_path: str = "" 105 | lora_bias: str = "none" 106 | group_by_modality_length: bool = field(default=False) 107 | 108 | 109 | def maybe_zero_3(param, ignore_status=False, name=None): 110 | from deepspeed import zero 111 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 112 | if hasattr(param, "ds_id"): 113 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 114 | if not ignore_status: 115 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 116 | with zero.GatheredParameters([param]): 117 | param = param.data.detach().cpu().clone() 118 | else: 119 | param = param.detach().cpu().clone() 120 | return param 121 | 122 | 123 | # Borrowed from peft.utils.get_peft_model_state_dict 124 | def get_peft_state_maybe_zero_3(named_params, bias): 125 | if bias == "none": 126 | to_return = {k: t for k, t in named_params if "lora_" in k} 127 | elif bias == "all": 128 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 129 | elif bias == "lora_only": 130 | to_return = {} 131 | maybe_lora_bias = {} 132 | lora_bias_names = set() 133 | for k, t in named_params: 134 | if "lora_" in k: 135 | to_return[k] = t 136 | bias_name = k.split("lora_")[0] + "bias" 137 | lora_bias_names.add(bias_name) 138 | elif "bias" in k: 139 | maybe_lora_bias[k] = t 140 | for k, t in maybe_lora_bias: 141 | if bias_name in lora_bias_names: 142 | to_return[bias_name] = t 143 | else: 144 | raise NotImplementedError 145 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 146 | return to_return 147 | 148 | 149 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 150 | to_return = {k: t for k, t in named_params if "lora_" not in k} 151 | if require_grad_only: 152 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 153 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 154 | return to_return 155 | 156 | 157 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 158 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 159 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 160 | return to_return 161 | 162 | 163 | def find_all_linear_names(model): 164 | cls = torch.nn.Linear 165 | lora_module_names = set() 166 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] 167 | for name, module in model.named_modules(): 168 | if any(mm_keyword in name for mm_keyword in multimodal_keywords): 169 | continue 170 | if isinstance(module, cls): 171 | names = name.split('.') 172 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 173 | 174 | if 'lm_head' in lora_module_names: # needed for 16-bit 175 | lora_module_names.remove('lm_head') 176 | return list(lora_module_names) 177 | 178 | 179 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 180 | output_dir: str): 181 | """Collects the state dict and dump to disk.""" 182 | 183 | if getattr(trainer.args, "tune_mm_mlp_adapter", False): 184 | # Only save Adapter 185 | keys_to_match = ['mm_projector'] 186 | if getattr(trainer.args, "use_im_start_end", False): 187 | keys_to_match.extend(['embed_tokens', 'embed_in']) 188 | 189 | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) 190 | trainer.model.config.save_pretrained(output_dir) 191 | 192 | current_folder = output_dir.split('/')[-1] 193 | parent_folder = os.path.dirname(output_dir) 194 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 195 | if current_folder.startswith('checkpoint-'): 196 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 197 | os.makedirs(mm_projector_folder, exist_ok=True) 198 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 199 | else: 200 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 201 | return 202 | 203 | if trainer.deepspeed: 204 | torch.cuda.synchronize() 205 | trainer.save_model(output_dir) 206 | return 207 | 208 | state_dict = trainer.model.state_dict() 209 | if trainer.args.should_save: 210 | cpu_state_dict = { 211 | key: value.cpu() 212 | for key, value in state_dict.items() 213 | } 214 | del state_dict 215 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 216 | 217 | 218 | def smart_tokenizer_and_embedding_resize( 219 | special_tokens_dict: Dict, 220 | tokenizer: transformers.PreTrainedTokenizer, 221 | model: transformers.PreTrainedModel, 222 | ): 223 | """Resize tokenizer and embedding. 224 | 225 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 226 | """ 227 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 228 | model.resize_token_embeddings(len(tokenizer)) 229 | 230 | if num_new_tokens > 0: 231 | input_embeddings = model.get_input_embeddings().weight.data 232 | output_embeddings = model.get_output_embeddings().weight.data 233 | 234 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 235 | dim=0, keepdim=True) 236 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 237 | dim=0, keepdim=True) 238 | 239 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 240 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 241 | 242 | 243 | def _tokenize_fn(strings: Sequence[str], 244 | tokenizer: transformers.PreTrainedTokenizer) -> Dict: 245 | """Tokenize a list of strings.""" 246 | tokenized_list = [ 247 | tokenizer( 248 | text, 249 | return_tensors="pt", 250 | padding="longest", 251 | max_length=tokenizer.model_max_length, 252 | truncation=True, 253 | ) for text in strings 254 | ] 255 | input_ids = labels = [ 256 | tokenized.input_ids[0] for tokenized in tokenized_list 257 | ] 258 | input_ids_lens = labels_lens = [ 259 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 260 | for tokenized in tokenized_list 261 | ] 262 | return dict( 263 | input_ids=input_ids, 264 | labels=labels, 265 | input_ids_lens=input_ids_lens, 266 | labels_lens=labels_lens, 267 | ) 268 | 269 | 270 | def _mask_targets(target, tokenized_lens, speakers): 271 | # cur_idx = 0 272 | cur_idx = tokenized_lens[0] 273 | tokenized_lens = tokenized_lens[1:] 274 | target[:cur_idx] = IGNORE_INDEX 275 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 276 | if speaker == "human": 277 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 278 | cur_idx += tokenized_len 279 | 280 | 281 | def _add_speaker_and_signal(header, source, get_conversation=True): 282 | """Add speaker and start/end signal on each round.""" 283 | BEGIN_SIGNAL = "### " 284 | END_SIGNAL = "\n" 285 | conversation = header 286 | for sentence in source: 287 | from_str = sentence["from"] 288 | if from_str.lower() == "human": 289 | from_str = conversation_lib.default_conversation.roles[0] 290 | elif from_str.lower() == "gpt": 291 | from_str = conversation_lib.default_conversation.roles[1] 292 | else: 293 | from_str = 'unknown' 294 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + 295 | sentence["value"] + END_SIGNAL) 296 | if get_conversation: 297 | conversation += sentence["value"] 298 | conversation += BEGIN_SIGNAL 299 | return conversation 300 | 301 | 302 | def preprocess_multimodal( 303 | sources: Sequence[str], 304 | data_args: DataArguments 305 | ) -> Dict: 306 | is_multimodal = data_args.is_multimodal 307 | if not is_multimodal: 308 | return sources 309 | 310 | for source in sources: 311 | for sentence in source: 312 | 313 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 314 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 315 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 316 | sentence['value'] = sentence['value'].strip() 317 | if "mmtag" in conversation_lib.default_conversation.version: 318 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') 319 | replace_token = DEFAULT_IMAGE_TOKEN 320 | if data_args.mm_use_im_start_end: 321 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 322 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 323 | 324 | return sources 325 | 326 | 327 | def preprocess_llama_2( 328 | sources, 329 | tokenizer: transformers.PreTrainedTokenizer, 330 | has_image: bool = False 331 | ) -> Dict: 332 | conv = conversation_lib.default_conversation.copy() 333 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 334 | 335 | # Apply prompt templates 336 | conversations = [] 337 | for i, source in enumerate(sources): 338 | if roles[source[0]["from"]] != conv.roles[0]: 339 | # Skip the first one if it is not from human 340 | source = source[1:] 341 | 342 | conv.messages = [] 343 | for j, sentence in enumerate(source): 344 | role = roles[sentence["from"]] 345 | assert role == conv.roles[j % 2], f"{i}" 346 | conv.append_message(role, sentence["value"]) 347 | conversations.append(conv.get_prompt()) 348 | 349 | # Tokenize conversations 350 | 351 | if has_image: 352 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 353 | else: 354 | input_ids = tokenizer( 355 | conversations, 356 | return_tensors="pt", 357 | padding="longest", 358 | max_length=tokenizer.model_max_length, 359 | truncation=True, 360 | ).input_ids 361 | 362 | targets = input_ids.clone() 363 | 364 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 365 | 366 | # Mask targets 367 | sep = "[/INST] " 368 | for conversation, target in zip(conversations, targets): 369 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 370 | 371 | rounds = conversation.split(conv.sep2) 372 | cur_len = 1 373 | target[:cur_len] = IGNORE_INDEX 374 | for i, rou in enumerate(rounds): 375 | if rou == "": 376 | break 377 | 378 | parts = rou.split(sep) 379 | if len(parts) != 2: 380 | break 381 | parts[0] += sep 382 | 383 | if has_image: 384 | round_len = len(tokenizer_image_token(rou, tokenizer)) 385 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 386 | else: 387 | round_len = len(tokenizer(rou).input_ids) 388 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 389 | 390 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 391 | 392 | cur_len += round_len 393 | target[cur_len:] = IGNORE_INDEX 394 | 395 | if cur_len < tokenizer.model_max_length: 396 | if cur_len != total_len: 397 | target[:] = IGNORE_INDEX 398 | print( 399 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 400 | f" (ignored)" 401 | ) 402 | 403 | return dict( 404 | input_ids=input_ids, 405 | labels=targets, 406 | ) 407 | 408 | 409 | def preprocess_v1( 410 | sources, 411 | tokenizer: transformers.PreTrainedTokenizer, 412 | has_image: bool = False 413 | ) -> Dict: 414 | conv = conversation_lib.default_conversation.copy() 415 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 416 | 417 | # Apply prompt templates 418 | conversations = [] 419 | for i, source in enumerate(sources): 420 | if roles[source[0]["from"]] != conv.roles[0]: 421 | # Skip the first one if it is not from human 422 | source = source[1:] 423 | 424 | conv.messages = [] 425 | for j, sentence in enumerate(source): 426 | role = roles[sentence["from"]] 427 | assert role == conv.roles[j % 2], f"{i}" 428 | conv.append_message(role, sentence["value"]) 429 | conversations.append(conv.get_prompt()) 430 | 431 | # Tokenize conversations 432 | 433 | if has_image: 434 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 435 | else: 436 | input_ids = tokenizer( 437 | conversations, 438 | return_tensors="pt", 439 | padding="longest", 440 | max_length=tokenizer.model_max_length, 441 | truncation=True, 442 | ).input_ids 443 | 444 | targets = input_ids.clone() 445 | 446 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 447 | 448 | # Mask targets 449 | sep = conv.sep + conv.roles[1] + ": " 450 | for conversation, target in zip(conversations, targets): 451 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 452 | 453 | rounds = conversation.split(conv.sep2) 454 | cur_len = 1 455 | target[:cur_len] = IGNORE_INDEX 456 | for i, rou in enumerate(rounds): 457 | if rou == "": 458 | break 459 | 460 | parts = rou.split(sep) 461 | if len(parts) != 2: 462 | break 463 | parts[0] += sep 464 | 465 | if has_image: 466 | round_len = len(tokenizer_image_token(rou, tokenizer)) 467 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 468 | else: 469 | round_len = len(tokenizer(rou).input_ids) 470 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 471 | 472 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 473 | 474 | cur_len += round_len 475 | target[cur_len:] = IGNORE_INDEX 476 | 477 | if cur_len < tokenizer.model_max_length: 478 | if cur_len != total_len: 479 | target[:] = IGNORE_INDEX 480 | print( 481 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 482 | f" (ignored)" 483 | ) 484 | 485 | return dict( 486 | input_ids=input_ids, 487 | labels=targets, 488 | ) 489 | 490 | ''' 491 | def preprocess_mpt( 492 | sources, 493 | tokenizer: transformers.PreTrainedTokenizer, 494 | ) -> Dict: 495 | conv = conversation_lib.default_conversation.copy() 496 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 497 | 498 | # Apply prompt templates 499 | conversations = [] 500 | for i, source in enumerate(sources): 501 | if roles[source[0]["from"]] != conv.roles[0]: 502 | # Skip the first one if it is not from human 503 | source = source[1:] 504 | 505 | conv.messages = [] 506 | for j, sentence in enumerate(source): 507 | role = roles[sentence["from"]] 508 | assert role == conv.roles[j % 2], f"{i}" 509 | conv.append_message(role, sentence["value"]) 510 | conversations.append(conv.get_prompt()) 511 | 512 | # Tokenize conversations 513 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 514 | targets = input_ids.clone() 515 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT 516 | 517 | # Mask targets 518 | sep = conv.sep + conv.roles[1] 519 | for conversation, target in zip(conversations, targets): 520 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 521 | 522 | rounds = conversation.split(conv.sep) 523 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt 524 | for conv_idx in range(3, len(rounds), 2): 525 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt 526 | cur_len = 0 527 | target[:cur_len] = IGNORE_INDEX 528 | for i, rou in enumerate(re_rounds): 529 | if rou == "": 530 | break 531 | 532 | parts = rou.split(sep) 533 | if len(parts) != 2: 534 | break 535 | parts[0] += sep 536 | round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) 537 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) 538 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 539 | 540 | cur_len += round_len 541 | target[cur_len:] = IGNORE_INDEX 542 | 543 | if cur_len < tokenizer.model_max_length: 544 | if cur_len != total_len: 545 | target[:] = IGNORE_INDEX 546 | print( 547 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 548 | f" (ignored)" 549 | ) 550 | 551 | return dict( 552 | input_ids=input_ids, 553 | labels=targets, 554 | ) 555 | ''' 556 | 557 | def preprocess_plain( 558 | sources: Sequence[str], 559 | tokenizer: transformers.PreTrainedTokenizer, 560 | ) -> Dict: 561 | # add end signal and concatenate together 562 | conversations = [] 563 | for source in sources: 564 | assert len(source) == 2 565 | assert DEFAULT_IMAGE_TOKEN in source[0]['value'] 566 | source[0]['value'] = DEFAULT_IMAGE_TOKEN 567 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 568 | conversations.append(conversation) 569 | # tokenize conversations 570 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 571 | targets = copy.deepcopy(input_ids) 572 | for target, source in zip(targets, sources): 573 | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) 574 | target[:tokenized_len] = IGNORE_INDEX 575 | 576 | return dict(input_ids=input_ids, labels=targets) 577 | 578 | 579 | def preprocess( 580 | sources: Sequence[str], 581 | tokenizer: transformers.PreTrainedTokenizer, 582 | has_image: bool = False 583 | ) -> Dict: 584 | """ 585 | Given a list of sources, each is a conversation list. This transform: 586 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 587 | 2. Concatenate conversations together; 588 | 3. Tokenize the concatenated conversation; 589 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 590 | """ 591 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: 592 | return preprocess_plain(sources, tokenizer) 593 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: 594 | return preprocess_llama_2(sources, tokenizer, has_image=has_image) 595 | if conversation_lib.default_conversation.version.startswith("v1"): 596 | return preprocess_v1(sources, tokenizer, has_image=has_image) 597 | ''' 598 | if conversation_lib.default_conversation.version == "mpt": 599 | return preprocess_mpt(sources, tokenizer) 600 | ''' 601 | # add end signal and concatenate together 602 | conversations = [] 603 | for source in sources: 604 | header = f"{conversation_lib.default_conversation.system}\n\n" 605 | conversation = _add_speaker_and_signal(header, source) 606 | conversations.append(conversation) 607 | # tokenize conversations 608 | def get_tokenize_len(prompts): 609 | return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] 610 | 611 | if has_image: 612 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 613 | else: 614 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) 615 | input_ids = conversations_tokenized["input_ids"] 616 | 617 | targets = copy.deepcopy(input_ids) 618 | for target, source in zip(targets, sources): 619 | if has_image: 620 | tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) 621 | else: 622 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] 623 | speakers = [sentence["from"] for sentence in source] 624 | _mask_targets(target, tokenized_lens, speakers) 625 | 626 | return dict(input_ids=input_ids, labels=targets) 627 | 628 | 629 | class LazySupervisedDataset(Dataset): 630 | """Dataset for supervised fine-tuning.""" 631 | 632 | def __init__(self, data_path: str, 633 | tokenizer: transformers.PreTrainedTokenizer, 634 | data_args: DataArguments): 635 | super(LazySupervisedDataset, self).__init__() 636 | list_data_dict = json.load(open(data_path, "r")) 637 | 638 | rank0_print("Formatting inputs...Skip in lazy mode") 639 | self.tokenizer = tokenizer 640 | self.list_data_dict = list_data_dict 641 | self.data_args = data_args 642 | 643 | def __len__(self): 644 | return len(self.list_data_dict) 645 | 646 | @property 647 | def lengths(self): 648 | length_list = [] 649 | for sample in self.list_data_dict: 650 | img_tokens = 128 if 'image' in sample else 0 651 | length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) 652 | return length_list 653 | 654 | @property 655 | def modality_lengths(self): 656 | length_list = [] 657 | for sample in self.list_data_dict: 658 | cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) 659 | cur_len = cur_len if 'image' in sample else -cur_len 660 | length_list.append(cur_len) 661 | return length_list 662 | 663 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 664 | 665 | #print('************************ ', self.data_args.image_aspect_ratio) 666 | 667 | sources = self.list_data_dict[i] 668 | if isinstance(i, int): 669 | sources = [sources] 670 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 671 | if 'image' in sources[0]: 672 | image_file = self.list_data_dict[i]['image'] 673 | image_folder = self.data_args.image_folder 674 | 675 | #print('************************* Image Path: ', image_file) 676 | #################################################### 677 | image_folder_A = os.path.join(image_folder, 'A') 678 | image_folder_B = os.path.join(image_folder, 'B') 679 | image_folder_label = os.path.join(image_folder, 'label') 680 | 681 | #################################################### 682 | processor = self.data_args.image_processor 683 | #image = Image.open((os.path.join(image_folder, image_file)).strip()).convert('RGB') 684 | image_A = Image.open((os.path.join(image_folder_A, image_file)).strip()).convert('RGB') #.resize((336,336), resample=Image.Resampling.BILINEAR) 685 | image_B = Image.open((os.path.join(image_folder_B, image_file)).strip()).convert('RGB') #.resize((336,336), resample=Image.Resampling.BILINEAR) 686 | image_label = Image.open((os.path.join(image_folder_label, image_file)).strip()).convert('L') 687 | image_label = (torch.Tensor(1)*np.array(image_label)/255.).unsqueeze(0) 688 | 689 | if self.data_args.image_aspect_ratio == 'pad': 690 | def expand2square(pil_img, background_color): 691 | width, height = pil_img.size 692 | if width == height: 693 | return pil_img 694 | elif width > height: 695 | result = Image.new(pil_img.mode, (width, width), background_color) 696 | result.paste(pil_img, (0, (width - height) // 2)) 697 | return result 698 | else: 699 | result = Image.new(pil_img.mode, (height, height), background_color) 700 | result.paste(pil_img, ((height - width) // 2, 0)) 701 | return result 702 | #image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) 703 | # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 704 | #image = processor.preprocess(image,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 705 | image_A = expand2square(image_A, tuple(int(x*255) for x in processor.image_mean)) 706 | image_B = expand2square(image_B, tuple(int(x*255) for x in processor.image_mean)) 707 | image_A = processor.preprocess(image_A, do_resize=True, crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 708 | image_B = processor.preprocess(image_B, do_resize=True, crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 709 | 710 | else: 711 | # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 712 | #image = processor.preprocess(image,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 713 | image_A = processor.preprocess(image_A,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 714 | image_B = processor.preprocess(image_B,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0] 715 | 716 | sources = preprocess_multimodal( 717 | copy.deepcopy([e["conversations"] for e in sources]), 718 | self.data_args) 719 | else: 720 | sources = copy.deepcopy([e["conversations"] for e in sources]) 721 | data_dict = preprocess( 722 | sources, 723 | self.tokenizer, 724 | has_image=('image' in self.list_data_dict[i])) 725 | if isinstance(i, int): 726 | data_dict = dict(input_ids=data_dict["input_ids"][0], 727 | labels=data_dict["labels"][0]) 728 | 729 | image_A = image_A[[2,1,0], :, :] 730 | image_B = image_B[[2,1,0], :, :] 731 | #print(image_A.shape) 732 | # image exist in the data 733 | if 'image' in self.list_data_dict[i]: 734 | data_dict['image'] = {'pre':image_A, 'post':image_B, 'targets': image_label} #image 735 | elif self.data_args.is_multimodal: 736 | # image does not exist in the data, but the model is multimodal 737 | crop_size = self.data_args.image_processor.crop_size 738 | data_dict['image'] = {'pre':torch.zeros(3, crop_size['height'], crop_size['width']), 739 | 'post':torch.zeros(3, crop_size['height'], crop_size['width']), 740 | 'targets':torch.zeros(3, crop_size['height'], crop_size['width'])} #torch.zeros(3, crop_size['height'], crop_size['width']) 741 | return data_dict 742 | 743 | 744 | @dataclass 745 | class DataCollatorForSupervisedDataset(object): 746 | """Collate examples for supervised fine-tuning.""" 747 | 748 | tokenizer: transformers.PreTrainedTokenizer 749 | 750 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 751 | input_ids, labels = tuple([instance[key] for instance in instances] 752 | for key in ("input_ids", "labels")) 753 | input_ids = torch.nn.utils.rnn.pad_sequence( 754 | input_ids, 755 | batch_first=True, 756 | padding_value=self.tokenizer.pad_token_id) 757 | labels = torch.nn.utils.rnn.pad_sequence(labels, 758 | batch_first=True, 759 | padding_value=IGNORE_INDEX) 760 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 761 | labels = labels[:, :self.tokenizer.model_max_length] 762 | batch = dict( 763 | input_ids=input_ids, 764 | labels=labels, 765 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 766 | ) 767 | 768 | if 'image' in instances[0]: 769 | ''' 770 | images = [instance['image'] for instance in instances] 771 | if all(x is not None and x.shape == images[0].shape for x in images): 772 | batch['images'] = torch.stack(images) 773 | else: 774 | batch['images'] = images 775 | ''' 776 | images_A = [instance['image']['pre'] for instance in instances] 777 | images_B = [instance['image']['post'] for instance in instances] 778 | images_label = [instance['image']['targets'] for instance in instances] 779 | if all(x is not None and x.shape == images_A[0].shape for x in images_A): 780 | batch['images'] = {'pre':torch.stack(images_A), 'post':torch.stack(images_B), 'targets':torch.stack(images_label)} 781 | else: 782 | batch['images'] = {'pre':images_A, 'post':images_B, 'targets':images_label} 783 | 784 | return batch 785 | 786 | 787 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 788 | data_args) -> Dict: 789 | """Make dataset and collator for supervised fine-tuning.""" 790 | train_dataset = LazySupervisedDataset(tokenizer=tokenizer, 791 | data_path=data_args.data_path, 792 | data_args=data_args) 793 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 794 | return dict(train_dataset=train_dataset, 795 | eval_dataset=None, 796 | data_collator=data_collator) 797 | 798 | 799 | def train(): 800 | global local_rank 801 | 802 | parser = transformers.HfArgumentParser( 803 | (ModelArguments, DataArguments, TrainingArguments)) 804 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 805 | local_rank = training_args.local_rank 806 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 807 | 808 | bnb_model_from_pretrained_args = {} 809 | if training_args.bits in [4, 8]: 810 | from transformers import BitsAndBytesConfig 811 | bnb_model_from_pretrained_args.update(dict( 812 | device_map={"": training_args.device}, 813 | load_in_4bit=training_args.bits == 4, 814 | load_in_8bit=training_args.bits == 8, 815 | quantization_config=BitsAndBytesConfig( 816 | load_in_4bit=training_args.bits == 4, 817 | load_in_8bit=training_args.bits == 8, 818 | llm_int8_threshold=6.0, 819 | llm_int8_has_fp16_weight=False, 820 | bnb_4bit_compute_dtype=compute_dtype, 821 | bnb_4bit_use_double_quant=training_args.double_quant, 822 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} 823 | ) 824 | )) 825 | 826 | if model_args.vision_tower is not None: 827 | ''' 828 | if 'mpt' in model_args.model_name_or_path: 829 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 830 | config.attn_config['attn_impl'] = training_args.mpt_attn_impl 831 | model = CDChatMPTForCausalLM.from_pretrained( 832 | model_args.model_name_or_path, 833 | config=config, 834 | cache_dir=training_args.cache_dir, 835 | ignore_mismatched_sizes=True, 836 | **bnb_model_from_pretrained_args 837 | ) 838 | else: 839 | ''' 840 | model = CDChatLlamaForCausalLM.from_pretrained( 841 | model_args.model_name_or_path, 842 | cache_dir=training_args.cache_dir, 843 | ignore_mismatched_sizes=True, 844 | **bnb_model_from_pretrained_args 845 | ) 846 | else: 847 | model = transformers.LlamaForCausalLM.from_pretrained( 848 | model_args.model_name_or_path, 849 | cache_dir=training_args.cache_dir, 850 | **bnb_model_from_pretrained_args 851 | ) 852 | model.config.use_cache = False 853 | 854 | if model_args.freeze_backbone: 855 | model.model.requires_grad_(False) 856 | 857 | if training_args.bits in [4, 8]: 858 | from peft import prepare_model_for_kbit_training 859 | model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 860 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) 861 | 862 | if training_args.gradient_checkpointing: 863 | if hasattr(model, "enable_input_require_grads"): 864 | model.enable_input_require_grads() 865 | else: 866 | def make_inputs_require_grad(module, input, output): 867 | output.requires_grad_(True) 868 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 869 | 870 | if training_args.lora_enable: 871 | from peft import LoraConfig, get_peft_model 872 | lora_config = LoraConfig( 873 | r=training_args.lora_r, 874 | lora_alpha=training_args.lora_alpha, 875 | target_modules=find_all_linear_names(model), 876 | lora_dropout=training_args.lora_dropout, 877 | bias=training_args.lora_bias, 878 | task_type="CAUSAL_LM", 879 | ) 880 | if training_args.bits == 16: 881 | if training_args.bf16: 882 | model.to(torch.bfloat16) 883 | if training_args.fp16: 884 | model.to(torch.float16) 885 | rank0_print("Adding LoRA adapters...") 886 | model = get_peft_model(model, lora_config) 887 | 888 | ''' 889 | if 'mpt' in model_args.model_name_or_path: 890 | tokenizer = transformers.AutoTokenizer.from_pretrained( 891 | model_args.model_name_or_path, 892 | cache_dir=training_args.cache_dir, 893 | model_max_length=training_args.model_max_length, 894 | padding_side="right" 895 | ) 896 | else: 897 | ''' 898 | tokenizer = transformers.AutoTokenizer.from_pretrained( 899 | model_args.model_name_or_path, 900 | cache_dir=training_args.cache_dir, 901 | model_max_length=training_args.model_max_length, 902 | padding_side="right", 903 | use_fast=False, 904 | ) 905 | 906 | if model_args.version == "v0": 907 | if tokenizer.pad_token is None: 908 | smart_tokenizer_and_embedding_resize( 909 | special_tokens_dict=dict(pad_token="[PAD]"), 910 | tokenizer=tokenizer, 911 | model=model, 912 | ) 913 | elif model_args.version == "v0.5": 914 | tokenizer.pad_token = tokenizer.unk_token 915 | else: 916 | tokenizer.pad_token = tokenizer.unk_token 917 | if model_args.version in conversation_lib.conv_templates: 918 | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] 919 | else: 920 | conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] 921 | 922 | if model_args.vision_tower is not None: 923 | model.get_model().initialize_vision_modules( 924 | model_args=model_args, 925 | fsdp=training_args.fsdp 926 | ) 927 | 928 | vision_tower = model.get_vision_tower() 929 | vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) 930 | 931 | data_args.image_processor = vision_tower.image_processor 932 | data_args.is_multimodal = True 933 | 934 | model.config.image_aspect_ratio = data_args.image_aspect_ratio 935 | model.config.image_grid_pinpoints = data_args.image_grid_pinpoints 936 | 937 | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter 938 | if model_args.tune_mm_mlp_adapter: 939 | model.requires_grad_(False) 940 | for p in model.get_model().mm_projector.parameters(): 941 | p.requires_grad = True 942 | 943 | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter 944 | if training_args.freeze_mm_mlp_adapter: 945 | for p in model.get_model().mm_projector.parameters(): 946 | p.requires_grad = False 947 | 948 | if training_args.bits in [4, 8]: 949 | model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) 950 | 951 | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end 952 | training_args.use_im_start_end = model_args.mm_use_im_start_end 953 | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token 954 | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) 955 | 956 | if training_args.bits in [4, 8]: 957 | from peft.tuners.lora import LoraLayer 958 | for name, module in model.named_modules(): 959 | if isinstance(module, LoraLayer): 960 | if training_args.bf16: 961 | module = module.to(torch.bfloat16) 962 | if 'norm' in name: 963 | module = module.to(torch.float32) 964 | if 'lm_head' in name or 'embed_tokens' in name: 965 | if hasattr(module, 'weight'): 966 | if training_args.bf16 and module.weight.dtype == torch.float32: 967 | module = module.to(torch.bfloat16) 968 | 969 | data_module = make_supervised_data_module(tokenizer=tokenizer, 970 | data_args=data_args) 971 | trainer = CDChatTrainer(model=model, 972 | tokenizer=tokenizer, 973 | args=training_args, 974 | **data_module) 975 | 976 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 977 | trainer.train(resume_from_checkpoint=True) 978 | else: 979 | trainer.train() 980 | trainer.save_state() 981 | 982 | model.config.use_cache = True 983 | 984 | if training_args.lora_enable: 985 | state_dict = get_peft_state_maybe_zero_3( 986 | model.named_parameters(), training_args.lora_bias 987 | ) 988 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( 989 | model.named_parameters() 990 | ) 991 | if training_args.local_rank == 0 or training_args.local_rank == -1: 992 | model.config.save_pretrained(training_args.output_dir) 993 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 994 | torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) 995 | else: 996 | safe_save_model_for_hf_trainer(trainer=trainer, 997 | output_dir=training_args.output_dir) 998 | 999 | 1000 | if __name__ == "__main__": 1001 | train() 1002 | -------------------------------------------------------------------------------- /cdchat/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from cdchat.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from cdchat.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /cdchat/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from cdchat.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /images/cdchat_annotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/cdchat_annotation.png -------------------------------------------------------------------------------- /images/cdchat_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/cdchat_arch.png -------------------------------------------------------------------------------- /images/example_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/example_01.png -------------------------------------------------------------------------------- /images/example_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/example_02.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cdchat" 7 | version = "1.0" 8 | description = "VLM for Remote Sensing Change Description" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers>=0.12.1", 18 | "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.9.5", 21 | "peft==0.4.0", 22 | "transformers==4.31.0", 23 | "accelerate==0.21.0", 24 | "bitsandbytes==0.41.0", 25 | "scikit-learn==1.2.2", 26 | "sentencepiece==0.1.99", 27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 28 | "gradio_client==0.2.9" 29 | ] 30 | 31 | [project.urls] 32 | "Homepage" = "https://github.com/techmn/cdchat" 33 | 34 | [tool.setuptools.packages.find] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | 37 | [tool.wheel] 38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 39 | -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################## VICUNA ################## 4 | PROMPT_VERSION=v1 5 | MODEL_VERSION="vicuna-v1.5-7b" 6 | ################## VICUNA ################## 7 | 8 | deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 ./cdchat/train/train_mem.py \ 9 | --deepspeed ./scripts/zero2.json \ 10 | --lora_enable True \ 11 | --model_name_or_path ./llava-v1.5-7b \ 12 | --version $PROMPT_VERSION \ 13 | --data_path dataset/cdchat_instruct_file.json \ 14 | --image_folder dataset/ \ 15 | --vision_tower openai/clip-vit-large-patch14-336 \ 16 | --mm_projector_type mlp2x_gelu \ 17 | --pretrain_mm_mlp_adapter cdchat/checkpoints/pretrain_mm_projector/mm_projector.bin \ 18 | --mm_vision_select_layer -2 \ 19 | --mm_use_im_start_end False \ 20 | --mm_use_im_patch_token False \ 21 | --image_aspect_ratio pad \ 22 | --bf16 True \ 23 | --output_dir cdchat/checkpoints/cdchat_log_lora \ 24 | --num_train_epochs 1 \ 25 | --per_device_train_batch_size 16 \ 26 | --per_device_eval_batch_size 4 \ 27 | --gradient_accumulation_steps 1 \ 28 | --evaluation_strategy "no" \ 29 | --save_strategy "epoch" \ 30 | --save_steps 7000 \ 31 | --save_total_limit 1 \ 32 | --learning_rate 2e-4 \ 33 | --weight_decay 0. \ 34 | --warmup_ratio 0.03 \ 35 | --lr_scheduler_type "cosine" \ 36 | --logging_steps 1 \ 37 | --tf32 True \ 38 | --model_max_length 2048 \ 39 | --gradient_checkpointing True \ 40 | --lazy_preprocess True \ 41 | --dataloader_num_workers 16 \ 42 | --report_to wandb 43 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cdchat.model.builder import load_pretrained_model 3 | from cdchat.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | # MODEL_VERSION=vicuna-v1-3-7b 8 | # MODEL_VERSION=llama-2-7b-chat 9 | #MODEL_VERSION="vicuna-v1.5-7b" 10 | 11 | ########### DO NOT CHANGE ########### 12 | ########### USE THIS FOR BOTH ########### 13 | PROMPT_VERSION=plain 14 | ########### DO NOT CHANGE ########### 15 | 16 | deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 ./cdchat/train/train_mem.py \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path ./llava-v1.5-7b \ 19 | --version $PROMPT_VERSION \ 20 | --data_path dataset/cdchat_instruct_file.json \ 21 | --image_folder dataset/ \ 22 | --vision_tower openai/clip-vit-large-patch14-336 \ 23 | --mm_projector_type mlp2x_gelu \ 24 | --tune_mm_mlp_adapter True \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --bf16 True \ 30 | --output_dir cdchat/checkpoints/pretrain_mm_projector \ 31 | --num_train_epochs 1 \ 32 | --per_device_train_batch_size 16 \ 33 | --per_device_eval_batch_size 4 \ 34 | --gradient_accumulation_steps 1 \ 35 | --evaluation_strategy "no" \ 36 | --save_strategy "epoch" \ 37 | --save_steps 24000 \ 38 | --save_total_limit 1 \ 39 | --learning_rate 2e-3 \ 40 | --weight_decay 0. \ 41 | --warmup_ratio 0.03 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 1 \ 44 | --tf32 True \ 45 | --model_max_length 2048 \ 46 | --gradient_checkpointing True \ 47 | --dataloader_num_workers 4 \ 48 | --lazy_preprocess True \ 49 | --report_to wandb 50 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------