├── README.md
├── cdchat
├── __init__.py
├── constants.py
├── conversation.py
├── eval
│ └── batch_cdchat_vqa.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── cdchat_arch.py
│ ├── consolidate.py
│ ├── language_model
│ │ └── cdchat_llama.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── train
│ ├── cdchat_trainer.py
│ ├── llama_flash_attn_monkey_patch.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── data_files
├── eval_questions_levir_test.json
├── eval_questions_sysu_test.json
├── levir_captions.json
└── sysu_test_captions.json
├── images
├── cdchat_annotation.png
├── cdchat_arch.png
├── example_01.png
└── example_02.png
├── pyproject.toml
└── scripts
├── extract_mm_projector.py
├── finetune_lora.sh
├── merge_lora_weights.py
├── pretrain.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
/README.md:
--------------------------------------------------------------------------------
1 | # CDChat: A Large Multimodal Model for Remote Sensing Change Description
2 |
3 | [[Arxiv]](https://arxiv.org/abs/2409.16261)
4 |
5 |
6 | - **Test Captions file of SYSU-CD is uploaded.**
7 | - **Evaluation test questions file of SYSU-CD is uploaded**
8 | - **Train/Val/Test Captions file of LEVIR-CD is uploaded**
9 | - **Evaluation test questions file of LEVIR-CD is uploaded**
10 |
11 | Intruction tuning files will be available soon !!!
12 |
13 |
14 | ### Overview
15 | CDChat is a conversational assistant for RS change description task. We annotate the SYSU-CD dataset to obtain the change text and image pairs for instruction tuning of CDChat. We create change text and image pairs from the two large scale change detection datasets including [SYSU-CD](https://github.com/liumency/SYSU-CD) and [LEVIR-CD](https://github.com/Chen-Yang-Liu/LEVIR-CC-Dataset).
16 |
17 |
18 |
19 | **Annotation Tool**
20 |
21 | A custom annotation tool was utilized to annotate the SYSY-CD dataset as shown below:
22 |
23 |
24 |
25 |
26 | ### Installation
27 | - **Clone this repository and navigate to cdchat folder**
28 | ```
29 | git clone https://github.com/techmn/cdchat.git
30 | cd cdchat
31 | ```
32 | - **Install Package**
33 | ```
34 | conda create -n cdchat python=3.10 -y
35 | conda activate cdchat
36 | pip install --upgrade pip # enable PEP 660 support
37 | pip install -e .
38 | ```
39 | - **Install additional packages for training cases**
40 | ```
41 | pip install ninja
42 | pip install flash-attn --no-build-isolation
43 | ```
44 |
45 | ### Train
46 | **Feature Alignment**
47 |
48 | We use the pretrained projector from LLaVA-v1.5 similar to [GeoChat](https://github.com/mbzuai-oryx/GeoChat).
49 |
50 | `--mm_projector_type mlp2x_gelu:` the two-layer MLP vision-language connector
51 |
52 | `--vision_tower openai/clip-vit-large-patch14-336:` CLIP ViT-L/14 336px
53 |
54 |
55 | **Instruction Tuning**
56 |
57 | - Download the cdchat_instruct_file and the image pairs. Place the image pairs folders and cdchat_instruct_file in the same folder.
58 | - Update the `--data_path` and `--image_folder` in the file `finetune_lora.sh`
59 | - Start training !!!!
60 |
61 | **Note** Our Codebase is inspired from [GeoChat](https://github.com/mbzuai-oryx/GeoChat). Please refer to it for detailed instructions on installation and training.
62 |
63 | **Model Weights**
64 | will be available soon !!!!
65 |
66 | ### Evaluation
67 | We evaluate the CDChat on the test sets of LEVIR-CD and SYSU-CD datasets. Below is the command to evaluate the CDChat on the dataset:
68 |
69 | ```
70 | python cdchat/eval/batch_cdchat_vqa.py \
71 | --model-path /path/to/model \
72 | --question-file path/to/json/file \
73 | --answer-file path/to/output/jsonl/file \
74 | --image-folder path/to/image/folder/
75 | ```
76 |
77 | Here are few examples of the responses from LMMs for change description task:
78 |
79 |
80 |
81 |
82 | ### Citation
83 |
84 | ```
85 | @misc{cdchat_2024,
86 | title={CDChat: A Large Multimodal Model for Remote Sensing Change Description},
87 | author={Mubashir Noman and Noor Ahsan and Muzammal Naseer and Hisham Cholakkal and Rao Muhammad Anwer and Salman Khan and Fahad Shahbaz Khan},
88 | year={2024},
89 | eprint={2409.16261},
90 | archivePrefix={arXiv},
91 | primaryClass={cs.CV},
92 | url={https://arxiv.org/abs/2409.16261},
93 | }
94 | ```
95 | ### Acknowledgements
96 | Our codebase is inspired from the [GeoChat](https://github.com/mbzuai-oryx/GeoChat) repository. We thank them for releasing their valuable codebase.
97 |
--------------------------------------------------------------------------------
/cdchat/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import CDChatLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/cdchat/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/cdchat/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 | from PIL import Image
5 | from threading import Thread
6 |
7 | from cdchat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8 | from cdchat.utils import disable_torch_init
9 | from cdchat.mm_utils import process_images_demo, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
10 | from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer,TextStreamer
11 | import torch
12 | import dataclasses
13 | from enum import auto, Enum
14 | from typing import List, Tuple, Any
15 |
16 |
17 | class SeparatorStyle(Enum):
18 | """Different separator style."""
19 | SINGLE = auto()
20 | TWO = auto()
21 | MPT = auto()
22 | PLAIN = auto()
23 | LLAMA_2 = auto()
24 |
25 |
26 | @dataclasses.dataclass
27 | class Conversation:
28 | """A class that keeps all conversation history."""
29 | system: str
30 | roles: List[str]
31 | messages: List[List[str]]
32 | offset: int
33 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
34 | sep: str = "###"
35 | sep2: str = None
36 | version: str = "Unknown"
37 |
38 | skip_next: bool = False
39 |
40 | def get_prompt(self):
41 | messages = self.messages
42 | if len(messages) > 0 and type(messages[0][1]) is tuple:
43 | messages = self.messages.copy()
44 | init_role, init_msg = messages[0].copy()
45 | init_msg = init_msg[0].replace("", "").strip()
46 | if 'mmtag' in self.version:
47 | messages[0] = (init_role, init_msg)
48 | messages.insert(0, (self.roles[0], ""))
49 | messages.insert(1, (self.roles[1], "Received."))
50 | else:
51 | messages[0] = (init_role, "\n" + init_msg)
52 |
53 | if self.sep_style == SeparatorStyle.SINGLE:
54 | ret = self.system + self.sep
55 | for role, message in messages:
56 | if message:
57 | if type(message) is tuple:
58 | message, _, _ = message
59 | ret += role + ": " + message + self.sep
60 | else:
61 | ret += role + ":"
62 | elif self.sep_style == SeparatorStyle.TWO:
63 | seps = [self.sep, self.sep2]
64 | ret = self.system + seps[0]
65 | for i, (role, message) in enumerate(messages):
66 | if message:
67 | if type(message) is tuple:
68 | message, _, _ = message
69 | ret += role + ": " + message + seps[i % 2]
70 | else:
71 | ret += role + ":"
72 | elif self.sep_style == SeparatorStyle.MPT:
73 | ret = self.system + self.sep
74 | for role, message in messages:
75 | if message:
76 | if type(message) is tuple:
77 | message, _, _ = message
78 | ret += role + message + self.sep
79 | else:
80 | ret += role
81 | elif self.sep_style == SeparatorStyle.LLAMA_2:
82 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n"
83 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
84 | ret = ""
85 |
86 | for i, (role, message) in enumerate(messages):
87 | if i == 0:
88 | assert message, "first message should not be none"
89 | assert role == self.roles[0], "first message should come from user"
90 | if message:
91 | if type(message) is tuple:
92 | message, _, _ = message
93 | if i == 0: message = wrap_sys(self.system) + message
94 | if i % 2 == 0:
95 | message = wrap_inst(message)
96 | ret += self.sep + message
97 | else:
98 | ret += " " + message + " " + self.sep2
99 | else:
100 | ret += ""
101 | ret = ret.lstrip(self.sep)
102 | elif self.sep_style == SeparatorStyle.PLAIN:
103 | seps = [self.sep, self.sep2]
104 | ret = self.system
105 | for i, (role, message) in enumerate(messages):
106 | if message:
107 | if type(message) is tuple:
108 | message, _, _ = message
109 | ret += message + seps[i % 2]
110 | else:
111 | ret += ""
112 | else:
113 | raise ValueError(f"Invalid style: {self.sep_style}")
114 |
115 | return ret
116 |
117 | def append_message(self, role, message):
118 | self.messages.append([role, message])
119 |
120 | def get_images(self, return_pil=False):
121 | images = []
122 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
123 | if i % 2 == 0:
124 | if type(msg) is tuple:
125 | import base64
126 | from io import BytesIO
127 | from PIL import Image
128 | msg, image, image_process_mode = msg
129 | if image_process_mode == "Pad":
130 | def expand2square(pil_img, background_color=(122, 116, 104)):
131 | width, height = pil_img.size
132 | if width == height:
133 | return pil_img
134 | elif width > height:
135 | result = Image.new(pil_img.mode, (width, width), background_color)
136 | result.paste(pil_img, (0, (width - height) // 2))
137 | return result
138 | else:
139 | result = Image.new(pil_img.mode, (height, height), background_color)
140 | result.paste(pil_img, ((height - width) // 2, 0))
141 | return result
142 | image = expand2square(image)
143 | elif image_process_mode in ["Default", "Crop"]:
144 | pass
145 | elif image_process_mode == "Resize":
146 | image = image.resize((448, 448))
147 | else:
148 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
149 | max_hw, min_hw = max(image.size), min(image.size)
150 | aspect_ratio = max_hw / min_hw
151 | max_len, min_len = 800, 448
152 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
153 | longest_edge = int(shortest_edge * aspect_ratio)
154 | W, H = image.size
155 | if longest_edge != max(image.size):
156 | if H > W:
157 | H, W = longest_edge, shortest_edge
158 | else:
159 | H, W = shortest_edge, longest_edge
160 | image = image.resize((W, H))
161 | if return_pil:
162 | images.append(image)
163 | else:
164 | buffered = BytesIO()
165 | image.save(buffered, format="PNG")
166 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
167 | images.append(img_b64_str)
168 | return images
169 |
170 | def to_gradio_chatbot(self):
171 | ret = []
172 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
173 | if i % 2 == 0:
174 | if type(msg) is tuple:
175 | import base64
176 | from io import BytesIO
177 | msg, image, image_process_mode = msg
178 | max_hw, min_hw = max(image.size), min(image.size)
179 | aspect_ratio = max_hw / min_hw
180 | max_len, min_len = 800, 448
181 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
182 | longest_edge = int(shortest_edge * aspect_ratio)
183 | W, H = image.size
184 | if H > W:
185 | H, W = longest_edge, shortest_edge
186 | else:
187 | H, W = shortest_edge, longest_edge
188 | image = image.resize((W, H))
189 | buffered = BytesIO()
190 | image.save(buffered, format="JPEG")
191 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
192 | img_str = f'
'
193 | msg = img_str + msg.replace('', '').strip()
194 | ret.append([msg, None])
195 | else:
196 | ret.append([msg, None])
197 | else:
198 | ret[-1][-1] = msg
199 | return ret
200 |
201 | def copy(self):
202 | return Conversation(
203 | system=self.system,
204 | roles=self.roles,
205 | messages=[[x, y] for x, y in self.messages],
206 | offset=self.offset,
207 | sep_style=self.sep_style,
208 | sep=self.sep,
209 | sep2=self.sep2,
210 | version=self.version)
211 |
212 | def dict(self):
213 | if len(self.get_images()) > 0:
214 | return {
215 | "system": self.system,
216 | "roles": self.roles,
217 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
218 | "offset": self.offset,
219 | "sep": self.sep,
220 | "sep2": self.sep2,
221 | }
222 | return {
223 | "system": self.system,
224 | "roles": self.roles,
225 | "messages": self.messages,
226 | "offset": self.offset,
227 | "sep": self.sep,
228 | "sep2": self.sep2,
229 | }
230 |
231 |
232 | conv_vicuna_v0 = Conversation(
233 | system="A chat between a curious human and an artificial intelligence assistant. "
234 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
235 | roles=("Human", "Assistant"),
236 | messages=(
237 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
238 | ("Assistant",
239 | "Renewable energy sources are those that can be replenished naturally in a relatively "
240 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
241 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
242 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
243 | "renewable and non-renewable energy sources:\n"
244 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
245 | "energy sources are finite and will eventually run out.\n"
246 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
247 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
248 | "and other negative effects.\n"
249 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
250 | "have lower operational costs than non-renewable sources.\n"
251 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
252 | "locations than non-renewable sources.\n"
253 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
254 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
255 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
256 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
257 | ),
258 | offset=2,
259 | sep_style=SeparatorStyle.SINGLE,
260 | sep="###",
261 | )
262 |
263 | conv_vicuna_v1 = Conversation(
264 | system="A chat between a curious user and an artificial intelligence assistant. "
265 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
266 | roles=("USER", "ASSISTANT"),
267 | version="v1",
268 | messages=(),
269 | offset=0,
270 | sep_style=SeparatorStyle.TWO,
271 | sep=" ",
272 | sep2="",
273 | )
274 |
275 | conv_llama_2 = Conversation(
276 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
277 |
278 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
279 | roles=("USER", "ASSISTANT"),
280 | version="llama_v2",
281 | messages=(),
282 | offset=0,
283 | sep_style=SeparatorStyle.LLAMA_2,
284 | sep="",
285 | sep2="",
286 | )
287 |
288 | conv_llava_llama_2 = Conversation(
289 | system="You are a helpful language and vision assistant. "
290 | "You are able to understand the visual content that the user provides, "
291 | "and assist the user with a variety of tasks using natural language.",
292 | roles=("USER", "ASSISTANT"),
293 | version="llama_v2",
294 | messages=(),
295 | offset=0,
296 | sep_style=SeparatorStyle.LLAMA_2,
297 | sep="",
298 | sep2="",
299 | )
300 |
301 | conv_mpt = Conversation(
302 | system="""<|im_start|>system
303 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
304 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
305 | version="mpt",
306 | messages=(),
307 | offset=0,
308 | sep_style=SeparatorStyle.MPT,
309 | sep="<|im_end|>",
310 | )
311 |
312 | conv_llava_plain = Conversation(
313 | system="",
314 | roles=("", ""),
315 | messages=(
316 | ),
317 | offset=0,
318 | sep_style=SeparatorStyle.PLAIN,
319 | sep="\n",
320 | )
321 |
322 | conv_llava_v0 = Conversation(
323 | system="A chat between a curious human and an artificial intelligence assistant. "
324 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
325 | roles=("Human", "Assistant"),
326 | messages=(
327 | ),
328 | offset=0,
329 | sep_style=SeparatorStyle.SINGLE,
330 | sep="###",
331 | )
332 |
333 | conv_llava_v0_mmtag = Conversation(
334 | system="A chat between a curious user and an artificial intelligence assistant. "
335 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
336 | "The visual content will be provided with the following format: visual content.",
337 | roles=("Human", "Assistant"),
338 | messages=(
339 | ),
340 | offset=0,
341 | sep_style=SeparatorStyle.SINGLE,
342 | sep="###",
343 | version="v0_mmtag",
344 | )
345 |
346 | conv_llava_v1 = Conversation(
347 | system="A chat between a curious human and an artificial intelligence assistant. "
348 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
349 | roles=("USER", "ASSISTANT"),
350 | version="v1",
351 | messages=(),
352 | offset=0,
353 | sep_style=SeparatorStyle.TWO,
354 | sep=" ",
355 | sep2="",
356 | )
357 |
358 | conv_llava_v1_mmtag = Conversation(
359 | system="A chat between a curious user and an artificial intelligence assistant. "
360 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
361 | "The visual content will be provided with the following format: visual content.",
362 | roles=("USER", "ASSISTANT"),
363 | messages=(),
364 | offset=0,
365 | sep_style=SeparatorStyle.TWO,
366 | sep=" ",
367 | sep2="",
368 | version="v1_mmtag",
369 | )
370 |
371 | default_conversation = conv_vicuna_v0
372 | conv_templates = {
373 | "default": conv_vicuna_v0,
374 | "v0": conv_vicuna_v0,
375 | "v1": conv_vicuna_v1,
376 | "vicuna_v1": conv_vicuna_v1,
377 | "llama_2": conv_llama_2,
378 |
379 | "plain": conv_llava_plain,
380 | "v0_plain": conv_llava_plain,
381 | "llava_v0": conv_llava_v0,
382 | "v0_mmtag": conv_llava_v0_mmtag,
383 | "llava_v1": conv_llava_v1,
384 | "v1_mmtag": conv_llava_v1_mmtag,
385 | "llava_llama_2": conv_llava_llama_2,
386 |
387 | "mpt": conv_mpt,
388 | }
389 |
390 | class Chat:
391 | def __init__(self, model, image_processor,tokenizer, device='cuda:0', stopping_criteria=None):
392 | self.device = device
393 | self.model = model
394 | self.vis_processor = image_processor
395 | self.tokenizer=tokenizer
396 |
397 | # if stopping_criteria is not None:
398 | # self.stopping_criteria = stopping_criteria
399 | # else:
400 | # stop_words_ids = [torch.tensor([2]).to(self.device)]
401 | # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
402 |
403 | def ask(self, text, conv):
404 | # import pdb;pdb.set_trace()
405 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
406 | and conv.messages[-1][1][-9:] == '\n': # last message is image.
407 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
408 | else:
409 | conv.append_message(conv.roles[0], text)
410 |
411 | def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
412 | repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
413 | conv.append_message(conv.roles[1], None)
414 | prompt = conv.get_prompt()
415 | # prompt='A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \n hello ASSISTANT:'
416 | text_input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device=self.device)
417 |
418 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
419 | keywords = [stop_str]
420 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, text_input_ids)
421 | current_max_len = text_input_ids.shape[1] + max_new_tokens
422 | if current_max_len - max_length > 0:
423 | print('Warning: The number of tokens in current conversation exceeds the max length. '
424 | 'The model will not see the contexts outside the range.')
425 | begin_idx = max(0, current_max_len - max_length)
426 | embs = text_input_ids[:, begin_idx:]
427 |
428 | generation_kwargs = dict(
429 | input_ids=embs,
430 | images=img_list[0],
431 | max_new_tokens=max_new_tokens,
432 | stopping_criteria=[stopping_criteria],
433 | num_beams=num_beams,
434 | do_sample=True,
435 | min_length=min_length,
436 | top_p=top_p,
437 | use_cache=True,
438 | repetition_penalty=repetition_penalty,
439 | length_penalty=length_penalty,
440 | temperature=float(temperature),
441 | )
442 | return generation_kwargs
443 |
444 | # def answer(self, conv, img_list, **kargs):
445 | # generation_dict = self.answer_prepare(conv, img_list, **kargs)
446 | # output_token = self.model_generate(**generation_dict)[0]
447 | # output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
448 |
449 | # output_text = output_text.split('###')[0] # remove the stop sign '###'
450 | # output_text = output_text.split('Assistant:')[-1].strip()
451 |
452 | # conv.messages[-1][1] = output_text
453 | # return output_text, output_token.cpu().numpy()
454 |
455 | def stream_answer(self, conv, img_list, **kargs):
456 | generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
457 |
458 | streamer = TextIteratorStreamer(self.tokenizer,skip_prompt=True, skip_special_tokens=True)
459 | generation_kwargs['streamer'] = streamer
460 | # import pdb;pdb.set_trace()
461 | # output_ids=self.model.generate(*generation_kwargs)
462 | output=self.model_generate(kwargs=generation_kwargs)
463 | # thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
464 | # thread.start()
465 | return streamer
466 |
467 | def model_generate(self, *args, **kwargs):
468 | # for 8 bit and 16 bit compatibility
469 | with torch.inference_mode():
470 | output = self.model.generate(kwargs['kwargs']['input_ids'],
471 | images=kwargs['kwargs']['images'],
472 | do_sample=False,
473 | temperature=kwargs['kwargs']['temperature'],
474 | max_new_tokens=kwargs['kwargs']['max_new_tokens'],
475 | streamer=kwargs['kwargs']['streamer'],
476 | use_cache=kwargs['kwargs']['use_cache'],
477 | stopping_criteria=kwargs['kwargs']['stopping_criteria'])
478 | # import pdb;pdb.set_trace()
479 | # print(output)
480 | outputs = self.tokenizer.decode(output[0,kwargs['kwargs']['input_ids'].shape[1]:]).strip()
481 | # print(outputs)
482 | return output
483 |
484 | def encode_img(self, img_list):
485 |
486 | image = img_list[0]
487 | img_list.pop(0)
488 | if isinstance(image, str): # is a image path
489 | raw_image = Image.open(image).convert('RGB')
490 | image = process_images_demo([raw_image], self.vis_processor)
491 | # print("raw")
492 | # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
493 | elif isinstance(image, Image.Image):
494 | raw_image = image
495 | image = process_images_demo([raw_image], self.vis_processor )
496 | image=image.to(device=self.device,dtype=torch.float16)
497 | # print("Image")
498 | # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
499 | elif isinstance(image, torch.Tensor):
500 | if len(image.shape) == 3:
501 | image = image.unsqueeze(0)
502 | image = image.to(self.device)
503 |
504 | # image_emb, _ = self.model.encode_img(image)
505 | img_list.append(image)
506 |
507 | def upload_img(self, image, conv, img_list):
508 | conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN+'\n')
509 | img_list.append(image)
510 | msg = "Received."
511 |
512 | return msg
513 |
514 |
515 |
516 | # if __name__ == "__main__":
517 | # print(default_conversation.get_prompt())
518 |
--------------------------------------------------------------------------------
/cdchat/eval/batch_cdchat_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from cdchat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from cdchat.conversation import conv_templates, SeparatorStyle
10 | from cdchat.model.builder import load_pretrained_model
11 | from cdchat.utils import disable_torch_init
12 | from cdchat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13 |
14 | from PIL import Image
15 | import math
16 | import numpy as np
17 |
18 |
19 | def split_list(lst, n):
20 | """Split a list into n (roughly) equal-sized chunks"""
21 | chunk_size = math.ceil(len(lst) / n) # integer division
22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23 |
24 |
25 | def get_chunk(lst, n, k):
26 | chunks = split_list(lst, n)
27 | return chunks[k]
28 |
29 |
30 | def eval_model(args):
31 | # Model
32 | disable_torch_init()
33 | model_path = os.path.expanduser(args.model_path)
34 | model_name = get_model_name_from_path(model_path)
35 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, mm_projector_path=args.mm_projector_path)
36 |
37 | #questions=[]
38 | #questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
39 | fid = open(args.question_file, 'r')
40 | questions = json.load(fid)
41 | fid.close()
42 |
43 |
44 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
45 | answers_file = os.path.expanduser(args.answers_file)
46 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
47 |
48 | ans_file = open(answers_file, "w")
49 |
50 | for i in tqdm(range(0,len(questions),args.batch_size)):
51 | input_batch=[]
52 | input_image_batch=[]
53 | count=i
54 | image_folder_A = []
55 | image_folder_B = []
56 | image_label_list = []
57 | batch_end = min(i + args.batch_size, len(questions))
58 |
59 |
60 | for j in range(i,batch_end):
61 | #image_file=questions[j]['image']
62 | #qs=questions[j]['text']
63 | image_file = questions[j]['img_id']
64 | print(image_file)
65 | qs = questions[j]['question']
66 |
67 | if model.config.mm_use_im_start_end:
68 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
69 | else:
70 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
71 |
72 | conv = conv_templates[args.conv_mode].copy()
73 | conv.append_message(conv.roles[0], qs)
74 | conv.append_message(conv.roles[1], None)
75 | prompt = conv.get_prompt()
76 |
77 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
78 | input_batch.append(input_ids)
79 |
80 | #image = Image.open(os.path.join(args.image_folder, image_file))
81 | #image_folder.append(image)
82 | image_A = Image.open(os.path.join(args.image_folder, 'A', image_file)).convert('RGB')
83 | image_B = Image.open(os.path.join(args.image_folder, 'B', image_file)).convert('RGB')
84 | image_folder_label = os.path.join(args.image_folder, 'label')
85 | image_label = Image.open((os.path.join(image_folder_label, image_file)).strip()).convert('L')
86 | image_label = (torch.Tensor(1)*np.array(image_label)/255.).unsqueeze(0)
87 |
88 | image_tensor_A = image_processor.preprocess(image_A,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
89 | image_tensor_B = image_processor.preprocess(image_B,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
90 |
91 | image_tensor_A = image_tensor_A[[2,1,0], :, :]
92 | image_tensor_B = image_tensor_B[[2,1,0], :, :]
93 |
94 | image_folder_A.append(image_tensor_A)
95 | image_folder_B.append(image_tensor_B)
96 | image_label_list.append(image_label)
97 |
98 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
99 | keywords = [stop_str]
100 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
101 |
102 | max_length = max(tensor.size(1) for tensor in input_batch)
103 |
104 | final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
105 | final_input_tensors=torch.cat(final_input_list,dim=0)
106 | #image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
107 | image_tensor_batch = {'pre': torch.stack(image_folder_A).half().cuda(),
108 | 'post': torch.stack(image_folder_B).half().cuda(),
109 | 'targets': torch.stack(image_label_list).cuda()}
110 |
111 | with torch.inference_mode():
112 | output_ids = model.generate( final_input_tensors, images=image_tensor_batch, do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
113 |
114 | input_token_len = final_input_tensors.shape[1]
115 | n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
116 | if n_diff_input_output > 0:
117 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
118 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
119 | for k in range(0,len(final_input_list)):
120 | output = outputs[k].strip()
121 | if output.endswith(stop_str):
122 | output = output[:-len(stop_str)]
123 | output = output.strip()
124 |
125 | ans_id = shortuuid.uuid()
126 |
127 | ans_file.write(json.dumps({
128 | "question_id": questions[count]["question"],
129 | "image_id": questions[count]["img_id"],
130 | "answer": output,
131 | }) + "\n")
132 | count=count+1
133 | ans_file.flush()
134 | ans_file.close()
135 |
136 |
137 | if __name__ == "__main__":
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument("--model-path", type=str, default="./cdchat/cdchat/checkpoints/cdchat_lora")
140 | parser.add_argument("--model-base", type=str, default='./cdchat/llava-v1.5-7b')
141 | parser.add_argument("--mm-projector-path", type=str, default='./cdchat/cdchat/checkpoints/pretrain_mm_projector/mm_projector.bin')
142 | parser.add_argument("--image-folder", type=str, default="./dataset/cd_dataset/")
143 | parser.add_argument("--question-file", type=str, default="./dataset/LEVIR-CD-256/questions_levir_test.json")
144 | parser.add_argument("--answers-file", type=str, default="./answer_levir.jsonl")
145 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
146 | parser.add_argument("--num-chunks", type=int, default=1)
147 | parser.add_argument("--chunk-idx", type=int, default=0)
148 | parser.add_argument("--temperature", type=float, default=0.2)
149 | parser.add_argument("--top_p", type=float, default=None)
150 | parser.add_argument("--num_beams", type=int, default=1)
151 | parser.add_argument("--batch_size",type=int, default=1)
152 | args = parser.parse_args()
153 |
154 | eval_model(args)
155 |
--------------------------------------------------------------------------------
/cdchat/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from cdchat.constants import IMAGE_TOKEN_INDEX
8 | import numpy as np
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 |
14 | def expand2square(pil_img, background_color):
15 | width, height = pil_img.size
16 | if width == height:
17 | return pil_img
18 | elif width > height:
19 | result = Image.new(pil_img.mode, (width, width), background_color)
20 | result.paste(pil_img, (0, (width - height) // 2))
21 | return result
22 | else:
23 | result = Image.new(pil_img.mode, (height, height), background_color)
24 | result.paste(pil_img, ((height - width) // 2, 0))
25 | return result
26 |
27 |
28 | def process_images(images, image_processor, model_cfg):
29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30 | new_images = []
31 | if image_aspect_ratio == 'pad':
32 | for image in images:
33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34 | image = image_processor.preprocess(image,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448},return_tensors='pt')['pixel_values'][0]
35 | # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
36 |
37 | new_images.append(image)
38 | else:
39 | return image_processor(images, return_tensors='pt')['pixel_values']
40 | if all(x.shape == new_images[0].shape for x in new_images):
41 | new_images = torch.stack(new_images, dim=0)
42 | return new_images
43 |
44 | def process_images_demo(images, image_processor):
45 | new_images = []
46 | # image_aspect_ratio = 'pad'
47 | for image in images:
48 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
49 | image = image_processor.preprocess(image,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448},return_tensors='pt')['pixel_values'][0]
50 | # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
51 |
52 | new_images.append(image)
53 |
54 | if all(x.shape == new_images[0].shape for x in new_images):
55 | new_images = torch.stack(new_images, dim=0)
56 | return new_images
57 |
58 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
59 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
60 |
61 | def insert_separator(X, sep):
62 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63 |
64 | input_ids = []
65 | offset = 0
66 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
67 | offset = 1
68 | input_ids.append(prompt_chunks[0][0])
69 |
70 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
71 | input_ids.extend(x[offset:])
72 |
73 | if return_tensors is not None:
74 | if return_tensors == 'pt':
75 | return torch.tensor(input_ids, dtype=torch.long)
76 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
77 | return input_ids
78 |
79 |
80 | def get_model_name_from_path(model_path):
81 | model_path = model_path.strip("/")
82 | model_paths = model_path.split("/")
83 | if model_paths[-1].startswith('checkpoint-'):
84 | return model_paths[-2] + "_" + model_paths[-1]
85 | else:
86 | return model_paths[-1]
87 |
88 |
89 |
90 |
91 | class KeywordsStoppingCriteria(StoppingCriteria):
92 | def __init__(self, keywords, tokenizer, input_ids):
93 | self.keywords = keywords
94 | self.keyword_ids = []
95 | self.max_keyword_len = 0
96 | for keyword in keywords:
97 | cur_keyword_ids = tokenizer(keyword).input_ids
98 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
99 | cur_keyword_ids = cur_keyword_ids[1:]
100 | if len(cur_keyword_ids) > self.max_keyword_len:
101 | self.max_keyword_len = len(cur_keyword_ids)
102 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
103 | self.tokenizer = tokenizer
104 | self.start_len = input_ids.shape[1]
105 |
106 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
107 | # assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
108 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
109 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
110 | for keyword_id in self.keyword_ids:
111 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
112 | return True
113 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
114 | flag=False
115 | for output in outputs:
116 |
117 | for keyword in self.keywords:
118 | if keyword in output:
119 | flag=True
120 | return flag
121 | return flag
--------------------------------------------------------------------------------
/cdchat/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.cdchat_llama import CDChatLlamaForCausalLM, CDChatConfig
2 |
--------------------------------------------------------------------------------
/cdchat/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from cdchat import CDChatLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/cdchat/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | import torch
22 | from cdchat.model import *
23 | from cdchat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 | from cdchat.model.multimodal_projector.builder import build_vision_projector
25 |
26 |
27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", mm_projector_path=None):
28 | kwargs = {"device_map": device_map}
29 |
30 | if load_8bit:
31 | kwargs['load_in_8bit'] = True
32 | elif load_4bit:
33 | kwargs['load_in_4bit'] = True
34 | kwargs['quantization_config'] = BitsAndBytesConfig(
35 | load_in_4bit=True,
36 | bnb_4bit_compute_dtype=torch.float16,
37 | bnb_4bit_use_double_quant=True,
38 | bnb_4bit_quant_type='nf4'
39 | )
40 | else:
41 | kwargs['torch_dtype'] = torch.float16
42 |
43 | if 'cdchat' in model_name.lower():
44 | # Load LLaVA model
45 | if 'lora' in model_name.lower() and model_base is None:
46 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
47 | if 'lora' in model_name.lower() and model_base is not None:
48 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
49 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
50 | print('Loading CDChat from base model...')
51 | model = CDChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
52 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
53 | if model.lm_head.weight.shape[0] != token_num:
54 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
56 |
57 | print('Loading additional CDChat weights...')
58 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
59 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
60 | else:
61 | # this is probably from HF Hub
62 | from huggingface_hub import hf_hub_download
63 | def load_from_hf(repo_id, filename, subfolder=None):
64 | cache_file = hf_hub_download(
65 | repo_id=repo_id,
66 | filename=filename,
67 | subfolder=subfolder)
68 | return torch.load(cache_file, map_location='cpu')
69 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
70 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
71 | if any(k.startswith('model.model.') for k in non_lora_trainables):
72 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
73 | model.load_state_dict(non_lora_trainables, strict=False)
74 |
75 |
76 | from peft import PeftModel
77 | print('Loading LoRA weights...')
78 | model = PeftModel.from_pretrained(model, model_path)
79 | print('Merging LoRA weights...')
80 | model = model.merge_and_unload()
81 | print('Model is loaded...')
82 | elif model_base is not None:
83 | # this may be mm projector only
84 | print('Loading CDChat from base model...')
85 |
86 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
87 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
88 | model = CDChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89 |
90 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
91 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
92 | model.load_state_dict(mm_projector_weights, strict=False)
93 | else:
94 | print("Loading CDChat......")
95 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
96 | model = CDChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
97 | else:
98 | # Load language model
99 | if model_base is not None:
100 | # PEFT model
101 | from peft import PeftModel
102 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
103 | model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
104 | print(f"Loading LoRA weights from {model_path}")
105 | model = PeftModel.from_pretrained(model, model_path)
106 | print(f"Merging weights")
107 | model = model.merge_and_unload()
108 | print('Convert to FP16...')
109 | model.to(torch.float16)
110 | else:
111 | use_fast = False
112 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
113 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
114 |
115 | image_processor = None
116 |
117 | if 'cdchat' in model_name.lower():
118 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
119 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
120 | if mm_use_im_patch_token:
121 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
122 | if mm_use_im_start_end:
123 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
124 | model.resize_token_embeddings(len(tokenizer))
125 |
126 |
127 | vision_tower = model.get_vision_tower()
128 | if not vision_tower.is_loaded:
129 | vision_tower.load_model()
130 | vision_tower.to(device=device, dtype=torch.float16)
131 | image_processor = vision_tower.image_processor
132 |
133 | ####################################
134 | model.model.mm_projector = build_vision_projector(model.model.config)
135 | if mm_projector_path is not None:
136 | mm_projector_weights = torch.load(mm_projector_path, map_location='cpu')
137 |
138 | def get_w(weights, keyword):
139 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
140 |
141 | msg = model.model.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
142 | print(msg)
143 |
144 | model.model.mm_projector.to(device=device, dtype=torch.float16)
145 | ####################################################################
146 |
147 | if hasattr(model.config, "max_sequence_length"):
148 | context_len = model.config.max_sequence_length
149 | else:
150 | context_len = 2048
151 |
152 | return tokenizer, model, image_processor, context_len
153 |
--------------------------------------------------------------------------------
/cdchat/model/cdchat_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 |
21 | from .multimodal_encoder.builder import build_vision_tower
22 | from .multimodal_projector.builder import build_vision_projector
23 |
24 | from cdchat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25 |
26 |
27 | class CDChatMetaModel:
28 |
29 | def __init__(self, config):
30 | super(CDChatMetaModel, self).__init__(config)
31 |
32 | if hasattr(config, "mm_vision_tower"):
33 | self.vision_tower = build_vision_tower(config, delay_load=True)
34 | self.mm_projector = build_vision_projector(config)
35 |
36 | def get_vision_tower(self):
37 | vision_tower = getattr(self, 'vision_tower', None)
38 | if type(vision_tower) is list:
39 | vision_tower = vision_tower[0]
40 | return vision_tower
41 |
42 | def initialize_vision_modules(self, model_args, fsdp=None):
43 | vision_tower = model_args.vision_tower
44 | mm_vision_select_layer = model_args.mm_vision_select_layer
45 | mm_vision_select_feature = model_args.mm_vision_select_feature
46 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47 |
48 | self.config.mm_vision_tower = vision_tower
49 |
50 | if self.get_vision_tower() is None:
51 | vision_tower = build_vision_tower(model_args)
52 |
53 | if fsdp is not None and len(fsdp) > 0:
54 | self.vision_tower = [vision_tower]
55 | else:
56 | self.vision_tower = vision_tower
57 | else:
58 | if fsdp is not None and len(fsdp) > 0:
59 | vision_tower = self.vision_tower[0]
60 | else:
61 | vision_tower = self.vision_tower
62 | vision_tower.load_model()
63 |
64 | self.config.use_mm_proj = True
65 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66 | self.config.mm_hidden_size = vision_tower.embed_dim #vision_tower.hidden_size
67 | self.config.mm_vision_select_layer = mm_vision_select_layer
68 | self.config.mm_vision_select_feature = mm_vision_select_feature
69 |
70 | if getattr(self, 'mm_projector', None) is None:
71 | self.mm_projector = build_vision_projector(self.config)
72 | # print(mm_projector)
73 |
74 |
75 | if pretrain_mm_mlp_adapter is not None:
76 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
77 |
78 | def get_w(weights, keyword):
79 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
80 |
81 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
82 |
83 |
84 |
85 |
86 | class CDChatMetaForCausalLM(ABC):
87 |
88 | @abstractmethod
89 | def get_model(self):
90 | pass
91 |
92 | def get_vision_tower(self):
93 | return self.get_model().get_vision_tower()
94 |
95 | def encode_images(self, images):
96 | #image_features = self.get_model().get_vision_tower()(images)
97 | #image_features = self.get_model().mm_projector(image_features)
98 |
99 | image_features_A = self.get_model().get_vision_tower()(images['pre'])
100 | image_features_B = self.get_model().get_vision_tower()(images['post'])
101 |
102 | image_features = torch.cat([image_features_A, image_features_B], dim=-1)
103 | proj_features = self.get_model().mm_projector(image_features)
104 |
105 | return image_features, proj_features
106 |
107 | def prepare_inputs_labels_for_multimodal(
108 | self, input_ids, attention_mask, past_key_values, labels, images
109 | ):
110 | vision_tower = self.get_vision_tower()
111 | if vision_tower is None or images is None or input_ids.shape[1] == 1:
112 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
113 | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
114 | return input_ids, attention_mask, past_key_values, None, labels
115 |
116 | '''
117 | if type(images) is list or images.ndim == 5:
118 | concat_images = torch.cat([image for image in images], dim=0)
119 |
120 | #image_features = self.encode_images(concat_images)
121 | diff_features, image_features = self.encode_images(images)
122 |
123 | split_sizes = [image.shape[0] for image in images]
124 | image_features = torch.split(image_features, split_sizes, dim=0)
125 | image_features = [x.flatten(0, 1) for x in image_features]
126 | else:
127 | image_features = self.encode_images(images)
128 | '''
129 | diff_features, image_features = self.encode_images(images)
130 |
131 | new_input_embeds = []
132 | new_labels = [] if labels is not None else None
133 | cur_image_idx = 0
134 | for batch_idx, cur_input_ids in enumerate(input_ids):
135 | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
136 | # multimodal LLM, but the current sample is not multimodal
137 | # FIXME: this is a hacky fix, for deepspeed zero3 to work
138 | half_len = cur_input_ids.shape[0] // 2
139 | cur_image_features = image_features[cur_image_idx]
140 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
141 | cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
142 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
143 | new_input_embeds.append(cur_input_embeds)
144 | if labels is not None:
145 | new_labels.append(labels[batch_idx])
146 | cur_image_idx += 1
147 | continue
148 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
149 | cur_new_input_embeds = []
150 | if labels is not None:
151 | cur_labels = labels[batch_idx]
152 | cur_new_labels = []
153 | assert cur_labels.shape == cur_input_ids.shape
154 | while image_token_indices.numel() > 0:
155 | cur_image_features = image_features[cur_image_idx]
156 | image_token_start = image_token_indices[0]
157 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
158 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
159 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
160 | cur_new_input_embeds.append(cur_image_features)
161 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
162 | if labels is not None:
163 | cur_new_labels.append(cur_labels[:image_token_start])
164 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
165 | cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
166 | cur_labels = cur_labels[image_token_start+2:]
167 | else:
168 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
169 | cur_new_input_embeds.append(cur_image_features)
170 | if labels is not None:
171 | cur_new_labels.append(cur_labels[:image_token_start])
172 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
173 | cur_labels = cur_labels[image_token_start+1:]
174 | cur_image_idx += 1
175 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
176 | cur_input_ids = cur_input_ids[image_token_start+2:]
177 | else:
178 | cur_input_ids = cur_input_ids[image_token_start+1:]
179 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
180 | if cur_input_ids.numel() > 0:
181 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
182 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
183 | else:
184 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
185 | if labels is not None:
186 | cur_new_labels.append(cur_labels)
187 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
188 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
189 | new_input_embeds.append(cur_new_input_embeds)
190 | if labels is not None:
191 | cur_new_labels = torch.cat(cur_new_labels, dim=0)
192 | new_labels.append(cur_new_labels)
193 |
194 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
195 | max_len = max(x.shape[0] for x in new_input_embeds)
196 |
197 | new_input_embeds_align = []
198 | for cur_new_embed in new_input_embeds:
199 | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
200 | new_input_embeds_align.append(cur_new_embed)
201 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
202 |
203 | if labels is not None:
204 | new_labels_align = []
205 | _new_labels = new_labels
206 | for cur_new_label in new_labels:
207 | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
208 | new_labels_align.append(cur_new_label)
209 | new_labels = torch.stack(new_labels_align, dim=0)
210 |
211 | if attention_mask is not None:
212 | new_attention_mask = []
213 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
214 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
215 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
216 | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
217 | new_attention_mask.append(cur_new_attention_mask)
218 | attention_mask = torch.stack(new_attention_mask, dim=0)
219 | assert attention_mask.shape == new_labels.shape
220 | else:
221 | new_input_embeds = torch.stack(new_input_embeds, dim=0)
222 | if labels is not None:
223 | new_labels = torch.stack(new_labels, dim=0)
224 |
225 | if attention_mask is not None:
226 | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
227 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
228 | assert attention_mask.shape == new_input_embeds.shape[:2]
229 |
230 | return None, attention_mask, past_key_values, new_input_embeds, new_labels
231 |
232 | def initialize_vision_tokenizer(self, model_args, tokenizer):
233 | if model_args.mm_use_im_patch_token:
234 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
235 | self.resize_token_embeddings(len(tokenizer))
236 |
237 | if model_args.mm_use_im_start_end:
238 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
239 | self.resize_token_embeddings(len(tokenizer))
240 |
241 | if num_new_tokens > 0:
242 | input_embeddings = self.get_input_embeddings().weight.data
243 | output_embeddings = self.get_output_embeddings().weight.data
244 |
245 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
246 | dim=0, keepdim=True)
247 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
248 | dim=0, keepdim=True)
249 |
250 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
251 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
252 |
253 | if model_args.tune_mm_mlp_adapter:
254 | for p in self.get_input_embeddings().parameters():
255 | p.requires_grad = True
256 | for p in self.get_output_embeddings().parameters():
257 | p.requires_grad = False
258 |
259 | if model_args.pretrain_mm_mlp_adapter:
260 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
261 | print(mm_projector_weights)
262 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
263 | assert num_new_tokens == 2
264 | if input_embeddings.shape == embed_tokens_weight.shape:
265 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
266 | elif embed_tokens_weight.shape[0] == num_new_tokens:
267 | input_embeddings[-num_new_tokens:] = embed_tokens_weight
268 | else:
269 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
270 | elif model_args.mm_use_im_patch_token:
271 | if model_args.tune_mm_mlp_adapter:
272 | for p in self.get_input_embeddings().parameters():
273 | p.requires_grad = False
274 | for p in self.get_output_embeddings().parameters():
275 | p.requires_grad = False
276 |
--------------------------------------------------------------------------------
/cdchat/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from cdchat.model import *
10 | from cdchat.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/cdchat/model/language_model/cdchat_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | LlamaConfig, LlamaModel, LlamaForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 |
27 | from ..cdchat_arch import CDChatMetaModel, CDChatMetaForCausalLM
28 |
29 |
30 | class CDChatConfig(LlamaConfig):
31 | model_type = "cdchat"
32 |
33 |
34 | class CDChatLlamaModel(CDChatMetaModel, LlamaModel):
35 | config_class = CDChatConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(CDChatLlamaModel, self).__init__(config)
39 |
40 |
41 | class CDChatLlamaForCausalLM(LlamaForCausalLM, CDChatMetaForCausalLM):
42 | config_class = CDChatConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = CDChatLlamaModel(config)
47 |
48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.model
55 |
56 | def forward(
57 | self,
58 | input_ids: torch.LongTensor = None,
59 | attention_mask: Optional[torch.Tensor] = None,
60 | past_key_values: Optional[List[torch.FloatTensor]] = None,
61 | inputs_embeds: Optional[torch.FloatTensor] = None,
62 | labels: Optional[torch.LongTensor] = None,
63 | use_cache: Optional[bool] = None,
64 | output_attentions: Optional[bool] = None,
65 | output_hidden_states: Optional[bool] = None,
66 | images: Optional[torch.FloatTensor] = None,
67 | return_dict: Optional[bool] = None,
68 | ) -> Union[Tuple, CausalLMOutputWithPast]:
69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70 | output_hidden_states = (
71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72 | )
73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74 |
75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76 |
77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78 | outputs = self.model(
79 | input_ids=input_ids,
80 | attention_mask=attention_mask,
81 | past_key_values=past_key_values,
82 | inputs_embeds=inputs_embeds,
83 | use_cache=use_cache,
84 | output_attentions=output_attentions,
85 | output_hidden_states=output_hidden_states,
86 | return_dict=return_dict
87 | )
88 |
89 | hidden_states = outputs[0]
90 | logits = self.lm_head(hidden_states)
91 |
92 | loss = None
93 | if labels is not None:
94 | # Shift so that tokens < n predict n
95 | shift_logits = logits[..., :-1, :].contiguous()
96 | shift_labels = labels[..., 1:].contiguous()
97 | # Flatten the tokens
98 | loss_fct = CrossEntropyLoss()
99 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
100 | shift_labels = shift_labels.view(-1)
101 | # Enable model/pipeline parallelism
102 | shift_labels = shift_labels.to(shift_logits.device)
103 | loss = loss_fct(shift_logits, shift_labels)
104 |
105 | if not return_dict:
106 | output = (logits,) + outputs[1:]
107 | return (loss,) + output if loss is not None else output
108 |
109 | return CausalLMOutputWithPast(
110 | loss=loss,
111 | logits=logits,
112 | past_key_values=outputs.past_key_values,
113 | hidden_states=outputs.hidden_states,
114 | attentions=outputs.attentions,
115 | )
116 |
117 | def prepare_inputs_for_generation(
118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119 | ):
120 | if past_key_values:
121 | input_ids = input_ids[:, -1:]
122 |
123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124 | if inputs_embeds is not None and past_key_values is None:
125 | model_inputs = {"inputs_embeds": inputs_embeds}
126 | else:
127 | model_inputs = {"input_ids": input_ids}
128 |
129 | model_inputs.update(
130 | {
131 | "past_key_values": past_key_values,
132 | "use_cache": kwargs.get("use_cache"),
133 | "attention_mask": attention_mask,
134 | "images": kwargs.get("images", None),
135 | }
136 | )
137 | return model_inputs
138 |
139 | AutoConfig.register("cdchat", CDChatConfig)
140 | AutoModelForCausalLM.register(CDChatConfig, CDChatLlamaForCausalLM)
141 |
--------------------------------------------------------------------------------
/cdchat/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from cdchat.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/cdchat/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/cdchat/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from PIL import ImageFile
5 | ImageFile.LOAD_TRUNCATED_IMAGES = True
6 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
7 |
8 |
9 | class CLIPVisionTower(nn.Module):
10 | def clip_interpolate_embeddings(self, image_size=448, patch_size= 14):
11 | """This function helps interpolating positional embeddings during checkpoint loading,
12 | especially when you want to apply a pre-trained model on images with different resolution.
13 |
14 | Args:
15 | image_size (int): Image size of the new model.
16 | patch_size (int): Patch size of the new model.
17 | model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
18 | interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
19 | reset_heads (bool): If true, not copying the state of heads. Default: False.
20 |
21 | Returns:
22 | OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
23 | """
24 | # Shape of pos_embedding is (1, seq_length, hidden_dim)
25 | state_dict = self.vision_tower.vision_model.embeddings.position_embedding.state_dict()
26 | pos_embedding = state_dict['weight']
27 | pos_embedding = pos_embedding.unsqueeze(0)
28 | n, seq_length, hidden_dim = pos_embedding.shape
29 | if n != 1:
30 | raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
31 |
32 | new_seq_length = (image_size // patch_size) ** 2 + 1
33 |
34 | # Need to interpolate the weights for the position embedding.
35 | # We do this by reshaping the positions embeddings to a 2d grid, performing
36 | # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
37 | if new_seq_length != seq_length:
38 | # The class token embedding shouldn't be interpolated so we split it up.
39 | seq_length -= 1
40 | new_seq_length -= 1
41 | pos_embedding_token = pos_embedding[:, :1, :]
42 | pos_embedding_img = pos_embedding[:, 1:, :]
43 |
44 | # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
45 | pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
46 | seq_length_1d = int(math.sqrt(seq_length))
47 | torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
48 |
49 | # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
50 | pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
51 | new_seq_length_1d = image_size // patch_size
52 |
53 | # Perform interpolation.
54 | # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
55 | new_pos_embedding_img = nn.functional.interpolate(
56 | pos_embedding_img,
57 | size=new_seq_length_1d,
58 | mode='bicubic',
59 | align_corners=True,
60 | )
61 |
62 | # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
63 | new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
64 |
65 | # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
66 | new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
67 | new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)[0]
68 | state_dict['weight'] = new_pos_embedding
69 | self.vision_tower.vision_model.embeddings.position_embedding = nn.Embedding(new_seq_length+1, hidden_dim)
70 | self.vision_tower.vision_model.embeddings.position_embedding.load_state_dict(state_dict)
71 | self.vision_tower.vision_model.embeddings.image_size = image_size
72 | self.vision_tower.vision_model.embeddings.patch_size = patch_size
73 | self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_seq_length+1).expand((1, -1))
74 |
75 | def __init__(self, vision_tower, args, delay_load=False):
76 | super().__init__()
77 |
78 | self.is_loaded = False
79 |
80 | self.vision_tower_name = vision_tower
81 | self.select_layer = args.mm_vision_select_layer
82 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
83 |
84 | if not delay_load:
85 | self.load_model()
86 | else:
87 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
88 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
89 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
90 | self.vision_tower.requires_grad_(False)
91 | self.clip_interpolate_embeddings(image_size=448, patch_size=14)
92 |
93 |
94 | def load_model(self):
95 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
96 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
97 | self.vision_tower.requires_grad_(False)
98 | self.clip_interpolate_embeddings(image_size=448, patch_size=14)
99 |
100 | self.is_loaded = True
101 | # print(self.is_loaded)
102 |
103 | def feature_select(self, image_forward_outs):
104 | image_features = image_forward_outs.hidden_states[self.select_layer]
105 | if self.select_feature == 'patch':
106 | image_features = image_features[:, 1:]
107 | elif self.select_feature == 'cls_patch':
108 | image_features = image_features
109 | else:
110 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
111 | return image_features
112 |
113 | @torch.no_grad()
114 | def forward(self, images):
115 | if type(images) is list:
116 | image_features = []
117 | for image in images:
118 | # print(image.shape)
119 | # import pdb; pdb.set_trace()
120 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
121 |
122 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
123 | # print(image_features.shape)
124 |
125 | image_features.append(image_feature)
126 | else:
127 | # print(images.shape)
128 | # import pdb; pdb.set_trace()
129 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
130 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
131 | # print(image_features.shape)
132 |
133 |
134 | return image_features
135 |
136 | @property
137 | def dummy_feature(self):
138 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
139 |
140 | @property
141 | def dtype(self):
142 | return self.vision_tower.dtype
143 |
144 | @property
145 | def device(self):
146 | return self.vision_tower.device
147 |
148 | @property
149 | def config(self):
150 | if self.is_loaded:
151 | return self.vision_tower.config
152 | else:
153 | return self.cfg_only
154 |
155 | @property
156 | def hidden_size(self):
157 | return self.config.hidden_size
158 |
159 | @property
160 | def num_patches(self):
161 | return (self.config.image_size // self.config.patch_size) ** 2
162 |
--------------------------------------------------------------------------------
/cdchat/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | print(x)
31 |
32 | return x + self.proj(x)
33 |
34 |
35 | def build_vision_projector(config, delay_load=False, **kwargs):
36 | projector_type = getattr(config, 'mm_projector_type', 'linear')
37 |
38 | if projector_type == 'linear':
39 | return nn.Linear(config.mm_hidden_size*2, config.hidden_size)
40 |
41 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
42 | if mlp_gelu_match:
43 | mlp_depth = int(mlp_gelu_match.group(1))
44 | modules = [nn.Linear(config.mm_hidden_size*2, config.hidden_size)]
45 | for _ in range(1, mlp_depth):
46 | modules.append(nn.GELU())
47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48 | return nn.Sequential(*modules)
49 |
50 | if projector_type == 'identity':
51 | return IdentityMap()
52 |
53 | raise ValueError(f'Unknown projector type: {projector_type}')
54 |
--------------------------------------------------------------------------------
/cdchat/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/cdchat/train/cdchat_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from torch.utils.data import Sampler
5 |
6 | from transformers import Trainer
7 | from transformers.trainer import (
8 | has_length,
9 | )
10 | from typing import List, Optional
11 |
12 |
13 | def maybe_zero_3(param, ignore_status=False, name=None):
14 | from deepspeed import zero
15 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
16 | if hasattr(param, "ds_id"):
17 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
18 | if not ignore_status:
19 | print(name, 'no ignore status')
20 | with zero.GatheredParameters([param]):
21 | param = param.data.detach().cpu().clone()
22 | else:
23 | param = param.detach().cpu().clone()
24 | return param
25 |
26 |
27 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
28 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
29 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
30 | return to_return
31 |
32 |
33 | def split_to_even_chunks(indices, lengths, num_chunks):
34 | """
35 | Split a list of indices into `chunks` chunks of roughly equal lengths.
36 | """
37 |
38 | if len(indices) % num_chunks != 0:
39 | return [indices[i::num_chunks] for i in range(num_chunks)]
40 |
41 | num_indices_per_chunk = len(indices) // num_chunks
42 |
43 | chunks = [[] for _ in range(num_chunks)]
44 | chunks_lengths = [0 for _ in range(num_chunks)]
45 | for index in indices:
46 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
47 | chunks[shortest_chunk].append(index)
48 | chunks_lengths[shortest_chunk] += lengths[index]
49 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
50 | chunks_lengths[shortest_chunk] = float("inf")
51 |
52 | return chunks
53 |
54 |
55 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
56 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
57 | assert all(l != 0 for l in lengths), "Should not have zero length."
58 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
59 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
60 |
61 | assert len(mm_indices) > 0, "Should have at least one multimodal sample."
62 | assert len(lang_indices) > 0, "Should have at least one language sample."
63 |
64 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
65 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
66 | megabatch_size = world_size * batch_size
67 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
68 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
69 |
70 | last_mm = mm_megabatches[-1]
71 | last_lang = lang_megabatches[-1]
72 | additional_batch = last_mm + last_lang
73 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
74 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
75 | megabatches = [megabatches[i] for i in megabatch_indices]
76 |
77 | if len(additional_batch) >= megabatch_size:
78 | megabatches = [additional_batch[:megabatch_size]] + megabatches
79 | additional_batch = additional_batch[megabatch_size:]
80 |
81 | if len(additional_batch) > 0:
82 | megabatches.append(additional_batch)
83 |
84 | return [i for megabatch in megabatches for i in megabatch]
85 |
86 |
87 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
88 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
89 | indices = torch.randperm(len(lengths), generator=generator)
90 | megabatch_size = world_size * batch_size
91 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
92 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
93 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
94 |
95 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
96 |
97 |
98 | class LengthGroupedSampler(Sampler):
99 | r"""
100 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
101 | keeping a bit of randomness.
102 | """
103 |
104 | def __init__(
105 | self,
106 | batch_size: int,
107 | world_size: int,
108 | lengths: Optional[List[int]] = None,
109 | generator=None,
110 | group_by_modality: bool = False,
111 | ):
112 | if lengths is None:
113 | raise ValueError("Lengths must be provided.")
114 |
115 | self.batch_size = batch_size
116 | self.world_size = world_size
117 | self.lengths = lengths
118 | self.generator = generator
119 | self.group_by_modality = group_by_modality
120 |
121 | def __len__(self):
122 | return len(self.lengths)
123 |
124 | def __iter__(self):
125 | if self.group_by_modality:
126 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
127 | else:
128 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
129 | return iter(indices)
130 |
131 |
132 | class CDChatTrainer(Trainer):
133 |
134 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
135 | if self.train_dataset is None or not has_length(self.train_dataset):
136 | return None
137 |
138 | if self.args.group_by_modality_length:
139 | lengths = self.train_dataset.modality_lengths
140 | return LengthGroupedSampler(
141 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
142 | self.args.train_batch_size,
143 | world_size=self.args.world_size,
144 | lengths=lengths,
145 | group_by_modality=True,
146 | )
147 | else:
148 | return super()._get_train_sampler()
149 |
150 | def _save_checkpoint(self, model, trial, metrics=None):
151 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
152 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
153 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
154 |
155 | run_dir = self._get_output_dir(trial=trial)
156 | output_dir = os.path.join(run_dir, checkpoint_folder)
157 |
158 | # Only save Adapter
159 | keys_to_match = ['mm_projector', 'vision_resampler']
160 | if getattr(self.args, "use_im_start_end", False):
161 | keys_to_match.extend(['embed_tokens', 'embed_in'])
162 |
163 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
164 |
165 | if self.args.local_rank == 0 or self.args.local_rank == -1:
166 | self.model.config.save_pretrained(output_dir)
167 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
168 | else:
169 | super(CDChatTrainer, self)._save_checkpoint(model, trial, metrics)
170 |
171 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
172 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
173 | pass
174 | else:
175 | super(CDChatTrainer, self)._save(output_dir, state_dict)
176 |
--------------------------------------------------------------------------------
/cdchat/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/cdchat/train/train.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import copy
19 | from dataclasses import dataclass, field
20 | import json
21 | import logging
22 | import pathlib
23 | from typing import Dict, Optional, Sequence, List
24 | import numpy as np
25 | import torch
26 |
27 | import transformers
28 |
29 | from cdchat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
30 | from torch.utils.data import Dataset
31 | from cdchat.train.cdchat_trainer import CDChatTrainer
32 |
33 | from cdchat import conversation as conversation_lib
34 | from cdchat.model import *
35 | from cdchat.mm_utils import tokenizer_image_token
36 |
37 | from PIL import Image
38 |
39 |
40 | local_rank = None
41 |
42 |
43 | def rank0_print(*args):
44 | if local_rank == 0:
45 | print(*args)
46 |
47 |
48 | @dataclass
49 | class ModelArguments:
50 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
51 | version: Optional[str] = field(default="v0")
52 | freeze_backbone: bool = field(default=False)
53 | tune_mm_mlp_adapter: bool = field(default=False)
54 | vision_tower: Optional[str] = field(default=None)
55 | mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
56 | pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
57 | mm_projector_type: Optional[str] = field(default='linear')
58 | mm_use_im_start_end: bool = field(default=False)
59 | mm_use_im_patch_token: bool = field(default=True)
60 | mm_vision_select_feature: Optional[str] = field(default="patch")
61 |
62 |
63 | @dataclass
64 | class DataArguments:
65 | data_path: str = field(default=None,
66 | metadata={"help": "Path to the training data."})
67 | lazy_preprocess: bool = False
68 | is_multimodal: bool = False
69 | image_folder: Optional[str] = field(default=None)
70 | image_aspect_ratio: str = 'square'
71 | image_grid_pinpoints: Optional[str] = field(default=None)
72 |
73 |
74 | @dataclass
75 | class TrainingArguments(transformers.TrainingArguments):
76 | cache_dir: Optional[str] = field(default=None)
77 | optim: str = field(default="adamw_torch")
78 | remove_unused_columns: bool = field(default=False)
79 | freeze_mm_mlp_adapter: bool = field(default=False)
80 | mpt_attn_impl: Optional[str] = field(default="triton")
81 | model_max_length: int = field(
82 | default=512,
83 | metadata={
84 | "help":
85 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
86 | },
87 | )
88 | double_quant: bool = field(
89 | default=True,
90 | metadata={"help": "Compress the quantization statistics through double quantization."}
91 | )
92 | quant_type: str = field(
93 | default="nf4",
94 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
95 | )
96 | bits: int = field(
97 | default=16,
98 | metadata={"help": "How many bits to use."}
99 | )
100 | lora_enable: bool = False
101 | lora_r: int = 64
102 | lora_alpha: int = 16
103 | lora_dropout: float = 0.05
104 | lora_weight_path: str = ""
105 | lora_bias: str = "none"
106 | group_by_modality_length: bool = field(default=False)
107 |
108 |
109 | def maybe_zero_3(param, ignore_status=False, name=None):
110 | from deepspeed import zero
111 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
112 | if hasattr(param, "ds_id"):
113 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
114 | if not ignore_status:
115 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
116 | with zero.GatheredParameters([param]):
117 | param = param.data.detach().cpu().clone()
118 | else:
119 | param = param.detach().cpu().clone()
120 | return param
121 |
122 |
123 | # Borrowed from peft.utils.get_peft_model_state_dict
124 | def get_peft_state_maybe_zero_3(named_params, bias):
125 | if bias == "none":
126 | to_return = {k: t for k, t in named_params if "lora_" in k}
127 | elif bias == "all":
128 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
129 | elif bias == "lora_only":
130 | to_return = {}
131 | maybe_lora_bias = {}
132 | lora_bias_names = set()
133 | for k, t in named_params:
134 | if "lora_" in k:
135 | to_return[k] = t
136 | bias_name = k.split("lora_")[0] + "bias"
137 | lora_bias_names.add(bias_name)
138 | elif "bias" in k:
139 | maybe_lora_bias[k] = t
140 | for k, t in maybe_lora_bias:
141 | if bias_name in lora_bias_names:
142 | to_return[bias_name] = t
143 | else:
144 | raise NotImplementedError
145 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
146 | return to_return
147 |
148 |
149 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
150 | to_return = {k: t for k, t in named_params if "lora_" not in k}
151 | if require_grad_only:
152 | to_return = {k: t for k, t in to_return.items() if t.requires_grad}
153 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
154 | return to_return
155 |
156 |
157 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
158 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
159 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
160 | return to_return
161 |
162 |
163 | def find_all_linear_names(model):
164 | cls = torch.nn.Linear
165 | lora_module_names = set()
166 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
167 | for name, module in model.named_modules():
168 | if any(mm_keyword in name for mm_keyword in multimodal_keywords):
169 | continue
170 | if isinstance(module, cls):
171 | names = name.split('.')
172 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
173 |
174 | if 'lm_head' in lora_module_names: # needed for 16-bit
175 | lora_module_names.remove('lm_head')
176 | return list(lora_module_names)
177 |
178 |
179 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
180 | output_dir: str):
181 | """Collects the state dict and dump to disk."""
182 |
183 | if getattr(trainer.args, "tune_mm_mlp_adapter", False):
184 | # Only save Adapter
185 | keys_to_match = ['mm_projector']
186 | if getattr(trainer.args, "use_im_start_end", False):
187 | keys_to_match.extend(['embed_tokens', 'embed_in'])
188 |
189 | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
190 | trainer.model.config.save_pretrained(output_dir)
191 |
192 | current_folder = output_dir.split('/')[-1]
193 | parent_folder = os.path.dirname(output_dir)
194 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
195 | if current_folder.startswith('checkpoint-'):
196 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
197 | os.makedirs(mm_projector_folder, exist_ok=True)
198 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
199 | else:
200 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
201 | return
202 |
203 | if trainer.deepspeed:
204 | torch.cuda.synchronize()
205 | trainer.save_model(output_dir)
206 | return
207 |
208 | state_dict = trainer.model.state_dict()
209 | if trainer.args.should_save:
210 | cpu_state_dict = {
211 | key: value.cpu()
212 | for key, value in state_dict.items()
213 | }
214 | del state_dict
215 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
216 |
217 |
218 | def smart_tokenizer_and_embedding_resize(
219 | special_tokens_dict: Dict,
220 | tokenizer: transformers.PreTrainedTokenizer,
221 | model: transformers.PreTrainedModel,
222 | ):
223 | """Resize tokenizer and embedding.
224 |
225 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
226 | """
227 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
228 | model.resize_token_embeddings(len(tokenizer))
229 |
230 | if num_new_tokens > 0:
231 | input_embeddings = model.get_input_embeddings().weight.data
232 | output_embeddings = model.get_output_embeddings().weight.data
233 |
234 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
235 | dim=0, keepdim=True)
236 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
237 | dim=0, keepdim=True)
238 |
239 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
240 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
241 |
242 |
243 | def _tokenize_fn(strings: Sequence[str],
244 | tokenizer: transformers.PreTrainedTokenizer) -> Dict:
245 | """Tokenize a list of strings."""
246 | tokenized_list = [
247 | tokenizer(
248 | text,
249 | return_tensors="pt",
250 | padding="longest",
251 | max_length=tokenizer.model_max_length,
252 | truncation=True,
253 | ) for text in strings
254 | ]
255 | input_ids = labels = [
256 | tokenized.input_ids[0] for tokenized in tokenized_list
257 | ]
258 | input_ids_lens = labels_lens = [
259 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
260 | for tokenized in tokenized_list
261 | ]
262 | return dict(
263 | input_ids=input_ids,
264 | labels=labels,
265 | input_ids_lens=input_ids_lens,
266 | labels_lens=labels_lens,
267 | )
268 |
269 |
270 | def _mask_targets(target, tokenized_lens, speakers):
271 | # cur_idx = 0
272 | cur_idx = tokenized_lens[0]
273 | tokenized_lens = tokenized_lens[1:]
274 | target[:cur_idx] = IGNORE_INDEX
275 | for tokenized_len, speaker in zip(tokenized_lens, speakers):
276 | if speaker == "human":
277 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
278 | cur_idx += tokenized_len
279 |
280 |
281 | def _add_speaker_and_signal(header, source, get_conversation=True):
282 | """Add speaker and start/end signal on each round."""
283 | BEGIN_SIGNAL = "### "
284 | END_SIGNAL = "\n"
285 | conversation = header
286 | for sentence in source:
287 | from_str = sentence["from"]
288 | if from_str.lower() == "human":
289 | from_str = conversation_lib.default_conversation.roles[0]
290 | elif from_str.lower() == "gpt":
291 | from_str = conversation_lib.default_conversation.roles[1]
292 | else:
293 | from_str = 'unknown'
294 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
295 | sentence["value"] + END_SIGNAL)
296 | if get_conversation:
297 | conversation += sentence["value"]
298 | conversation += BEGIN_SIGNAL
299 | return conversation
300 |
301 |
302 | def preprocess_multimodal(
303 | sources: Sequence[str],
304 | data_args: DataArguments
305 | ) -> Dict:
306 | is_multimodal = data_args.is_multimodal
307 | if not is_multimodal:
308 | return sources
309 |
310 | for source in sources:
311 | for sentence in source:
312 |
313 | if DEFAULT_IMAGE_TOKEN in sentence['value']:
314 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
315 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
316 | sentence['value'] = sentence['value'].strip()
317 | if "mmtag" in conversation_lib.default_conversation.version:
318 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
319 | replace_token = DEFAULT_IMAGE_TOKEN
320 | if data_args.mm_use_im_start_end:
321 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
322 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
323 |
324 | return sources
325 |
326 |
327 | def preprocess_llama_2(
328 | sources,
329 | tokenizer: transformers.PreTrainedTokenizer,
330 | has_image: bool = False
331 | ) -> Dict:
332 | conv = conversation_lib.default_conversation.copy()
333 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
334 |
335 | # Apply prompt templates
336 | conversations = []
337 | for i, source in enumerate(sources):
338 | if roles[source[0]["from"]] != conv.roles[0]:
339 | # Skip the first one if it is not from human
340 | source = source[1:]
341 |
342 | conv.messages = []
343 | for j, sentence in enumerate(source):
344 | role = roles[sentence["from"]]
345 | assert role == conv.roles[j % 2], f"{i}"
346 | conv.append_message(role, sentence["value"])
347 | conversations.append(conv.get_prompt())
348 |
349 | # Tokenize conversations
350 |
351 | if has_image:
352 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
353 | else:
354 | input_ids = tokenizer(
355 | conversations,
356 | return_tensors="pt",
357 | padding="longest",
358 | max_length=tokenizer.model_max_length,
359 | truncation=True,
360 | ).input_ids
361 |
362 | targets = input_ids.clone()
363 |
364 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
365 |
366 | # Mask targets
367 | sep = "[/INST] "
368 | for conversation, target in zip(conversations, targets):
369 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
370 |
371 | rounds = conversation.split(conv.sep2)
372 | cur_len = 1
373 | target[:cur_len] = IGNORE_INDEX
374 | for i, rou in enumerate(rounds):
375 | if rou == "":
376 | break
377 |
378 | parts = rou.split(sep)
379 | if len(parts) != 2:
380 | break
381 | parts[0] += sep
382 |
383 | if has_image:
384 | round_len = len(tokenizer_image_token(rou, tokenizer))
385 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
386 | else:
387 | round_len = len(tokenizer(rou).input_ids)
388 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
389 |
390 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
391 |
392 | cur_len += round_len
393 | target[cur_len:] = IGNORE_INDEX
394 |
395 | if cur_len < tokenizer.model_max_length:
396 | if cur_len != total_len:
397 | target[:] = IGNORE_INDEX
398 | print(
399 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
400 | f" (ignored)"
401 | )
402 |
403 | return dict(
404 | input_ids=input_ids,
405 | labels=targets,
406 | )
407 |
408 |
409 | def preprocess_v1(
410 | sources,
411 | tokenizer: transformers.PreTrainedTokenizer,
412 | has_image: bool = False
413 | ) -> Dict:
414 | conv = conversation_lib.default_conversation.copy()
415 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
416 |
417 | # Apply prompt templates
418 | conversations = []
419 | for i, source in enumerate(sources):
420 | if roles[source[0]["from"]] != conv.roles[0]:
421 | # Skip the first one if it is not from human
422 | source = source[1:]
423 |
424 | conv.messages = []
425 | for j, sentence in enumerate(source):
426 | role = roles[sentence["from"]]
427 | assert role == conv.roles[j % 2], f"{i}"
428 | conv.append_message(role, sentence["value"])
429 | conversations.append(conv.get_prompt())
430 |
431 | # Tokenize conversations
432 |
433 | if has_image:
434 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
435 | else:
436 | input_ids = tokenizer(
437 | conversations,
438 | return_tensors="pt",
439 | padding="longest",
440 | max_length=tokenizer.model_max_length,
441 | truncation=True,
442 | ).input_ids
443 |
444 | targets = input_ids.clone()
445 |
446 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
447 |
448 | # Mask targets
449 | sep = conv.sep + conv.roles[1] + ": "
450 | for conversation, target in zip(conversations, targets):
451 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
452 |
453 | rounds = conversation.split(conv.sep2)
454 | cur_len = 1
455 | target[:cur_len] = IGNORE_INDEX
456 | for i, rou in enumerate(rounds):
457 | if rou == "":
458 | break
459 |
460 | parts = rou.split(sep)
461 | if len(parts) != 2:
462 | break
463 | parts[0] += sep
464 |
465 | if has_image:
466 | round_len = len(tokenizer_image_token(rou, tokenizer))
467 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
468 | else:
469 | round_len = len(tokenizer(rou).input_ids)
470 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
471 |
472 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
473 |
474 | cur_len += round_len
475 | target[cur_len:] = IGNORE_INDEX
476 |
477 | if cur_len < tokenizer.model_max_length:
478 | if cur_len != total_len:
479 | target[:] = IGNORE_INDEX
480 | print(
481 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
482 | f" (ignored)"
483 | )
484 |
485 | return dict(
486 | input_ids=input_ids,
487 | labels=targets,
488 | )
489 |
490 | '''
491 | def preprocess_mpt(
492 | sources,
493 | tokenizer: transformers.PreTrainedTokenizer,
494 | ) -> Dict:
495 | conv = conversation_lib.default_conversation.copy()
496 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
497 |
498 | # Apply prompt templates
499 | conversations = []
500 | for i, source in enumerate(sources):
501 | if roles[source[0]["from"]] != conv.roles[0]:
502 | # Skip the first one if it is not from human
503 | source = source[1:]
504 |
505 | conv.messages = []
506 | for j, sentence in enumerate(source):
507 | role = roles[sentence["from"]]
508 | assert role == conv.roles[j % 2], f"{i}"
509 | conv.append_message(role, sentence["value"])
510 | conversations.append(conv.get_prompt())
511 |
512 | # Tokenize conversations
513 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
514 | targets = input_ids.clone()
515 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
516 |
517 | # Mask targets
518 | sep = conv.sep + conv.roles[1]
519 | for conversation, target in zip(conversations, targets):
520 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
521 |
522 | rounds = conversation.split(conv.sep)
523 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
524 | for conv_idx in range(3, len(rounds), 2):
525 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
526 | cur_len = 0
527 | target[:cur_len] = IGNORE_INDEX
528 | for i, rou in enumerate(re_rounds):
529 | if rou == "":
530 | break
531 |
532 | parts = rou.split(sep)
533 | if len(parts) != 2:
534 | break
535 | parts[0] += sep
536 | round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
537 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
538 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
539 |
540 | cur_len += round_len
541 | target[cur_len:] = IGNORE_INDEX
542 |
543 | if cur_len < tokenizer.model_max_length:
544 | if cur_len != total_len:
545 | target[:] = IGNORE_INDEX
546 | print(
547 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
548 | f" (ignored)"
549 | )
550 |
551 | return dict(
552 | input_ids=input_ids,
553 | labels=targets,
554 | )
555 | '''
556 |
557 | def preprocess_plain(
558 | sources: Sequence[str],
559 | tokenizer: transformers.PreTrainedTokenizer,
560 | ) -> Dict:
561 | # add end signal and concatenate together
562 | conversations = []
563 | for source in sources:
564 | assert len(source) == 2
565 | assert DEFAULT_IMAGE_TOKEN in source[0]['value']
566 | source[0]['value'] = DEFAULT_IMAGE_TOKEN
567 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
568 | conversations.append(conversation)
569 | # tokenize conversations
570 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
571 | targets = copy.deepcopy(input_ids)
572 | for target, source in zip(targets, sources):
573 | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
574 | target[:tokenized_len] = IGNORE_INDEX
575 |
576 | return dict(input_ids=input_ids, labels=targets)
577 |
578 |
579 | def preprocess(
580 | sources: Sequence[str],
581 | tokenizer: transformers.PreTrainedTokenizer,
582 | has_image: bool = False
583 | ) -> Dict:
584 | """
585 | Given a list of sources, each is a conversation list. This transform:
586 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
587 | 2. Concatenate conversations together;
588 | 3. Tokenize the concatenated conversation;
589 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
590 | """
591 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
592 | return preprocess_plain(sources, tokenizer)
593 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
594 | return preprocess_llama_2(sources, tokenizer, has_image=has_image)
595 | if conversation_lib.default_conversation.version.startswith("v1"):
596 | return preprocess_v1(sources, tokenizer, has_image=has_image)
597 | '''
598 | if conversation_lib.default_conversation.version == "mpt":
599 | return preprocess_mpt(sources, tokenizer)
600 | '''
601 | # add end signal and concatenate together
602 | conversations = []
603 | for source in sources:
604 | header = f"{conversation_lib.default_conversation.system}\n\n"
605 | conversation = _add_speaker_and_signal(header, source)
606 | conversations.append(conversation)
607 | # tokenize conversations
608 | def get_tokenize_len(prompts):
609 | return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
610 |
611 | if has_image:
612 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
613 | else:
614 | conversations_tokenized = _tokenize_fn(conversations, tokenizer)
615 | input_ids = conversations_tokenized["input_ids"]
616 |
617 | targets = copy.deepcopy(input_ids)
618 | for target, source in zip(targets, sources):
619 | if has_image:
620 | tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
621 | else:
622 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
623 | speakers = [sentence["from"] for sentence in source]
624 | _mask_targets(target, tokenized_lens, speakers)
625 |
626 | return dict(input_ids=input_ids, labels=targets)
627 |
628 |
629 | class LazySupervisedDataset(Dataset):
630 | """Dataset for supervised fine-tuning."""
631 |
632 | def __init__(self, data_path: str,
633 | tokenizer: transformers.PreTrainedTokenizer,
634 | data_args: DataArguments):
635 | super(LazySupervisedDataset, self).__init__()
636 | list_data_dict = json.load(open(data_path, "r"))
637 |
638 | rank0_print("Formatting inputs...Skip in lazy mode")
639 | self.tokenizer = tokenizer
640 | self.list_data_dict = list_data_dict
641 | self.data_args = data_args
642 |
643 | def __len__(self):
644 | return len(self.list_data_dict)
645 |
646 | @property
647 | def lengths(self):
648 | length_list = []
649 | for sample in self.list_data_dict:
650 | img_tokens = 128 if 'image' in sample else 0
651 | length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
652 | return length_list
653 |
654 | @property
655 | def modality_lengths(self):
656 | length_list = []
657 | for sample in self.list_data_dict:
658 | cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
659 | cur_len = cur_len if 'image' in sample else -cur_len
660 | length_list.append(cur_len)
661 | return length_list
662 |
663 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
664 |
665 | #print('************************ ', self.data_args.image_aspect_ratio)
666 |
667 | sources = self.list_data_dict[i]
668 | if isinstance(i, int):
669 | sources = [sources]
670 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
671 | if 'image' in sources[0]:
672 | image_file = self.list_data_dict[i]['image']
673 | image_folder = self.data_args.image_folder
674 |
675 | #print('************************* Image Path: ', image_file)
676 | ####################################################
677 | image_folder_A = os.path.join(image_folder, 'A')
678 | image_folder_B = os.path.join(image_folder, 'B')
679 | image_folder_label = os.path.join(image_folder, 'label')
680 |
681 | ####################################################
682 | processor = self.data_args.image_processor
683 | #image = Image.open((os.path.join(image_folder, image_file)).strip()).convert('RGB')
684 | image_A = Image.open((os.path.join(image_folder_A, image_file)).strip()).convert('RGB') #.resize((336,336), resample=Image.Resampling.BILINEAR)
685 | image_B = Image.open((os.path.join(image_folder_B, image_file)).strip()).convert('RGB') #.resize((336,336), resample=Image.Resampling.BILINEAR)
686 | image_label = Image.open((os.path.join(image_folder_label, image_file)).strip()).convert('L')
687 | image_label = (torch.Tensor(1)*np.array(image_label)/255.).unsqueeze(0)
688 |
689 | if self.data_args.image_aspect_ratio == 'pad':
690 | def expand2square(pil_img, background_color):
691 | width, height = pil_img.size
692 | if width == height:
693 | return pil_img
694 | elif width > height:
695 | result = Image.new(pil_img.mode, (width, width), background_color)
696 | result.paste(pil_img, (0, (width - height) // 2))
697 | return result
698 | else:
699 | result = Image.new(pil_img.mode, (height, height), background_color)
700 | result.paste(pil_img, ((height - width) // 2, 0))
701 | return result
702 | #image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
703 | # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
704 | #image = processor.preprocess(image,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
705 | image_A = expand2square(image_A, tuple(int(x*255) for x in processor.image_mean))
706 | image_B = expand2square(image_B, tuple(int(x*255) for x in processor.image_mean))
707 | image_A = processor.preprocess(image_A, do_resize=True, crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
708 | image_B = processor.preprocess(image_B, do_resize=True, crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
709 |
710 | else:
711 | # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
712 | #image = processor.preprocess(image,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
713 | image_A = processor.preprocess(image_A,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
714 | image_B = processor.preprocess(image_B,do_resize=True,crop_size ={'height': 448, 'width': 448},size = {'shortest_edge': 448}, return_tensors='pt')['pixel_values'][0]
715 |
716 | sources = preprocess_multimodal(
717 | copy.deepcopy([e["conversations"] for e in sources]),
718 | self.data_args)
719 | else:
720 | sources = copy.deepcopy([e["conversations"] for e in sources])
721 | data_dict = preprocess(
722 | sources,
723 | self.tokenizer,
724 | has_image=('image' in self.list_data_dict[i]))
725 | if isinstance(i, int):
726 | data_dict = dict(input_ids=data_dict["input_ids"][0],
727 | labels=data_dict["labels"][0])
728 |
729 | image_A = image_A[[2,1,0], :, :]
730 | image_B = image_B[[2,1,0], :, :]
731 | #print(image_A.shape)
732 | # image exist in the data
733 | if 'image' in self.list_data_dict[i]:
734 | data_dict['image'] = {'pre':image_A, 'post':image_B, 'targets': image_label} #image
735 | elif self.data_args.is_multimodal:
736 | # image does not exist in the data, but the model is multimodal
737 | crop_size = self.data_args.image_processor.crop_size
738 | data_dict['image'] = {'pre':torch.zeros(3, crop_size['height'], crop_size['width']),
739 | 'post':torch.zeros(3, crop_size['height'], crop_size['width']),
740 | 'targets':torch.zeros(3, crop_size['height'], crop_size['width'])} #torch.zeros(3, crop_size['height'], crop_size['width'])
741 | return data_dict
742 |
743 |
744 | @dataclass
745 | class DataCollatorForSupervisedDataset(object):
746 | """Collate examples for supervised fine-tuning."""
747 |
748 | tokenizer: transformers.PreTrainedTokenizer
749 |
750 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
751 | input_ids, labels = tuple([instance[key] for instance in instances]
752 | for key in ("input_ids", "labels"))
753 | input_ids = torch.nn.utils.rnn.pad_sequence(
754 | input_ids,
755 | batch_first=True,
756 | padding_value=self.tokenizer.pad_token_id)
757 | labels = torch.nn.utils.rnn.pad_sequence(labels,
758 | batch_first=True,
759 | padding_value=IGNORE_INDEX)
760 | input_ids = input_ids[:, :self.tokenizer.model_max_length]
761 | labels = labels[:, :self.tokenizer.model_max_length]
762 | batch = dict(
763 | input_ids=input_ids,
764 | labels=labels,
765 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
766 | )
767 |
768 | if 'image' in instances[0]:
769 | '''
770 | images = [instance['image'] for instance in instances]
771 | if all(x is not None and x.shape == images[0].shape for x in images):
772 | batch['images'] = torch.stack(images)
773 | else:
774 | batch['images'] = images
775 | '''
776 | images_A = [instance['image']['pre'] for instance in instances]
777 | images_B = [instance['image']['post'] for instance in instances]
778 | images_label = [instance['image']['targets'] for instance in instances]
779 | if all(x is not None and x.shape == images_A[0].shape for x in images_A):
780 | batch['images'] = {'pre':torch.stack(images_A), 'post':torch.stack(images_B), 'targets':torch.stack(images_label)}
781 | else:
782 | batch['images'] = {'pre':images_A, 'post':images_B, 'targets':images_label}
783 |
784 | return batch
785 |
786 |
787 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
788 | data_args) -> Dict:
789 | """Make dataset and collator for supervised fine-tuning."""
790 | train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
791 | data_path=data_args.data_path,
792 | data_args=data_args)
793 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
794 | return dict(train_dataset=train_dataset,
795 | eval_dataset=None,
796 | data_collator=data_collator)
797 |
798 |
799 | def train():
800 | global local_rank
801 |
802 | parser = transformers.HfArgumentParser(
803 | (ModelArguments, DataArguments, TrainingArguments))
804 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
805 | local_rank = training_args.local_rank
806 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
807 |
808 | bnb_model_from_pretrained_args = {}
809 | if training_args.bits in [4, 8]:
810 | from transformers import BitsAndBytesConfig
811 | bnb_model_from_pretrained_args.update(dict(
812 | device_map={"": training_args.device},
813 | load_in_4bit=training_args.bits == 4,
814 | load_in_8bit=training_args.bits == 8,
815 | quantization_config=BitsAndBytesConfig(
816 | load_in_4bit=training_args.bits == 4,
817 | load_in_8bit=training_args.bits == 8,
818 | llm_int8_threshold=6.0,
819 | llm_int8_has_fp16_weight=False,
820 | bnb_4bit_compute_dtype=compute_dtype,
821 | bnb_4bit_use_double_quant=training_args.double_quant,
822 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
823 | )
824 | ))
825 |
826 | if model_args.vision_tower is not None:
827 | '''
828 | if 'mpt' in model_args.model_name_or_path:
829 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
830 | config.attn_config['attn_impl'] = training_args.mpt_attn_impl
831 | model = CDChatMPTForCausalLM.from_pretrained(
832 | model_args.model_name_or_path,
833 | config=config,
834 | cache_dir=training_args.cache_dir,
835 | ignore_mismatched_sizes=True,
836 | **bnb_model_from_pretrained_args
837 | )
838 | else:
839 | '''
840 | model = CDChatLlamaForCausalLM.from_pretrained(
841 | model_args.model_name_or_path,
842 | cache_dir=training_args.cache_dir,
843 | ignore_mismatched_sizes=True,
844 | **bnb_model_from_pretrained_args
845 | )
846 | else:
847 | model = transformers.LlamaForCausalLM.from_pretrained(
848 | model_args.model_name_or_path,
849 | cache_dir=training_args.cache_dir,
850 | **bnb_model_from_pretrained_args
851 | )
852 | model.config.use_cache = False
853 |
854 | if model_args.freeze_backbone:
855 | model.model.requires_grad_(False)
856 |
857 | if training_args.bits in [4, 8]:
858 | from peft import prepare_model_for_kbit_training
859 | model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
860 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
861 |
862 | if training_args.gradient_checkpointing:
863 | if hasattr(model, "enable_input_require_grads"):
864 | model.enable_input_require_grads()
865 | else:
866 | def make_inputs_require_grad(module, input, output):
867 | output.requires_grad_(True)
868 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
869 |
870 | if training_args.lora_enable:
871 | from peft import LoraConfig, get_peft_model
872 | lora_config = LoraConfig(
873 | r=training_args.lora_r,
874 | lora_alpha=training_args.lora_alpha,
875 | target_modules=find_all_linear_names(model),
876 | lora_dropout=training_args.lora_dropout,
877 | bias=training_args.lora_bias,
878 | task_type="CAUSAL_LM",
879 | )
880 | if training_args.bits == 16:
881 | if training_args.bf16:
882 | model.to(torch.bfloat16)
883 | if training_args.fp16:
884 | model.to(torch.float16)
885 | rank0_print("Adding LoRA adapters...")
886 | model = get_peft_model(model, lora_config)
887 |
888 | '''
889 | if 'mpt' in model_args.model_name_or_path:
890 | tokenizer = transformers.AutoTokenizer.from_pretrained(
891 | model_args.model_name_or_path,
892 | cache_dir=training_args.cache_dir,
893 | model_max_length=training_args.model_max_length,
894 | padding_side="right"
895 | )
896 | else:
897 | '''
898 | tokenizer = transformers.AutoTokenizer.from_pretrained(
899 | model_args.model_name_or_path,
900 | cache_dir=training_args.cache_dir,
901 | model_max_length=training_args.model_max_length,
902 | padding_side="right",
903 | use_fast=False,
904 | )
905 |
906 | if model_args.version == "v0":
907 | if tokenizer.pad_token is None:
908 | smart_tokenizer_and_embedding_resize(
909 | special_tokens_dict=dict(pad_token="[PAD]"),
910 | tokenizer=tokenizer,
911 | model=model,
912 | )
913 | elif model_args.version == "v0.5":
914 | tokenizer.pad_token = tokenizer.unk_token
915 | else:
916 | tokenizer.pad_token = tokenizer.unk_token
917 | if model_args.version in conversation_lib.conv_templates:
918 | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
919 | else:
920 | conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
921 |
922 | if model_args.vision_tower is not None:
923 | model.get_model().initialize_vision_modules(
924 | model_args=model_args,
925 | fsdp=training_args.fsdp
926 | )
927 |
928 | vision_tower = model.get_vision_tower()
929 | vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
930 |
931 | data_args.image_processor = vision_tower.image_processor
932 | data_args.is_multimodal = True
933 |
934 | model.config.image_aspect_ratio = data_args.image_aspect_ratio
935 | model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
936 |
937 | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
938 | if model_args.tune_mm_mlp_adapter:
939 | model.requires_grad_(False)
940 | for p in model.get_model().mm_projector.parameters():
941 | p.requires_grad = True
942 |
943 | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
944 | if training_args.freeze_mm_mlp_adapter:
945 | for p in model.get_model().mm_projector.parameters():
946 | p.requires_grad = False
947 |
948 | if training_args.bits in [4, 8]:
949 | model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
950 |
951 | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
952 | training_args.use_im_start_end = model_args.mm_use_im_start_end
953 | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
954 | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
955 |
956 | if training_args.bits in [4, 8]:
957 | from peft.tuners.lora import LoraLayer
958 | for name, module in model.named_modules():
959 | if isinstance(module, LoraLayer):
960 | if training_args.bf16:
961 | module = module.to(torch.bfloat16)
962 | if 'norm' in name:
963 | module = module.to(torch.float32)
964 | if 'lm_head' in name or 'embed_tokens' in name:
965 | if hasattr(module, 'weight'):
966 | if training_args.bf16 and module.weight.dtype == torch.float32:
967 | module = module.to(torch.bfloat16)
968 |
969 | data_module = make_supervised_data_module(tokenizer=tokenizer,
970 | data_args=data_args)
971 | trainer = CDChatTrainer(model=model,
972 | tokenizer=tokenizer,
973 | args=training_args,
974 | **data_module)
975 |
976 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
977 | trainer.train(resume_from_checkpoint=True)
978 | else:
979 | trainer.train()
980 | trainer.save_state()
981 |
982 | model.config.use_cache = True
983 |
984 | if training_args.lora_enable:
985 | state_dict = get_peft_state_maybe_zero_3(
986 | model.named_parameters(), training_args.lora_bias
987 | )
988 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
989 | model.named_parameters()
990 | )
991 | if training_args.local_rank == 0 or training_args.local_rank == -1:
992 | model.config.save_pretrained(training_args.output_dir)
993 | model.save_pretrained(training_args.output_dir, state_dict=state_dict)
994 | torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
995 | else:
996 | safe_save_model_for_hf_trainer(trainer=trainer,
997 | output_dir=training_args.output_dir)
998 |
999 |
1000 | if __name__ == "__main__":
1001 | train()
1002 |
--------------------------------------------------------------------------------
/cdchat/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from cdchat.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from cdchat.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/cdchat/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from cdchat.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/images/cdchat_annotation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/cdchat_annotation.png
--------------------------------------------------------------------------------
/images/cdchat_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/cdchat_arch.png
--------------------------------------------------------------------------------
/images/example_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/example_01.png
--------------------------------------------------------------------------------
/images/example_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/techmn/cdchat/a75bf3b9cd08bd3e4a8076c4d0f272dee0bfcbea/images/example_02.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "cdchat"
7 | version = "1.0"
8 | description = "VLM for Remote Sensing Change Description"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
17 | "requests", "sentencepiece", "tokenizers>=0.12.1",
18 | "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
19 | "shortuuid", "httpx==0.24.0",
20 | "deepspeed==0.9.5",
21 | "peft==0.4.0",
22 | "transformers==4.31.0",
23 | "accelerate==0.21.0",
24 | "bitsandbytes==0.41.0",
25 | "scikit-learn==1.2.2",
26 | "sentencepiece==0.1.99",
27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28 | "gradio_client==0.2.9"
29 | ]
30 |
31 | [project.urls]
32 | "Homepage" = "https://github.com/techmn/cdchat"
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/scripts/extract_mm_projector.py:
--------------------------------------------------------------------------------
1 | """
2 | This is just a utility that I use to extract the projector for quantized models.
3 | It is NOT necessary at all to train, or run inference/serve demos.
4 | Use this script ONLY if you fully understand its implications.
5 | """
6 |
7 |
8 | import os
9 | import argparse
10 | import torch
11 | import json
12 | from collections import defaultdict
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17 | parser.add_argument('--model-path', type=str, help='model folder')
18 | parser.add_argument('--output', type=str, help='output file')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | if __name__ == '__main__':
24 | args = parse_args()
25 |
26 | keys_to_match = ['mm_projector']
27 | ckpt_to_key = defaultdict(list)
28 | try:
29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30 | for k, v in model_indices['weight_map'].items():
31 | if any(key_match in k for key_match in keys_to_match):
32 | ckpt_to_key[v].append(k)
33 | except FileNotFoundError:
34 | # Smaller models or model checkpoints saved by DeepSpeed.
35 | v = 'pytorch_model.bin'
36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37 | if any(key_match in k for key_match in keys_to_match):
38 | ckpt_to_key[v].append(k)
39 |
40 | loaded_weights = {}
41 |
42 | for ckpt_name, weight_keys in ckpt_to_key.items():
43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44 | for k in weight_keys:
45 | loaded_weights[k] = ckpt[k]
46 |
47 | torch.save(loaded_weights, args.output)
48 |
--------------------------------------------------------------------------------
/scripts/finetune_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ################## VICUNA ##################
4 | PROMPT_VERSION=v1
5 | MODEL_VERSION="vicuna-v1.5-7b"
6 | ################## VICUNA ##################
7 |
8 | deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 ./cdchat/train/train_mem.py \
9 | --deepspeed ./scripts/zero2.json \
10 | --lora_enable True \
11 | --model_name_or_path ./llava-v1.5-7b \
12 | --version $PROMPT_VERSION \
13 | --data_path dataset/cdchat_instruct_file.json \
14 | --image_folder dataset/ \
15 | --vision_tower openai/clip-vit-large-patch14-336 \
16 | --mm_projector_type mlp2x_gelu \
17 | --pretrain_mm_mlp_adapter cdchat/checkpoints/pretrain_mm_projector/mm_projector.bin \
18 | --mm_vision_select_layer -2 \
19 | --mm_use_im_start_end False \
20 | --mm_use_im_patch_token False \
21 | --image_aspect_ratio pad \
22 | --bf16 True \
23 | --output_dir cdchat/checkpoints/cdchat_log_lora \
24 | --num_train_epochs 1 \
25 | --per_device_train_batch_size 16 \
26 | --per_device_eval_batch_size 4 \
27 | --gradient_accumulation_steps 1 \
28 | --evaluation_strategy "no" \
29 | --save_strategy "epoch" \
30 | --save_steps 7000 \
31 | --save_total_limit 1 \
32 | --learning_rate 2e-4 \
33 | --weight_decay 0. \
34 | --warmup_ratio 0.03 \
35 | --lr_scheduler_type "cosine" \
36 | --logging_steps 1 \
37 | --tf32 True \
38 | --model_max_length 2048 \
39 | --gradient_checkpointing True \
40 | --lazy_preprocess True \
41 | --dataloader_num_workers 16 \
42 | --report_to wandb
43 |
--------------------------------------------------------------------------------
/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from cdchat.model.builder import load_pretrained_model
3 | from cdchat.mm_utils import get_model_name_from_path
4 |
5 |
6 | def merge_lora(args):
7 | model_name = get_model_name_from_path(args.model_path)
8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9 |
10 | model.save_pretrained(args.save_model_path)
11 | tokenizer.save_pretrained(args.save_model_path)
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model-path", type=str, required=True)
17 | parser.add_argument("--model-base", type=str, required=True)
18 | parser.add_argument("--save-model-path", type=str, required=True)
19 |
20 | args = parser.parse_args()
21 |
22 | merge_lora(args)
23 |
--------------------------------------------------------------------------------
/scripts/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | # MODEL_VERSION=vicuna-v1-3-7b
8 | # MODEL_VERSION=llama-2-7b-chat
9 | #MODEL_VERSION="vicuna-v1.5-7b"
10 |
11 | ########### DO NOT CHANGE ###########
12 | ########### USE THIS FOR BOTH ###########
13 | PROMPT_VERSION=plain
14 | ########### DO NOT CHANGE ###########
15 |
16 | deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 ./cdchat/train/train_mem.py \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path ./llava-v1.5-7b \
19 | --version $PROMPT_VERSION \
20 | --data_path dataset/cdchat_instruct_file.json \
21 | --image_folder dataset/ \
22 | --vision_tower openai/clip-vit-large-patch14-336 \
23 | --mm_projector_type mlp2x_gelu \
24 | --tune_mm_mlp_adapter True \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --bf16 True \
30 | --output_dir cdchat/checkpoints/pretrain_mm_projector \
31 | --num_train_epochs 1 \
32 | --per_device_train_batch_size 16 \
33 | --per_device_eval_batch_size 4 \
34 | --gradient_accumulation_steps 1 \
35 | --evaluation_strategy "no" \
36 | --save_strategy "epoch" \
37 | --save_steps 24000 \
38 | --save_total_limit 1 \
39 | --learning_rate 2e-3 \
40 | --weight_decay 0. \
41 | --warmup_ratio 0.03 \
42 | --lr_scheduler_type "cosine" \
43 | --logging_steps 1 \
44 | --tf32 True \
45 | --model_max_length 2048 \
46 | --gradient_checkpointing True \
47 | --dataloader_num_workers 4 \
48 | --lazy_preprocess True \
49 | --report_to wandb
50 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------