├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── chat.py ├── imgs ├── blackpink.jpg ├── camera_lens.jpg ├── car_speed.jpg ├── dog_with_horn.jpg ├── example1.jpg ├── example2.jpg ├── fig_overview.jpg ├── jackma.jpg ├── obama.jpg ├── stand_higher.jpg ├── table1.jpg ├── teaser.jpg ├── trump.jpg └── wash_hands.jpg ├── merge_lora_weights_and_save_hf_model.py ├── model ├── LISA.py ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model │ │ │ ├── llava_llama.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ ├── attention.py │ │ │ │ ├── blocks.py │ │ │ │ ├── configuration_mpt.py │ │ │ │ ├── custom_embedding.py │ │ │ │ ├── flash_attn_triton.py │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ ├── meta_init_context.py │ │ │ │ ├── modeling_mpt.py │ │ │ │ ├── norm.py │ │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ └── utils.py │ ├── train │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ └── train_mem.py │ └── utils.py └── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py │ ├── predictor.py │ └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── requirements.txt ├── train_ds.py ├── utils ├── ade20k_classes.json ├── cocostuff_classes.txt ├── conversation.py ├── data_processing.py ├── dataset.py ├── grefcoco.py ├── grefer.py ├── reason_seg_dataset.py ├── refer.py ├── refer_seg_dataset.py ├── sem_seg_dataset.py ├── utils.py └── vqa_dataset.py └── vis_output ├── blackpink.jpg ├── camera_lens.jpg ├── dog_with_horn.jpg ├── example1_mask_0.jpg ├── example1_masked_img_0.jpg ├── example2_mask_0.jpg ├── example2_masked_img_0.jpg ├── jackma.jpg ├── obama.jpg ├── stand_higher.jpg ├── trump.jpg └── wash_hands.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | runs/ 3 | .vscode/ 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor 10 | 11 | from model.LISA import LISAForCausalLM 12 | from model.llava import conversation as conversation_lib 13 | from model.llava.mm_utils import tokenizer_image_token 14 | from model.segment_anything.utils.transforms import ResizeLongestSide 15 | from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, 16 | DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) 17 | 18 | 19 | def parse_args(args): 20 | parser = argparse.ArgumentParser(description="LISA chat") 21 | parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1") 22 | parser.add_argument("--vis_save_path", default="./vis_output", type=str) 23 | parser.add_argument( 24 | "--precision", 25 | default="bf16", 26 | type=str, 27 | choices=["fp32", "bf16", "fp16"], 28 | help="precision for inference", 29 | ) 30 | parser.add_argument("--image_size", default=1024, type=int, help="image size") 31 | parser.add_argument("--model_max_length", default=512, type=int) 32 | parser.add_argument("--lora_r", default=8, type=int) 33 | parser.add_argument( 34 | "--vision-tower", default="openai/clip-vit-large-patch14", type=str 35 | ) 36 | parser.add_argument("--local-rank", default=0, type=int, help="node rank") 37 | parser.add_argument("--load_in_8bit", action="store_true", default=False) 38 | parser.add_argument("--load_in_4bit", action="store_true", default=False) 39 | parser.add_argument("--use_mm_start_end", action="store_true", default=True) 40 | parser.add_argument( 41 | "--conv_type", 42 | default="llava_v1", 43 | type=str, 44 | choices=["llava_v1", "llava_llama_2"], 45 | ) 46 | return parser.parse_args(args) 47 | 48 | 49 | def preprocess( 50 | x, 51 | pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), 52 | pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), 53 | img_size=1024, 54 | ) -> torch.Tensor: 55 | """Normalize pixel values and pad to a square input.""" 56 | # Normalize colors 57 | x = (x - pixel_mean) / pixel_std 58 | # Pad 59 | h, w = x.shape[-2:] 60 | padh = img_size - h 61 | padw = img_size - w 62 | x = F.pad(x, (0, padw, 0, padh)) 63 | return x 64 | 65 | 66 | def main(args): 67 | args = parse_args(args) 68 | os.makedirs(args.vis_save_path, exist_ok=True) 69 | 70 | # Create model 71 | tokenizer = AutoTokenizer.from_pretrained( 72 | args.version, 73 | cache_dir=None, 74 | model_max_length=args.model_max_length, 75 | padding_side="right", 76 | use_fast=False, 77 | ) 78 | tokenizer.pad_token = tokenizer.unk_token 79 | args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] 80 | 81 | 82 | torch_dtype = torch.float32 83 | if args.precision == "bf16": 84 | torch_dtype = torch.bfloat16 85 | elif args.precision == "fp16": 86 | torch_dtype = torch.half 87 | 88 | kwargs = {"torch_dtype": torch_dtype} 89 | if args.load_in_4bit: 90 | kwargs.update( 91 | { 92 | "torch_dtype": torch.half, 93 | "load_in_4bit": True, 94 | "quantization_config": BitsAndBytesConfig( 95 | load_in_4bit=True, 96 | bnb_4bit_compute_dtype=torch.float16, 97 | bnb_4bit_use_double_quant=True, 98 | bnb_4bit_quant_type="nf4", 99 | llm_int8_skip_modules=["visual_model"], 100 | ), 101 | } 102 | ) 103 | elif args.load_in_8bit: 104 | kwargs.update( 105 | { 106 | "torch_dtype": torch.half, 107 | "quantization_config": BitsAndBytesConfig( 108 | llm_int8_skip_modules=["visual_model"], 109 | load_in_8bit=True, 110 | ), 111 | } 112 | ) 113 | 114 | model = LISAForCausalLM.from_pretrained( 115 | args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs 116 | ) 117 | 118 | model.config.eos_token_id = tokenizer.eos_token_id 119 | model.config.bos_token_id = tokenizer.bos_token_id 120 | model.config.pad_token_id = tokenizer.pad_token_id 121 | 122 | model.get_model().initialize_vision_modules(model.get_model().config) 123 | vision_tower = model.get_model().get_vision_tower() 124 | vision_tower.to(dtype=torch_dtype) 125 | 126 | if args.precision == "bf16": 127 | model = model.bfloat16().cuda() 128 | elif ( 129 | args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) 130 | ): 131 | vision_tower = model.get_model().get_vision_tower() 132 | model.model.vision_tower = None 133 | import deepspeed 134 | 135 | model_engine = deepspeed.init_inference( 136 | model=model, 137 | dtype=torch.half, 138 | replace_with_kernel_inject=True, 139 | replace_method="auto", 140 | ) 141 | model = model_engine.module 142 | model.model.vision_tower = vision_tower.half().cuda() 143 | elif args.precision == "fp32": 144 | model = model.float().cuda() 145 | 146 | vision_tower = model.get_model().get_vision_tower() 147 | vision_tower.to(device=args.local_rank) 148 | 149 | clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) 150 | transform = ResizeLongestSide(args.image_size) 151 | 152 | model.eval() 153 | 154 | while True: 155 | conv = conversation_lib.conv_templates[args.conv_type].copy() 156 | conv.messages = [] 157 | 158 | prompt = input("Please input your prompt: ") 159 | prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt 160 | if args.use_mm_start_end: 161 | replace_token = ( 162 | DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 163 | ) 164 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 165 | 166 | conv.append_message(conv.roles[0], prompt) 167 | conv.append_message(conv.roles[1], "") 168 | prompt = conv.get_prompt() 169 | 170 | image_path = input("Please input the image path: ") 171 | if not os.path.exists(image_path): 172 | print("File not found in {}".format(image_path)) 173 | continue 174 | 175 | image_np = cv2.imread(image_path) 176 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 177 | original_size_list = [image_np.shape[:2]] 178 | 179 | image_clip = ( 180 | clip_image_processor.preprocess(image_np, return_tensors="pt")[ 181 | "pixel_values" 182 | ][0] 183 | .unsqueeze(0) 184 | .cuda() 185 | ) 186 | if args.precision == "bf16": 187 | image_clip = image_clip.bfloat16() 188 | elif args.precision == "fp16": 189 | image_clip = image_clip.half() 190 | else: 191 | image_clip = image_clip.float() 192 | 193 | image = transform.apply_image(image_np) 194 | resize_list = [image.shape[:2]] 195 | 196 | image = ( 197 | preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) 198 | .unsqueeze(0) 199 | .cuda() 200 | ) 201 | if args.precision == "bf16": 202 | image = image.bfloat16() 203 | elif args.precision == "fp16": 204 | image = image.half() 205 | else: 206 | image = image.float() 207 | 208 | input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") 209 | input_ids = input_ids.unsqueeze(0).cuda() 210 | 211 | output_ids, pred_masks = model.evaluate( 212 | image_clip, 213 | image, 214 | input_ids, 215 | resize_list, 216 | original_size_list, 217 | max_new_tokens=512, 218 | tokenizer=tokenizer, 219 | ) 220 | output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] 221 | 222 | text_output = tokenizer.decode(output_ids, skip_special_tokens=False) 223 | text_output = text_output.replace("\n", "").replace(" ", " ") 224 | print("text_output: ", text_output) 225 | 226 | for i, pred_mask in enumerate(pred_masks): 227 | if pred_mask.shape[0] == 0: 228 | continue 229 | 230 | pred_mask = pred_mask.detach().cpu().numpy()[0] 231 | pred_mask = pred_mask > 0 232 | 233 | save_path = "{}/{}_mask_{}.jpg".format( 234 | args.vis_save_path, image_path.split("/")[-1].split(".")[0], i 235 | ) 236 | cv2.imwrite(save_path, pred_mask * 100) 237 | print("{} has been saved.".format(save_path)) 238 | 239 | save_path = "{}/{}_masked_img_{}.jpg".format( 240 | args.vis_save_path, image_path.split("/")[-1].split(".")[0], i 241 | ) 242 | save_img = image_np.copy() 243 | save_img[pred_mask] = ( 244 | image_np * 0.5 245 | + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 246 | )[pred_mask] 247 | save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) 248 | cv2.imwrite(save_path, save_img) 249 | print("{} has been saved.".format(save_path)) 250 | 251 | 252 | if __name__ == "__main__": 253 | main(sys.argv[1:]) 254 | -------------------------------------------------------------------------------- /imgs/blackpink.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/blackpink.jpg -------------------------------------------------------------------------------- /imgs/camera_lens.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/camera_lens.jpg -------------------------------------------------------------------------------- /imgs/car_speed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/car_speed.jpg -------------------------------------------------------------------------------- /imgs/dog_with_horn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/dog_with_horn.jpg -------------------------------------------------------------------------------- /imgs/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/example1.jpg -------------------------------------------------------------------------------- /imgs/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/example2.jpg -------------------------------------------------------------------------------- /imgs/fig_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/fig_overview.jpg -------------------------------------------------------------------------------- /imgs/jackma.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/jackma.jpg -------------------------------------------------------------------------------- /imgs/obama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/obama.jpg -------------------------------------------------------------------------------- /imgs/stand_higher.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/stand_higher.jpg -------------------------------------------------------------------------------- /imgs/table1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/table1.jpg -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/teaser.jpg -------------------------------------------------------------------------------- /imgs/trump.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/trump.jpg -------------------------------------------------------------------------------- /imgs/wash_hands.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/imgs/wash_hands.jpg -------------------------------------------------------------------------------- /merge_lora_weights_and_save_hf_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import sys 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import transformers 11 | from peft import LoraConfig, get_peft_model 12 | from transformers import AutoTokenizer 13 | 14 | from model.LISA import LISAForCausalLM 15 | from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN 16 | 17 | 18 | def parse_args(args): 19 | parser = argparse.ArgumentParser( 20 | description="merge lora weights and save model with hf format" 21 | ) 22 | parser.add_argument( 23 | "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" 24 | ) 25 | parser.add_argument("--vis_save_path", default="./vis_output", type=str) 26 | parser.add_argument( 27 | "--precision", 28 | default="bf16", 29 | type=str, 30 | choices=["fp32", "bf16", "fp16"], 31 | help="precision for inference", 32 | ) 33 | parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) 34 | parser.add_argument("--out_dim", default=256, type=int) 35 | parser.add_argument("--image_size", default=1024, type=int, help="image size") 36 | parser.add_argument("--model_max_length", default=512, type=int) 37 | parser.add_argument( 38 | "--vision-tower", default="openai/clip-vit-large-patch14", type=str 39 | ) 40 | parser.add_argument("--lora_r", default=8, type=int) 41 | parser.add_argument("--lora_alpha", default=16, type=int) 42 | parser.add_argument("--lora_dropout", default=0.05, type=float) 43 | parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) 44 | parser.add_argument("--local-rank", default=0, type=int, help="node rank") 45 | parser.add_argument("--train_mask_decoder", action="store_true", default=True) 46 | parser.add_argument("--use_mm_start_end", action="store_true", default=True) 47 | parser.add_argument( 48 | "--conv_type", 49 | default="llava_v1", 50 | type=str, 51 | choices=["llava_v1", "llava_llama_2"], 52 | ) 53 | parser.add_argument("--weight", default="", type=str, required=True) 54 | parser.add_argument("--save_path", default="./lisa_model", type=str, required=True) 55 | return parser.parse_args(args) 56 | 57 | 58 | def main(args): 59 | args = parse_args(args) 60 | os.makedirs(args.vis_save_path, exist_ok=True) 61 | 62 | # Create model 63 | tokenizer = transformers.AutoTokenizer.from_pretrained( 64 | args.version, 65 | cache_dir=None, 66 | model_max_length=args.model_max_length, 67 | padding_side="right", 68 | use_fast=False, 69 | ) 70 | tokenizer.pad_token = tokenizer.unk_token 71 | num_added_tokens = tokenizer.add_tokens("[SEG]") 72 | args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] 73 | 74 | if args.use_mm_start_end: 75 | tokenizer.add_tokens( 76 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True 77 | ) 78 | 79 | model_args = { 80 | "train_mask_decoder": args.train_mask_decoder, 81 | "out_dim": args.out_dim, 82 | "seg_token_idx": args.seg_token_idx, 83 | "vision_tower": args.vision_tower, 84 | } 85 | 86 | torch_dtype = torch.float32 87 | if args.precision == "bf16": 88 | torch_dtype = torch.bfloat16 89 | elif args.precision == "fp16": 90 | torch_dtype = torch.half 91 | model = LISAForCausalLM.from_pretrained( 92 | args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args 93 | ) 94 | model.config.eos_token_id = tokenizer.eos_token_id 95 | model.config.bos_token_id = tokenizer.bos_token_id 96 | model.config.pad_token_id = tokenizer.pad_token_id 97 | 98 | model.get_model().initialize_vision_modules(model.get_model().config) 99 | vision_tower = model.get_model().get_vision_tower() 100 | vision_tower.to(dtype=torch_dtype) 101 | model.get_model().initialize_lisa_modules(model.get_model().config) 102 | 103 | lora_r = args.lora_r 104 | if lora_r > 0: 105 | 106 | def find_linear_layers(model, lora_target_modules): 107 | cls = torch.nn.Linear 108 | lora_module_names = set() 109 | for name, module in model.named_modules(): 110 | if ( 111 | isinstance(module, cls) 112 | and all( 113 | [ 114 | x not in name 115 | for x in [ 116 | "visual_model", 117 | "vision_tower", 118 | "mm_projector", 119 | "text_hidden_fcs", 120 | ] 121 | ] 122 | ) 123 | and any([x in name for x in lora_target_modules]) 124 | ): 125 | lora_module_names.add(name) 126 | return sorted(list(lora_module_names)) 127 | 128 | lora_alpha = args.lora_alpha 129 | lora_dropout = args.lora_dropout 130 | lora_target_modules = find_linear_layers( 131 | model, args.lora_target_modules.split(",") 132 | ) 133 | lora_config = LoraConfig( 134 | r=lora_r, 135 | lora_alpha=lora_alpha, 136 | target_modules=lora_target_modules, 137 | lora_dropout=lora_dropout, 138 | bias="none", 139 | task_type="CAUSAL_LM", 140 | ) 141 | model = get_peft_model(model, lora_config) 142 | model.print_trainable_parameters() 143 | 144 | model.resize_token_embeddings(len(tokenizer)) 145 | 146 | state_dict = torch.load(args.weight, map_location="cpu") 147 | model.load_state_dict(state_dict, strict=True) 148 | 149 | model = model.merge_and_unload() 150 | state_dict = {} 151 | for k, v in model.state_dict().items(): 152 | if "vision_tower" not in k: 153 | state_dict[k] = v 154 | model.save_pretrained(args.save_path, state_dict=state_dict) 155 | tokenizer.save_pretrained(args.save_path) 156 | 157 | 158 | if __name__ == "__main__": 159 | main(sys.argv[1:]) 160 | -------------------------------------------------------------------------------- /model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /model/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /model/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import StoppingCriteria 7 | 8 | from .constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def process_images(images, image_processor, model_cfg): 16 | return image_processor(images, return_tensors="pt")["pixel_values"] 17 | 18 | 19 | def tokenizer_image_token( 20 | prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None 21 | ): 22 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] 23 | 24 | def insert_separator(X, sep): 25 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 26 | 27 | input_ids = [] 28 | offset = 0 29 | if ( 30 | len(prompt_chunks) > 0 31 | and len(prompt_chunks[0]) > 0 32 | and prompt_chunks[0][0] == tokenizer.bos_token_id 33 | ): 34 | offset = 1 35 | input_ids.append(prompt_chunks[0][0]) 36 | 37 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 38 | input_ids.extend(x[offset:]) 39 | 40 | if return_tensors is not None: 41 | if return_tensors == "pt": 42 | return torch.tensor(input_ids, dtype=torch.long) 43 | raise ValueError(f"Unsupported tensor type: {return_tensors}") 44 | return input_ids 45 | 46 | 47 | def get_model_name_from_path(model_path): 48 | model_path = model_path.strip("/") 49 | model_paths = model_path.split("/") 50 | if model_paths[-1].startswith("checkpoint-"): 51 | return model_paths[-2] + "_" + model_paths[-1] 52 | else: 53 | return model_paths[-1] 54 | 55 | 56 | class KeywordsStoppingCriteria(StoppingCriteria): 57 | def __init__(self, keywords, tokenizer, input_ids): 58 | self.keywords = keywords 59 | self.keyword_ids = [] 60 | for keyword in keywords: 61 | cur_keyword_ids = tokenizer(keyword).input_ids 62 | if ( 63 | len(cur_keyword_ids) > 1 64 | and cur_keyword_ids[0] == tokenizer.bos_token_id 65 | ): 66 | cur_keyword_ids = cur_keyword_ids[1:] 67 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 68 | self.tokenizer = tokenizer 69 | self.start_len = input_ids.shape[1] 70 | 71 | def __call__( 72 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 73 | ) -> bool: 74 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 75 | offset = min(output_ids.shape[1] - self.start_len, 3) 76 | self.keyword_ids = [ 77 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids 78 | ] 79 | for keyword_id in self.keyword_ids: 80 | if output_ids[0, -keyword_id.shape[0] :] == keyword_id: 81 | return True 82 | outputs = self.tokenizer.batch_decode( 83 | output_ids[:, -offset:], skip_special_tokens=True 84 | )[0] 85 | for keyword in self.keywords: 86 | if keyword in outputs: 87 | return True 88 | return False 89 | -------------------------------------------------------------------------------- /model/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM 2 | from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM 3 | -------------------------------------------------------------------------------- /model/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava import LlavaLlamaForCausalLM 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading delta") 20 | delta = LlavaLlamaForCausalLM.from_pretrained( 21 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 22 | ) 23 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 24 | 25 | print("Applying delta") 26 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data += base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 42 | 43 | print("Saving target model") 44 | delta.save_pretrained(target_model_path) 45 | delta_tokenizer.save_pretrained(target_model_path) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--base-model-path", type=str, required=True) 51 | parser.add_argument("--target-model-path", type=str, required=True) 52 | parser.add_argument("--delta-path", type=str, required=True) 53 | 54 | args = parser.parse_args() 55 | 56 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 57 | -------------------------------------------------------------------------------- /model/llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import shutil 18 | 19 | import torch 20 | from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, 21 | DEFAULT_IMAGE_PATCH_TOKEN) 22 | from llava.model import * 23 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 24 | BitsAndBytesConfig) 25 | 26 | 27 | def load_pretrained_model( 28 | model_path, 29 | model_base, 30 | model_name, 31 | load_8bit=False, 32 | load_4bit=False, 33 | device_map="auto", 34 | ): 35 | kwargs = {"device_map": device_map} 36 | 37 | if load_8bit: 38 | kwargs["load_in_8bit"] = True 39 | elif load_4bit: 40 | kwargs["load_in_4bit"] = True 41 | kwargs["quantization_config"] = BitsAndBytesConfig( 42 | load_in_4bit=True, 43 | bnb_4bit_compute_dtype=torch.float16, 44 | bnb_4bit_use_double_quant=True, 45 | bnb_4bit_quant_type="nf4", 46 | ) 47 | else: 48 | kwargs["torch_dtype"] = torch.float16 49 | 50 | if "llava" in model_name.lower(): 51 | # Load LLaVA model 52 | if "lora" in model_name.lower() and model_base is not None: 53 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 54 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 55 | print("Loading LLaVA from base model...") 56 | model = LlavaLlamaForCausalLM.from_pretrained( 57 | model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs 58 | ) 59 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 60 | if model.lm_head.weight.shape[0] != token_num: 61 | model.lm_head.weight = torch.nn.Parameter( 62 | torch.empty( 63 | token_num, tokem_dim, device=model.device, dtype=model.dtype 64 | ) 65 | ) 66 | model.model.embed_tokens.weight = torch.nn.Parameter( 67 | torch.empty( 68 | token_num, tokem_dim, device=model.device, dtype=model.dtype 69 | ) 70 | ) 71 | 72 | print("Loading additional LLaVA weights...") 73 | if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): 74 | non_lora_trainables = torch.load( 75 | os.path.join(model_path, "non_lora_trainables.bin"), 76 | map_location="cpu", 77 | ) 78 | else: 79 | # this is probably from HF Hub 80 | from huggingface_hub import hf_hub_download 81 | 82 | def load_from_hf(repo_id, filename, subfolder=None): 83 | cache_file = hf_hub_download( 84 | repo_id=repo_id, filename=filename, subfolder=subfolder 85 | ) 86 | return torch.load(cache_file, map_location="cpu") 87 | 88 | non_lora_trainables = load_from_hf( 89 | model_path, "non_lora_trainables.bin" 90 | ) 91 | non_lora_trainables = { 92 | (k[11:] if k.startswith("base_model.") else k): v 93 | for k, v in non_lora_trainables.items() 94 | } 95 | if any(k.startswith("model.model.") for k in non_lora_trainables): 96 | non_lora_trainables = { 97 | (k[6:] if k.startswith("model.") else k): v 98 | for k, v in non_lora_trainables.items() 99 | } 100 | model.load_state_dict(non_lora_trainables, strict=False) 101 | 102 | from peft import PeftModel 103 | 104 | print("Loading LoRA weights...") 105 | model = PeftModel.from_pretrained(model, model_path) 106 | print("Merging LoRA weights...") 107 | model = model.merge_and_unload() 108 | print("Model is loaded...") 109 | elif model_base is not None: 110 | # this may be mm projector only 111 | print("Loading LLaVA from base model...") 112 | if "mpt" in model_name.lower(): 113 | if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")): 114 | shutil.copyfile( 115 | os.path.join(model_base, "configuration_mpt.py"), 116 | os.path.join(model_path, "configuration_mpt.py"), 117 | ) 118 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 119 | cfg_pretrained = AutoConfig.from_pretrained( 120 | model_path, trust_remote_code=True 121 | ) 122 | model = LlavaMPTForCausalLM.from_pretrained( 123 | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs 124 | ) 125 | else: 126 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 127 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 128 | model = LlavaLlamaForCausalLM.from_pretrained( 129 | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs 130 | ) 131 | 132 | mm_projector_weights = torch.load( 133 | os.path.join(model_path, "mm_projector.bin"), map_location="cpu" 134 | ) 135 | mm_projector_weights = { 136 | k: v.to(torch.float16) for k, v in mm_projector_weights.items() 137 | } 138 | model.load_state_dict(mm_projector_weights, strict=False) 139 | else: 140 | if "mpt" in model_name.lower(): 141 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 142 | model = LlavaMPTForCausalLM.from_pretrained( 143 | model_path, low_cpu_mem_usage=True, **kwargs 144 | ) 145 | else: 146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 147 | model = LlavaLlamaForCausalLM.from_pretrained( 148 | model_path, low_cpu_mem_usage=True, **kwargs 149 | ) 150 | else: 151 | # Load language model 152 | if model_base is not None: 153 | # PEFT model 154 | from peft import PeftModel 155 | 156 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 157 | model = AutoModelForCausalLM.from_pretrained( 158 | model_base, 159 | torch_dtype=torch.float16, 160 | low_cpu_mem_usage=True, 161 | device_map="auto", 162 | ) 163 | print(f"Loading LoRA weights from {model_path}") 164 | model = PeftModel.from_pretrained(model, model_path) 165 | print(f"Merging weights") 166 | model = model.merge_and_unload() 167 | print("Convert to FP16...") 168 | model.to(torch.float16) 169 | else: 170 | use_fast = False 171 | if "mpt" in model_name.lower(): 172 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 173 | model = AutoModelForCausalLM.from_pretrained( 174 | model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs 175 | ) 176 | else: 177 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 178 | model = AutoModelForCausalLM.from_pretrained( 179 | model_path, low_cpu_mem_usage=True, **kwargs 180 | ) 181 | 182 | image_processor = None 183 | 184 | if "llava" in model_name.lower(): 185 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 186 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 187 | if mm_use_im_patch_token: 188 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 189 | if mm_use_im_start_end: 190 | tokenizer.add_tokens( 191 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True 192 | ) 193 | model.resize_token_embeddings(len(tokenizer)) 194 | 195 | vision_tower = model.get_vision_tower() 196 | if not vision_tower.is_loaded: 197 | vision_tower.load_model() 198 | vision_tower.to(device="cuda", dtype=torch.float16) 199 | image_processor = vision_tower.image_processor 200 | 201 | if hasattr(model.config, "max_sequence_length"): 202 | context_len = model.config.max_sequence_length 203 | else: 204 | context_len = 2048 205 | 206 | return tokenizer, model, image_processor, context_len 207 | -------------------------------------------------------------------------------- /model/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model import * 9 | from llava.model.utils import auto_upgrade 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained( 17 | src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 20 | src_model.save_pretrained(dst_path) 21 | src_tokenizer.save_pretrained(dst_path) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--src", type=str, required=True) 27 | parser.add_argument("--dst", type=str, required=True) 28 | 29 | args = parser.parse_args() 30 | 31 | consolidate_ckpt(args.src, args.dst) 32 | -------------------------------------------------------------------------------- /model/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig, 22 | LlamaForCausalLM, LlamaModel) 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel 26 | 27 | 28 | class LlavaConfig(LlamaConfig): 29 | model_type = "llava" 30 | 31 | 32 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 33 | config_class = LlavaConfig 34 | 35 | def __init__(self, config: LlamaConfig): 36 | super(LlavaLlamaModel, self).__init__(config) 37 | 38 | 39 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaConfig 41 | 42 | def __init__(self, config): 43 | super(LlamaForCausalLM, self).__init__(config) 44 | 45 | self.model = LlavaLlamaModel(config) 46 | 47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 48 | 49 | # Initialize weights and apply final processing 50 | self.post_init() 51 | 52 | def get_model(self): 53 | return self.model 54 | 55 | def forward( 56 | self, 57 | input_ids: torch.LongTensor = None, 58 | attention_mask: Optional[torch.Tensor] = None, 59 | past_key_values: Optional[List[torch.FloatTensor]] = None, 60 | inputs_embeds: Optional[torch.FloatTensor] = None, 61 | labels: Optional[torch.LongTensor] = None, 62 | use_cache: Optional[bool] = None, 63 | output_attentions: Optional[bool] = None, 64 | output_hidden_states: Optional[bool] = None, 65 | images: Optional[torch.FloatTensor] = None, 66 | return_dict: Optional[bool] = None, 67 | ) -> Union[Tuple, CausalLMOutputWithPast]: 68 | output_attentions = ( 69 | output_attentions 70 | if output_attentions is not None 71 | else self.config.output_attentions 72 | ) 73 | output_hidden_states = ( 74 | output_hidden_states 75 | if output_hidden_states is not None 76 | else self.config.output_hidden_states 77 | ) 78 | return_dict = ( 79 | return_dict if return_dict is not None else self.config.use_return_dict 80 | ) 81 | 82 | ( 83 | input_ids, 84 | attention_mask, 85 | past_key_values, 86 | inputs_embeds, 87 | labels, 88 | ) = self.prepare_inputs_labels_for_multimodal( 89 | input_ids, attention_mask, past_key_values, labels, images 90 | ) 91 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 92 | 93 | outputs = self.model( 94 | input_ids=input_ids, 95 | attention_mask=attention_mask, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict, 102 | ) 103 | 104 | hidden_states = outputs[0] 105 | logits = self.lm_head(hidden_states) 106 | 107 | loss = None 108 | if labels is not None: 109 | # Shift so that tokens < n predict n 110 | shift_logits = logits[..., :-1, :].contiguous() 111 | shift_labels = labels[..., 1:].contiguous() 112 | # Flatten the tokens 113 | loss_fct = CrossEntropyLoss() 114 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 115 | shift_labels = shift_labels.view(-1) 116 | # Enable model/pipeline parallelism 117 | shift_labels = shift_labels.to(shift_logits.device) 118 | loss = loss_fct(shift_logits, shift_labels) 119 | 120 | if not return_dict: 121 | output = (logits,) + outputs[1:] 122 | return (loss,) + output if loss is not None else output 123 | 124 | if self.training: 125 | output_hidden_states = outputs.hidden_states 126 | else: 127 | output_hidden_states = hidden_states 128 | 129 | return CausalLMOutputWithPast( 130 | loss=loss, 131 | logits=logits, 132 | past_key_values=outputs.past_key_values, 133 | hidden_states=output_hidden_states, # outputs.hidden_states, 134 | attentions=outputs.attentions, 135 | ) 136 | 137 | def prepare_inputs_for_generation( 138 | self, 139 | input_ids, 140 | past_key_values=None, 141 | attention_mask=None, 142 | inputs_embeds=None, 143 | images=None, 144 | **kwargs 145 | ): 146 | if past_key_values: 147 | input_ids = input_ids[:, -1:] 148 | 149 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 150 | if inputs_embeds is not None and past_key_values is None: 151 | model_inputs = {"inputs_embeds": inputs_embeds} 152 | else: 153 | model_inputs = {"input_ids": input_ids} 154 | 155 | model_inputs.update( 156 | { 157 | "past_key_values": past_key_values, 158 | "use_cache": kwargs.get("use_cache"), 159 | "attention_mask": attention_mask, 160 | "images": images, 161 | } 162 | ) 163 | return model_inputs 164 | 165 | 166 | AutoConfig.register("llava", LlavaConfig) 167 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 168 | -------------------------------------------------------------------------------- /model/llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import warnings 18 | from typing import List, Optional, Tuple 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from transformers import AutoConfig, AutoModelForCausalLM 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | 28 | 29 | class LlavaMPTConfig(MPTConfig): 30 | model_type = "llava_mpt" 31 | 32 | 33 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 34 | config_class = LlavaMPTConfig 35 | 36 | def __init__(self, config: MPTConfig): 37 | config.hidden_size = config.d_model 38 | super(LlavaMPTModel, self).__init__(config) 39 | 40 | def embed_tokens(self, x): 41 | return self.wte(x) 42 | 43 | 44 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 45 | config_class = LlavaMPTConfig 46 | supports_gradient_checkpointing = True 47 | 48 | def __init__(self, config): 49 | super(MPTForCausalLM, self).__init__(config) 50 | 51 | if not config.tie_word_embeddings: 52 | raise ValueError("MPTForCausalLM only supports tied word embeddings") 53 | self.transformer = LlavaMPTModel(config) 54 | self.logit_scale = None 55 | if config.logit_scale is not None: 56 | logit_scale = config.logit_scale 57 | if isinstance(logit_scale, str): 58 | if logit_scale == "inv_sqrt_d_model": 59 | logit_scale = 1 / math.sqrt(config.d_model) 60 | else: 61 | raise ValueError( 62 | f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." 63 | ) 64 | self.logit_scale = logit_scale 65 | 66 | def get_model(self): 67 | return self.transformer 68 | 69 | def _set_gradient_checkpointing(self, module, value=False): 70 | if isinstance(module, LlavaMPTModel): 71 | module.gradient_checkpointing = value 72 | 73 | def forward( 74 | self, 75 | input_ids: torch.LongTensor, 76 | past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, 77 | attention_mask: Optional[torch.ByteTensor] = None, 78 | prefix_mask: Optional[torch.ByteTensor] = None, 79 | sequence_id: Optional[torch.LongTensor] = None, 80 | labels: Optional[torch.LongTensor] = None, 81 | return_dict: Optional[bool] = None, 82 | output_attentions: Optional[bool] = None, 83 | output_hidden_states: Optional[bool] = None, 84 | use_cache: Optional[bool] = None, 85 | images=None, 86 | ): 87 | return_dict = ( 88 | return_dict if return_dict is not None else self.config.return_dict 89 | ) 90 | use_cache = use_cache if use_cache is not None else self.config.use_cache 91 | 92 | ( 93 | input_ids, 94 | attention_mask, 95 | past_key_values, 96 | inputs_embeds, 97 | labels, 98 | ) = self.prepare_inputs_labels_for_multimodal( 99 | input_ids, attention_mask, past_key_values, labels, images 100 | ) 101 | outputs = self.transformer( 102 | input_ids=input_ids, 103 | inputs_embeds=inputs_embeds, 104 | past_key_values=past_key_values, 105 | attention_mask=attention_mask, 106 | prefix_mask=prefix_mask, 107 | sequence_id=sequence_id, 108 | return_dict=return_dict, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | use_cache=use_cache, 112 | ) 113 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 114 | logits = F.linear( 115 | outputs.last_hidden_state.to(self.transformer.wte.weight.device), 116 | self.transformer.wte.weight, 117 | ) 118 | if self.logit_scale is not None: 119 | if self.logit_scale == 0: 120 | warnings.warn( 121 | f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." 122 | ) 123 | logits *= self.logit_scale 124 | loss = None 125 | if labels is not None: 126 | labels = torch.roll(labels, shifts=-1) 127 | labels[:, -1] = -100 128 | loss = F.cross_entropy( 129 | logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) 130 | ) 131 | return CausalLMOutputWithPast( 132 | loss=loss, 133 | logits=logits, 134 | past_key_values=outputs.past_key_values, 135 | hidden_states=outputs.hidden_states, 136 | ) 137 | 138 | def prepare_inputs_for_generation( 139 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 140 | ): 141 | if inputs_embeds is not None: 142 | raise NotImplementedError("inputs_embeds is not implemented for MPT yet") 143 | attention_mask = kwargs["attention_mask"].bool() 144 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 145 | raise NotImplementedError( 146 | "MPT does not support generation with right padding." 147 | ) 148 | if self.transformer.attn_uses_sequence_id and self.training: 149 | sequence_id = torch.zeros_like(input_ids[:1]) 150 | else: 151 | sequence_id = None 152 | if past_key_values is not None: 153 | input_ids = input_ids[:, -1].unsqueeze(-1) 154 | if self.transformer.prefix_lm: 155 | prefix_mask = torch.ones_like(attention_mask) 156 | if kwargs.get("use_cache") == False: 157 | raise NotImplementedError( 158 | "MPT with prefix_lm=True does not support use_cache=False." 159 | ) 160 | else: 161 | prefix_mask = None 162 | return { 163 | "input_ids": input_ids, 164 | "attention_mask": attention_mask, 165 | "prefix_mask": prefix_mask, 166 | "sequence_id": sequence_id, 167 | "past_key_values": past_key_values, 168 | "use_cache": kwargs.get("use_cache", True), 169 | "images": kwargs.get("images", None), 170 | } 171 | 172 | 173 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 174 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 175 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from transformers import (AutoTokenizer, PreTrainedTokenizer, 4 | PreTrainedTokenizerFast) 5 | 6 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 7 | NUM_SENTINEL_TOKENS: int = 100 8 | 9 | 10 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 11 | """Adds sentinel tokens and padding token (if missing). 12 | 13 | Expands the tokenizer vocabulary to include sentinel tokens 14 | used in mixture-of-denoiser tasks as well as a padding token. 15 | 16 | All added tokens are added as special tokens. No tokens are 17 | added if sentinel tokens and padding token already exist. 18 | """ 19 | sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)] 20 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 21 | if tokenizer.pad_token is None: 22 | tokenizer.add_tokens("", special_tokens=True) 23 | tokenizer.pad_token = "" 24 | assert tokenizer.pad_token_id is not None 25 | sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)]) 26 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 27 | tokenizer.sentinel_token_ids = _sentinel_token_ids 28 | 29 | 30 | class AutoTokenizerForMOD(AutoTokenizer): 31 | """AutoTokenizer + Adaptation for MOD. 32 | 33 | A simple wrapper around AutoTokenizer to make instantiating 34 | an MOD-adapted tokenizer a bit easier. 35 | 36 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 37 | a padding token, and a property to get the token ids of the 38 | sentinel tokens. 39 | """ 40 | 41 | @classmethod 42 | def from_pretrained(cls, *args, **kwargs): 43 | """See `AutoTokenizer.from_pretrained` docstring.""" 44 | tokenizer = super().from_pretrained(*args, **kwargs) 45 | adapt_tokenizer_for_denoising(tokenizer) 46 | return tokenizer 47 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .attention import ATTN_CLASS_REGISTRY 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | 11 | class MPTMLP(nn.Module): 12 | def __init__( 13 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None 14 | ): 15 | super().__init__() 16 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 17 | self.act = nn.GELU(approximate="none") 18 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 19 | self.down_proj._is_residual = True 20 | 21 | def forward(self, x): 22 | return self.down_proj(self.act(self.up_proj(x))) 23 | 24 | 25 | class MPTBlock(nn.Module): 26 | def __init__( 27 | self, 28 | d_model: int, 29 | n_heads: int, 30 | expansion_ratio: int, 31 | attn_config: Dict = { 32 | "attn_type": "multihead_attention", 33 | "attn_pdrop": 0.0, 34 | "attn_impl": "triton", 35 | "qk_ln": False, 36 | "clip_qkv": None, 37 | "softmax_scale": None, 38 | "prefix_lm": False, 39 | "attn_uses_sequence_id": False, 40 | "alibi": False, 41 | "alibi_bias_max": 8, 42 | }, 43 | resid_pdrop: float = 0.0, 44 | norm_type: str = "low_precision_layernorm", 45 | verbose: int = 0, 46 | device: Optional[str] = None, 47 | **kwargs 48 | ): 49 | del kwargs 50 | super().__init__() 51 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 52 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] 53 | self.norm_1 = norm_class(d_model, device=device) 54 | self.attn = attn_class( 55 | attn_impl=attn_config["attn_impl"], 56 | clip_qkv=attn_config["clip_qkv"], 57 | qk_ln=attn_config["qk_ln"], 58 | softmax_scale=attn_config["softmax_scale"], 59 | attn_pdrop=attn_config["attn_pdrop"], 60 | d_model=d_model, 61 | n_heads=n_heads, 62 | verbose=verbose, 63 | device=device, 64 | ) 65 | self.norm_2 = norm_class(d_model, device=device) 66 | self.ffn = MPTMLP( 67 | d_model=d_model, expansion_ratio=expansion_ratio, device=device 68 | ) 69 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 70 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 71 | 72 | def forward( 73 | self, 74 | x: torch.Tensor, 75 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 76 | attn_bias: Optional[torch.Tensor] = None, 77 | attention_mask: Optional[torch.ByteTensor] = None, 78 | is_causal: bool = True, 79 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 80 | a = self.norm_1(x) 81 | (b, attn_weights, past_key_value) = self.attn( 82 | a, 83 | past_key_value=past_key_value, 84 | attn_bias=attn_bias, 85 | attention_mask=attention_mask, 86 | is_causal=is_causal, 87 | ) 88 | x = x + self.resid_attn_dropout(b) 89 | m = self.norm_2(x) 90 | n = self.ffn(m) 91 | x = x + self.resid_ffn_dropout(n) 92 | return (x, attn_weights, past_key_value) 93 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | 4 | from transformers import PretrainedConfig 5 | 6 | attn_config_defaults: Dict = { 7 | "attn_type": "multihead_attention", 8 | "attn_pdrop": 0.0, 9 | "attn_impl": "triton", 10 | "qk_ln": False, 11 | "clip_qkv": None, 12 | "softmax_scale": None, 13 | "prefix_lm": False, 14 | "attn_uses_sequence_id": False, 15 | "alibi": False, 16 | "alibi_bias_max": 8, 17 | } 18 | init_config_defaults: Dict = { 19 | "name": "kaiming_normal_", 20 | "fan_mode": "fan_in", 21 | "init_nonlinearity": "relu", 22 | "init_div_is_residual": True, 23 | "emb_init_std": None, 24 | "emb_init_uniform_lim": None, 25 | "init_std": None, 26 | "init_gain": 0.0, 27 | } 28 | 29 | 30 | class MPTConfig(PretrainedConfig): 31 | model_type = "mpt" 32 | 33 | def __init__( 34 | self, 35 | d_model: int = 2048, 36 | n_heads: int = 16, 37 | n_layers: int = 24, 38 | expansion_ratio: int = 4, 39 | max_seq_len: int = 2048, 40 | vocab_size: int = 50368, 41 | resid_pdrop: float = 0.0, 42 | emb_pdrop: float = 0.0, 43 | learned_pos_emb: bool = True, 44 | attn_config: Dict = attn_config_defaults, 45 | init_device: str = "cpu", 46 | logit_scale: Optional[Union[float, str]] = None, 47 | no_bias: bool = False, 48 | verbose: int = 0, 49 | embedding_fraction: float = 1.0, 50 | norm_type: str = "low_precision_layernorm", 51 | use_cache: bool = False, 52 | init_config: Dict = init_config_defaults, 53 | **kwargs, 54 | ): 55 | """The MPT configuration class. 56 | 57 | Args: 58 | d_model (int): The size of the embedding dimension of the model. 59 | n_heads (int): The number of attention heads. 60 | n_layers (int): The number of layers in the model. 61 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 62 | max_seq_len (int): The maximum sequence length of the model. 63 | vocab_size (int): The size of the vocabulary. 64 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 65 | emb_pdrop (float): The dropout probability for the embedding layer. 66 | learned_pos_emb (bool): Whether to use learned positional embeddings 67 | attn_config (Dict): A dictionary used to configure the model's attention module: 68 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 69 | attn_pdrop (float): The dropout probability for the attention layers. 70 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 71 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 72 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 73 | this value. 74 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 75 | use the default scale of ``1/sqrt(d_keys)``. 76 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 77 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 78 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 79 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 80 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 81 | which sub-sequence each token belongs to. 82 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 83 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 84 | alibi_bias_max (int): The maximum value of the alibi bias. 85 | init_device (str): The device to use for parameter initialization. 86 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 87 | no_bias (bool): Whether to use bias in all layers. 88 | verbose (int): The verbosity level. 0 is silent. 89 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 90 | norm_type (str): choose type of norm to use 91 | multiquery_attention (bool): Whether to use multiquery attention implementation. 92 | use_cache (bool): Whether or not the model should return the last key/values attentions 93 | init_config (Dict): A dictionary used to configure the model initialization: 94 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 95 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 96 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 97 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 98 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 99 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 100 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 101 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 102 | if using the baseline_ parameter initialization scheme. 103 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 104 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 105 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 106 | --- 107 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 108 | """ 109 | self.d_model = d_model 110 | self.n_heads = n_heads 111 | self.n_layers = n_layers 112 | self.expansion_ratio = expansion_ratio 113 | self.max_seq_len = max_seq_len 114 | self.vocab_size = vocab_size 115 | self.resid_pdrop = resid_pdrop 116 | self.emb_pdrop = emb_pdrop 117 | self.learned_pos_emb = learned_pos_emb 118 | self.attn_config = attn_config 119 | self.init_device = init_device 120 | self.logit_scale = logit_scale 121 | self.no_bias = no_bias 122 | self.verbose = verbose 123 | self.embedding_fraction = embedding_fraction 124 | self.norm_type = norm_type 125 | self.use_cache = use_cache 126 | self.init_config = init_config 127 | if "name" in kwargs: 128 | del kwargs["name"] 129 | if "loss_fn" in kwargs: 130 | del kwargs["loss_fn"] 131 | super().__init__(**kwargs) 132 | self._validate_config() 133 | 134 | def _set_config_defaults(self, config, config_defaults): 135 | for k, v in config_defaults.items(): 136 | if k not in config: 137 | config[k] = v 138 | return config 139 | 140 | def _validate_config(self): 141 | self.attn_config = self._set_config_defaults( 142 | self.attn_config, attn_config_defaults 143 | ) 144 | self.init_config = self._set_config_defaults( 145 | self.init_config, init_config_defaults 146 | ) 147 | if self.d_model % self.n_heads != 0: 148 | raise ValueError("d_model must be divisible by n_heads") 149 | if any( 150 | ( 151 | prob < 0 or prob > 1 152 | for prob in [ 153 | self.attn_config["attn_pdrop"], 154 | self.resid_pdrop, 155 | self.emb_pdrop, 156 | ] 157 | ) 158 | ): 159 | raise ValueError( 160 | "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" 161 | ) 162 | if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]: 163 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 164 | if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [ 165 | "torch", 166 | "triton", 167 | ]: 168 | raise NotImplementedError( 169 | "prefix_lm only implemented with torch and triton attention." 170 | ) 171 | if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [ 172 | "torch", 173 | "triton", 174 | ]: 175 | raise NotImplementedError( 176 | "alibi only implemented with torch and triton attention." 177 | ) 178 | if self.attn_config["attn_uses_sequence_id"] and self.attn_config[ 179 | "attn_impl" 180 | ] not in ["torch", "triton"]: 181 | raise NotImplementedError( 182 | "attn_uses_sequence_id only implemented with torch and triton attention." 183 | ) 184 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 185 | raise ValueError( 186 | "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!" 187 | ) 188 | if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model": 189 | raise ValueError( 190 | f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." 191 | ) 192 | if self.init_config.get("name", None) is None: 193 | raise ValueError( 194 | f"self.init_config={self.init_config!r} 'name' needs to be set." 195 | ) 196 | if not self.learned_pos_emb and (not self.attn_config["alibi"]): 197 | raise ValueError( 198 | f"Positional information must be provided to the model using either learned_pos_emb or alibi." 199 | ) 200 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | 7 | class SharedEmbedding(nn.Embedding): 8 | def forward(self, input: Tensor, unembed: bool = False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) 12 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | @contextmanager 8 | def init_empty_weights(include_buffers: bool = False): 9 | """Meta initialization context manager. 10 | 11 | A context manager under which models are initialized with all parameters 12 | on the meta device, therefore creating an empty model. Useful when just 13 | initializing the model would blow the available RAM. 14 | 15 | Args: 16 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 17 | not to also put all buffers on the meta device while initializing. 18 | 19 | Example: 20 | ```python 21 | import torch.nn as nn 22 | 23 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 24 | with init_empty_weights(): 25 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 26 | ``` 27 | 28 | 29 | 30 | Any model created under this context manager has no weights. As such you can't do something like 31 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 32 | 33 | 34 | """ 35 | with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: 36 | yield f 37 | 38 | 39 | @contextmanager 40 | def init_on_device(device: torch.device, include_buffers: bool = False): 41 | """Device initialization context manager. 42 | 43 | A context manager under which models are initialized with all parameters 44 | on the specified device. 45 | 46 | Args: 47 | device (`torch.device`): Device to initialize all parameters on. 48 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 49 | not to also put all buffers on the meta device while initializing. 50 | 51 | Example: 52 | ```python 53 | import torch.nn as nn 54 | 55 | with init_on_device(device=torch.device("cuda")): 56 | tst = nn.Liner(100, 100) # on `cuda` device 57 | ``` 58 | """ 59 | old_register_parameter = nn.Module.register_parameter 60 | if include_buffers: 61 | old_register_buffer = nn.Module.register_buffer 62 | 63 | def register_empty_parameter(module, name, param): 64 | old_register_parameter(module, name, param) 65 | if param is not None: 66 | param_cls = type(module._parameters[name]) 67 | kwargs = module._parameters[name].__dict__ 68 | module._parameters[name] = param_cls( 69 | module._parameters[name].to(device), **kwargs 70 | ) 71 | 72 | def register_empty_buffer(module, name, buffer): 73 | old_register_buffer(module, name, buffer) 74 | if buffer is not None: 75 | module._buffers[name] = module._buffers[name].to(device) 76 | 77 | if include_buffers: 78 | tensor_constructors_to_patch = { 79 | torch_function_name: getattr(torch, torch_function_name) 80 | for torch_function_name in ["empty", "zeros", "ones", "full"] 81 | } 82 | else: 83 | tensor_constructors_to_patch = {} 84 | 85 | def patch_tensor_constructor(fn): 86 | def wrapper(*args, **kwargs): 87 | kwargs["device"] = device 88 | return fn(*args, **kwargs) 89 | 90 | return wrapper 91 | 92 | try: 93 | nn.Module.register_parameter = register_empty_parameter 94 | if include_buffers: 95 | nn.Module.register_buffer = register_empty_buffer 96 | for torch_function_name in tensor_constructors_to_patch.keys(): 97 | setattr( 98 | torch, 99 | torch_function_name, 100 | patch_tensor_constructor(getattr(torch, torch_function_name)), 101 | ) 102 | yield 103 | finally: 104 | nn.Module.register_parameter = old_register_parameter 105 | if include_buffers: 106 | nn.Module.register_buffer = old_register_buffer 107 | for ( 108 | torch_function_name, 109 | old_torch_function, 110 | ) in tensor_constructors_to_patch.items(): 111 | setattr(torch, torch_function_name, old_torch_function) 112 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _cast_if_autocast_enabled(tensor): 5 | if torch.is_autocast_enabled(): 6 | if tensor.device.type == "cuda": 7 | dtype = torch.get_autocast_gpu_dtype() 8 | elif tensor.device.type == "cpu": 9 | dtype = torch.get_autocast_cpu_dtype() 10 | else: 11 | raise NotImplementedError() 12 | return tensor.to(dtype=dtype) 13 | return tensor 14 | 15 | 16 | class LPLayerNorm(torch.nn.LayerNorm): 17 | def __init__( 18 | self, 19 | normalized_shape, 20 | eps=1e-05, 21 | elementwise_affine=True, 22 | device=None, 23 | dtype=None, 24 | ): 25 | super().__init__( 26 | normalized_shape=normalized_shape, 27 | eps=eps, 28 | elementwise_affine=elementwise_affine, 29 | device=device, 30 | dtype=dtype, 31 | ) 32 | 33 | def forward(self, x): 34 | module_device = x.device 35 | downcast_x = _cast_if_autocast_enabled(x) 36 | downcast_weight = ( 37 | _cast_if_autocast_enabled(self.weight) 38 | if self.weight is not None 39 | else self.weight 40 | ) 41 | downcast_bias = ( 42 | _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 43 | ) 44 | with torch.autocast(enabled=False, device_type=module_device.type): 45 | return torch.nn.functional.layer_norm( 46 | downcast_x, 47 | self.normalized_shape, 48 | downcast_weight, 49 | downcast_bias, 50 | self.eps, 51 | ) 52 | 53 | 54 | def rms_norm(x, weight=None, eps=1e-05): 55 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 56 | if weight is not None: 57 | return output * weight 58 | return output 59 | 60 | 61 | class RMSNorm(torch.nn.Module): 62 | def __init__( 63 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 64 | ): 65 | super().__init__() 66 | self.eps = eps 67 | if weight: 68 | self.weight = torch.nn.Parameter( 69 | torch.ones(normalized_shape, dtype=dtype, device=device) 70 | ) 71 | else: 72 | self.register_parameter("weight", None) 73 | 74 | def forward(self, x): 75 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 76 | 77 | 78 | class LPRMSNorm(RMSNorm): 79 | def __init__( 80 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 81 | ): 82 | super().__init__( 83 | normalized_shape=normalized_shape, 84 | eps=eps, 85 | weight=weight, 86 | dtype=dtype, 87 | device=device, 88 | ) 89 | 90 | def forward(self, x): 91 | downcast_x = _cast_if_autocast_enabled(x) 92 | downcast_weight = ( 93 | _cast_if_autocast_enabled(self.weight) 94 | if self.weight is not None 95 | else self.weight 96 | ) 97 | with torch.autocast(enabled=False, device_type=x.device.type): 98 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 99 | 100 | 101 | NORM_CLASS_REGISTRY = { 102 | "layernorm": torch.nn.LayerNorm, 103 | "low_precision_layernorm": LPLayerNorm, 104 | "rmsnorm": RMSNorm, 105 | "low_precision_rmsnorm": LPRMSNorm, 106 | } 107 | -------------------------------------------------------------------------------- /model/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model.utils import auto_upgrade 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading target model") 20 | auto_upgrade(target_model_path) 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | 25 | print("Calculating delta") 26 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data -= base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 42 | 43 | print("Saving delta") 44 | if hub_repo_id: 45 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 46 | else: 47 | kwargs = {} 48 | target.save_pretrained(delta_path, **kwargs) 49 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 50 | target_tokenizer.save_pretrained(delta_path, **kwargs) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--base-model-path", type=str, required=True) 56 | parser.add_argument("--target-model-path", type=str, required=True) 57 | parser.add_argument("--delta-path", type=str, required=True) 58 | parser.add_argument("--hub-repo-id", type=str, default=None) 59 | args = parser.parse_args() 60 | 61 | make_delta( 62 | args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id 63 | ) 64 | -------------------------------------------------------------------------------- /model/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr( 6 | vision_tower_cfg, 7 | "mm_vision_tower", 8 | getattr(vision_tower_cfg, "vision_tower", None), 9 | ) 10 | if ( 11 | vision_tower.startswith("openai") 12 | or vision_tower.startswith("laion") 13 | or "clip" in vision_tower 14 | ): 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | 17 | raise ValueError(f"Unknown vision tower: {vision_tower}") 18 | -------------------------------------------------------------------------------- /model/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel 4 | 5 | 6 | class CLIPVisionTower(nn.Module): 7 | def __init__(self, vision_tower, args, delay_load=False): 8 | super().__init__() 9 | 10 | self.is_loaded = False 11 | 12 | self.vision_tower_name = vision_tower 13 | self.select_layer = args.mm_vision_select_layer 14 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 15 | 16 | if not delay_load: 17 | self.load_model() 18 | else: 19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 20 | 21 | def load_model(self): 22 | self.image_processor = CLIPImageProcessor.from_pretrained( 23 | self.vision_tower_name 24 | ) 25 | self.vision_tower = CLIPVisionModel.from_pretrained( 26 | self.vision_tower_name, low_cpu_mem_usage=True 27 | ) 28 | self.vision_tower.requires_grad_(False) 29 | self.is_loaded = True 30 | 31 | def feature_select(self, image_forward_outs): 32 | image_features = image_forward_outs.hidden_states[self.select_layer] 33 | if self.select_feature == "patch": 34 | image_features = image_features[:, 1:] 35 | elif self.select_feature == "cls_patch": 36 | image_features = image_features 37 | else: 38 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 39 | return image_features 40 | 41 | @torch.no_grad() 42 | def forward(self, images): 43 | if type(images) is list: 44 | image_features = [] 45 | for image in images: 46 | image_forward_out = self.vision_tower( 47 | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 48 | output_hidden_states=True, 49 | ) 50 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 51 | image_features.append(image_feature) 52 | else: 53 | image_forward_outs = self.vision_tower( 54 | images.to(device=self.device, dtype=self.dtype), 55 | output_hidden_states=True, 56 | ) 57 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 58 | 59 | torch.cuda.empty_cache() 60 | return image_features 61 | 62 | @property 63 | def dummy_feature(self): 64 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 65 | 66 | @property 67 | def dtype(self): 68 | return self.vision_tower.dtype 69 | 70 | @property 71 | def device(self): 72 | return self.vision_tower.device 73 | 74 | @property 75 | def config(self): 76 | if self.is_loaded: 77 | return self.vision_tower.config 78 | else: 79 | return self.cfg_only 80 | 81 | @property 82 | def hidden_size(self): 83 | return self.config.hidden_size 84 | 85 | @property 86 | def num_patches(self): 87 | return (self.config.image_size // self.config.patch_size) ** 2 88 | -------------------------------------------------------------------------------- /model/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print( 9 | "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." 10 | ) 11 | print( 12 | "You must upgrade the checkpoint to the new code base (this can be done automatically)." 13 | ) 14 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 15 | if confirm.lower() in ["y", "yes"]: 16 | print("Upgrading checkpoint...") 17 | assert len(cfg.architectures) == 1 18 | setattr(cfg.__class__, "model_type", "llava") 19 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 20 | cfg.save_pretrained(config) 21 | print("Checkpoint upgraded.") 22 | else: 23 | print("Checkpoint upgrade aborted.") 24 | exit(1) 25 | -------------------------------------------------------------------------------- /model/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | import transformers 6 | from einops import rearrange 7 | from torch import nn 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | try: 11 | from flash_attn.flash_attn_interface import \ 12 | flash_attn_unpadded_qkvpacked_func 13 | except ImportError: 14 | from flash_attn.flash_attn_interface import ( 15 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func, 16 | ) 17 | 18 | from flash_attn.bert_padding import pad_input, unpad_input 19 | 20 | 21 | def forward( 22 | self, 23 | hidden_states: torch.Tensor, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | position_ids: Optional[torch.Tensor] = None, 26 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 27 | output_attentions: bool = False, 28 | use_cache: bool = False, 29 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 30 | """Input shape: Batch x Time x Channel 31 | 32 | attention_mask: [bsz, q_len] 33 | """ 34 | bsz, q_len, _ = hidden_states.size() 35 | 36 | query_states = ( 37 | self.q_proj(hidden_states) 38 | .view(bsz, q_len, self.num_heads, self.head_dim) 39 | .transpose(1, 2) 40 | ) 41 | key_states = ( 42 | self.k_proj(hidden_states) 43 | .view(bsz, q_len, self.num_heads, self.head_dim) 44 | .transpose(1, 2) 45 | ) 46 | value_states = ( 47 | self.v_proj(hidden_states) 48 | .view(bsz, q_len, self.num_heads, self.head_dim) 49 | .transpose(1, 2) 50 | ) 51 | # [bsz, q_len, nh, hd] 52 | # [bsz, nh, q_len, hd] 53 | 54 | kv_seq_len = key_states.shape[-2] 55 | assert past_key_value is None, "past_key_value is not supported" 56 | 57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 58 | query_states, key_states = apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | assert not output_attentions, "output_attentions is not supported" 63 | assert not use_cache, "use_cache is not supported" 64 | 65 | # Flash attention codes from 66 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 67 | 68 | # transform the data into the format required by flash attention 69 | qkv = torch.stack( 70 | [query_states, key_states, value_states], dim=2 71 | ) # [bsz, nh, 3, q_len, hd] 72 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 73 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 74 | # the attention_mask should be the same as the key_padding_mask 75 | key_padding_mask = attention_mask 76 | 77 | if key_padding_mask is None: 78 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 79 | max_s = q_len 80 | cu_q_lens = torch.arange( 81 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 82 | ) 83 | output = flash_attn_unpadded_qkvpacked_func( 84 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 85 | ) 86 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 87 | else: 88 | nheads = qkv.shape[-2] 89 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 90 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 91 | x_unpad = rearrange( 92 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 93 | ) 94 | output_unpad = flash_attn_unpadded_qkvpacked_func( 95 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 96 | ) 97 | output = rearrange( 98 | pad_input( 99 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 100 | ), 101 | "b s (h d) -> b s h d", 102 | h=nheads, 103 | ) 104 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 105 | 106 | 107 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 108 | # requires the attention mask to be the same as the key_padding_mask 109 | def _prepare_decoder_attention_mask( 110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 111 | ): 112 | # [bsz, seq_len] 113 | return attention_mask 114 | 115 | 116 | def replace_llama_attn_with_flash_attn(): 117 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 118 | if cuda_major < 8: 119 | logging.warning( 120 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 121 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 122 | ) 123 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 124 | _prepare_decoder_attention_mask 125 | ) 126 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 127 | -------------------------------------------------------------------------------- /model/llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | from transformers import Trainer 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | 12 | if hasattr(param, "ds_id"): 13 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 14 | if not ignore_status: 15 | print(name, "no ignore status") 16 | with zero.GatheredParameters([param]): 17 | param = param.data.detach().cpu().clone() 18 | else: 19 | param = param.detach().cpu().clone() 20 | return param 21 | 22 | 23 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 24 | to_return = { 25 | k: t 26 | for k, t in named_params 27 | if any(key_match in k for key_match in keys_to_match) 28 | } 29 | to_return = { 30 | k: maybe_zero_3(v, ignore_status=True, name=k).cpu() 31 | for k, v in to_return.items() 32 | } 33 | return to_return 34 | 35 | 36 | class LLaVATrainer(Trainer): 37 | def _save_checkpoint(self, model, trial, metrics=None): 38 | if getattr(self.args, "tune_mm_mlp_adapter", False): 39 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 40 | 41 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 42 | 43 | run_dir = self._get_output_dir(trial=trial) 44 | output_dir = os.path.join(run_dir, checkpoint_folder) 45 | 46 | # Only save Adapter 47 | keys_to_match = ["mm_projector"] 48 | if getattr(self.args, "use_im_start_end", False): 49 | keys_to_match.extend(["embed_tokens", "embed_in"]) 50 | 51 | weight_to_save = get_mm_adapter_state_maybe_zero_3( 52 | self.model.named_parameters(), keys_to_match 53 | ) 54 | 55 | if self.args.local_rank == 0 or self.args.local_rank == -1: 56 | self.model.config.save_pretrained(output_dir) 57 | torch.save( 58 | weight_to_save, os.path.join(output_dir, f"mm_projector.bin") 59 | ) 60 | else: 61 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 62 | 63 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 64 | if getattr(self.args, "tune_mm_mlp_adapter", False): 65 | pass 66 | else: 67 | super(LLaVATrainer, self)._save(output_dir, state_dict) 68 | -------------------------------------------------------------------------------- /model/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import \ 7 | replace_llama_attn_with_flash_attn 8 | 9 | replace_llama_attn_with_flash_attn() 10 | 11 | from llava.train.train import train 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /model/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | from llava.constants import LOGDIR 9 | 10 | server_error_msg = ( 11 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | ) 13 | moderation_msg = ( 14 | "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 15 | ) 16 | 17 | handler = None 18 | 19 | 20 | def build_logger(logger_name, logger_filename): 21 | global handler 22 | 23 | formatter = logging.Formatter( 24 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | ) 27 | 28 | # Set the format of root handlers 29 | if not logging.getLogger().handlers: 30 | logging.basicConfig(level=logging.INFO) 31 | logging.getLogger().handlers[0].setFormatter(formatter) 32 | 33 | # Redirect stdout and stderr to loggers 34 | stdout_logger = logging.getLogger("stdout") 35 | stdout_logger.setLevel(logging.INFO) 36 | sl = StreamToLogger(stdout_logger, logging.INFO) 37 | sys.stdout = sl 38 | 39 | stderr_logger = logging.getLogger("stderr") 40 | stderr_logger.setLevel(logging.ERROR) 41 | sl = StreamToLogger(stderr_logger, logging.ERROR) 42 | sys.stderr = sl 43 | 44 | # Get logger 45 | logger = logging.getLogger(logger_name) 46 | logger.setLevel(logging.INFO) 47 | 48 | # Add a file handler for all loggers 49 | if handler is None: 50 | os.makedirs(LOGDIR, exist_ok=True) 51 | filename = os.path.join(LOGDIR, logger_filename) 52 | handler = logging.handlers.TimedRotatingFileHandler( 53 | filename, when="D", utc=True 54 | ) 55 | handler.setFormatter(formatter) 56 | 57 | for name, item in logging.root.manager.loggerDict.items(): 58 | if isinstance(item, logging.Logger): 59 | item.addHandler(handler) 60 | 61 | return logger 62 | 63 | 64 | class StreamToLogger(object): 65 | """ 66 | Fake file-like stream object that redirects writes to a logger instance. 67 | """ 68 | 69 | def __init__(self, logger, log_level=logging.INFO): 70 | self.terminal = sys.stdout 71 | self.logger = logger 72 | self.log_level = log_level 73 | self.linebuf = "" 74 | 75 | def __getattr__(self, attr): 76 | return getattr(self.terminal, attr) 77 | 78 | def write(self, buf): 79 | temp_linebuf = self.linebuf + buf 80 | self.linebuf = "" 81 | for line in temp_linebuf.splitlines(True): 82 | # From the io.TextIOWrapper docs: 83 | # On output, if newline is None, any '\n' characters written 84 | # are translated to the system default line separator. 85 | # By default sys.stdout.write() expects '\n' newlines and then 86 | # translates them so this is still cross platform. 87 | if line[-1] == "\n": 88 | self.logger.log(self.log_level, line.rstrip()) 89 | else: 90 | self.linebuf += line 91 | 92 | def flush(self): 93 | if self.linebuf != "": 94 | self.logger.log(self.log_level, self.linebuf.rstrip()) 95 | self.linebuf = "" 96 | 97 | 98 | def disable_torch_init(): 99 | """ 100 | Disable the redundant torch default initialization to accelerate model creation. 101 | """ 102 | import torch 103 | 104 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 105 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 106 | 107 | 108 | def violates_moderation(text): 109 | """ 110 | Check whether the text violates OpenAI moderation API. 111 | """ 112 | url = "https://api.openai.com/v1/moderations" 113 | headers = { 114 | "Content-Type": "application/json", 115 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], 116 | } 117 | text = text.replace("\n", "") 118 | data = "{" + '"input": ' + f'"{text}"' + "}" 119 | data = data.encode("utf-8") 120 | try: 121 | ret = requests.post(url, headers=headers, data=data, timeout=5) 122 | flagged = ret.json()["results"][0]["flagged"] 123 | except requests.exceptions.RequestException as e: 124 | flagged = False 125 | except KeyError as e: 126 | flagged = False 127 | 128 | return flagged 129 | 130 | 131 | def pretty_print_semaphore(semaphore): 132 | if semaphore is None: 133 | return "None" 134 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 135 | -------------------------------------------------------------------------------- /model/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .automatic_mask_generator import SamAutomaticMaskGenerator 8 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, 9 | build_sam_vit_l, sam_model_registry) 10 | from .predictor import SamPredictor 11 | -------------------------------------------------------------------------------- /model/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, 12 | TwoWayTransformer) 13 | 14 | 15 | def build_sam_vit_h(checkpoint=None): 16 | return _build_sam( 17 | encoder_embed_dim=1280, 18 | encoder_depth=32, 19 | encoder_num_heads=16, 20 | encoder_global_attn_indexes=[7, 15, 23, 31], 21 | checkpoint=checkpoint, 22 | ) 23 | 24 | 25 | build_sam = build_sam_vit_h 26 | 27 | 28 | def build_sam_vit_l(checkpoint=None): 29 | return _build_sam( 30 | encoder_embed_dim=1024, 31 | encoder_depth=24, 32 | encoder_num_heads=16, 33 | encoder_global_attn_indexes=[5, 11, 17, 23], 34 | checkpoint=checkpoint, 35 | ) 36 | 37 | 38 | def build_sam_vit_b(checkpoint=None): 39 | return _build_sam( 40 | encoder_embed_dim=768, 41 | encoder_depth=12, 42 | encoder_num_heads=12, 43 | encoder_global_attn_indexes=[2, 5, 8, 11], 44 | checkpoint=checkpoint, 45 | ) 46 | 47 | 48 | sam_model_registry = { 49 | "default": build_sam_vit_h, 50 | "vit_h": build_sam_vit_h, 51 | "vit_l": build_sam_vit_l, 52 | "vit_b": build_sam_vit_b, 53 | } 54 | 55 | 56 | def _build_sam( 57 | encoder_embed_dim, 58 | encoder_depth, 59 | encoder_num_heads, 60 | encoder_global_attn_indexes, 61 | checkpoint=None, 62 | ): 63 | prompt_embed_dim = 256 64 | image_size = 1024 65 | vit_patch_size = 16 66 | image_embedding_size = image_size // vit_patch_size 67 | sam = Sam( 68 | image_encoder=ImageEncoderViT( 69 | depth=encoder_depth, 70 | embed_dim=encoder_embed_dim, 71 | img_size=image_size, 72 | mlp_ratio=4, 73 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 74 | num_heads=encoder_num_heads, 75 | patch_size=vit_patch_size, 76 | qkv_bias=True, 77 | use_rel_pos=True, 78 | global_attn_indexes=encoder_global_attn_indexes, 79 | window_size=14, 80 | out_chans=prompt_embed_dim, 81 | ), 82 | prompt_encoder=PromptEncoder( 83 | embed_dim=prompt_embed_dim, 84 | image_embedding_size=(image_embedding_size, image_embedding_size), 85 | input_image_size=(image_size, image_size), 86 | mask_in_chans=16, 87 | ), 88 | mask_decoder=MaskDecoder( 89 | num_multimask_outputs=3, 90 | transformer=TwoWayTransformer( 91 | depth=2, 92 | embedding_dim=prompt_embed_dim, 93 | mlp_dim=2048, 94 | num_heads=8, 95 | ), 96 | transformer_dim=prompt_embed_dim, 97 | iou_head_depth=3, 98 | iou_head_hidden_dim=256, 99 | ), 100 | pixel_mean=[123.675, 116.28, 103.53], 101 | pixel_std=[58.395, 57.12, 57.375], 102 | ) 103 | sam.eval() 104 | if checkpoint is not None: 105 | with open(checkpoint, "rb") as f: 106 | state_dict = torch.load(f) 107 | sam.load_state_dict(state_dict, strict=False) 108 | return sam 109 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .image_encoder import ImageEncoderViT 8 | from .mask_decoder import MaskDecoder 9 | from .prompt_encoder import PromptEncoder 10 | from .sam import Sam 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d( 55 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 56 | ), 57 | LayerNorm2d(transformer_dim // 4), 58 | activation(), 59 | nn.ConvTranspose2d( 60 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 61 | ), 62 | activation(), 63 | ) 64 | self.output_hypernetworks_mlps = nn.ModuleList( 65 | [ 66 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 67 | for i in range(self.num_mask_tokens) 68 | ] 69 | ) 70 | 71 | self.iou_prediction_head = MLP( 72 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 73 | ) 74 | 75 | def forward( 76 | self, 77 | image_embeddings: torch.Tensor, 78 | image_pe: torch.Tensor, 79 | sparse_prompt_embeddings: torch.Tensor, 80 | dense_prompt_embeddings: torch.Tensor, 81 | multimask_output: bool, 82 | ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | """ 84 | Predict masks given image and prompt embeddings. 85 | 86 | Arguments: 87 | image_embeddings (torch.Tensor): the embeddings from the image encoder 88 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 89 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 90 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 91 | multimask_output (bool): Whether to return multiple masks or a single 92 | mask. 93 | 94 | Returns: 95 | torch.Tensor: batched predicted masks 96 | torch.Tensor: batched predictions of mask quality 97 | """ 98 | masks, iou_pred = self.predict_masks( 99 | image_embeddings=image_embeddings, 100 | image_pe=image_pe, 101 | sparse_prompt_embeddings=sparse_prompt_embeddings, 102 | dense_prompt_embeddings=dense_prompt_embeddings, 103 | ) 104 | 105 | # Select the correct mask or masks for output 106 | if multimask_output: 107 | mask_slice = slice(1, None) 108 | else: 109 | mask_slice = slice(0, 1) 110 | masks = masks[:, mask_slice, :, :] 111 | iou_pred = iou_pred[:, mask_slice] 112 | 113 | # Prepare output 114 | return masks, iou_pred 115 | 116 | def predict_masks( 117 | self, 118 | image_embeddings: torch.Tensor, 119 | image_pe: torch.Tensor, 120 | sparse_prompt_embeddings: torch.Tensor, 121 | dense_prompt_embeddings: torch.Tensor, 122 | ) -> Tuple[torch.Tensor, torch.Tensor]: 123 | """Predicts masks. See 'forward' for more details.""" 124 | # Concatenate output tokens 125 | output_tokens = torch.cat( 126 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 127 | ) 128 | output_tokens = output_tokens.unsqueeze(0).expand( 129 | sparse_prompt_embeddings.size(0), -1, -1 130 | ) 131 | 132 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 133 | 134 | # image_embeddings: [1, C, H, W], tokens: [B, N, C] 135 | # dense_prompt_embeddings: [B, C, H, W] 136 | # Expand per-image data in batch direction to be per-mask 137 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 138 | src = src + dense_prompt_embeddings 139 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 140 | b, c, h, w = src.shape 141 | 142 | # Run the transformer 143 | hs, src = self.transformer(src, pos_src, tokens) 144 | iou_token_out = hs[:, 0, :] 145 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 146 | 147 | # Upscale mask embeddings and predict masks using the mask tokens 148 | src = src.transpose(1, 2).view(b, c, h, w) 149 | upscaled_embedding = self.output_upscaling(src) 150 | hyper_in_list: List[torch.Tensor] = [] 151 | for i in range(self.num_mask_tokens): 152 | hyper_in_list.append( 153 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 154 | ) 155 | hyper_in = torch.stack(hyper_in_list, dim=1) 156 | b, c, h, w = upscaled_embedding.shape 157 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view( 158 | b, self.num_mask_tokens, h, w 159 | ) 160 | 161 | # Generate mask quality predictions 162 | iou_pred = self.iou_prediction_head(iou_token_out) 163 | 164 | return masks, iou_pred 165 | 166 | 167 | # Lightly adapted from 168 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 169 | class MLP(nn.Module): 170 | def __init__( 171 | self, 172 | input_dim: int, 173 | hidden_dim: int, 174 | output_dim: int, 175 | num_layers: int, 176 | sigmoid_output: bool = False, 177 | ) -> None: 178 | super().__init__() 179 | self.num_layers = num_layers 180 | h = [hidden_dim] * (num_layers - 1) 181 | self.layers = nn.ModuleList( 182 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 183 | ) 184 | self.sigmoid_output = sigmoid_output 185 | 186 | def forward(self, x): 187 | for i, layer in enumerate(self.layers): 188 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 189 | if self.sigmoid_output: 190 | x = F.sigmoid(x) 191 | return x 192 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Optional, Tuple, Type 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [ 47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 48 | ] 49 | self.point_embeddings = nn.ModuleList(point_embeddings) 50 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 51 | 52 | self.mask_input_size = ( 53 | 4 * image_embedding_size[0], 54 | 4 * image_embedding_size[1], 55 | ) 56 | self.mask_downscaling = nn.Sequential( 57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 58 | LayerNorm2d(mask_in_chans // 4), 59 | activation(), 60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 61 | LayerNorm2d(mask_in_chans), 62 | activation(), 63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 64 | ) 65 | self.no_mask_embed = nn.Embedding(1, embed_dim) 66 | 67 | def get_dense_pe(self) -> torch.Tensor: 68 | """ 69 | Returns the positional encoding used to encode point prompts, 70 | applied to a dense set of points the shape of the image encoding. 71 | 72 | Returns: 73 | torch.Tensor: Positional encoding with shape 74 | 1x(embed_dim)x(embedding_h)x(embedding_w) 75 | """ 76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 77 | 78 | def _embed_points( 79 | self, 80 | points: torch.Tensor, 81 | labels: torch.Tensor, 82 | pad: bool, 83 | ) -> torch.Tensor: 84 | """Embeds point prompts.""" 85 | points = points + 0.5 # Shift to center of pixel 86 | if pad: 87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 89 | points = torch.cat([points, padding_point], dim=1) 90 | labels = torch.cat([labels, padding_label], dim=1) 91 | point_embedding = self.pe_layer.forward_with_coords( 92 | points, self.input_image_size 93 | ) 94 | point_embedding[labels == -1] = 0.0 95 | point_embedding[labels == -1] += self.not_a_point_embed.weight 96 | point_embedding[labels == 0] += self.point_embeddings[0].weight 97 | point_embedding[labels == 1] += self.point_embeddings[1].weight 98 | return point_embedding 99 | 100 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 101 | """Embeds box prompts.""" 102 | boxes = boxes + 0.5 # Shift to center of pixel 103 | coords = boxes.reshape(-1, 2, 2) 104 | corner_embedding = self.pe_layer.forward_with_coords( 105 | coords, self.input_image_size 106 | ) 107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 109 | return corner_embedding 110 | 111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 112 | """Embeds mask inputs.""" 113 | mask_embedding = self.mask_downscaling(masks) 114 | return mask_embedding 115 | 116 | def _get_batch_size( 117 | self, 118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 119 | boxes: Optional[torch.Tensor], 120 | masks: Optional[torch.Tensor], 121 | text_embeds: Optional[torch.Tensor], 122 | ) -> int: 123 | """ 124 | Gets the batch size of the output given the batch size of the input prompts. 125 | """ 126 | if points is not None: 127 | return points[0].shape[0] 128 | elif boxes is not None: 129 | return boxes.shape[0] 130 | elif masks is not None: 131 | return masks.shape[0] 132 | elif text_embeds is not None: 133 | return text_embeds.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | text_embeds: Optional[torch.Tensor], 146 | ) -> Tuple[torch.Tensor, torch.Tensor]: 147 | """ 148 | Embeds different types of prompts, returning both sparse and dense 149 | embeddings. 150 | 151 | Arguments: 152 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 153 | and labels to embed. 154 | boxes (torch.Tensor or none): boxes to embed 155 | masks (torch.Tensor or none): masks to embed 156 | 157 | Returns: 158 | torch.Tensor: sparse embeddings for the points and boxes, with shape 159 | BxNx(embed_dim), where N is determined by the number of input points 160 | and boxes. 161 | torch.Tensor: dense embeddings for the masks, in the shape 162 | Bx(embed_dim)x(embed_H)x(embed_W) 163 | """ 164 | bs = self._get_batch_size(points, boxes, masks, text_embeds) 165 | sparse_embeddings = torch.empty( 166 | (bs, 0, self.embed_dim), device=self._get_device() 167 | ) 168 | if points is not None: 169 | coords, labels = points 170 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 171 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 172 | if boxes is not None: 173 | box_embeddings = self._embed_boxes(boxes) 174 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 175 | 176 | if text_embeds is not None: 177 | sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1) 178 | 179 | if masks is not None: 180 | dense_embeddings = self._embed_masks(masks) 181 | else: 182 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 183 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 184 | ) 185 | 186 | return sparse_embeddings, dense_embeddings 187 | 188 | 189 | class PositionEmbeddingRandom(nn.Module): 190 | """ 191 | Positional encoding using random spatial frequencies. 192 | """ 193 | 194 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 195 | super().__init__() 196 | if scale is None or scale <= 0.0: 197 | scale = 1.0 198 | self.register_buffer( 199 | "positional_encoding_gaussian_matrix", 200 | scale * torch.randn((2, num_pos_feats)), 201 | ) 202 | 203 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 204 | """Positionally encode points that are normalized to [0,1].""" 205 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 206 | coords = 2 * coords - 1 207 | 208 | if coords.dtype != self.positional_encoding_gaussian_matrix.dtype: 209 | coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) 210 | 211 | coords = coords @ self.positional_encoding_gaussian_matrix 212 | coords = 2 * np.pi * coords 213 | # outputs d_1 x ... x d_n x C shape 214 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 215 | 216 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 217 | """Generate positional encoding for a grid of the specified size.""" 218 | h, w = size 219 | device: Any = self.positional_encoding_gaussian_matrix.device 220 | grid = torch.ones( 221 | (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype 222 | ) 223 | y_embed = grid.cumsum(dim=0) - 0.5 224 | x_embed = grid.cumsum(dim=1) - 0.5 225 | y_embed = y_embed / h 226 | x_embed = x_embed / w 227 | 228 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 229 | return pe.permute(2, 0, 1) # C x H x W 230 | 231 | def forward_with_coords( 232 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 233 | ) -> torch.Tensor: 234 | """Positionally encode points that are not normalized to [0,1].""" 235 | coords = coords_input.clone() 236 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 237 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 238 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 239 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer( 47 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False 48 | ) 49 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 50 | 51 | @property 52 | def device(self) -> Any: 53 | return self.pixel_mean.device 54 | 55 | @torch.no_grad() 56 | def forward( 57 | self, 58 | batched_input: List[Dict[str, Any]], 59 | multimask_output: bool, 60 | ) -> List[Dict[str, torch.Tensor]]: 61 | """ 62 | Predicts masks end-to-end from provided images and prompts. 63 | If prompts are not known in advance, using SamPredictor is 64 | recommended over calling the model directly. 65 | 66 | Arguments: 67 | batched_input (list(dict)): A list over input images, each a 68 | dictionary with the following keys. A prompt key can be 69 | excluded if it is not present. 70 | 'image': The image as a torch tensor in 3xHxW format, 71 | already transformed for input to the model. 72 | 'original_size': (tuple(int, int)) The original size of 73 | the image before transformation, as (H, W). 74 | 'point_coords': (torch.Tensor) Batched point prompts for 75 | this image, with shape BxNx2. Already transformed to the 76 | input frame of the model. 77 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 78 | with shape BxN. 79 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 80 | Already transformed to the input frame of the model. 81 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 82 | in the form Bx1xHxW. 83 | multimask_output (bool): Whether the model should predict multiple 84 | disambiguating masks, or return a single mask. 85 | 86 | Returns: 87 | (list(dict)): A list over input images, where each element is 88 | as dictionary with the following keys. 89 | 'masks': (torch.Tensor) Batched binary mask predictions, 90 | with shape BxCxHxW, where B is the number of input prompts, 91 | C is determined by multimask_output, and (H, W) is the 92 | original size of the image. 93 | 'iou_predictions': (torch.Tensor) The model's predictions 94 | of mask quality, in shape BxC. 95 | 'low_res_logits': (torch.Tensor) Low resolution logits with 96 | shape BxCxHxW, where H=W=256. Can be passed as mask input 97 | to subsequent iterations of prediction. 98 | """ 99 | input_images = torch.stack( 100 | [self.preprocess(x["image"]) for x in batched_input], dim=0 101 | ) 102 | image_embeddings = self.image_encoder(input_images) 103 | 104 | outputs = [] 105 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 106 | if "point_coords" in image_record: 107 | points = (image_record["point_coords"], image_record["point_labels"]) 108 | else: 109 | points = None 110 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 111 | points=points, 112 | boxes=image_record.get("boxes", None), 113 | masks=image_record.get("mask_inputs", None), 114 | ) 115 | low_res_masks, iou_predictions = self.mask_decoder( 116 | image_embeddings=curr_embedding.unsqueeze(0), 117 | image_pe=self.prompt_encoder.get_dense_pe(), 118 | sparse_prompt_embeddings=sparse_embeddings, 119 | dense_prompt_embeddings=dense_embeddings, 120 | multimask_output=multimask_output, 121 | ) 122 | masks = self.postprocess_masks( 123 | low_res_masks, 124 | input_size=image_record["image"].shape[-2:], 125 | original_size=image_record["original_size"], 126 | ) 127 | masks = masks > self.mask_threshold 128 | outputs.append( 129 | { 130 | "masks": masks, 131 | "iou_predictions": iou_predictions, 132 | "low_res_logits": low_res_masks, 133 | } 134 | ) 135 | return outputs 136 | 137 | def postprocess_masks( 138 | self, 139 | masks: torch.Tensor, 140 | input_size: Tuple[int, ...], 141 | original_size: Tuple[int, ...], 142 | ) -> torch.Tensor: 143 | """ 144 | Remove padding and upscale masks to the original image size. 145 | 146 | Arguments: 147 | masks (torch.Tensor): Batched masks from the mask_decoder, 148 | in BxCxHxW format. 149 | input_size (tuple(int, int)): The size of the image input to the 150 | model, in (H, W) format. Used to remove padding. 151 | original_size (tuple(int, int)): The original size of the image 152 | before resizing for input to the model, in (H, W) format. 153 | 154 | Returns: 155 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 156 | is given by original_size. 157 | """ 158 | 159 | dtype = masks.dtype 160 | 161 | masks = F.interpolate( 162 | masks.float(), 163 | (self.image_encoder.img_size, self.image_encoder.img_size), 164 | mode="bilinear", 165 | align_corners=False, 166 | ) 167 | # masks = masks.to(dtype) 168 | masks = masks[..., : input_size[0], : input_size[1]] 169 | masks = F.interpolate( 170 | masks, original_size, mode="bilinear", align_corners=False 171 | ) 172 | return masks 173 | 174 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 175 | """Normalize pixel values and pad to a square input.""" 176 | # Normalize colors 177 | x = (x - self.pixel_mean) / self.pixel_std 178 | 179 | # Pad 180 | h, w = x.shape[-2:] 181 | padh = self.image_encoder.img_size - h 182 | padw = self.image_encoder.img_size - w 183 | x = F.pad(x, (0, padw, 0, padh)) 184 | return x 185 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Type 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert ( 202 | self.internal_dim % num_heads == 0 203 | ), "num_heads must divide embedding_dim." 204 | 205 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 207 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 208 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 209 | 210 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 211 | b, n, c = x.shape 212 | x = x.reshape(b, n, num_heads, c // num_heads) 213 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 214 | 215 | def _recombine_heads(self, x: Tensor) -> Tensor: 216 | b, n_heads, n_tokens, c_per_head = x.shape 217 | x = x.transpose(1, 2) 218 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 219 | 220 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 221 | # Input projections 222 | q = self.q_proj(q) 223 | k = self.k_proj(k) 224 | v = self.v_proj(v) 225 | 226 | # Separate into heads 227 | q = self._separate_heads(q, self.num_heads) 228 | k = self._separate_heads(k, self.num_heads) 229 | v = self._separate_heads(v, self.num_heads) 230 | 231 | # Attention 232 | _, _, _, c_per_head = q.shape 233 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 234 | attn = attn / math.sqrt(c_per_head) 235 | attn = torch.softmax(attn, dim=-1) 236 | 237 | # Get output 238 | out = attn @ v 239 | out = self._recombine_heads(out) 240 | out = self.out_proj(out) 241 | 242 | return out 243 | -------------------------------------------------------------------------------- /model/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /model/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points( 52 | self, point_coords: torch.Tensor, point_labels: torch.Tensor 53 | ) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = ( 61 | point_embedding 62 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 63 | ) 64 | 65 | for i in range(self.model.prompt_encoder.num_point_embeddings): 66 | point_embedding = ( 67 | point_embedding 68 | + self.model.prompt_encoder.point_embeddings[i].weight 69 | * (point_labels == i) 70 | ) 71 | 72 | return point_embedding 73 | 74 | def _embed_masks( 75 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor 76 | ) -> torch.Tensor: 77 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( 78 | input_mask 79 | ) 80 | mask_embedding = mask_embedding + ( 81 | 1 - has_mask_input 82 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 83 | return mask_embedding 84 | 85 | def mask_postprocessing( 86 | self, masks: torch.Tensor, orig_im_size: torch.Tensor 87 | ) -> torch.Tensor: 88 | masks = F.interpolate( 89 | masks, 90 | size=(self.img_size, self.img_size), 91 | mode="bilinear", 92 | align_corners=False, 93 | ) 94 | 95 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( 96 | torch.int64 97 | ) 98 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 99 | 100 | orig_im_size = orig_im_size.to(torch.int64) 101 | h, w = orig_im_size[0], orig_im_size[1] 102 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 103 | return masks 104 | 105 | def select_masks( 106 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 107 | ) -> Tuple[torch.Tensor, torch.Tensor]: 108 | # Determine if we should return the multiclick mask or not from the number of points. 109 | # The reweighting is used to avoid control flow. 110 | score_reweight = torch.tensor( 111 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 112 | ).to(iou_preds.device) 113 | score = iou_preds + (num_points - 2.5) * score_reweight 114 | best_idx = torch.argmax(score, dim=1) 115 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 116 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 117 | 118 | return masks, iou_preds 119 | 120 | @torch.no_grad() 121 | def forward( 122 | self, 123 | image_embeddings: torch.Tensor, 124 | point_coords: torch.Tensor, 125 | point_labels: torch.Tensor, 126 | mask_input: torch.Tensor, 127 | has_mask_input: torch.Tensor, 128 | orig_im_size: torch.Tensor, 129 | ): 130 | sparse_embedding = self._embed_points(point_coords, point_labels) 131 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 132 | 133 | masks, scores = self.model.mask_decoder.predict_masks( 134 | image_embeddings=image_embeddings, 135 | image_pe=self.model.prompt_encoder.get_dense_pe(), 136 | sparse_prompt_embeddings=sparse_embedding, 137 | dense_prompt_embeddings=dense_embedding, 138 | ) 139 | 140 | if self.use_stability_score: 141 | scores = calculate_stability_score( 142 | masks, self.model.mask_threshold, self.stability_score_offset 143 | ) 144 | 145 | if self.return_single_mask: 146 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 147 | 148 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 149 | 150 | if self.return_extra_metrics: 151 | stability_scores = calculate_stability_score( 152 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 153 | ) 154 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 155 | return upscaled_masks, scores, stability_scores, areas, masks 156 | 157 | return upscaled_masks, scores, masks 158 | -------------------------------------------------------------------------------- /model/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from copy import deepcopy 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | from torchvision.transforms.functional import resize # type: ignore 14 | from torchvision.transforms.functional import to_pil_image 15 | 16 | 17 | class ResizeLongestSide: 18 | """ 19 | Resizes images to the longest side 'target_length', as well as provides 20 | methods for resizing coordinates and boxes. Provides methods for 21 | transforming both numpy array and batched torch tensors. 22 | """ 23 | 24 | def __init__(self, target_length: int) -> None: 25 | self.target_length = target_length 26 | 27 | def apply_image(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | Expects a numpy array with shape HxWxC in uint8 format. 30 | """ 31 | target_size = self.get_preprocess_shape( 32 | image.shape[0], image.shape[1], self.target_length 33 | ) 34 | return np.array(resize(to_pil_image(image), target_size)) 35 | 36 | def apply_coords( 37 | self, coords: np.ndarray, original_size: Tuple[int, ...] 38 | ) -> np.ndarray: 39 | """ 40 | Expects a numpy array of length 2 in the final dimension. Requires the 41 | original image size in (H, W) format. 42 | """ 43 | old_h, old_w = original_size 44 | new_h, new_w = self.get_preprocess_shape( 45 | original_size[0], original_size[1], self.target_length 46 | ) 47 | coords = deepcopy(coords).astype(float) 48 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 49 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 50 | return coords 51 | 52 | def apply_boxes( 53 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 54 | ) -> np.ndarray: 55 | """ 56 | Expects a numpy array shape Bx4. Requires the original image size 57 | in (H, W) format. 58 | """ 59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 60 | return boxes.reshape(-1, 4) 61 | 62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Expects batched images with shape BxCxHxW and float format. This 65 | transformation may not exactly match apply_image. apply_image is 66 | the transformation expected by the model. 67 | """ 68 | # Expects an image in BCHW format. May not exactly match apply_image. 69 | target_size = self.get_preprocess_shape( 70 | image.shape[0], image.shape[1], self.target_length 71 | ) 72 | return F.interpolate( 73 | image, target_size, mode="bilinear", align_corners=False, antialias=True 74 | ) 75 | 76 | def apply_coords_torch( 77 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 78 | ) -> torch.Tensor: 79 | """ 80 | Expects a torch tensor with length 2 in the last dimension. Requires the 81 | original image size in (H, W) format. 82 | """ 83 | old_h, old_w = original_size 84 | new_h, new_w = self.get_preprocess_shape( 85 | original_size[0], original_size[1], self.target_length 86 | ) 87 | coords = deepcopy(coords).to(torch.float) 88 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 89 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 90 | return coords 91 | 92 | def apply_boxes_torch( 93 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 94 | ) -> torch.Tensor: 95 | """ 96 | Expects a torch tensor with shape Bx4. Requires the original image 97 | size in (H, W) format. 98 | """ 99 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 100 | return boxes.reshape(-1, 4) 101 | 102 | @staticmethod 103 | def get_preprocess_shape( 104 | oldh: int, oldw: int, long_side_length: int 105 | ) -> Tuple[int, int]: 106 | """ 107 | Compute the output size given input size and target long side length. 108 | """ 109 | scale = long_side_length * 1.0 / max(oldh, oldw) 110 | newh, neww = oldh * scale, oldw * scale 111 | neww = int(neww + 0.5) 112 | newh = int(newh + 0.5) 113 | return (newh, neww) 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==1.13.1 3 | torchvision==0.14.1 4 | packaging 5 | sentencepiece 6 | peft==0.4.0 7 | einops==0.4.1 8 | fastapi==0.100.1 9 | gradio==3.39.0 10 | markdown2==2.4.10 11 | numpy==1.24.2 12 | openai==0.27.8 13 | opencv_python==4.8.0.74 14 | Pillow==9.4.0 15 | pycocotools==2.0.6 16 | ray==2.6.1 17 | Requests==2.31.0 18 | shortuuid==1.0.11 19 | tqdm==4.64.1 20 | transformers==4.31.0 21 | uvicorn==0.23.2 22 | scipy==1.11.2 23 | bitsandbytes==0.41.1 -------------------------------------------------------------------------------- /utils/ade20k_classes.json: -------------------------------------------------------------------------------- 1 | [ 2 | "wall", "building", "sky", "floor", "tree", "ceiling", "road", 3 | "bed", "windowpane", "grass", "cabinet", "sidewalk", 4 | "person", "earth", "door", "table", "mountain", "plant", 5 | "curtain", "chair", "car", "water", "painting", "sofa", 6 | "shelf", "house", "sea", "mirror", "rug", "field", "armchair", 7 | "seat", "fence", "desk", "rock", "wardrobe", "lamp", 8 | "bathtub", "railing", "cushion", "base", "box", "column", 9 | "signboard", "chest of drawers", "counter", "sand", "sink", 10 | "skyscraper", "fireplace", "refrigerator", "grandstand", 11 | "path", "stairs", "runway", "case", "pool table", "pillow", 12 | "screen door", "stairway", "river", "bridge", "bookcase", 13 | "blind", "coffee table", "toilet", "flower", "book", "hill", 14 | "bench", "countertop", "stove", "palm", "kitchen island", 15 | "computer", "swivel chair", "boat", "bar", "arcade machine", 16 | "hovel", "bus", "towel", "light", "truck", "tower", 17 | "chandelier", "awning", "streetlight", "booth", 18 | "television receiver", "airplane", "dirt track", "apparel", 19 | "pole", "land", "bannister", "escalator", "ottoman", "bottle", 20 | "buffet", "poster", "stage", "van", "ship", "fountain", 21 | "conveyer belt", "canopy", "washer", "plaything", 22 | "swimming pool", "stool", "barrel", "basket", "waterfall", 23 | "tent", "bag", "minibike", "cradle", "oven", "ball", "food", 24 | "step", "tank", "trade name", "microwave", "pot", "animal", 25 | "bicycle", "lake", "dishwasher", "screen", "blanket", 26 | "sculpture", "hood", "sconce", "vase", "traffic light", 27 | "tray", "ashcan", "fan", "pier", "crt screen", "plate", 28 | "monitor", "bulletin board", "shower", "radiator", "glass", 29 | "clock", "flag" 30 | ] -------------------------------------------------------------------------------- /utils/cocostuff_classes.txt: -------------------------------------------------------------------------------- 1 | 0: unlabeled 2 | 1: person 3 | 2: bicycle 4 | 3: car 5 | 4: motorcycle 6 | 5: airplane 7 | 6: bus 8 | 7: train 9 | 8: truck 10 | 9: boat 11 | 10: traffic light 12 | 11: fire hydrant 13 | 12: street sign 14 | 13: stop sign 15 | 14: parking meter 16 | 15: bench 17 | 16: bird 18 | 17: cat 19 | 18: dog 20 | 19: horse 21 | 20: sheep 22 | 21: cow 23 | 22: elephant 24 | 23: bear 25 | 24: zebra 26 | 25: giraffe 27 | 26: hat 28 | 27: backpack 29 | 28: umbrella 30 | 29: shoe 31 | 30: eye glasses 32 | 31: handbag 33 | 32: tie 34 | 33: suitcase 35 | 34: frisbee 36 | 35: skis 37 | 36: snowboard 38 | 37: sports ball 39 | 38: kite 40 | 39: baseball bat 41 | 40: baseball glove 42 | 41: skateboard 43 | 42: surfboard 44 | 43: tennis racket 45 | 44: bottle 46 | 45: plate 47 | 46: wine glass 48 | 47: cup 49 | 48: fork 50 | 49: knife 51 | 50: spoon 52 | 51: bowl 53 | 52: banana 54 | 53: apple 55 | 54: sandwich 56 | 55: orange 57 | 56: broccoli 58 | 57: carrot 59 | 58: hot dog 60 | 59: pizza 61 | 60: donut 62 | 61: cake 63 | 62: chair 64 | 63: couch 65 | 64: potted plant 66 | 65: bed 67 | 66: mirror 68 | 67: dining table 69 | 68: window 70 | 69: desk 71 | 70: toilet 72 | 71: door 73 | 72: tv 74 | 73: laptop 75 | 74: mouse 76 | 75: remote 77 | 76: keyboard 78 | 77: cell phone 79 | 78: microwave 80 | 79: oven 81 | 80: toaster 82 | 81: sink 83 | 82: refrigerator 84 | 83: blender 85 | 84: book 86 | 85: clock 87 | 86: vase 88 | 87: scissors 89 | 88: teddy bear 90 | 89: hair drier 91 | 90: toothbrush 92 | 91: hair brush 93 | 92: banner 94 | 93: blanket 95 | 94: branch 96 | 95: bridge 97 | 96: building-other 98 | 97: bush 99 | 98: cabinet 100 | 99: cage 101 | 100: cardboard 102 | 101: carpet 103 | 102: ceiling-other 104 | 103: ceiling-tile 105 | 104: cloth 106 | 105: clothes 107 | 106: clouds 108 | 107: counter 109 | 108: cupboard 110 | 109: curtain 111 | 110: desk-stuff 112 | 111: dirt 113 | 112: door-stuff 114 | 113: fence 115 | 114: floor-marble 116 | 115: floor-other 117 | 116: floor-stone 118 | 117: floor-tile 119 | 118: floor-wood 120 | 119: flower 121 | 120: fog 122 | 121: food-other 123 | 122: fruit 124 | 123: furniture-other 125 | 124: grass 126 | 125: gravel 127 | 126: ground-other 128 | 127: hill 129 | 128: house 130 | 129: leaves 131 | 130: light 132 | 131: mat 133 | 132: metal 134 | 133: mirror-stuff 135 | 134: moss 136 | 135: mountain 137 | 136: mud 138 | 137: napkin 139 | 138: net 140 | 139: paper 141 | 140: pavement 142 | 141: pillow 143 | 142: plant-other 144 | 143: plastic 145 | 144: platform 146 | 145: playingfield 147 | 146: railing 148 | 147: railroad 149 | 148: river 150 | 149: road 151 | 150: rock 152 | 151: roof 153 | 152: rug 154 | 153: salad 155 | 154: sand 156 | 155: sea 157 | 156: shelf 158 | 157: sky 159 | 158: skyscraper 160 | 159: snow 161 | 160: solid-other 162 | 161: stairs 163 | 162: stone 164 | 163: straw 165 | 164: structural-other 166 | 165: table 167 | 166: tent 168 | 167: textile-other 169 | 168: towel 170 | 169: tree 171 | 170: vegetable 172 | 171: wall-brick 173 | 172: wall-concrete 174 | 173: wall-other 175 | 174: wall-panel 176 | 175: wall-stone 177 | 176: wall-tile 178 | 177: wall-wood 179 | 178: water-other 180 | 179: waterdrops 181 | 180: window-blind 182 | 181: window-other 183 | 182: wood 184 | -------------------------------------------------------------------------------- /utils/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt templates. 3 | """ 4 | 5 | import dataclasses 6 | from enum import Enum, auto 7 | from typing import Any, List 8 | 9 | 10 | class SeparatorStyle(Enum): 11 | """Different separator style.""" 12 | 13 | ADD_COLON_SINGLE = auto() 14 | ADD_COLON_TWO = auto() 15 | NO_COLON_SINGLE = auto() 16 | BAIZE = auto() 17 | DOLLY = auto() 18 | RWKV = auto() 19 | 20 | 21 | @dataclasses.dataclass 22 | class Conversation: 23 | """A class that keeps all conversation history.""" 24 | 25 | # System prompts 26 | system: str 27 | # Two roles 28 | roles: List[str] 29 | # All messages 30 | messages: List[List[str]] 31 | # Offset of few shot examples 32 | offset: int 33 | # Separator 34 | sep_style: SeparatorStyle 35 | sep: str 36 | sep2: str = None 37 | # Stop criteria (the default one is EOS token) 38 | stop_str: str = None 39 | # Stops generation if meeting any token in this list 40 | stop_token_ids: List[int] = None 41 | 42 | # Used for the state in the gradio servers. 43 | # TODO(lmzheng): refactor this 44 | conv_id: Any = None 45 | skip_next: bool = False 46 | model_name: str = None 47 | 48 | def get_prompt(self): 49 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 50 | ret = self.system + self.sep 51 | for role, message in self.messages: 52 | if message: 53 | ret += role + ": " + message + self.sep 54 | else: 55 | ret += role + ":" 56 | return ret 57 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 58 | seps = [self.sep, self.sep2] 59 | ret = self.system + seps[0] 60 | for i, (role, message) in enumerate(self.messages): 61 | if message: 62 | ret += role + ": " + message + seps[i % 2] 63 | else: 64 | ret += role + ":" 65 | return ret 66 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 67 | ret = self.system 68 | for role, message in self.messages: 69 | if message: 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | return ret 74 | elif self.sep_style == SeparatorStyle.BAIZE: 75 | ret = self.system + "\n" 76 | for role, message in self.messages: 77 | if message: 78 | ret += role + message + "\n" 79 | else: 80 | ret += role 81 | return ret 82 | elif self.sep_style == SeparatorStyle.DOLLY: 83 | seps = [self.sep, self.sep2] 84 | ret = self.system 85 | for i, (role, message) in enumerate(self.messages): 86 | if message: 87 | ret += role + ":\n" + message + seps[i % 2] 88 | if i % 2 == 1: 89 | ret += "\n\n" 90 | else: 91 | ret += role + ":\n" 92 | return ret 93 | elif self.sep_style == SeparatorStyle.RWKV: 94 | ret = self.system 95 | for i, (role, message) in enumerate(self.messages): 96 | if message: 97 | ret += ( 98 | role 99 | + ": " 100 | + message.replace("\r\n", "\n").replace("\n\n", "\n") 101 | ) 102 | ret += "\n\n" 103 | else: 104 | ret += role + ":" 105 | return ret 106 | else: 107 | raise ValueError(f"Invalid style: {self.sep_style}") 108 | 109 | def append_message(self, role, message): 110 | self.messages.append([role, message]) 111 | 112 | def to_gradio_chatbot(self): 113 | ret = [] 114 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 115 | if i % 2 == 0: 116 | ret.append([msg, None]) 117 | else: 118 | ret[-1][-1] = msg 119 | return ret 120 | 121 | def copy(self): 122 | return Conversation( 123 | system=self.system, 124 | roles=self.roles, 125 | messages=[[x, y] for x, y in self.messages], 126 | offset=self.offset, 127 | sep_style=self.sep_style, 128 | sep=self.sep, 129 | sep2=self.sep2, 130 | stop_str=self.stop_str, 131 | stop_token_ids=self.stop_token_ids, 132 | conv_id=self.conv_id, 133 | model_name=self.model_name, 134 | ) 135 | 136 | def dict(self): 137 | return { 138 | "system": self.system, 139 | "roles": self.roles, 140 | "messages": self.messages, 141 | "offset": self.offset, 142 | "conv_id": self.conv_id, 143 | "model_name": self.model_name, 144 | } 145 | 146 | 147 | # A template with one conversation example 148 | conv_one_shot = Conversation( 149 | system="A chat between a curious human and an artificial intelligence assistant. " 150 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 151 | roles=("Human", "Assistant"), 152 | messages=( 153 | ( 154 | "Human", 155 | "What are the key differences between renewable and non-renewable energy sources?", 156 | ), 157 | ( 158 | "Assistant", 159 | "Renewable energy sources are those that can be replenished naturally in a relatively " 160 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 161 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 162 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 163 | "renewable and non-renewable energy sources:\n" 164 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 165 | "energy sources are finite and will eventually run out.\n" 166 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 167 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 168 | "and other negative effects.\n" 169 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 170 | "have lower operational costs than non-renewable sources.\n" 171 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 172 | "locations than non-renewable sources.\n" 173 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 174 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 175 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 176 | "non-renewable sources are not, and their depletion can lead to economic and social instability.", 177 | ), 178 | ), 179 | offset=2, 180 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 181 | sep="\n### ", 182 | stop_str="###", 183 | ) 184 | 185 | 186 | # Vicuna v1.1 template 187 | conv_vicuna_v1_1 = Conversation( 188 | system="A chat between a curious user and an artificial intelligence assistant. " 189 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 190 | roles=("USER", "ASSISTANT"), 191 | messages=(), 192 | offset=0, 193 | sep_style=SeparatorStyle.ADD_COLON_TWO, 194 | sep=" ", 195 | sep2="", 196 | ) 197 | 198 | # Koala default template 199 | conv_koala_v1 = Conversation( 200 | system="BEGINNING OF CONVERSATION:", 201 | roles=("USER", "GPT"), 202 | messages=(), 203 | offset=0, 204 | sep_style=SeparatorStyle.ADD_COLON_TWO, 205 | sep=" ", 206 | sep2="", 207 | ) 208 | 209 | # Dolly V2 default template 210 | conv_dolly = Conversation( 211 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", 212 | roles=("### Instruction", "### Response"), 213 | messages=(), 214 | offset=0, 215 | sep_style=SeparatorStyle.DOLLY, 216 | sep="\n\n", 217 | sep2="### End", 218 | ) 219 | 220 | # OpenAssistant Pythia default template 221 | conv_oasst = Conversation( 222 | system="", 223 | roles=("<|prompter|>", "<|assistant|>"), 224 | messages=(), 225 | offset=0, 226 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 227 | sep="<|endoftext|>", 228 | ) 229 | 230 | # StableLM Alpha default template 231 | conv_stablelm = Conversation( 232 | system="""<|SYSTEM|># StableLM Tuned (Alpha version) 233 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 234 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 235 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 236 | - StableLM will refuse to participate in anything that could harm a human. 237 | """, 238 | roles=("<|USER|>", "<|ASSISTANT|>"), 239 | messages=(), 240 | offset=0, 241 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 242 | sep="", 243 | stop_token_ids=[50278, 50279, 50277, 1, 0], 244 | ) 245 | 246 | # Baize default template 247 | conv_baize = Conversation( 248 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.", 249 | roles=("[|Human|]", "[|AI|]"), 250 | messages=( 251 | ("[|Human|]", "Hello!"), 252 | ("[|AI|]", "Hi!"), 253 | ), 254 | offset=2, 255 | sep_style=SeparatorStyle.BAIZE, 256 | sep="[|Human|]", 257 | stop_str="[|Human|]", 258 | ) 259 | 260 | # RWKV-4-Raven default template 261 | conv_rwkv = Conversation( 262 | system="", 263 | roles=("Bob", "Alice"), 264 | messages=(), 265 | offset=0, 266 | sep_style=SeparatorStyle.RWKV, 267 | sep="", 268 | stop_str="\n\n", 269 | ) 270 | 271 | conv_templates = { 272 | "baize": conv_baize, 273 | "conv_one_shot": conv_one_shot, 274 | "dolly": conv_dolly, 275 | "koala_v1": conv_koala_v1, 276 | "oasst": conv_oasst, 277 | "stablelm": conv_stablelm, 278 | "vicuna_v1.1": conv_vicuna_v1_1, 279 | "rwkv": conv_rwkv, 280 | } 281 | 282 | 283 | def get_default_conv_template(model_name): 284 | model_name = model_name.lower() 285 | if "vicuna" in model_name or "output" in model_name: 286 | return conv_vicuna_v1_1 287 | elif "koala" in model_name: 288 | return conv_koala_v1 289 | elif "dolly-v2" in model_name: 290 | return conv_dolly 291 | elif "oasst" in model_name and "pythia" in model_name: 292 | return conv_oasst 293 | elif "baize" in model_name: 294 | return conv_baize 295 | elif "stablelm" in model_name: 296 | return conv_stablelm 297 | elif "rwkv-4" in model_name: 298 | return conv_rwkv 299 | return conv_one_shot 300 | 301 | 302 | if __name__ == "__main__": 303 | conv = conv_templates["vicuna_v1.1"].copy() 304 | conv.append_message(conv.roles[0], "Hello!") 305 | conv.append_message(conv.roles[1], "Hi!") 306 | conv.append_message(conv.roles[0], "How are you?") 307 | conv.append_message(conv.roles[1], None) 308 | print(conv.get_prompt()) 309 | -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def get_mask_from_json(json_path, img): 10 | try: 11 | with open(json_path, "r") as r: 12 | anno = json.loads(r.read()) 13 | except: 14 | with open(json_path, "r", encoding="cp1252") as r: 15 | anno = json.loads(r.read()) 16 | 17 | inform = anno["shapes"] 18 | comments = anno["text"] 19 | is_sentence = anno["is_sentence"] 20 | 21 | height, width = img.shape[:2] 22 | 23 | ### sort polies by area 24 | area_list = [] 25 | valid_poly_list = [] 26 | for i in inform: 27 | label_id = i["label"] 28 | points = i["points"] 29 | if "flag" == label_id.lower(): ## meaningless deprecated annotations 30 | continue 31 | 32 | tmp_mask = np.zeros((height, width), dtype=np.uint8) 33 | cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1) 34 | cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1) 35 | tmp_area = tmp_mask.sum() 36 | 37 | area_list.append(tmp_area) 38 | valid_poly_list.append(i) 39 | 40 | ### ground-truth mask 41 | sort_index = np.argsort(area_list)[::-1].astype(np.int32) 42 | sort_index = list(sort_index) 43 | sort_inform = [] 44 | for s_idx in sort_index: 45 | sort_inform.append(valid_poly_list[s_idx]) 46 | 47 | mask = np.zeros((height, width), dtype=np.uint8) 48 | for i in sort_inform: 49 | label_id = i["label"] 50 | points = i["points"] 51 | 52 | if "ignore" in label_id.lower(): 53 | label_value = 255 # ignored during evaluation 54 | else: 55 | label_value = 1 # target 56 | 57 | cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1) 58 | cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value) 59 | 60 | return mask, comments, is_sentence 61 | 62 | 63 | if __name__ == "__main__": 64 | data_dir = "./train" 65 | vis_dir = "./vis" 66 | 67 | if not os.path.exists(vis_dir): 68 | os.makedirs(vis_dir) 69 | 70 | json_path_list = sorted(glob.glob(data_dir + "/*.json")) 71 | for json_path in json_path_list: 72 | img_path = json_path.replace(".json", ".jpg") 73 | img = cv2.imread(img_path)[:, :, ::-1] 74 | 75 | # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton. 76 | mask, comments, is_sentence = get_mask_from_json(json_path, img) 77 | 78 | ## visualization. Green for target, and red for ignore. 79 | valid_mask = (mask == 1).astype(np.float32)[:, :, None] 80 | ignore_mask = (mask == 255).astype(np.float32)[:, :, None] 81 | vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ( 82 | (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask 83 | + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask 84 | ) 85 | vis_img = np.concatenate([img, vis_img], 1) 86 | vis_path = os.path.join( 87 | vis_dir, json_path.split("/")[-1].replace(".json", ".jpg") 88 | ) 89 | cv2.imwrite(vis_path, vis_img[:, :, ::-1]) 90 | print("Visualization has been saved to: ", vis_path) 91 | -------------------------------------------------------------------------------- /utils/grefcoco.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import io 4 | import logging 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import pycocotools.mask as mask_util 10 | from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes 11 | from detectron2.utils.file_io import PathManager 12 | from fvcore.common.timer import Timer 13 | from PIL import Image 14 | 15 | """ 16 | This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format". 17 | """ 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | __all__ = ["load_refcoco_json"] 23 | 24 | 25 | def load_grefcoco_json( 26 | refer_root, 27 | dataset_name, 28 | splitby, 29 | split, 30 | image_root, 31 | extra_annotation_keys=None, 32 | extra_refer_keys=None, 33 | ): 34 | if dataset_name == "refcocop": 35 | dataset_name = "refcoco+" 36 | if dataset_name == "refcoco" or dataset_name == "refcoco+": 37 | splitby == "unc" 38 | if dataset_name == "refcocog": 39 | assert splitby == "umd" or splitby == "google" 40 | 41 | dataset_id = "_".join([dataset_name, splitby, split]) 42 | 43 | from .grefer import G_REFER 44 | 45 | logger.info("Loading dataset {} ({}-{}) ...".format(dataset_name, splitby, split)) 46 | logger.info("Refcoco root: {}".format(refer_root)) 47 | timer = Timer() 48 | refer_root = PathManager.get_local_path(refer_root) 49 | with contextlib.redirect_stdout(io.StringIO()): 50 | refer_api = G_REFER(data_root=refer_root, dataset=dataset_name, splitBy=splitby) 51 | if timer.seconds() > 1: 52 | logger.info( 53 | "Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds()) 54 | ) 55 | 56 | ref_ids = refer_api.getRefIds(split=split) 57 | img_ids = refer_api.getImgIds(ref_ids) 58 | refs = refer_api.loadRefs(ref_ids) 59 | imgs = [refer_api.loadImgs(ref["image_id"])[0] for ref in refs] 60 | anns = [refer_api.loadAnns(ref["ann_id"]) for ref in refs] 61 | imgs_refs_anns = list(zip(imgs, refs, anns)) 62 | 63 | logger.info( 64 | "Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format( 65 | len(img_ids), len(ref_ids), dataset_id 66 | ) 67 | ) 68 | 69 | dataset_dicts = [] 70 | 71 | ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or []) 72 | ref_keys = ["raw", "sent_id"] + (extra_refer_keys or []) 73 | 74 | ann_lib = {} 75 | 76 | NT_count = 0 77 | MT_count = 0 78 | 79 | for img_dict, ref_dict, anno_dicts in imgs_refs_anns: 80 | record = {} 81 | record["source"] = "grefcoco" 82 | record["file_name"] = os.path.join(image_root, img_dict["file_name"]) 83 | record["height"] = img_dict["height"] 84 | record["width"] = img_dict["width"] 85 | image_id = record["image_id"] = img_dict["id"] 86 | 87 | # Check that information of image, ann and ref match each other 88 | # This fails only when the data parsing logic or the annotation file is buggy. 89 | assert ref_dict["image_id"] == image_id 90 | assert ref_dict["split"] == split 91 | if not isinstance(ref_dict["ann_id"], list): 92 | ref_dict["ann_id"] = [ref_dict["ann_id"]] 93 | 94 | # No target samples 95 | if None in anno_dicts: 96 | assert anno_dicts == [None] 97 | assert ref_dict["ann_id"] == [-1] 98 | record["empty"] = True 99 | obj = {key: None for key in ann_keys if key in ann_keys} 100 | obj["bbox_mode"] = BoxMode.XYWH_ABS 101 | obj["empty"] = True 102 | obj = [obj] 103 | 104 | # Multi target samples 105 | else: 106 | record["empty"] = False 107 | obj = [] 108 | for anno_dict in anno_dicts: 109 | ann_id = anno_dict["id"] 110 | if anno_dict["iscrowd"]: 111 | continue 112 | assert anno_dict["image_id"] == image_id 113 | assert ann_id in ref_dict["ann_id"] 114 | 115 | if ann_id in ann_lib: 116 | ann = ann_lib[ann_id] 117 | else: 118 | ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict} 119 | ann["bbox_mode"] = BoxMode.XYWH_ABS 120 | ann["empty"] = False 121 | 122 | segm = anno_dict.get("segmentation", None) 123 | assert segm # either list[list[float]] or dict(RLE) 124 | if isinstance(segm, dict): 125 | if isinstance(segm["counts"], list): 126 | # convert to compressed RLE 127 | segm = mask_util.frPyObjects(segm, *segm["size"]) 128 | else: 129 | # filter out invalid polygons (< 3 points) 130 | segm = [ 131 | poly 132 | for poly in segm 133 | if len(poly) % 2 == 0 and len(poly) >= 6 134 | ] 135 | if len(segm) == 0: 136 | num_instances_without_valid_segmentation += 1 137 | continue # ignore this instance 138 | ann["segmentation"] = segm 139 | ann_lib[ann_id] = ann 140 | 141 | obj.append(ann) 142 | 143 | record["annotations"] = obj 144 | 145 | # Process referring expressions 146 | sents = ref_dict["sentences"] 147 | for sent in sents: 148 | ref_record = record.copy() 149 | ref = {key: sent[key] for key in ref_keys if key in sent} 150 | ref["ref_id"] = ref_dict["ref_id"] 151 | ref_record["sentence"] = ref 152 | dataset_dicts.append(ref_record) 153 | # if ref_record['empty']: 154 | # NT_count += 1 155 | # else: 156 | # MT_count += 1 157 | 158 | # logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count) 159 | 160 | # Debug mode 161 | # return dataset_dicts[:100] 162 | 163 | return dataset_dicts 164 | 165 | 166 | if __name__ == "__main__": 167 | """ 168 | Test the COCO json dataset loader. 169 | 170 | Usage: 171 | python -m detectron2.data.datasets.coco \ 172 | path/to/json path/to/image_root dataset_name 173 | 174 | "dataset_name" can be "coco_2014_minival_100", or other 175 | pre-registered ones 176 | """ 177 | import sys 178 | 179 | import detectron2.data.datasets # noqa # add pre-defined metadata 180 | from detectron2.utils.logger import setup_logger 181 | from detectron2.utils.visualizer import Visualizer 182 | 183 | REFCOCO_PATH = "/mnt/lustre/hhding/code/ReLA/datasets" 184 | COCO_TRAIN_2014_IMAGE_ROOT = "/mnt/lustre/hhding/code/ReLA/datasets/images" 185 | REFCOCO_DATASET = "grefcoco" 186 | REFCOCO_SPLITBY = "unc" 187 | REFCOCO_SPLIT = "train" 188 | 189 | logger = setup_logger(name=__name__) 190 | 191 | dicts = load_grefcoco_json( 192 | REFCOCO_PATH, 193 | REFCOCO_DATASET, 194 | REFCOCO_SPLITBY, 195 | REFCOCO_SPLIT, 196 | COCO_TRAIN_2014_IMAGE_ROOT, 197 | ) 198 | logger.info("Done loading {} samples.".format(len(dicts))) 199 | -------------------------------------------------------------------------------- /utils/reason_seg_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import random 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from transformers import CLIPImageProcessor 11 | 12 | from model.llava import conversation as conversation_lib 13 | from model.segment_anything.utils.transforms import ResizeLongestSide 14 | 15 | from .data_processing import get_mask_from_json 16 | from .utils import (ANSWER_LIST, DEFAULT_IMAGE_TOKEN, 17 | EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST, 18 | SHORT_QUESTION_LIST) 19 | 20 | 21 | class ReasonSegDataset(torch.utils.data.Dataset): 22 | pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) 23 | pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) 24 | img_size = 1024 25 | ignore_label = 255 26 | 27 | def __init__( 28 | self, 29 | base_image_dir, 30 | tokenizer, 31 | vision_tower, 32 | samples_per_epoch=500 * 8 * 2 * 10, 33 | precision: str = "fp32", 34 | image_size: int = 224, 35 | num_classes_per_sample: int = 3, 36 | exclude_val=False, 37 | reason_seg_data="ReasonSeg|train", 38 | explanatory=0.1, 39 | ): 40 | self.exclude_val = exclude_val 41 | self.reason_seg_data = reason_seg_data 42 | self.samples_per_epoch = samples_per_epoch 43 | self.explanatory = explanatory 44 | self.num_classes_per_sample = num_classes_per_sample 45 | 46 | self.base_image_dir = base_image_dir 47 | self.image_size = image_size 48 | self.tokenizer = tokenizer 49 | self.precision = precision 50 | self.transform = ResizeLongestSide(image_size) 51 | self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) 52 | 53 | self.short_question_list = SHORT_QUESTION_LIST 54 | self.long_question_list = LONG_QUESTION_LIST 55 | self.answer_list = ANSWER_LIST 56 | 57 | reason_seg_data, splits = reason_seg_data.split("|") 58 | splits = splits.split("_") 59 | images = [] 60 | for split in splits: 61 | images_split = glob.glob( 62 | os.path.join( 63 | base_image_dir, "reason_seg", reason_seg_data, split, "*.jpg" 64 | ) 65 | ) 66 | images.extend(images_split) 67 | jsons = [path.replace(".jpg", ".json") for path in images] 68 | self.reason_seg_data = (images, jsons) 69 | 70 | print("number of reason_seg samples: ", len(images)) 71 | 72 | if explanatory != -1: 73 | self.explanatory_question_list = EXPLANATORY_QUESTION_LIST 74 | self.img_to_explanation = {} 75 | with open( 76 | os.path.join( 77 | base_image_dir, 78 | "reason_seg", 79 | reason_seg_data, 80 | "explanatory", 81 | "train.json", 82 | ) 83 | ) as f: 84 | items = json.load(f) 85 | for item in items: 86 | img_name = item["image"] 87 | self.img_to_explanation[img_name] = { 88 | "query": item["query"], 89 | "outputs": item["outputs"], 90 | } 91 | 92 | print("len(self.img_to_explanation): ", len(self.img_to_explanation)) 93 | 94 | def __len__(self): 95 | return self.samples_per_epoch 96 | 97 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 98 | """Normalize pixel values and pad to a square input.""" 99 | # Normalize colors 100 | x = (x - self.pixel_mean) / self.pixel_std 101 | 102 | # Pad 103 | h, w = x.shape[-2:] 104 | padh = self.img_size - h 105 | padw = self.img_size - w 106 | x = F.pad(x, (0, padw, 0, padh)) 107 | return x 108 | 109 | def __getitem__(self, idx): 110 | images, jsons = self.reason_seg_data 111 | idx = random.randint(0, len(images) - 1) 112 | image_path = images[idx] 113 | json_path = jsons[idx] 114 | 115 | image = cv2.imread(image_path) 116 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 117 | ori_size = image.shape[:2] 118 | # preprocess image for clip 119 | image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[ 120 | "pixel_values" 121 | ][0] 122 | 123 | mask, sents, is_sentence = get_mask_from_json(json_path, image) 124 | if len(sents) >= self.num_classes_per_sample: 125 | sampled_inds = np.random.choice( 126 | list(range(len(sents))), size=self.num_classes_per_sample, replace=False 127 | ) 128 | else: 129 | sampled_inds = list(range(len(sents))) 130 | sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() 131 | sampled_masks = [ 132 | (mask == 1).astype(np.float32) for _ in range(len(sampled_inds)) 133 | ] 134 | 135 | image = self.transform.apply_image(image) # preprocess image for sam 136 | resize = image.shape[:2] 137 | 138 | image_name = image_path.split("/")[-1] 139 | if self.explanatory != -1 and image_name in self.img_to_explanation: 140 | if random.random() < self.explanatory: 141 | choice = 2 142 | else: 143 | choice = random.randint(0, 1) 144 | 145 | questions = [] 146 | answers = [] 147 | for text in sampled_sents: 148 | if is_sentence: 149 | question_template = random.choice(self.long_question_list) 150 | questions.append(question_template.format(sent=text)) 151 | else: 152 | question_template = random.choice(self.short_question_list) 153 | questions.append(question_template.format(class_name=text.lower())) 154 | 155 | # add explanation if applicable 156 | img_name = image_path.split("/")[-1] 157 | if self.explanatory != -1 and img_name in self.img_to_explanation: 158 | if choice == 0: # [SEG] token 159 | answers.append(random.choice(self.answer_list)) 160 | elif choice == 1: # [SEG] token + text answer 161 | image_name = image_path.split("/")[-1] 162 | answer = self.img_to_explanation[image_name]["outputs"] 163 | answer = random.choice(self.answer_list) + " {}".format(answer) 164 | questions[-1] = ( 165 | DEFAULT_IMAGE_TOKEN 166 | + "\n" 167 | + text 168 | + " {}".format(random.choice(self.explanatory_question_list)) 169 | ) 170 | answers.append(answer) 171 | elif choice == 2: # vanilla text answer 172 | image_name = image_path.split("/")[-1] 173 | answer = self.img_to_explanation[image_name]["outputs"] 174 | questions[-1] = DEFAULT_IMAGE_TOKEN + "\n" + text 175 | answers.append(answer) 176 | else: 177 | raise ValueError("Not implemented yet.") 178 | else: 179 | answers.append(random.choice(self.answer_list)) 180 | 181 | conversations = [] 182 | conv = conversation_lib.default_conversation.copy() 183 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 184 | 185 | i = 0 186 | while i < len(questions): 187 | conv.messages = [] 188 | conv.append_message(conv.roles[0], questions[i]) 189 | conv.append_message(conv.roles[1], answers[i]) 190 | conversations.append(conv.get_prompt()) 191 | i += 1 192 | 193 | image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) 194 | 195 | image_name = image_path.split("/")[-1] 196 | if ( 197 | self.explanatory != -1 198 | and image_name in self.img_to_explanation 199 | and choice == 2 200 | ): 201 | masks = torch.rand(0, *ori_size) 202 | label = torch.ones(ori_size) * self.ignore_label 203 | else: 204 | masks = np.stack(sampled_masks, axis=0) 205 | masks = torch.from_numpy(masks) 206 | label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label 207 | 208 | return ( 209 | image_path, 210 | image, 211 | image_clip, 212 | conversations, 213 | masks, 214 | label, 215 | resize, 216 | questions, 217 | sampled_sents, 218 | ) 219 | -------------------------------------------------------------------------------- /utils/refer_seg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from pycocotools import mask 9 | from transformers import CLIPImageProcessor 10 | 11 | from model.llava import conversation as conversation_lib 12 | from model.segment_anything.utils.transforms import ResizeLongestSide 13 | 14 | from .grefer import G_REFER 15 | from .refer import REFER 16 | from .utils import ANSWER_LIST, SHORT_QUESTION_LIST 17 | 18 | 19 | class ReferSegDataset(torch.utils.data.Dataset): 20 | pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) 21 | pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) 22 | img_size = 1024 23 | ignore_label = 255 24 | 25 | def __init__( 26 | self, 27 | base_image_dir, 28 | tokenizer, 29 | vision_tower, 30 | samples_per_epoch=500 * 8 * 2 * 10, 31 | precision: str = "fp32", 32 | image_size: int = 224, 33 | num_classes_per_sample: int = 3, 34 | exclude_val=False, 35 | refer_seg_data="refclef||refcoco||refcoco+||refcocog", 36 | ): 37 | self.exclude_val = exclude_val 38 | self.samples_per_epoch = samples_per_epoch 39 | self.num_classes_per_sample = num_classes_per_sample 40 | 41 | self.base_image_dir = base_image_dir 42 | self.image_size = image_size 43 | self.tokenizer = tokenizer 44 | self.precision = precision 45 | self.transform = ResizeLongestSide(image_size) 46 | self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) 47 | 48 | self.short_question_list = SHORT_QUESTION_LIST 49 | self.answer_list = ANSWER_LIST 50 | 51 | DATA_DIR = os.path.join(base_image_dir, "refer_seg") 52 | self.refer_seg_ds_list = refer_seg_data.split( 53 | "||" 54 | ) # ['refclef', 'refcoco', 'refcoco+', 'refcocog'] 55 | self.refer_seg_data = {} 56 | for ds in self.refer_seg_ds_list: 57 | if ds == "refcocog": 58 | splitBy = "umd" 59 | else: 60 | splitBy = "unc" 61 | 62 | if ds == "grefcoco": 63 | refer_api = G_REFER(DATA_DIR, ds, splitBy) 64 | else: 65 | refer_api = REFER(DATA_DIR, ds, splitBy) 66 | ref_ids_train = refer_api.getRefIds(split="train") 67 | images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) 68 | refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) 69 | 70 | refer_seg_ds = {} 71 | refer_seg_ds["images"] = [] 72 | loaded_images = refer_api.loadImgs(image_ids=images_ids_train) 73 | 74 | for item in loaded_images: 75 | item = item.copy() 76 | if ds == "refclef": 77 | item["file_name"] = os.path.join( 78 | DATA_DIR, "images/saiapr_tc-12", item["file_name"] 79 | ) 80 | else: 81 | item["file_name"] = os.path.join( 82 | DATA_DIR, "images/mscoco/images/train2014", item["file_name"] 83 | ) 84 | refer_seg_ds["images"].append(item) 85 | refer_seg_ds["annotations"] = refer_api.Anns # anns_train 86 | 87 | print( 88 | "dataset {} (refs {}) (train split) has {} images and {} annotations.".format( 89 | ds, 90 | splitBy, 91 | len(refer_seg_ds["images"]), 92 | len(refer_seg_ds["annotations"]), 93 | ) 94 | ) 95 | 96 | img2refs = {} 97 | for ref in refs_train: 98 | image_id = ref["image_id"] 99 | img2refs[image_id] = img2refs.get(image_id, []) + [ 100 | ref, 101 | ] 102 | refer_seg_ds["img2refs"] = img2refs 103 | self.refer_seg_data[ds] = refer_seg_ds 104 | 105 | def __len__(self): 106 | return self.samples_per_epoch 107 | 108 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 109 | """Normalize pixel values and pad to a square input.""" 110 | # Normalize colors 111 | x = (x - self.pixel_mean) / self.pixel_std 112 | 113 | # Pad 114 | h, w = x.shape[-2:] 115 | padh = self.img_size - h 116 | padw = self.img_size - w 117 | x = F.pad(x, (0, padw, 0, padh)) 118 | return x 119 | 120 | def __getitem__(self, idx): 121 | ds = random.randint(0, len(self.refer_seg_ds_list) - 1) 122 | ds = self.refer_seg_ds_list[ds] 123 | refer_seg_ds = self.refer_seg_data[ds] 124 | images = refer_seg_ds["images"] 125 | annotations = refer_seg_ds["annotations"] 126 | img2refs = refer_seg_ds["img2refs"] 127 | idx = random.randint(0, len(images) - 1) 128 | image_info = images[idx] 129 | image_path = image_info["file_name"] 130 | image_id = image_info["id"] 131 | refs = img2refs[image_id] 132 | if len(refs) == 0: 133 | return self.__getitem__(0) 134 | 135 | sents = [] 136 | ann_ids = [] 137 | for ref in refs: 138 | for sent in ref["sentences"]: 139 | text = sent["sent"] 140 | sents.append(text) 141 | ann_ids.append(ref["ann_id"]) 142 | if len(sents) >= self.num_classes_per_sample: 143 | sampled_inds = np.random.choice( 144 | list(range(len(sents))), size=self.num_classes_per_sample, replace=False 145 | ) 146 | else: 147 | sampled_inds = list(range(len(sents))) 148 | sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() 149 | # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist() 150 | sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] 151 | sampled_classes = sampled_sents 152 | image = cv2.imread(image_path) 153 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 154 | 155 | # preprocess image for clip 156 | image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[ 157 | "pixel_values" 158 | ][0] 159 | 160 | image = self.transform.apply_image(image) # preprocess image for sam 161 | resize = image.shape[:2] 162 | 163 | questions = [] 164 | answers = [] 165 | for text in sampled_classes: 166 | text = text.strip() 167 | assert len(text.split("||")) == 1 168 | question_template = random.choice(self.short_question_list) 169 | questions.append(question_template.format(class_name=text.lower())) 170 | answers.append(random.choice(self.answer_list)) 171 | 172 | conversations = [] 173 | conv = conversation_lib.default_conversation.copy() 174 | 175 | i = 0 176 | while i < len(questions): 177 | conv.messages = [] 178 | conv.append_message(conv.roles[0], questions[i]) 179 | conv.append_message(conv.roles[1], answers[i]) 180 | conversations.append(conv.get_prompt()) 181 | i += 1 182 | 183 | image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) 184 | 185 | flag = False 186 | masks = [] 187 | for ann_id in sampled_ann_ids: 188 | if isinstance(ann_id, list): 189 | flag = True 190 | if -1 in ann_id: 191 | assert len(ann_id) == 1 192 | m = np.zeros((image_info["height"], image_info["width"])).astype( 193 | np.uint8 194 | ) 195 | else: 196 | m_final = np.zeros( 197 | (image_info["height"], image_info["width"]) 198 | ).astype(np.uint8) 199 | for ann_id_i in ann_id: 200 | ann = annotations[ann_id_i] 201 | 202 | if len(ann["segmentation"]) == 0: 203 | m = np.zeros( 204 | (image_info["height"], image_info["width"]) 205 | ).astype(np.uint8) 206 | else: 207 | if type(ann["segmentation"][0]) == list: # polygon 208 | rle = mask.frPyObjects( 209 | ann["segmentation"], 210 | image_info["height"], 211 | image_info["width"], 212 | ) 213 | else: 214 | rle = ann["segmentation"] 215 | for i in range(len(rle)): 216 | if not isinstance(rle[i]["counts"], bytes): 217 | rle[i]["counts"] = rle[i]["counts"].encode() 218 | m = mask.decode(rle) 219 | m = np.sum( 220 | m, axis=2 221 | ) # sometimes there are multiple binary map (corresponding to multiple segs) 222 | m = m.astype(np.uint8) # convert to np.uint8 223 | m_final = m_final | m 224 | m = m_final 225 | masks.append(m) 226 | continue 227 | 228 | ann = annotations[ann_id] 229 | 230 | if len(ann["segmentation"]) == 0: 231 | m = np.zeros((image_info["height"], image_info["width"])).astype( 232 | np.uint8 233 | ) 234 | masks.append(m) 235 | continue 236 | 237 | if type(ann["segmentation"][0]) == list: # polygon 238 | rle = mask.frPyObjects( 239 | ann["segmentation"], image_info["height"], image_info["width"] 240 | ) 241 | else: 242 | rle = ann["segmentation"] 243 | for i in range(len(rle)): 244 | if not isinstance(rle[i]["counts"], bytes): 245 | rle[i]["counts"] = rle[i]["counts"].encode() 246 | m = mask.decode(rle) 247 | m = np.sum( 248 | m, axis=2 249 | ) # sometimes there are multiple binary map (corresponding to multiple segs) 250 | m = m.astype(np.uint8) # convert to np.uint8 251 | masks.append(m) 252 | 253 | masks = np.stack(masks, axis=0) 254 | 255 | # if ds == 'grefcoco' and flag: 256 | # import shutil 257 | # image_name = image_path.split("/")[-1] 258 | # save_dir = os.path.join("/group/30042/xlai/LISA_refactor_final/debug", image_name.split(".")[0]) 259 | # os.makedirs(save_dir, exist_ok=True) 260 | # shutil.copy(image_path, save_dir) 261 | # for i in range(masks.shape[0]): 262 | # cv2.imwrite(os.path.join(save_dir, "{}_{}_{}.jpg".format(image_name, i, sampled_classes[i])), masks[i].astype(np.int32) * 100) 263 | 264 | masks = torch.from_numpy(masks) 265 | label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label 266 | 267 | return ( 268 | image_path, 269 | image, 270 | image_clip, 271 | conversations, 272 | masks, 273 | label, 274 | resize, 275 | questions, 276 | sampled_classes, 277 | ) 278 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | 14 | SHORT_QUESTION_LIST = [ 15 | DEFAULT_IMAGE_TOKEN + "\n" + "Can you segment the {class_name} in this image?", 16 | DEFAULT_IMAGE_TOKEN + "\n" + "Please segment the {class_name} in this image.", 17 | DEFAULT_IMAGE_TOKEN 18 | + "\n" 19 | + "What is {class_name} in this image? Please respond with segmentation mask.", 20 | DEFAULT_IMAGE_TOKEN 21 | + "\n" 22 | + "What is {class_name} in this image? Please output segmentation mask.", 23 | ] 24 | 25 | LONG_QUESTION_LIST = [ 26 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please respond with segmentation mask.", 27 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please output segmentation mask.", 28 | ] 29 | 30 | EXPLANATORY_QUESTION_LIST = [ 31 | "Please output segmentation mask and explain why.", 32 | "Please output segmentation mask and explain the reason.", 33 | "Please output segmentation mask and give some explanation.", 34 | ] 35 | 36 | ANSWER_LIST = [ 37 | "It is [SEG].", 38 | "Sure, [SEG].", 39 | "Sure, it is [SEG].", 40 | "Sure, the segmentation result is [SEG].", 41 | "[SEG].", 42 | ] 43 | 44 | 45 | class Summary(Enum): 46 | NONE = 0 47 | AVERAGE = 1 48 | SUM = 2 49 | COUNT = 3 50 | 51 | 52 | class AverageMeter(object): 53 | """Computes and stores the average and current value""" 54 | 55 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 56 | self.name = name 57 | self.fmt = fmt 58 | self.summary_type = summary_type 59 | self.reset() 60 | 61 | def reset(self): 62 | self.val = 0 63 | self.avg = 0 64 | self.sum = 0 65 | self.count = 0 66 | 67 | def update(self, val, n=1): 68 | self.val = val 69 | self.sum += val * n 70 | self.count += n 71 | self.avg = self.sum / self.count 72 | 73 | def all_reduce(self): 74 | device = "cuda" if torch.cuda.is_available() else "cpu" 75 | if isinstance(self.sum, np.ndarray): 76 | total = torch.tensor( 77 | self.sum.tolist() 78 | + [ 79 | self.count, 80 | ], 81 | dtype=torch.float32, 82 | device=device, 83 | ) 84 | else: 85 | total = torch.tensor( 86 | [self.sum, self.count], dtype=torch.float32, device=device 87 | ) 88 | 89 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 90 | if total.shape[0] > 2: 91 | self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() 92 | else: 93 | self.sum, self.count = total.tolist() 94 | self.avg = self.sum / (self.count + 1e-5) 95 | 96 | def __str__(self): 97 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 98 | return fmtstr.format(**self.__dict__) 99 | 100 | def summary(self): 101 | fmtstr = "" 102 | if self.summary_type is Summary.NONE: 103 | fmtstr = "" 104 | elif self.summary_type is Summary.AVERAGE: 105 | fmtstr = "{name} {avg:.3f}" 106 | elif self.summary_type is Summary.SUM: 107 | fmtstr = "{name} {sum:.3f}" 108 | elif self.summary_type is Summary.COUNT: 109 | fmtstr = "{name} {count:.3f}" 110 | else: 111 | raise ValueError("invalid summary type %r" % self.summary_type) 112 | 113 | return fmtstr.format(**self.__dict__) 114 | 115 | 116 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 117 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 118 | assert output.dim() in [1, 2, 3] 119 | assert output.shape == target.shape 120 | output = output.view(-1) 121 | target = target.view(-1) 122 | output[target == ignore_index] = ignore_index 123 | intersection = output[output == target] 124 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) 125 | area_output = torch.histc(output, bins=K, min=0, max=K - 1) 126 | area_target = torch.histc(target, bins=K, min=0, max=K - 1) 127 | area_union = area_output + area_target - area_intersection 128 | return area_intersection, area_union, area_target 129 | 130 | 131 | class ProgressMeter(object): 132 | def __init__(self, num_batches, meters, prefix=""): 133 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 134 | self.meters = meters 135 | self.prefix = prefix 136 | 137 | def display(self, batch): 138 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 139 | entries += [str(meter) for meter in self.meters] 140 | print("\t".join(entries)) 141 | 142 | def display_summary(self): 143 | entries = [" *"] 144 | entries += [meter.summary() for meter in self.meters] 145 | print(" ".join(entries)) 146 | 147 | def _get_batch_fmtstr(self, num_batches): 148 | num_digits = len(str(num_batches // 1)) 149 | fmt = "{:" + str(num_digits) + "d}" 150 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 151 | 152 | 153 | def dict_to_cuda(input_dict): 154 | for k, v in input_dict.items(): 155 | if isinstance(input_dict[k], torch.Tensor): 156 | input_dict[k] = v.cuda(non_blocking=True) 157 | elif ( 158 | isinstance(input_dict[k], list) 159 | and len(input_dict[k]) > 0 160 | and isinstance(input_dict[k][0], torch.Tensor) 161 | ): 162 | input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] 163 | return input_dict 164 | -------------------------------------------------------------------------------- /utils/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import cv2 6 | import torch 7 | import torch.nn.functional as F 8 | from transformers import CLIPImageProcessor 9 | 10 | from model.llava import conversation as conversation_lib 11 | from model.segment_anything.utils.transforms import ResizeLongestSide 12 | 13 | from .utils import DEFAULT_IMAGE_TOKEN 14 | 15 | 16 | def preprocess_multimodal(source, mm_use_im_start_end): 17 | for sentence in source: 18 | if DEFAULT_IMAGE_TOKEN in sentence["value"]: 19 | sentence["value"] = ( 20 | sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() 21 | ) 22 | sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] 23 | sentence["value"] = sentence["value"].strip() 24 | if "mmtag" in conversation_lib.default_conversation.version: 25 | sentence["value"] = sentence["value"].replace( 26 | DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "" 27 | ) 28 | return source 29 | 30 | 31 | class VQADataset(torch.utils.data.Dataset): 32 | pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) 33 | pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) 34 | img_size = 1024 35 | ignore_label = 255 36 | 37 | def __init__( 38 | self, 39 | base_image_dir, 40 | tokenizer, 41 | vision_tower, 42 | samples_per_epoch=500 * 8 * 2 * 10, 43 | precision: str = "fp32", 44 | image_size: int = 224, 45 | num_classes_per_sample: int = 3, 46 | exclude_val=False, 47 | vqa_data="llava_instruct_150k", 48 | ): 49 | self.exclude_val = exclude_val 50 | self.samples_per_epoch = samples_per_epoch 51 | self.num_classes_per_sample = num_classes_per_sample 52 | 53 | self.base_image_dir = base_image_dir 54 | self.image_size = image_size 55 | self.tokenizer = tokenizer 56 | self.precision = precision 57 | self.transform = ResizeLongestSide(image_size) 58 | self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) 59 | 60 | DATA_DIR = os.path.join(base_image_dir, "llava_dataset") 61 | self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017") 62 | with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f: 63 | vqa_data = json.load(f) 64 | self.vqa_data = vqa_data 65 | 66 | print("vqa_data: ", len(self.vqa_data)) 67 | 68 | def __len__(self): 69 | return self.samples_per_epoch 70 | 71 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 72 | """Normalize pixel values and pad to a square input.""" 73 | # Normalize colors 74 | x = (x - self.pixel_mean) / self.pixel_std 75 | 76 | # Pad 77 | h, w = x.shape[-2:] 78 | padh = self.img_size - h 79 | padw = self.img_size - w 80 | x = F.pad(x, (0, padw, 0, padh)) 81 | return x 82 | 83 | def __getitem__(self, idx): 84 | idx = random.randint(0, len(self.vqa_data) - 1) 85 | item = self.vqa_data[idx] 86 | image_path = os.path.join(self.vqa_image_root, item["image"]) 87 | image = cv2.imread(image_path) 88 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 89 | ori_size = image.shape[:2] 90 | image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[ 91 | "pixel_values" 92 | ][ 93 | 0 94 | ] # preprocess image for clip 95 | 96 | image = self.transform.apply_image(image) # preprocess image for sam 97 | resize = image.shape[:2] 98 | 99 | conv = conversation_lib.default_conversation.copy() 100 | source = item["conversations"] 101 | source = preprocess_multimodal( 102 | source, 103 | mm_use_im_start_end=conv.sep_style == conversation_lib.SeparatorStyle.TWO, 104 | ) 105 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 106 | conversations = [] 107 | if roles[source[0]["from"]] != conv.roles[0]: 108 | # Skip the first one if it is not from human 109 | source = source[1:] 110 | conv.messages = [] 111 | for j, sentence in enumerate(source): 112 | role = roles[sentence["from"]] 113 | assert role == conv.roles[j % 2], f"{i}" 114 | conv.append_message(role, sentence["value"]) 115 | conversations.append(conv.get_prompt()) 116 | 117 | questions = conversations 118 | sampled_classes = conversations 119 | 120 | image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) 121 | 122 | masks = torch.rand(0, *ori_size) 123 | label = torch.ones(ori_size) * self.ignore_label 124 | 125 | return ( 126 | image_path, 127 | image, 128 | image_clip, 129 | conversations, 130 | masks, 131 | label, 132 | resize, 133 | questions, 134 | sampled_classes, 135 | ) 136 | -------------------------------------------------------------------------------- /vis_output/blackpink.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/blackpink.jpg -------------------------------------------------------------------------------- /vis_output/camera_lens.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/camera_lens.jpg -------------------------------------------------------------------------------- /vis_output/dog_with_horn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/dog_with_horn.jpg -------------------------------------------------------------------------------- /vis_output/example1_mask_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/example1_mask_0.jpg -------------------------------------------------------------------------------- /vis_output/example1_masked_img_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/example1_masked_img_0.jpg -------------------------------------------------------------------------------- /vis_output/example2_mask_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/example2_mask_0.jpg -------------------------------------------------------------------------------- /vis_output/example2_masked_img_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/example2_masked_img_0.jpg -------------------------------------------------------------------------------- /vis_output/jackma.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/jackma.jpg -------------------------------------------------------------------------------- /vis_output/obama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/obama.jpg -------------------------------------------------------------------------------- /vis_output/stand_higher.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/stand_higher.jpg -------------------------------------------------------------------------------- /vis_output/trump.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/trump.jpg -------------------------------------------------------------------------------- /vis_output/wash_hands.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LISA/3cb2d4301f1af4691bd4f3938335ef06e76f155a/vis_output/wash_hands.jpg --------------------------------------------------------------------------------