├── README.md ├── cocoagent_overview.png ├── data ├── fetch_dataset_for_t5_blipv2.py ├── fetch_dataset_for_t5_blipv2_google_pre10.py ├── fetch_dataset_for_t5_blipv2_single.py └── metagui.txt ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── action_matching.py │ ├── action_type.py │ ├── eval_aitw.py │ ├── eval_aitw_cot.py │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── generate_webpage_data_from_table.py │ ├── model_aitw.py │ ├── model_aitw_1102.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ └── utils_data_for_owl.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── builder.cpython-310.pyc │ │ └── llava_arch.cpython-310.pyc │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── __pycache__ │ │ │ ├── llava_llama.cpython-310.pyc │ │ │ └── llava_mpt.cpython-310.pyc │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── __pycache__ │ │ │ ├── adapt_tokenizer.cpython-310.pyc │ │ │ ├── attention.cpython-310.pyc │ │ │ ├── blocks.cpython-310.pyc │ │ │ ├── configuration_mpt.cpython-310.pyc │ │ │ ├── custom_embedding.cpython-310.pyc │ │ │ ├── flash_attn_triton.cpython-310.pyc │ │ │ ├── hf_prefixlm_converter.cpython-310.pyc │ │ │ ├── meta_init_context.cpython-310.pyc │ │ │ ├── modeling_mpt.cpython-310.pyc │ │ │ ├── norm.cpython-310.pyc │ │ │ └── param_init_fns.cpython-310.pyc │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── __pycache__ │ │ │ ├── builder.cpython-310.pyc │ │ │ └── clip_encoder.cpython-310.pyc │ │ ├── builder.py │ │ └── clip_encoder.py │ └── utils.py ├── train │ ├── __pycache__ │ │ ├── llama_flash_attn_monkey_patch.cpython-310.pyc │ │ ├── llava_trainer.cpython-310.pyc │ │ └── train.cpython-310.pyc │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── models └── README.md ├── run_eg ├── eval.sh ├── prepare_data.sh └── train.sh └── scripts ├── action_matching.py ├── action_type.py ├── convert_sqa_to_llava.py ├── convert_sqa_to_llava_base_prompt.py ├── covert_aitw_to_llava.py ├── covert_aitw_to_llavacot_hist_fullset.py ├── covert_aitw_to_mmicl.py ├── merge_lora_weights.py ├── utils_data_for_owl_1029.py └── utils_data_for_owl_cot_hist.py /README.md: -------------------------------------------------------------------------------- 1 | # CoCo-Agent 2 | 3 | **CoCo-Agent: A Comprehensive Cognitive LLM Agent for Smartphone GUI Automation** [[paper]](https://arxiv.org/abs/2402.11941) 4 | 5 | Accepted by ACL’24 Findings. 6 | Code for my CoCo-Agent. 7 | 8 | Large language models (LLMs) have shown remarkable potential as human-like autonomous language agents to interact with real-world environments, especially for graphical user interface (GUI) automation. However, those GUI agents require comprehensive cognition ability including exhaustive perception and reliable action response. We propose Comprehensive Cognitive LLM Agent, CoCo-Agent, with two novel approaches, comprehensive environment perception (CEP) and conditional action prediction (CAP), to systematically improve the GUI automation performance. First, CEP facilitates the GUI perception through different aspects and granularity, including screenshots and complementary detailed layouts for the visual channel and historical actions for the textual channel. Second, CAP decomposes the action prediction into sub-problems: action type prediction and action target conditioned on the action type. With our technical design, our agent achieves new state-of-the-art performance on AITW and META-GUI benchmarks, showing promising abilities in realistic scenarios. 9 | 10 | ![](cocoagent_overview.png) 11 | 12 | This repo will be improved continually. 13 | 14 | ``` 15 | @article{ma2024comprehensive, 16 | title={Comprehensive Cognitive LLM Agent for Smartphone GUI Automation}, 17 | author={Ma, Xinbei and Zhang, Zhuosheng and Zhao, Hai}, 18 | journal={arXiv preprint arXiv:2402.11941}, 19 | year={2024} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /cocoagent_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/cocoagent_overview.png -------------------------------------------------------------------------------- /data/fetch_dataset_for_t5_blipv2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./google-research') 3 | from android_in_the_wild import visualization_utils, action_type, action_matching 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from tqdm import tqdm 8 | import json 9 | import jax.numpy as jnp 10 | import argparse 11 | import pickle 12 | import torch 13 | import timm 14 | from timm.data import resolve_data_config 15 | from timm.data.transforms_factory import create_transform 16 | import os 17 | import tensorflow as tf 18 | from PIL import Image 19 | from transformers import AutoProcessor, Blip2Model 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 21 | 22 | # Get the list of available physical devices 23 | # physical_devices = tf.config.list_physical_devices('GPU') 24 | # # Disable GPU support by setting the visible devices to only include the CPU 25 | # tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU') 26 | 27 | # os.environ["CUDA_VISIBLE_DEVICES"]='0' 28 | # dataset_name = 'general' #@param ["general", "google_apps", "install", "single", "web_shopping"] 29 | # data_split = "general_texts_splits.json" 30 | 31 | # device = "cuda" if torch.cuda.is_available() else "cpu" 32 | # model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) 33 | # model.to(device) 34 | # processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 35 | 36 | dataset_directories = { 37 | 'general': 'gs://gresearch/android-in-the-wild/general/*', 38 | 'google_apps': 'gs://gresearch/android-in-the-wild/google_apps/*', 39 | 'install': 'gs://gresearch/android-in-the-wild/install/*', 40 | 'single': 'gs://gresearch/android-in-the-wild/single/*', 41 | 'web_shopping': 'gs://gresearch/android-in-the-wild/web_shopping/*', 42 | } 43 | 44 | def _decode_image( 45 | example, 46 | image_height, 47 | image_width, 48 | image_channels, 49 | ): 50 | """Decodes image from example and reshapes. 51 | 52 | Args: 53 | example: Example which contains encoded image. 54 | image_height: The height of the raw image. 55 | image_width: The width of the raw image. 56 | image_channels: The number of channels in the raw image. 57 | 58 | Returns: 59 | Decoded and reshaped image tensor. 60 | """ 61 | image = tf.io.decode_raw( 62 | example.features.feature['image/encoded'].bytes_list.value[0], 63 | out_type=tf.uint8, 64 | ) 65 | 66 | height = tf.cast(image_height, tf.int32) 67 | width = tf.cast(image_width, tf.int32) 68 | n_channels = tf.cast(image_channels, tf.int32) 69 | 70 | return tf.reshape(image, (height, width, n_channels)) 71 | 72 | def parse_episode( 73 | episode, 74 | ep_id, 75 | get_images = False, 76 | get_annotations = False, 77 | get_actions = False, 78 | output_dir = '.' 79 | ): 80 | parsed_episode = [] 81 | for i, ex in enumerate(episode): 82 | goal = ex.features.feature['goal_info'].bytes_list.value[0].decode('utf-8') 83 | step_id = ex.features.feature['step_id'].int64_list.value[0] 84 | # episode_id = ex.features.feature['episode_id'].bytes_list.value[0].decode('utf-8') 85 | output_ep = { 86 | "goal": goal, 87 | "step_id": step_id 88 | } 89 | 90 | image_height = ex.features.feature['image/height'].int64_list.value[0] 91 | image_width = ex.features.feature['image/width'].int64_list.value[0] 92 | image_channels = ex.features.feature['image/channels'].int64_list.value[0] 93 | if get_images: 94 | # image = _decode_image(ex, image_height, image_width, image_channels) 95 | # image = image.numpy() 96 | # image = Image.fromarray(image).convert('RGB') 97 | 98 | # with torch.no_grad(): 99 | # inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) 100 | # image_features = model.get_image_features(**inputs).pooler_output[0] 101 | # image_features = image_features.detach().cpu() 102 | # output_ep["image"] = image_features 103 | image = _decode_image(ex, image_height, image_width, image_channels) 104 | image = image.numpy() 105 | image = Image.fromarray(image).convert('RGB') 106 | # print(output_dir, ep_id, step_id) 107 | image_path = os.path.join(output_dir, str(ep_id) +'/'+str(step_id)+'.png') 108 | if not os.path.exists(os.path.join(output_dir, str(ep_id))): 109 | os.mkdir(os.path.join(output_dir, str(ep_id))) 110 | image.save(image_path) 111 | output_ep["image"] = image_path 112 | 113 | if get_annotations: 114 | flattened_positions = np.array( 115 | ex.features.feature['image/ui_annotations_positions'].float_list.value 116 | ) 117 | ui_text = ex.features.feature['image/ui_annotations_text'].bytes_list.value 118 | ui_text = [value.decode('utf-8') for value in ui_text] 119 | ui_type = ex.features.feature['image/ui_annotations_ui_types'].bytes_list.value 120 | ui_type = [value.decode('utf-8') for value in ui_type] 121 | 122 | positions = np.reshape(flattened_positions, (-1, 4)) #(y, x, height, width) 123 | output_ep["ui_positions"] = positions 124 | output_ep["ui_text"] = ui_text 125 | output_ep["ui_type"] = ui_type 126 | 127 | if get_actions: 128 | touch_y, touch_x = ex.features.feature['results/yx_touch'].float_list.value 129 | lift_y, lift_x = ex.features.feature['results/yx_lift'].float_list.value 130 | ex_action_type = ex.features.feature['results/action_type'].int64_list.value[0] 131 | 132 | ex_action_type = action_type.ActionType(ex_action_type).name 133 | 134 | type_text = (ex.features.feature['results/type_action'].bytes_list.value[0].decode('utf-8')) 135 | 136 | output_ep["result_touch_yx"] = [touch_y, touch_x] 137 | output_ep["result_lift_yx"] = [lift_y, lift_x] 138 | output_ep["result_action"] = [ex_action_type, type_text] 139 | 140 | parsed_episode.append(output_ep) 141 | return parsed_episode 142 | 143 | _SWIPE_DISTANCE_THRESHOLD = 0.04 144 | def is_tap_action(normalized_start_yx, normalized_end_yx): 145 | distance = jnp.linalg.norm( 146 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 147 | return distance <= _SWIPE_DISTANCE_THRESHOLD 148 | 149 | def _check_tap_actions_match( 150 | tap_yx, 151 | annotation_positions, 152 | annotation_width_augment_fraction, 153 | annotation_height_augment_fraction, 154 | ): 155 | """Determines if two tap actions are the same.""" 156 | resized_annotation_positions = action_matching._resize_annotation_bounding_boxes( 157 | annotation_positions, 158 | annotation_width_augment_fraction, 159 | annotation_height_augment_fraction, 160 | ) 161 | # Check if the ground truth tap action falls in an annotation's bounding box. 162 | tap_in_box = action_matching._yx_in_bounding_boxes(tap_yx, resized_annotation_positions) 163 | return tap_in_box 164 | 165 | def _check_drag_actions_match( 166 | drag_touch_yx, 167 | drag_lift_yx, 168 | ): 169 | """Determines if two drag actions are the same.""" 170 | # Store drag deltas (the change in the y and x coordinates from touch to 171 | # lift), magnitudes, and the index of the main axis, which is the axis with 172 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 173 | # ending at (0.3, 0.5) has a main axis index of 1). 174 | drag_1_deltas = drag_lift_yx - drag_touch_yx 175 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 176 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 177 | 178 | # y axis 179 | if drag_1_main_axis == 0: 180 | if drag_1_deltas[0] < 0: 181 | scroll = "up" 182 | else: 183 | scroll = "down" 184 | elif drag_1_main_axis == 1: 185 | if drag_1_deltas[1] < 0: 186 | scroll = "left" 187 | else: 188 | scroll = "right" 189 | 190 | return scroll 191 | 192 | def fetch_episode(dataset_name, data_split, get_images, get_annotations, get_actions, output_dir): 193 | filenames = tf.io.gfile.glob(dataset_directories[dataset_name]) 194 | dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP').as_numpy_iterator() 195 | 196 | with open (data_split, "r") as rp: 197 | split_data = json.load(rp) 198 | train_data = split_data["train"] 199 | val_data = split_data["validation"] 200 | test_data = split_data["test"] 201 | print(f"train_data size: {len(train_data)}, val_data size: {len(val_data)}, test_data size: {len(test_data)}") 202 | print(train_data[0], val_data[0], test_data[0]) 203 | 204 | all_parsed_episode = { 205 | "train": [], 206 | "val": [], 207 | "test": [], 208 | } 209 | total_screens = { 210 | "train": 0, 211 | "val": 0, 212 | "test": 0, 213 | } 214 | 215 | episode = [] 216 | episode_id = None 217 | 218 | for d in tqdm(dataset): 219 | ex = tf.train.Example() 220 | ex.ParseFromString(d) 221 | ep_id = ex.features.feature['episode_id'].bytes_list.value[0].decode('utf-8') 222 | if '.' in ep_id: 223 | ep_id = ep_id.split(".")[0] 224 | # if (ep_id not in train_data) & (ep_id not in test_data): 225 | # continue 226 | if episode_id is None: 227 | episode_id = ep_id 228 | episode.append(ex) 229 | elif ep_id == episode_id: 230 | episode.append(ex) 231 | else: 232 | # save data 233 | try: 234 | # here is a bug: ep_id is the new episode, this should be episode_id 235 | output = parse_episode(episode, ep_id, get_images=get_images, get_annotations=get_annotations, get_actions=get_actions, output_dir = output_dir) 236 | except Exception as exc: 237 | print(exc) 238 | # bad data point; init a new episode 239 | episode_id = ep_id 240 | episode = [ex] 241 | 242 | if int(episode_id) in train_data: 243 | curr_split = "train" 244 | elif int(episode_id) in val_data: 245 | curr_split = "val" 246 | elif int(episode_id) in test_data: 247 | curr_split = "test" 248 | else: 249 | print("error episode") 250 | # print(all_parsed_episode) 251 | all_parsed_episode[curr_split].append({"episode_id":episode_id, "data":output}) 252 | total_screens[curr_split] += len(episode) 253 | # init a new episode 254 | episode_id = ep_id 255 | episode = [ex] 256 | # last episode 257 | if len(episode) > 0: 258 | # save data 259 | output = parse_episode(episode, ep_id, get_images=get_images, get_annotations=get_annotations, get_actions=get_actions, output_dir = output_dir) 260 | if episode_id in train_data: 261 | curr_split = "train" 262 | elif episode_id in val_data: 263 | curr_split = "val" 264 | elif episode_id in test_data: 265 | curr_split = "test" 266 | else: 267 | assert "error episode" 268 | 269 | all_parsed_episode[curr_split].append({"episode_id":episode_id, "data":output}) 270 | total_screens[curr_split] += len(episode) 271 | 272 | print(len(all_parsed_episode["train"]), total_screens["train"], len(all_parsed_episode["val"]), total_screens["val"], len(all_parsed_episode["test"]), total_screens["test"]) 273 | return all_parsed_episode 274 | 275 | def parse_args(): 276 | parser = argparse.ArgumentParser() 277 | parser.add_argument('--dataset', type=str, default='general') 278 | parser.add_argument("--split_file", type=str, default="dataset/general_texts_splits.json") 279 | parser.add_argument('--output_dir', type=str, default='dataset/t5/general_parsed_episode_t5_clip') 280 | # parser.add_argument('--get_images', action='store_true') 281 | # parser.add_argument('--get_annotations', action='store_true') 282 | # parser.add_argument('--get_actions', action='store_true') 283 | 284 | parser.add_argument('--get_images', default=True, action='store_true') 285 | parser.add_argument('--get_annotations', default=True, action='store_true') 286 | parser.add_argument('--get_actions', default=True, action='store_true') 287 | 288 | args = parser.parse_args() 289 | return args 290 | 291 | if __name__ == '__main__': 292 | 293 | args = parse_args() 294 | print('====Input Arguments====') 295 | print(json.dumps(vars(args), indent=2, sort_keys=False)) 296 | 297 | all_parsed_episode = fetch_episode(args.dataset, args.split_file, args.get_images, args.get_annotations, args.get_actions, args.output_dir) 298 | 299 | with open(f"{args.output_dir}_train.obj", "wb") as wp: 300 | pickle.dump(all_parsed_episode["train"],wp) 301 | with open(f"{args.output_dir}_val.obj", "wb") as wp: 302 | pickle.dump(all_parsed_episode["val"],wp) 303 | with open(f"{args.output_dir}_test.obj", "wb") as wp: 304 | pickle.dump(all_parsed_episode["test"],wp) 305 | 306 | # python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/general_parsed_episode_owl" 307 | # python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/install_parsed_episode_owl" --dataset install 308 | # python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/google_apps_parsed_episode_owl" --dataset google_apps 309 | # CUDA_VISIBLE_DEVICES=5 python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/single_parsed_episode_owl" --dataset single ---------failed, the episode id is not inttt 310 | # python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/web_shopping_parsed_episode_owl" --dataset web_shopping 311 | 312 | -------------------------------------------------------------------------------- /data/fetch_dataset_for_t5_blipv2_google_pre10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./google-research') 3 | from android_in_the_wild import visualization_utils, action_type, action_matching 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from tqdm import tqdm 8 | import json 9 | import jax.numpy as jnp 10 | import argparse 11 | import pickle 12 | import torch 13 | import timm 14 | from timm.data import resolve_data_config 15 | from timm.data.transforms_factory import create_transform 16 | import os 17 | import tensorflow as tf 18 | from PIL import Image 19 | from transformers import AutoProcessor, Blip2Model 20 | import random 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 22 | 23 | # Get the list of available physical devices 24 | # physical_devices = tf.config.list_physical_devices('GPU') 25 | # # Disable GPU support by setting the visible devices to only include the CPU 26 | # tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU') 27 | 28 | # os.environ["CUDA_VISIBLE_DEVICES"]='0' 29 | # dataset_name = 'general' #@param ["general", "google_apps", "install", "single", "web_shopping"] 30 | # data_split = "general_texts_splits.json" 31 | 32 | # device = "cuda" if torch.cuda.is_available() else "cpu" 33 | # model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) 34 | # model.to(device) 35 | # processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 36 | 37 | dataset_directories = { 38 | 'general': 'gs://gresearch/android-in-the-wild/general/*', 39 | 'google_apps': 'gs://gresearch/android-in-the-wild/google_apps/*', 40 | 'install': 'gs://gresearch/android-in-the-wild/install/*', 41 | 'single': 'gs://gresearch/android-in-the-wild/single/*', 42 | 'web_shopping': 'gs://gresearch/android-in-the-wild/web_shopping/*', 43 | } 44 | 45 | def _decode_image( 46 | example, 47 | image_height, 48 | image_width, 49 | image_channels, 50 | ): 51 | """Decodes image from example and reshapes. 52 | 53 | Args: 54 | example: Example which contains encoded image. 55 | image_height: The height of the raw image. 56 | image_width: The width of the raw image. 57 | image_channels: The number of channels in the raw image. 58 | 59 | Returns: 60 | Decoded and reshaped image tensor. 61 | """ 62 | image = tf.io.decode_raw( 63 | example.features.feature['image/encoded'].bytes_list.value[0], 64 | out_type=tf.uint8, 65 | ) 66 | 67 | height = tf.cast(image_height, tf.int32) 68 | width = tf.cast(image_width, tf.int32) 69 | n_channels = tf.cast(image_channels, tf.int32) 70 | 71 | return tf.reshape(image, (height, width, n_channels)) 72 | 73 | def parse_episode( 74 | episode, 75 | ep_id, 76 | get_images = False, 77 | get_annotations = False, 78 | get_actions = False, 79 | output_dir = '.' 80 | ): 81 | parsed_episode = [] 82 | for i, ex in enumerate(episode): 83 | goal = ex.features.feature['goal_info'].bytes_list.value[0].decode('utf-8') 84 | step_id = ex.features.feature['step_id'].int64_list.value[0] 85 | # episode_id = ex.features.feature['episode_id'].bytes_list.value[0].decode('utf-8') 86 | output_ep = { 87 | "goal": goal, 88 | "step_id": step_id 89 | } 90 | 91 | image_height = ex.features.feature['image/height'].int64_list.value[0] 92 | image_width = ex.features.feature['image/width'].int64_list.value[0] 93 | image_channels = ex.features.feature['image/channels'].int64_list.value[0] 94 | if get_images: 95 | # image = _decode_image(ex, image_height, image_width, image_channels) 96 | # image = image.numpy() 97 | # image = Image.fromarray(image).convert('RGB') 98 | 99 | # with torch.no_grad(): 100 | # inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) 101 | # image_features = model.get_image_features(**inputs).pooler_output[0] 102 | # image_features = image_features.detach().cpu() 103 | # output_ep["image"] = image_features 104 | image = _decode_image(ex, image_height, image_width, image_channels) 105 | image = image.numpy() 106 | image = Image.fromarray(image).convert('RGB') 107 | # print(output_dir, ep_id, step_id) 108 | image_path = os.path.join(output_dir, str(ep_id) +'/'+str(step_id)+'.png') 109 | if not os.path.exists(os.path.join(output_dir, str(ep_id))): 110 | os.mkdir(os.path.join(output_dir, str(ep_id))) 111 | image.save(image_path) 112 | output_ep["image"] = image_path 113 | 114 | if get_annotations: 115 | flattened_positions = np.array( 116 | ex.features.feature['image/ui_annotations_positions'].float_list.value 117 | ) 118 | ui_text = ex.features.feature['image/ui_annotations_text'].bytes_list.value 119 | ui_text = [value.decode('utf-8') for value in ui_text] 120 | ui_type = ex.features.feature['image/ui_annotations_ui_types'].bytes_list.value 121 | ui_type = [value.decode('utf-8') for value in ui_type] 122 | 123 | positions = np.reshape(flattened_positions, (-1, 4)) #(y, x, height, width) 124 | output_ep["ui_positions"] = positions 125 | output_ep["ui_text"] = ui_text 126 | output_ep["ui_type"] = ui_type 127 | 128 | if get_actions: 129 | touch_y, touch_x = ex.features.feature['results/yx_touch'].float_list.value 130 | lift_y, lift_x = ex.features.feature['results/yx_lift'].float_list.value 131 | ex_action_type = ex.features.feature['results/action_type'].int64_list.value[0] 132 | 133 | ex_action_type = action_type.ActionType(ex_action_type).name 134 | 135 | type_text = (ex.features.feature['results/type_action'].bytes_list.value[0].decode('utf-8')) 136 | 137 | output_ep["result_touch_yx"] = [touch_y, touch_x] 138 | output_ep["result_lift_yx"] = [lift_y, lift_x] 139 | output_ep["result_action"] = [ex_action_type, type_text] 140 | 141 | parsed_episode.append(output_ep) 142 | return parsed_episode 143 | 144 | _SWIPE_DISTANCE_THRESHOLD = 0.04 145 | def is_tap_action(normalized_start_yx, normalized_end_yx): 146 | distance = jnp.linalg.norm( 147 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 148 | return distance <= _SWIPE_DISTANCE_THRESHOLD 149 | 150 | def _check_tap_actions_match( 151 | tap_yx, 152 | annotation_positions, 153 | annotation_width_augment_fraction, 154 | annotation_height_augment_fraction, 155 | ): 156 | """Determines if two tap actions are the same.""" 157 | resized_annotation_positions = action_matching._resize_annotation_bounding_boxes( 158 | annotation_positions, 159 | annotation_width_augment_fraction, 160 | annotation_height_augment_fraction, 161 | ) 162 | # Check if the ground truth tap action falls in an annotation's bounding box. 163 | tap_in_box = action_matching._yx_in_bounding_boxes(tap_yx, resized_annotation_positions) 164 | return tap_in_box 165 | 166 | def _check_drag_actions_match( 167 | drag_touch_yx, 168 | drag_lift_yx, 169 | ): 170 | """Determines if two drag actions are the same.""" 171 | # Store drag deltas (the change in the y and x coordinates from touch to 172 | # lift), magnitudes, and the index of the main axis, which is the axis with 173 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 174 | # ending at (0.3, 0.5) has a main axis index of 1). 175 | drag_1_deltas = drag_lift_yx - drag_touch_yx 176 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 177 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 178 | 179 | # y axis 180 | if drag_1_main_axis == 0: 181 | if drag_1_deltas[0] < 0: 182 | scroll = "up" 183 | else: 184 | scroll = "down" 185 | elif drag_1_main_axis == 1: 186 | if drag_1_deltas[1] < 0: 187 | scroll = "left" 188 | else: 189 | scroll = "right" 190 | 191 | return scroll 192 | 193 | def fetch_episode(dataset_name, data_split, get_images, get_annotations, get_actions, output_dir): 194 | filenames = tf.io.gfile.glob(dataset_directories[dataset_name]) 195 | dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP').as_numpy_iterator() 196 | 197 | with open (data_split, "r") as rp: 198 | split_data = json.load(rp) 199 | train_data = split_data["train"] 200 | val_data = split_data["validation"] 201 | test_data = split_data["test"] 202 | print(f"train_data size: {len(train_data)}, val_data size: {len(val_data)}, test_data size: {len(test_data)}") 203 | print(train_data[0], val_data[0], test_data[0]) 204 | 205 | all_parsed_episode = { 206 | "train": [], 207 | "val": [], 208 | "test": [], 209 | } 210 | total_screens = { 211 | "train": 0, 212 | "val": 0, 213 | "test": 0, 214 | } 215 | 216 | episode = [] 217 | episode_id = None 218 | numi = 0 219 | for d in tqdm(dataset, total=490360): 220 | numi+=1 221 | if numi > 490360: 222 | break 223 | ex = tf.train.Example() 224 | ex.ParseFromString(d) 225 | ep_id = ex.features.feature['episode_id'].bytes_list.value[0].decode('utf-8') 226 | if '.' in ep_id: 227 | ep_id = ep_id.split(".")[0] 228 | # if (ep_id not in train_data) & (ep_id not in test_data): 229 | # continue 230 | if episode_id is None: 231 | episode_id = ep_id 232 | episode.append(ex) 233 | elif ep_id == episode_id: 234 | episode.append(ex) 235 | else: 236 | # save data 237 | # seed = random.uniform(0,1) 238 | # if seed >= 0.1: 239 | # episode_id = ep_id 240 | # episode = [ex] 241 | # else: 242 | try: 243 | # here is a bug: ep_id is the new episode, this should be episode_id 244 | output = parse_episode(episode, ep_id, get_images=get_images, get_annotations=get_annotations, get_actions=get_actions, output_dir = output_dir) 245 | except Exception as exc: 246 | print(exc) 247 | # bad data point; init a new episode 248 | episode_id = ep_id 249 | episode = [ex] 250 | 251 | if int(episode_id) in train_data: 252 | curr_split = "train" 253 | elif int(episode_id) in val_data: 254 | curr_split = "val" 255 | elif int(episode_id) in test_data: 256 | curr_split = "test" 257 | else: 258 | print("error episode") 259 | # print(all_parsed_episode) 260 | all_parsed_episode[curr_split].append({"episode_id":episode_id, "data":output}) 261 | total_screens[curr_split] += len(episode) 262 | # init a new episode 263 | episode_id = ep_id 264 | episode = [ex] 265 | # last episode 266 | if len(episode) > 0: 267 | # save data 268 | output = parse_episode(episode, ep_id, get_images=get_images, get_annotations=get_annotations, get_actions=get_actions, output_dir = output_dir) 269 | if episode_id in train_data: 270 | curr_split = "train" 271 | elif episode_id in val_data: 272 | curr_split = "val" 273 | elif episode_id in test_data: 274 | curr_split = "test" 275 | else: 276 | assert "error episode" 277 | 278 | all_parsed_episode[curr_split].append({"episode_id":episode_id, "data":output}) 279 | total_screens[curr_split] += len(episode) 280 | 281 | print(len(all_parsed_episode["train"]), total_screens["train"], len(all_parsed_episode["val"]), total_screens["val"], len(all_parsed_episode["test"]), total_screens["test"]) 282 | return all_parsed_episode 283 | 284 | def parse_args(): 285 | parser = argparse.ArgumentParser() 286 | parser.add_argument('--dataset', type=str, default='general') 287 | parser.add_argument("--split_file", type=str, default="dataset/general_texts_splits.json") 288 | parser.add_argument('--output_dir', type=str, default='dataset/t5/general_parsed_episode_t5_clip') 289 | # parser.add_argument('--get_images', action='store_true') 290 | # parser.add_argument('--get_annotations', action='store_true') 291 | # parser.add_argument('--get_actions', action='store_true') 292 | 293 | parser.add_argument('--get_images', default=True, action='store_true') 294 | parser.add_argument('--get_annotations', default=True, action='store_true') 295 | parser.add_argument('--get_actions', default=True, action='store_true') 296 | 297 | args = parser.parse_args() 298 | return args 299 | 300 | if __name__ == '__main__': 301 | 302 | args = parse_args() 303 | print('====Input Arguments====') 304 | print(json.dumps(vars(args), indent=2, sort_keys=False)) 305 | 306 | all_parsed_episode = fetch_episode(args.dataset, args.split_file, args.get_images, args.get_annotations, args.get_actions, args.output_dir) 307 | 308 | with open(f"{args.output_dir}_pre10_train.obj", "wb") as wp: 309 | pickle.dump(all_parsed_episode["train"],wp) 310 | with open(f"{args.output_dir}_pre10_val.obj", "wb") as wp: 311 | pickle.dump(all_parsed_episode["val"],wp) 312 | with open(f"{args.output_dir}_pre10_test.obj", "wb") as wp: 313 | pickle.dump(all_parsed_episode["test"],wp) 314 | 315 | # python fetch_dataset_for_t5_blipv2_google_pre10.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/google_apps_parsed_episode_owl_pre10" --dataset google_apps 316 | 317 | 318 | -------------------------------------------------------------------------------- /data/metagui.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1EaqgfOkJ9xrPlcRP2mTr-FK1eXU6oCC7?usp=drive_link 2 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/eval/action_matching.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import action_type as action_type_lib 6 | 7 | 8 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 9 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 10 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 11 | 12 | # Interval determining if an action is a tap or a swipe. 13 | _SWIPE_DISTANCE_THRESHOLD = 0.04 14 | 15 | 16 | def _yx_in_bounding_boxes( 17 | yx, bounding_boxes 18 | ): 19 | """Check if the (y,x) point is contained in each bounding box. 20 | 21 | Args: 22 | yx: The (y, x) coordinate in pixels of the point. 23 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 24 | represents a bounding box: (y_top_left, x_top_left, box_height, 25 | box_width). Note: containment is inclusive of the bounding box edges. 26 | 27 | Returns: 28 | is_inside: A 1D bool array where each element specifies if the point is 29 | contained within the respective box. 30 | """ 31 | y, x = yx 32 | 33 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 34 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 35 | top, left, height, width = [ 36 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 37 | ] 38 | 39 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 40 | bottom, right = top + height, left + width 41 | 42 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and( 43 | x >= left, x <= right) 44 | 45 | 46 | def _resize_annotation_bounding_boxes( 47 | annotation_positions, annotation_width_augment_fraction, 48 | annotation_height_augment_fraction): 49 | """Resize the bounding boxes by the given fractions. 50 | 51 | Args: 52 | annotation_positions: Array of shape (N, 4), where each row represents the 53 | (y, x, height, width) of the bounding boxes. 54 | annotation_width_augment_fraction: The fraction to augment the box widths, 55 | E.g., 1.4 == 240% total increase. 56 | annotation_height_augment_fraction: Same as described for width, but for box 57 | height. 58 | 59 | Returns: 60 | Resized bounding box. 61 | 62 | """ 63 | height_change = ( 64 | annotation_height_augment_fraction * annotation_positions[:, 2]) 65 | width_change = ( 66 | annotation_width_augment_fraction * annotation_positions[:, 3]) 67 | 68 | # Limit bounding box positions to the screen. 69 | resized_annotations = jnp.stack([ 70 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 71 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 72 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 73 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 74 | ], 75 | axis=1) 76 | return resized_annotations 77 | 78 | 79 | def is_tap_action(normalized_start_yx, 80 | normalized_end_yx): 81 | distance = jnp.linalg.norm( 82 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 83 | return distance <= _SWIPE_DISTANCE_THRESHOLD 84 | 85 | 86 | def _is_non_dual_point_action(action_type): 87 | return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT) 88 | 89 | 90 | def _check_tap_actions_match( 91 | tap_1_yx, 92 | tap_2_yx, 93 | annotation_positions, 94 | matching_tap_distance_threshold_screen_percentage, 95 | annotation_width_augment_fraction, 96 | annotation_height_augment_fraction, 97 | ): 98 | """Determines if two tap actions are the same.""" 99 | resized_annotation_positions = _resize_annotation_bounding_boxes( 100 | annotation_positions, 101 | annotation_width_augment_fraction, 102 | annotation_height_augment_fraction, 103 | ) 104 | 105 | # Check if the ground truth tap action falls in an annotation's bounding box. 106 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 107 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 108 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 109 | 110 | # If the ground-truth tap action falls outside any of the annotation 111 | # bounding boxes or one of the actions is inside a bounding box and the other 112 | # is outside bounding box or vice versa, compare the points using Euclidean 113 | # distance. 114 | within_threshold = ( 115 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 116 | <= matching_tap_distance_threshold_screen_percentage 117 | ) 118 | return jnp.logical_or(both_in_box, within_threshold) 119 | 120 | 121 | def _check_drag_actions_match( 122 | drag_1_touch_yx, 123 | drag_1_lift_yx, 124 | drag_2_touch_yx, 125 | drag_2_lift_yx, 126 | ): 127 | """Determines if two drag actions are the same.""" 128 | # Store drag deltas (the change in the y and x coordinates from touch to 129 | # lift), magnitudes, and the index of the main axis, which is the axis with 130 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 131 | # ending at (0.3, 0.5) has a main axis index of 1). 132 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 133 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 134 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 135 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 136 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 137 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 138 | 139 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) 140 | 141 | 142 | def check_actions_match( 143 | action_1_touch_yx, 144 | action_1_lift_yx, 145 | action_1_action_type, 146 | action_2_touch_yx, 147 | action_2_lift_yx, 148 | action_2_action_type, 149 | annotation_positions, 150 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 151 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 152 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 153 | ): 154 | """Determines if two actions are considered to be the same. 155 | 156 | Two actions being "the same" is defined here as two actions that would result 157 | in a similar screen state. 158 | 159 | Args: 160 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 161 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 162 | action_1_action_type: The action type of the first action. 163 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 164 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 165 | action_2_action_type: The action type of the second action. 166 | annotation_positions: The positions of the UI annotations for the screen. It 167 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 168 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 169 | containment is inclusive of the bounding box edges. 170 | tap_distance_threshold: The threshold that determines if two taps result in 171 | a matching screen state if they don't fall the same bounding boxes. 172 | annotation_width_augment_fraction: The fraction to increase the width of the 173 | bounding box by. 174 | annotation_height_augment_fraction: The fraction to increase the height of 175 | of the bounding box by. 176 | 177 | Returns: 178 | A boolean representing whether the two given actions are the same or not. 179 | """ 180 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 181 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 182 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 183 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 184 | 185 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 186 | # because if that is the case, only the actions' types need to be compared. 187 | has_non_dual_point_action = jnp.logical_or( 188 | _is_non_dual_point_action(action_1_action_type), 189 | _is_non_dual_point_action(action_2_action_type), 190 | ) 191 | 192 | different_dual_point_types = jnp.logical_xor( 193 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 194 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 195 | ) 196 | 197 | is_tap = jnp.logical_and( 198 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 199 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 200 | ) 201 | 202 | taps_match = _check_tap_actions_match( 203 | action_1_touch_yx, 204 | action_2_touch_yx, 205 | annotation_positions, 206 | tap_distance_threshold, 207 | annotation_width_augment_fraction, 208 | annotation_height_augment_fraction, 209 | ) 210 | 211 | taps_match = jnp.logical_and(is_tap, taps_match) 212 | 213 | drags_match = _check_drag_actions_match( 214 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 215 | ) 216 | drags_match = jnp.where(is_tap, False, drags_match) 217 | 218 | return jnp.where( 219 | has_non_dual_point_action, 220 | jnp.equal(action_1_action_type, action_2_action_type), 221 | jnp.where( 222 | different_dual_point_types, 223 | False, 224 | jnp.logical_or(taps_match, drags_match), 225 | ), 226 | ) -------------------------------------------------------------------------------- /llava/eval/action_type.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class ActionType(enum.IntEnum): 4 | 5 | # Placeholders for unused enum values 6 | UNUSED_0 = 0 7 | UNUSED_1 = 1 8 | UNUSED_2 = 2 9 | UNUSED_8 = 8 10 | UNUSED_9 = 9 11 | 12 | ########### Agent actions ########### 13 | 14 | # A type action that sends text to the emulator. Note that this simply sends 15 | # text and does not perform any clicks for element focus or enter presses for 16 | # submitting text. 17 | TYPE = 3 18 | 19 | # The dual point action used to represent all gestures. 20 | DUAL_POINT = 4 21 | 22 | # These actions differentiate pressing the home and back button from touches. 23 | # They represent explicit presses of back and home performed using ADB. 24 | PRESS_BACK = 5 25 | PRESS_HOME = 6 26 | 27 | # An action representing that ADB command for hitting enter was performed. 28 | PRESS_ENTER = 7 29 | 30 | ########### Episode status actions ########### 31 | 32 | # An action used to indicate the desired task has been completed and resets 33 | # the environment. This action should also be used in the case that the task 34 | # has already been completed and there is nothing to do. 35 | # e.g. The task is to turn on the Wi-Fi when it is already on 36 | STATUS_TASK_COMPLETE = 10 37 | 38 | # An action used to indicate that desired task is impossible to complete and 39 | # resets the environment. This can be a result of many different things 40 | # including UI changes, Android version differences, etc. 41 | STATUS_TASK_IMPOSSIBLE = 11 -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | predictions = [json.loads(line) for line in open(args.result_file)] 45 | predictions = {pred['question_id']: pred for pred in predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | results = {'correct': [], 'incorrect': []} 49 | sqa_results = {} 50 | sqa_results['acc'] = None 51 | sqa_results['correct'] = None 52 | sqa_results['count'] = None 53 | sqa_results['results'] = {} 54 | sqa_results['outputs'] = {} 55 | 56 | for prob_id, prob in split_problems.items(): 57 | if prob_id not in predictions: 58 | continue 59 | pred = predictions[prob_id] 60 | pred_text = pred['text'] 61 | 62 | pattern = re.compile(r'The answer is ([A-Z]).') 63 | res = pattern.findall(pred_text) 64 | if len(res) == 1: 65 | answer = res[0] # 'A', 'B', ... 66 | else: 67 | answer = "FAILED" 68 | 69 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 70 | 71 | analysis = { 72 | 'question_id': prob_id, 73 | 'parsed_ans': answer, 74 | 'ground_truth': args.options[prob['answer']], 75 | 'question': pred['prompt'], 76 | 'pred': pred_text, 77 | 'is_multimodal': '' in pred['prompt'], 78 | } 79 | 80 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 81 | sqa_results['outputs'][prob_id] = pred_text 82 | 83 | if pred_idx == prob['answer']: 84 | results['correct'].append(analysis) 85 | else: 86 | results['incorrect'].append(analysis) 87 | 88 | correct = len(results['correct']) 89 | total = len(results['correct']) + len(results['incorrect']) 90 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 91 | 92 | sqa_results['acc'] = correct / total * 100 93 | sqa_results['correct'] = correct 94 | sqa_results['count'] = total 95 | 96 | with open(args.output_file, 'w') as f: 97 | json.dump(results, f, indent=2) 98 | with open(args.output_result, 'w') as f: 99 | json.dump(sqa_results, f, indent=2) 100 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_aitw_1102.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | from utils_data_for_owl import load_for_owl 14 | 15 | from PIL import Image 16 | import math 17 | import action_type 18 | import action_matching 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 21 | 22 | def split_list(lst, n): 23 | """Split a list into n (roughly) equal-sized chunks""" 24 | chunk_size = math.ceil(len(lst) / n) # integer division 25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 26 | 27 | 28 | def get_chunk(lst, n, k): 29 | chunks = split_list(lst, n) 30 | return chunks[k] 31 | 32 | 33 | def eval_model(args): 34 | # Model 35 | disable_torch_init() 36 | model_path = os.path.expanduser(args.model_path) 37 | print(model_path) 38 | model_name = get_model_name_from_path(model_path) 39 | print(model_name) 40 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 41 | 42 | data = load_for_owl('.', 'test') 43 | # questions = data 44 | questions = json.load(open(os.path.expanduser(args.question_file), "r"))[:1000] 45 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 46 | answers_file = os.path.expanduser(args.answers_file) 47 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 48 | ans_file = open(answers_file, "w") 49 | preds = [] 50 | targets = [] 51 | metrics = {} 52 | partial_correct = 0 53 | text_correct = 0 54 | type_correct = 0 55 | reference_test_positions = [] 56 | for i, line in enumerate(tqdm(questions)): 57 | targets.append(data[i]['target_text']) 58 | reference_test_positions.append(data[i]['anno_pos']) 59 | print('assert: ', targets[-1], line['conversations'][-1]['value']) 60 | idx = line["id"] 61 | # question = line['conversations'][0] 62 | # qs = question['value'].replace('', '').strip() 63 | # cur_prompt = qs 64 | # print("original dataset: ", line['conversations']) 65 | if 'image' in line: 66 | image_file = line["image"] 67 | if type(image_file) == list: 68 | image = [ Image.open(os.path.join(args.image_folder, f)) for f in image_file ] 69 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] 70 | else: 71 | image = Image.open(os.path.join(args.image_folder, image_file)) 72 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 73 | images = image_tensor.unsqueeze(0).half().cuda() 74 | # if getattr(model.config, 'mm_use_im_start_end', False): 75 | # qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 76 | # else: 77 | # qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 78 | # qs = question['value'].strip() 79 | # cur_prompt = '' + '\n' + cur_prompt 80 | else: 81 | images = None 82 | conv = conv_templates[args.conv_mode].copy() 83 | for j, question in enumerate(line['conversations'][:-1]): 84 | conv.append_message(conv.roles[j%2], question['value']) 85 | # conv.append_message(conv.roles[j%2], None) 86 | conv.append_message(conv.roles[len(line['conversations'])%2], None) 87 | prompt = conv.get_prompt() 88 | # print(prompt) 89 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 90 | 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None 94 | 95 | with torch.inference_mode(): 96 | output_ids = model.generate( 97 | input_ids, 98 | images=images, 99 | do_sample=True, 100 | temperature=0.2, 101 | max_new_tokens=1024, 102 | use_cache=True, 103 | stopping_criteria=stopping_criteria, 104 | ) 105 | 106 | input_token_len = input_ids.shape[1] 107 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 108 | if n_diff_input_output > 0: 109 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 110 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 111 | outputs = outputs.strip() 112 | if outputs.endswith(stop_str): 113 | outputs = outputs[:-len(stop_str)] 114 | outputs = outputs.strip() 115 | print(outputs) 116 | 117 | # prompt for answer 118 | if args.answer_prompter: 119 | outputs_reasoning = outputs 120 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 121 | 122 | with torch.inference_mode(): 123 | output_ids = model.generate( 124 | input_ids, 125 | images=images, 126 | do_sample=True, 127 | temperature=0.2, 128 | max_new_tokens=64, 129 | use_cache=True, 130 | stopping_criteria=[stopping_criteria]) 131 | 132 | input_token_len = input_ids.shape[1] 133 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 134 | if n_diff_input_output > 0: 135 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 136 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 137 | outputs = outputs.strip() 138 | if outputs.endswith(stop_str): 139 | outputs = outputs[:-len(stop_str)] 140 | outputs = outputs.strip() 141 | outputs = outputs_reasoning + '\n The answer is ' + outputs 142 | 143 | preds.append(outputs) 144 | ans_id = shortuuid.uuid() 145 | ans_file.write(json.dumps({"question_id": idx, 146 | "prompt": prompt, 147 | "text": outputs, 148 | "answer_id": ans_id, 149 | "model_id": model_name, 150 | "metadata": {}}) + "\n") 151 | ans_file.flush() 152 | ans_file.close() 153 | 154 | print('file closed') 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | # parser.add_argument("--model-path", type=str, default="liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3") 159 | parser.add_argument("--model-path", type=str, default="/data/maxb/tag/LLaVA/checkpoints/aitwfull_4histwimg_dialog-llama-2-7b-chat-finetune/llava-checkpoint-5000") 160 | parser.add_argument("--model-base", type=str, default=None) 161 | parser.add_argument("--image-folder", type=str, default="") 162 | parser.add_argument("--question-file", type=str, default="/data/maxb/tag/LLaVA/scripts/aitw_data/histwimg/llava_aitwfull_debug_dialog_4histwimg_test_QCM-LEA.json") 163 | parser.add_argument("--answers-file", type=str, default="./debug_dialog_answer_test.jsonl") 164 | parser.add_argument("--conv-mode", type=str, default="llava_llama_2") 165 | parser.add_argument("--num-chunks", type=int, default=1) 166 | parser.add_argument("--chunk-idx", type=int, default=0) 167 | parser.add_argument("--answer-prompter", action="store_true") 168 | # parser.add_argument("--prd_output_path", type=str, default='.') 169 | # parser.add_argument("--eval_name", type=str, default=None) 170 | # parser.add_argument("--eval_data", type=str, default='/data/maxb/mmcot2/dataset/owl/general_parsed_episode_owl_test.obj') 171 | # parser.add_argument('--save_path', type=str, default=None) 172 | args = parser.parse_args() 173 | 174 | eval_model(args) 175 | 176 | # CUDA_VISIBLE_DEVICES=3 python model_aitw.py --question-file ../../scripts/llava_aitwv2_train_QCM-LEA.json 177 | # CUDA_VISIBLE_DEVICES=4,5,6,7 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12-llama-2-7b-chat-finetune --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/llava_aitwfullhist12_test_QCM-LEA.json --answers-file ./res_out/llava_try-llama-2-7b-chat-finetune.jsonl 178 | # CUDA_VISIBLE_DEVICES=4,5,6,7 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12-llama-2-7b-chat-finetune/checkpoint-5000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/llava_aitwfullhist12_test_QCM-LEA.json --answers-file ./res_out/5k_llava_try-llama-2-7b-chat-finetune.jsonl 179 | # CUDA_VISIBLE_DEVICES=4,5,6,7 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/llava_aitwfullhist12_test_QCM-LEA.json --answers-file ./res_out/10k_llava_try-llama-2-7b-chat-finetune.jsonl 180 | # CUDA_VISIBLE_DEVICES=4,5,6,7 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist20-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/llava_aitwfullhist20_test_QCM-LEA.json --answers-file ./res_out/20hist_10k_llava_try-llama-2-7b-chat-finetune.jsonl 181 | # CUDA_VISIBLE_DEVICES=2,3,1,0 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_centertype-llama-2-7b-chat-finetune --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_centertype_test_QCM-LEA.json --answers-file ./res_out/centertype_12k_llava_try-llama-2-7b-chat-finetune.jsonl 182 | # CUDA_VISIBLE_DEVICES=2,3,1,0 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_centertype-llama-2-7b-chat-finetune/checkpoint-5000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_centertype_test_QCM-LEA.json --answers-file ./res_out/centertype_5k_llava_try-llama-2-7b-chat-finetune.jsonl 183 | # CUDA_VISIBLE_DEVICES=4,5,6,7 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_centertype-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_centerfull_test_QCM-LEA.json --answers-file ./res_out/centerfull_10k_llava_try-llama-2-7b-chat-finetune.jsonl 184 | 185 | # CUDA_VISIBLE_DEVICES=4,5 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_ret-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_retrieve_test_QCM-LEA.json --answers-file ./res_out/retrieve_10k_v2_llava_try-llama-2-7b-chat-finetune.jsonl 186 | # CUDA_VISIBLE_DEVICES=4,5 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_ret-llama-2-7b-chat-finetune/checkpoint-5000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_retrieve_test_QCM-LEA.json --answers-file ./res_out/retrieve_5k_llava_try-llama-2-7b-chat-finetune.jsonl 187 | # CUDA_VISIBLE_DEVICES=4,5 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_highlevel-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_highlevel_test_QCM-LEA.json --answers-file ./res_out/highlevel_10k_llava_try-llama-2-7b-chat-finetune.jsonl 188 | 189 | # CUDA_VISIBLE_DEVICES=3 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/debug/llava-checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/histwimg/llava_aitwfull_4histwimg_test_QCM-LEA.json --answers-file ./res_out/4histwimg_llava_try-llama-2-7b-chat-finetune.jsonl 190 | 191 | # CUDA_VISIBLE_DEVICES=3 python model_aitw.py --model-path /data/maxb/tag/LLaVA/checkpoints/llava_hist12_ret-llama-2-7b-chat-finetune/checkpoint-10000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/memory/llava_aitwfull_hist12_retrieve_test_QCM-LEA.json --answers-file ./res_out/retrieve_10k_v2_llava_try-llama-2-7b-chat-finetune.jsonl 192 | 193 | # CUDA_VISIBLE_DEVICES=7 python model_aitw_1102.py --model-path /data/maxb/tag/LLaVA/checkpoints/aitwfull_4histwimg_dialog-llama-2-7b-chat-finetune/llava-checkpoint-5000 --question-file /data/maxb/tag/LLaVA/scripts/aitw_data/histwimg/llava_aitwfull_dialog_4histwimg_test_QCM-LEA.json --answers-file ./res_out/dialog_4histwimg_llava_try-llama-2-7b-chat-finetune.jsonl -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions[:10]): 42 | idx = line["question_id"] 43 | image_file = 'COCO_val2014_'+line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | qs = question['value'].replace('', '').strip() 45 | cur_prompt = qs 46 | 47 | if 'image' in line: 48 | image_file = line["image"] 49 | image = Image.open(os.path.join(args.image_folder, image_file)) 50 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 51 | images = image_tensor.unsqueeze(0).half().cuda() 52 | if getattr(model.config, 'mm_use_im_start_end', False): 53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 54 | else: 55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 56 | cur_prompt = '' + '\n' + cur_prompt 57 | else: 58 | images = None 59 | 60 | conv = conv_templates[args.conv_mode].copy() 61 | conv.append_message(conv.roles[0], qs) 62 | conv.append_message(conv.roles[1], None) 63 | prompt = conv.get_prompt() 64 | 65 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 66 | 67 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 68 | keywords = [stop_str] 69 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None 70 | 71 | with torch.inference_mode(): 72 | output_ids = model.generate( 73 | input_ids, 74 | images=images, 75 | do_sample=True, 76 | temperature=0.2, 77 | max_new_tokens=1024, 78 | use_cache=True, 79 | stopping_criteria=stopping_criteria, 80 | ) 81 | 82 | input_token_len = input_ids.shape[1] 83 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 84 | if n_diff_input_output > 0: 85 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 86 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 87 | outputs = outputs.strip() 88 | if outputs.endswith(stop_str): 89 | outputs = outputs[:-len(stop_str)] 90 | outputs = outputs.strip() 91 | 92 | # prompt for answer 93 | if args.answer_prompter: 94 | outputs_reasoning = outputs 95 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=images, 101 | do_sample=True, 102 | temperature=0.2, 103 | max_new_tokens=64, 104 | use_cache=True, 105 | stopping_criteria=[stopping_criteria]) 106 | 107 | input_token_len = input_ids.shape[1] 108 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 109 | if n_diff_input_output > 0: 110 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 111 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 112 | outputs = outputs.strip() 113 | if outputs.endswith(stop_str): 114 | outputs = outputs[:-len(stop_str)] 115 | outputs = outputs.strip() 116 | outputs = outputs_reasoning + '\n The answer is ' + outputs 117 | 118 | ans_id = shortuuid.uuid() 119 | ans_file.write(json.dumps({"question_id": idx, 120 | "prompt": cur_prompt, 121 | "text": outputs, 122 | "answer_id": ans_id, 123 | "model_id": model_name, 124 | "metadata": {}}) + "\n") 125 | ans_file.flush() 126 | ans_file.close() 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 131 | parser.add_argument("--model-base", type=str, default=None) 132 | parser.add_argument("--image-folder", type=str, default="") 133 | parser.add_argument("--question-file", type=str, default="tables/question.json") 134 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 135 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 136 | parser.add_argument("--num-chunks", type=int, default=1) 137 | parser.add_argument("--chunk-idx", type=int, default=0) 138 | parser.add_argument("--answer-prompter", action="store_true") 139 | args = parser.parse_args() 140 | 141 | eval_model(args) 142 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if 'llama-2' in model_name.lower(): 40 | conv_mode = "llava_llama_2" 41 | elif "v1" in model_name.lower(): 42 | conv_mode = "llava_v1" 43 | elif "mpt" in model_name.lower(): 44 | conv_mode = "mpt" 45 | else: 46 | conv_mode = "llava_v0" 47 | 48 | if args.conv_mode is not None and conv_mode != args.conv_mode: 49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 50 | else: 51 | args.conv_mode = conv_mode 52 | 53 | conv = conv_templates[args.conv_mode].copy() 54 | conv.append_message(conv.roles[0], qs) 55 | conv.append_message(conv.roles[1], None) 56 | prompt = conv.get_prompt() 57 | 58 | image = load_image(args.image_file) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 60 | 61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 62 | 63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 64 | keywords = [stop_str] 65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor, 71 | do_sample=True, 72 | temperature=0.2, 73 | max_new_tokens=1024, 74 | use_cache=True, 75 | stopping_criteria=[stopping_criteria]) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | print(outputs) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 91 | parser.add_argument("--model-base", type=str, default=None) 92 | parser.add_argument("--image-file", type=str, required=True) 93 | parser.add_argument("--query", type=str, required=True) 94 | parser.add_argument("--conv-mode", type=str, default=None) 95 | args = parser.parse_args() 96 | 97 | eval_model(args) 98 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-f', '--files', nargs='*', default=None) 13 | parser.add_argument('-i', '--ignore', nargs='*', default=None) 14 | return parser.parse_args() 15 | 16 | 17 | if __name__ == '__main__': 18 | args = parse_args() 19 | 20 | if args.ignore is not None: 21 | args.ignore = [int(x) for x in args.ignore] 22 | 23 | if args.files is not None and len(args.files) > 0: 24 | review_files = args.files 25 | else: 26 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] 27 | 28 | for review_file in sorted(review_files): 29 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 30 | scores = defaultdict(list) 31 | print(config) 32 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 33 | for review_str in f: 34 | review = json.loads(review_str) 35 | if args.ignore is not None and review['question_id'] in args.ignore: 36 | continue 37 | if 'category' in review: 38 | scores[review['category']].append(review['tuple']) 39 | scores['all'].append(review['tuple']) 40 | else: 41 | if 'tuple' in review: 42 | scores['all'].append(review['tuple']) 43 | else: 44 | scores['all'].append(review['score']) 45 | for k, v in sorted(scores.items()): 46 | stats = np.asarray(v).mean(0).tolist() 47 | stats = [round(x, 3) for x in stats] 48 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 49 | print(k, round(stats[1]/stats[0]*100, 1)) 50 | print('=================================') 51 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def process_images(images, image_processor, model_cfg): 15 | return image_processor(images, return_tensors='pt')['pixel_values'] 16 | 17 | 18 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 19 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 20 | 21 | def insert_separator(X, sep): 22 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 23 | 24 | input_ids = [] 25 | offset = 0 26 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 27 | offset = 1 28 | input_ids.append(prompt_chunks[0][0]) 29 | 30 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 31 | input_ids.extend(x[offset:]) 32 | 33 | if return_tensors is not None: 34 | if return_tensors == 'pt': 35 | return torch.tensor(input_ids, dtype=torch.long) 36 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 37 | return input_ids 38 | 39 | 40 | def get_model_name_from_path(model_path): 41 | model_path = model_path.strip("/") 42 | model_paths = model_path.split("/") 43 | if model_paths[-1].startswith('checkpoint-'): 44 | return model_paths[-2] + "_" + model_paths[-1] 45 | else: 46 | return model_paths[-1] 47 | 48 | 49 | 50 | 51 | class KeywordsStoppingCriteria(StoppingCriteria): 52 | def __init__(self, keywords, tokenizer, input_ids): 53 | self.keywords = keywords 54 | self.keyword_ids = [] 55 | for keyword in keywords: 56 | cur_keyword_ids = tokenizer(keyword).input_ids 57 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 58 | cur_keyword_ids = cur_keyword_ids[1:] 59 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 60 | self.tokenizer = tokenizer 61 | self.start_len = input_ids.shape[1] 62 | 63 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 64 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 65 | offset = min(output_ids.shape[1] - self.start_len, 3) 66 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 67 | for keyword_id in self.keyword_ids: 68 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 69 | return True 70 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 71 | for keyword in self.keywords: 72 | if keyword in outputs: 73 | return True 74 | return False 75 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/__pycache__/llava_arch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/__pycache__/llava_arch.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): 27 | kwargs = {"device_map": device_map} 28 | 29 | if load_8bit: 30 | kwargs['load_in_8bit'] = True 31 | elif load_4bit: 32 | kwargs['load_in_4bit'] = True 33 | kwargs['quantization_config'] = BitsAndBytesConfig( 34 | load_in_4bit=True, 35 | bnb_4bit_compute_dtype=torch.float16, 36 | bnb_4bit_use_double_quant=True, 37 | bnb_4bit_quant_type='nf4' 38 | ) 39 | else: 40 | kwargs['torch_dtype'] = torch.float16 41 | 42 | if 'llava' in model_name.lower(): 43 | # Load LLaVA model 44 | if 'lora' in model_name.lower() and model_base is None: 45 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 46 | if 'lora' in model_name.lower() and model_base is not None: 47 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 48 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 49 | print('Loading LLaVA from base model...') 50 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 51 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 52 | if model.lm_head.weight.shape[0] != token_num: 53 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 54 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 55 | 56 | print('Loading additional LLaVA weights...') 57 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 58 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 59 | else: 60 | # this is probably from HF Hub 61 | from huggingface_hub import hf_hub_download 62 | def load_from_hf(repo_id, filename, subfolder=None): 63 | cache_file = hf_hub_download( 64 | repo_id=repo_id, 65 | filename=filename, 66 | subfolder=subfolder) 67 | return torch.load(cache_file, map_location='cpu') 68 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 69 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 70 | if any(k.startswith('model.model.') for k in non_lora_trainables): 71 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 72 | model.load_state_dict(non_lora_trainables, strict=False) 73 | 74 | from peft import PeftModel 75 | print('Loading LoRA weights...') 76 | model = PeftModel.from_pretrained(model, model_path) 77 | print('Merging LoRA weights...') 78 | model = model.merge_and_unload() 79 | print('Model is loaded...') 80 | elif model_base is not None: 81 | # this may be mm projector only 82 | print('Loading LLaVA from base model...') 83 | if 'mpt' in model_name.lower(): 84 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 85 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 86 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 87 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 88 | model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 89 | else: 90 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 91 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 92 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 93 | 94 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 95 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 96 | model.load_state_dict(mm_projector_weights, strict=False) 97 | else: 98 | if 'mpt' in model_name.lower(): 99 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 100 | model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 101 | else: 102 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 103 | model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 104 | else: 105 | # Load language model 106 | if model_base is not None: 107 | # PEFT model 108 | from peft import PeftModel 109 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 110 | model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") 111 | print(f"Loading LoRA weights from {model_path}") 112 | model = PeftModel.from_pretrained(model, model_path) 113 | print(f"Merging weights") 114 | model = model.merge_and_unload() 115 | print('Convert to FP16...') 116 | model.to(torch.float16) 117 | else: 118 | use_fast = False 119 | if 'mpt' in model_name.lower(): 120 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 121 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 122 | else: 123 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 124 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 125 | 126 | image_processor = None 127 | 128 | if 'llava' in model_name.lower(): 129 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 130 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 131 | if mm_use_im_patch_token: 132 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 133 | if mm_use_im_start_end: 134 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 135 | model.resize_token_embeddings(len(tokenizer)) 136 | 137 | vision_tower = model.get_vision_tower() 138 | if not vision_tower.is_loaded: 139 | vision_tower.load_model() 140 | vision_tower.to(device='cuda', dtype=torch.float16) 141 | image_processor = vision_tower.image_processor 142 | 143 | if hasattr(model.config, "max_sequence_length"): 144 | context_len = model.config.max_sequence_length 145 | else: 146 | context_len = 2048 147 | 148 | return tokenizer, model, image_processor, context_len 149 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "images": kwargs.get("images", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("llava", LlavaConfig) 140 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/configuration_mpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/configuration_mpt.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/custom_embedding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/custom_embedding.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/flash_attn_triton.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/flash_attn_triton.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/hf_prefixlm_converter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/hf_prefixlm_converter.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/meta_init_context.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/meta_init_context.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/modeling_mpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/modeling_mpt.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/norm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/norm.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/__pycache__/param_init_fns.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/language_model/mpt/__pycache__/param_init_fns.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/language_model/mpt/param_init_fns.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections.abc import Sequence 4 | from functools import partial 5 | from typing import Optional, Tuple, Union 6 | import torch 7 | from torch import nn 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): 11 | del kwargs 12 | if verbose > 1: 13 | warnings.warn(f"Initializing network using module's reset_parameters attribute") 14 | if hasattr(module, 'reset_parameters'): 15 | module.reset_parameters() 16 | 17 | def fused_init_helper_(module: nn.Module, init_fn_): 18 | _fused = getattr(module, '_fused', None) 19 | if _fused is None: 20 | raise RuntimeError(f'Internal logic error') 21 | (dim, splits) = _fused 22 | splits = (0, *splits, module.weight.size(dim)) 23 | for (s, e) in zip(splits[:-1], splits[1:]): 24 | slice_indices = [slice(None)] * module.weight.ndim 25 | slice_indices[dim] = slice(s, e) 26 | init_fn_(module.weight[slice_indices]) 27 | 28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 29 | del kwargs 30 | if verbose > 1: 31 | warnings.warn(f'If model has bias parameters they are initialized to 0.') 32 | init_div_is_residual = init_div_is_residual 33 | if init_div_is_residual is False: 34 | div_is_residual = 1.0 35 | elif init_div_is_residual is True: 36 | div_is_residual = math.sqrt(2 * n_layers) 37 | elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): 38 | div_is_residual = init_div_is_residual 39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): 40 | div_is_residual = float(init_div_is_residual) 41 | else: 42 | div_is_residual = 1.0 43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') 44 | if init_div_is_residual is not False: 45 | if verbose > 1: 46 | warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') 47 | if isinstance(module, nn.Linear): 48 | if hasattr(module, '_fused'): 49 | fused_init_helper_(module, init_fn_) 50 | else: 51 | init_fn_(module.weight) 52 | if module.bias is not None: 53 | torch.nn.init.zeros_(module.bias) 54 | if init_div_is_residual is not False and getattr(module, '_is_residual', False): 55 | with torch.no_grad(): 56 | module.weight.div_(div_is_residual) 57 | elif isinstance(module, nn.Embedding): 58 | if emb_init_std is not None: 59 | std = emb_init_std 60 | if std == 0: 61 | warnings.warn(f'Embedding layer initialized to 0.') 62 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) 63 | if verbose > 1: 64 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') 65 | elif emb_init_uniform_lim is not None: 66 | lim = emb_init_uniform_lim 67 | if isinstance(lim, Sequence): 68 | if len(lim) > 2: 69 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') 70 | if lim[0] == lim[1]: 71 | warnings.warn(f'Embedding layer initialized to {lim[0]}.') 72 | else: 73 | if lim == 0: 74 | warnings.warn(f'Embedding layer initialized to 0.') 75 | lim = [-lim, lim] 76 | (a, b) = lim 77 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) 78 | if verbose > 1: 79 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') 80 | else: 81 | emb_init_fn_ = init_fn_ 82 | emb_init_fn_(module.weight) 83 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): 84 | if verbose > 1: 85 | warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') 86 | if hasattr(module, 'weight') and module.weight is not None: 87 | torch.nn.init.ones_(module.weight) 88 | if hasattr(module, 'bias') and module.bias is not None: 89 | torch.nn.init.zeros_(module.bias) 90 | elif isinstance(module, nn.MultiheadAttention): 91 | if module._qkv_same_embed_dim: 92 | assert module.in_proj_weight is not None 93 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) 94 | assert d_model is not None 95 | _d = d_model 96 | splits = (0, _d, 2 * _d, 3 * _d) 97 | for (s, e) in zip(splits[:-1], splits[1:]): 98 | init_fn_(module.in_proj_weight[s:e]) 99 | else: 100 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) 101 | assert module.in_proj_weight is None 102 | init_fn_(module.q_proj_weight) 103 | init_fn_(module.k_proj_weight) 104 | init_fn_(module.v_proj_weight) 105 | if module.in_proj_bias is not None: 106 | torch.nn.init.zeros_(module.in_proj_bias) 107 | if module.bias_k is not None: 108 | torch.nn.init.zeros_(module.bias_k) 109 | if module.bias_v is not None: 110 | torch.nn.init.zeros_(module.bias_v) 111 | init_fn_(module.out_proj.weight) 112 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): 113 | with torch.no_grad(): 114 | module.out_proj.weight.div_(div_is_residual) 115 | if module.out_proj.bias is not None: 116 | torch.nn.init.zeros_(module.out_proj.bias) 117 | else: 118 | for _ in module.parameters(recurse=False): 119 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') 120 | 121 | def _normal_init_(std, mean=0.0): 122 | return partial(torch.nn.init.normal_, mean=mean, std=std) 123 | 124 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 125 | del kwargs 126 | init_fn_ = _normal_init_(std=std) 127 | if verbose > 1: 128 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') 129 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 130 | 131 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 132 | del kwargs 133 | if init_std is None: 134 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") 135 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 136 | 137 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 138 | del kwargs 139 | std = math.sqrt(2 / (5 * d_model)) 140 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 141 | 142 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 143 | """From section 2.3.1 of GPT-NeoX-20B: 144 | 145 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) 146 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 147 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py 148 | """ 149 | del kwargs 150 | residual_div = n_layers / math.sqrt(10) 151 | if verbose > 1: 152 | warnings.warn(f'setting init_div_is_residual to {residual_div}') 153 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 154 | 155 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 156 | del kwargs 157 | if verbose > 1: 158 | warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 159 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 160 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 161 | 162 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 163 | del kwargs 164 | if verbose > 1: 165 | warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 166 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 167 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 168 | 169 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 170 | del kwargs 171 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) 172 | if verbose > 1: 173 | warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') 174 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 175 | 176 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 177 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) 178 | if verbose > 1: 179 | warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') 180 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 181 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | print(vision_tower) 8 | is_absolute_path_exists = os.path.exists(vision_tower) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 10 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 11 | 12 | raise ValueError(f'Unknown vision tower: {vision_tower}') 13 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc -------------------------------------------------------------------------------- /llava/train/__pycache__/llava_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/train/__pycache__/llava_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /llava/train/__pycache__/train.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xbmxb/CoCo-Agent/c53b67908b745c499a363c8b7e95fd6328b61238/llava/train/__pycache__/train.cpython-310.pyc -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 14 | except ImportError: 15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 16 | from flash_attn.bert_padding import unpad_input, pad_input 17 | 18 | 19 | def forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 25 | output_attentions: bool = False, 26 | use_cache: bool = False, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) 49 | # [bsz, q_len, nh, hd] 50 | # [bsz, nh, q_len, hd] 51 | 52 | kv_seq_len = key_states.shape[-2] 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb( 57 | query_states, key_states, cos, sin, position_ids 58 | ) 59 | # [bsz, nh, t, hd] 60 | assert not output_attentions, "output_attentions is not supported" 61 | assert not use_cache, "use_cache is not supported" 62 | 63 | # Flash attention codes from 64 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 65 | 66 | # transform the data into the format required by flash attention 67 | qkv = torch.stack( 68 | [query_states, key_states, value_states], dim=2 69 | ) # [bsz, nh, 3, q_len, hd] 70 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 71 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 72 | # the attention_mask should be the same as the key_padding_mask 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 77 | max_s = q_len 78 | cu_q_lens = torch.arange( 79 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 80 | ) 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 85 | else: 86 | nheads = qkv.shape[-2] 87 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 88 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 89 | x_unpad = rearrange( 90 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 91 | ) 92 | output_unpad = flash_attn_unpadded_qkvpacked_func( 93 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 94 | ) 95 | output = rearrange( 96 | pad_input( 97 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 98 | ), 99 | "b s (h d) -> b s h d", 100 | h=nheads, 101 | ) 102 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 103 | 104 | 105 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 106 | # requires the attention mask to be the same as the key_padding_mask 107 | def _prepare_decoder_attention_mask( 108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 109 | ): 110 | # [bsz, seq_len] 111 | return attention_mask 112 | 113 | 114 | def replace_llama_attn_with_flash_attn(): 115 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 116 | if cuda_major < 8: 117 | logging.warning( 118 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 119 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 120 | ) 121 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 122 | _prepare_decoder_attention_mask 123 | ) 124 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 125 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import Trainer 5 | from typing import Optional 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | if hasattr(param, "ds_id"): 12 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 13 | if not ignore_status: 14 | print(name, 'no ignore status') 15 | with zero.GatheredParameters([param]): 16 | param = param.data.detach().cpu().clone() 17 | else: 18 | param = param.detach().cpu().clone() 19 | return param 20 | 21 | 22 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 23 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 24 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 25 | return to_return 26 | 27 | 28 | class LLaVATrainer(Trainer): 29 | 30 | def _save_checkpoint(self, model, trial, metrics=None): 31 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 32 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 33 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 34 | 35 | run_dir = self._get_output_dir(trial=trial) 36 | output_dir = os.path.join(run_dir, checkpoint_folder) 37 | 38 | # Only save Adapter 39 | keys_to_match = ['mm_projector'] 40 | if getattr(self.args, "use_im_start_end", False): 41 | keys_to_match.extend(['embed_tokens', 'embed_in']) 42 | 43 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 44 | 45 | if self.args.local_rank == 0 or self.args.local_rank == -1: 46 | self.model.config.save_pretrained(output_dir) 47 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 48 | else: 49 | if self.args.train_vision: 50 | # save the trained vision encoder 51 | output_dir = self.args.output_dir 52 | keys_to_match = ['vision_tower'] 53 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 54 | self.model.config.save_pretrained(output_dir) 55 | current_folder = output_dir.split('/')[-1] 56 | parent_folder = os.path.dirname(output_dir) 57 | if self.args.local_rank == 0 or self.args.local_rank == -1: 58 | mm_projector_folder = os.path.join(output_dir, f"tuned_vision_tower-{self.state.global_step}") 59 | os.makedirs(mm_projector_folder, exist_ok=True) 60 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'tuned_vision_tower.bin')) 61 | self.model.config.mm_vision_tower = os.path.join(mm_projector_folder, f'tuned_vision_tower.bin') 62 | if self.args.train_adapter: 63 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 64 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}/adapter" 65 | 66 | run_dir = self._get_output_dir(trial=trial) 67 | output_dir = os.path.join(run_dir, checkpoint_folder) 68 | 69 | # Only save Adapter 70 | keys_to_match = ['mm_projector'] 71 | if getattr(self.args, "use_im_start_end", False): 72 | keys_to_match.extend(['embed_tokens', 'embed_in']) 73 | 74 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 75 | 76 | if self.args.local_rank == 0 or self.args.local_rank == -1: 77 | self.model.config.save_pretrained(output_dir) 78 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 79 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 80 | 81 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 82 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 83 | pass 84 | else: 85 | super(LLaVATrainer, self)._save(output_dir, state_dict) 86 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | **Unified model** 2 | 3 | The checkpoint of the unified GUI agent can be found [here](https://drive.google.com/drive/folders/1EaqgfOkJ9xrPlcRP2mTr-FK1eXU6oCC7?usp=drive_link). -------------------------------------------------------------------------------- /run_eg/eval.sh: -------------------------------------------------------------------------------- 1 | # evaluation example 2 | 3 | CUDA_VISIBLE_DEVICES=0 python model_aitw.py --model-path xx --question-file xx --answers-file xx --data_name google_apps_parsed_episode_owl_pre10_pre10 4 | python eval_aitw_cot.py --answers-file xx --prd_output_path xx --eval_name xx --data_name google_apps_parsed_episode_owl_pre10_pre10 -------------------------------------------------------------------------------- /run_eg/prepare_data.sh: -------------------------------------------------------------------------------- 1 | # splits 2 | # standard.json from https://github.com/google-research/google-research/tree/master/android_in_the_wild 3 | # single.json from https://github.com/cooelf/Auto-UI/tree/main 4 | 5 | # prepare data 6 | python fetch_dataset_for_t5_blipv2.py --split_file "dataset/splits/standard.json" --output_dir "dataset/owl/general_parsed_episode_owl" 7 | python covert_aitw_to_llavacot_hist_fullset.py -------------------------------------------------------------------------------- /run_eg/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | PROMPT_VERSION="llava_llama_2" 12 | MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | EXPERIMENT_NAME='aitwfull_fullset_percept_goopre10_fixed' 15 | deepspeed --include localhost:1,5,2,7 llava/train/train_mem.py \ 16 | --run_name $EXPERIMENT_NAME \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path root_dir/checkpoints/llava-7b-$MODEL_VERSION \ 19 | --version $PROMPT_VERSION \ 20 | --data_path root_dir/scripts/aitw_data/fullset/fullset_8hist_cot_norm/llava_aitwfull_fullsetgoopre_8histlocation_cot_norm_truncted_fixed_train_QCM-LEA.json \ 21 | --vision_tower openai/clip-vit-large-patch14 \ 22 | --pretrain_mm_mlp_adapter ./checkpoints/llava-pretrain-$MODEL_VERSION/mm_projector.bin \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --mm_use_im_patch_token False \ 26 | --bf16 True \ 27 | --output_dir ./checkpoints/llava-$EXPERIMENT_NAME-$MODEL_VERSION-finetune \ 28 | --max_steps 40000 \ 29 | --per_device_train_batch_size 16 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 1 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 5000 \ 35 | --save_total_limit 5 \ 36 | --learning_rate 2e-5 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 1 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --report_to wandb 47 | -------------------------------------------------------------------------------- /scripts/action_matching.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import action_type as action_type_lib 6 | 7 | 8 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 9 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 10 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 11 | 12 | # Interval determining if an action is a tap or a swipe. 13 | _SWIPE_DISTANCE_THRESHOLD = 0.04 14 | 15 | 16 | def _yx_in_bounding_boxes( 17 | yx, bounding_boxes 18 | ): 19 | """Check if the (y,x) point is contained in each bounding box. 20 | 21 | Args: 22 | yx: The (y, x) coordinate in pixels of the point. 23 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 24 | represents a bounding box: (y_top_left, x_top_left, box_height, 25 | box_width). Note: containment is inclusive of the bounding box edges. 26 | 27 | Returns: 28 | is_inside: A 1D bool array where each element specifies if the point is 29 | contained within the respective box. 30 | """ 31 | y, x = yx 32 | 33 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 34 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 35 | top, left, height, width = [ 36 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 37 | ] 38 | 39 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 40 | bottom, right = top + height, left + width 41 | 42 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and( 43 | x >= left, x <= right) 44 | 45 | 46 | def _resize_annotation_bounding_boxes( 47 | annotation_positions, annotation_width_augment_fraction, 48 | annotation_height_augment_fraction): 49 | """Resize the bounding boxes by the given fractions. 50 | 51 | Args: 52 | annotation_positions: Array of shape (N, 4), where each row represents the 53 | (y, x, height, width) of the bounding boxes. 54 | annotation_width_augment_fraction: The fraction to augment the box widths, 55 | E.g., 1.4 == 240% total increase. 56 | annotation_height_augment_fraction: Same as described for width, but for box 57 | height. 58 | 59 | Returns: 60 | Resized bounding box. 61 | 62 | """ 63 | height_change = ( 64 | annotation_height_augment_fraction * annotation_positions[:, 2]) 65 | width_change = ( 66 | annotation_width_augment_fraction * annotation_positions[:, 3]) 67 | 68 | # Limit bounding box positions to the screen. 69 | resized_annotations = jnp.stack([ 70 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 71 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 72 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 73 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 74 | ], 75 | axis=1) 76 | return resized_annotations 77 | 78 | 79 | def is_tap_action(normalized_start_yx, 80 | normalized_end_yx): 81 | distance = jnp.linalg.norm( 82 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 83 | return distance <= _SWIPE_DISTANCE_THRESHOLD 84 | 85 | 86 | def _is_non_dual_point_action(action_type): 87 | return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT) 88 | 89 | 90 | def _check_tap_actions_match( 91 | tap_1_yx, 92 | tap_2_yx, 93 | annotation_positions, 94 | matching_tap_distance_threshold_screen_percentage, 95 | annotation_width_augment_fraction, 96 | annotation_height_augment_fraction, 97 | ): 98 | """Determines if two tap actions are the same.""" 99 | resized_annotation_positions = _resize_annotation_bounding_boxes( 100 | annotation_positions, 101 | annotation_width_augment_fraction, 102 | annotation_height_augment_fraction, 103 | ) 104 | 105 | # Check if the ground truth tap action falls in an annotation's bounding box. 106 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 107 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 108 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 109 | 110 | # If the ground-truth tap action falls outside any of the annotation 111 | # bounding boxes or one of the actions is inside a bounding box and the other 112 | # is outside bounding box or vice versa, compare the points using Euclidean 113 | # distance. 114 | within_threshold = ( 115 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 116 | <= matching_tap_distance_threshold_screen_percentage 117 | ) 118 | return jnp.logical_or(both_in_box, within_threshold) 119 | 120 | 121 | def _check_drag_actions_match( 122 | drag_1_touch_yx, 123 | drag_1_lift_yx, 124 | drag_2_touch_yx, 125 | drag_2_lift_yx, 126 | ): 127 | """Determines if two drag actions are the same.""" 128 | # Store drag deltas (the change in the y and x coordinates from touch to 129 | # lift), magnitudes, and the index of the main axis, which is the axis with 130 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 131 | # ending at (0.3, 0.5) has a main axis index of 1). 132 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 133 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 134 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 135 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 136 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 137 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 138 | 139 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) 140 | 141 | 142 | def check_actions_match( 143 | action_1_touch_yx, 144 | action_1_lift_yx, 145 | action_1_action_type, 146 | action_2_touch_yx, 147 | action_2_lift_yx, 148 | action_2_action_type, 149 | annotation_positions, 150 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 151 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 152 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 153 | ): 154 | """Determines if two actions are considered to be the same. 155 | 156 | Two actions being "the same" is defined here as two actions that would result 157 | in a similar screen state. 158 | 159 | Args: 160 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 161 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 162 | action_1_action_type: The action type of the first action. 163 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 164 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 165 | action_2_action_type: The action type of the second action. 166 | annotation_positions: The positions of the UI annotations for the screen. It 167 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 168 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 169 | containment is inclusive of the bounding box edges. 170 | tap_distance_threshold: The threshold that determines if two taps result in 171 | a matching screen state if they don't fall the same bounding boxes. 172 | annotation_width_augment_fraction: The fraction to increase the width of the 173 | bounding box by. 174 | annotation_height_augment_fraction: The fraction to increase the height of 175 | of the bounding box by. 176 | 177 | Returns: 178 | A boolean representing whether the two given actions are the same or not. 179 | """ 180 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 181 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 182 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 183 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 184 | 185 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 186 | # because if that is the case, only the actions' types need to be compared. 187 | has_non_dual_point_action = jnp.logical_or( 188 | _is_non_dual_point_action(action_1_action_type), 189 | _is_non_dual_point_action(action_2_action_type), 190 | ) 191 | 192 | different_dual_point_types = jnp.logical_xor( 193 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 194 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 195 | ) 196 | 197 | is_tap = jnp.logical_and( 198 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 199 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 200 | ) 201 | 202 | taps_match = _check_tap_actions_match( 203 | action_1_touch_yx, 204 | action_2_touch_yx, 205 | annotation_positions, 206 | tap_distance_threshold, 207 | annotation_width_augment_fraction, 208 | annotation_height_augment_fraction, 209 | ) 210 | 211 | taps_match = jnp.logical_and(is_tap, taps_match) 212 | 213 | drags_match = _check_drag_actions_match( 214 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 215 | ) 216 | drags_match = jnp.where(is_tap, False, drags_match) 217 | 218 | return jnp.where( 219 | has_non_dual_point_action, 220 | jnp.equal(action_1_action_type, action_2_action_type), 221 | jnp.where( 222 | different_dual_point_types, 223 | False, 224 | jnp.logical_or(taps_match, drags_match), 225 | ), 226 | ) -------------------------------------------------------------------------------- /scripts/action_type.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class ActionType(enum.IntEnum): 4 | 5 | # Placeholders for unused enum values 6 | UNUSED_0 = 0 7 | UNUSED_1 = 1 8 | UNUSED_2 = 2 9 | UNUSED_8 = 8 10 | UNUSED_9 = 9 11 | 12 | ########### Agent actions ########### 13 | 14 | # A type action that sends text to the emulator. Note that this simply sends 15 | # text and does not perform any clicks for element focus or enter presses for 16 | # submitting text. 17 | TYPE = 3 18 | 19 | # The dual point action used to represent all gestures. 20 | DUAL_POINT = 4 21 | 22 | # These actions differentiate pressing the home and back button from touches. 23 | # They represent explicit presses of back and home performed using ADB. 24 | PRESS_BACK = 5 25 | PRESS_HOME = 6 26 | 27 | # An action representing that ADB command for hitting enter was performed. 28 | PRESS_ENTER = 7 29 | 30 | ########### Episode status actions ########### 31 | 32 | # An action used to indicate the desired task has been completed and resets 33 | # the environment. This action should also be used in the case that the task 34 | # has already been completed and there is nothing to do. 35 | # e.g. The task is to turn on the Wi-Fi when it is already on 36 | STATUS_TASK_COMPLETE = 10 37 | 38 | # An action used to indicate that desired task is impossible to complete and 39 | # resets the environment. This can be a result of many different things 40 | # including UI changes, Android version differences, etc. 41 | STATUS_TASK_IMPOSSIBLE = 11 -------------------------------------------------------------------------------- /scripts/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /scripts/covert_aitw_to_llava.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from utils_data_for_owl import load_for_owl 3 | 4 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA", name=''): 5 | # split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 6 | # problems = json.load(open(os.path.join(base_dir, "problems.json"))) 7 | 8 | # split_problems = build_prompt_chatbot( 9 | # problems, split_indices, prompt_format, 10 | # use_caption=False, is_test=False) 11 | orig_data = load_for_owl(base_dir, split) 12 | 13 | 14 | target_format = [] 15 | for idx, text in enumerate(orig_data): 16 | input = text['text'].split('AI: ')[0] 17 | output = 'AI: ' + text['text'].split('AI: ')[1] 18 | 19 | if input.startswith('Human: '): 20 | input = input.replace('Human: ', '') 21 | if output.startswith('AI: '): 22 | output = output.replace('AI: ', '') 23 | assert 'image' in text.keys() 24 | target_format.append({ 25 | "id": text['image'], 26 | "image": text['image'], 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | # "target_text": text['target_text'], 32 | # "annos": text['anno_pos'] 33 | }) 34 | 35 | print(f'Number of samples: {len(target_format)}') 36 | # outputfile = os.path.join('./aitw_data/memory/', f"llava_aitwfull{name}_{split}_{prompt_format}.json") 37 | # if os.path.exists(outputfile): 38 | # raise FileExistsError 39 | # with open(outputfile, "w") as f: 40 | # json.dump(target_format, f, indent=2) 41 | 42 | convert_to_llava('.', 'train', name='_hist12_highlevel') -------------------------------------------------------------------------------- /scripts/covert_aitw_to_llavacot_hist_fullset.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from utils_data_for_owl_cot_hist import load_for_owl 3 | from llava.conversation import conv_templates 4 | from llava.mm_utils import tokenizer_image_token 5 | from llava.constants import IMAGE_TOKEN_INDEX 6 | from transformers import AutoTokenizer 7 | from tqdm import tqdm 8 | 9 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA", name=''): 10 | # split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 11 | # problems = json.load(open(os.path.join(base_dir, "problems.json"))) 12 | 13 | # split_problems = build_prompt_chatbot( 14 | # problems, split_indices, prompt_format, 15 | # use_caption=False, is_test=False) 16 | model_path = '/data/maxb/tag/LLaVA/checkpoints/llava-7b-llama-2-7b-chat' 17 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 18 | orig_data = load_for_owl(base_dir, split) 19 | conv = conv_templates['llava_llama_2'].copy() 20 | 21 | target_format = [] 22 | trunced = 0 23 | for idx, text in enumerate(tqdm(orig_data)): 24 | input = text['text'].split('Next action:\nAI: ')[0] + 'Next action:\n' 25 | output = 'AI: ' + text['text'].split('Next action:\nAI: ')[1] 26 | 27 | if input.startswith('Human: '): 28 | input = input.replace('Human: ', '') 29 | if output.startswith('AI: '): 30 | output = output.replace('AI: ', '') 31 | assert 'image' in text.keys() 32 | conv = conv_templates['llava_llama_2'].copy() 33 | conv.append_message(conv.roles[0], input) 34 | conv.append_message(conv.roles[1], output) 35 | prompt = conv.get_prompt() 36 | # print("first:", prompt) 37 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=None)#.unsqueeze(0) 38 | # print(input_ids) 39 | # print(len(input_ids)) 40 | if len(input_ids) > 2048: 41 | # print('pick out layout to keep less than 2048') 42 | trunced += 1 43 | while len(input_ids) > 2048: 44 | # if idx > 28660: 45 | # print(text) 46 | # print('**prompt**',prompt) 47 | # print("trunc:", len(input_ids)) 48 | input_locs = input.split('')[-1].split('\nPrevious Actions')[0] 49 | # print(input_locs) 50 | input_locs_ = '\n'.join(input_locs.strip().split('\n')[:-1]) 51 | if len(input_locs_) >= len(input_locs): 52 | print("input_locs: ", input_locs, "input_locs_: ", input_locs_) 53 | os._exit(0) 54 | # print(input_locs_) 55 | input = input.replace(input_locs, input_locs_) 56 | conv = conv_templates['llava_llama_2'].copy() 57 | conv.append_message(conv.roles[0], input) 58 | conv.append_message(conv.roles[1], output) 59 | prompt = conv.get_prompt() 60 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=None) 61 | 62 | # os._exit(0) 63 | target_format.append({ 64 | "id": text['image'], 65 | "image": text['image'], 66 | "conversations": [ 67 | {'from': 'human', 'value': f"{input}"}, 68 | {'from': 'gpt', 'value': f"{output}"}, 69 | ], 70 | # "target_text": text['target_text'], 71 | # "annos": text['anno_pos'] 72 | }) 73 | print(f'Number of samples: {len(target_format)}, trunced samples: {trunced}') 74 | outputfile = os.path.join('/data/maxb/tag/LLaVA/scripts/aitw_data/fullset/fullset_8hist_cot_norm/', f"llava_aitwfull{name}_{split}_{prompt_format}.json") 75 | if os.path.exists(outputfile): 76 | raise FileExistsError 77 | with open(outputfile, "w") as f: 78 | json.dump(target_format, f, indent=2) 79 | 80 | # convert_to_llava('.', 'train', name='_fullsetgoopre_8histlocation_cot_norm_truncted_fixed') 81 | # convert_to_llava('.', 'test', name='_fullsetgoopre_8histlocation_cot_norm_truncted') 82 | # convert_to_llava('.', 'test', name='_fixsingle_fullsetgoopre_8histlocation_cot_norm_truncted') 83 | # convert_to_llava('.', 'test', name='_install_fullsetgoopre_8histlocation_cot_norm_truncted') 84 | # convert_to_llava('.', 'test', name='_google_fullsetgoopre_8histlocation_cot_norm_truncted') 85 | convert_to_llava('.', 'test', name='_webshop_fullsetgoopre_8histlocation_cot_norm_truncted') -------------------------------------------------------------------------------- /scripts/covert_aitw_to_mmicl.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from utils_data_for_owl import load_for_owl 3 | 4 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA", name=''): 5 | # split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 6 | # problems = json.load(open(os.path.join(base_dir, "problems.json"))) 7 | 8 | # split_problems = build_prompt_chatbot( 9 | # problems, split_indices, prompt_format, 10 | # use_caption=False, is_test=False) 11 | orig_data = load_for_owl(base_dir, split) 12 | 13 | 14 | target_format = [] 15 | for idx, text in enumerate(orig_data): 16 | input = text['text'].split('AI: ')[0] 17 | output = 'AI: ' + text['text'].split('AI: ')[1] 18 | input = input.replace('', 'image 0: 图') 19 | if input.startswith('Human: '): 20 | input = input.replace('Human: ', '') 21 | if output.startswith('AI: '): 22 | output = output.replace('AI: ', '') 23 | assert 'image' in text.keys() 24 | target_format.append({ 25 | "id": text['image'], 26 | "input_image": [text['image']], 27 | "input_text": input, 28 | "output_text": output, 29 | # "conversations": [ 30 | # {'from': 'human', 'value': f"{input}"}, 31 | # {'from': 'gpt', 'value': f"{output}"}, 32 | # ], 33 | # "target_text": text['target_text'], 34 | # "annos": text['anno_pos'] 35 | }) 36 | 37 | print(f'Number of samples: {len(target_format)}') 38 | outputfile = os.path.join('./aitw_data/mmicl/', f"llava_aitwfull{name}_{split}_{prompt_format}.json") 39 | if os.path.exists(outputfile): 40 | raise FileExistsError 41 | with open(outputfile, "w") as f: 42 | json.dump(target_format, f, indent=2) 43 | 44 | convert_to_llava('.', 'test', name='_hist4') -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | 24 | # python merge_lora_weights.py --model-path liuhaotian/llava-llama-2-7b-chat-lightning-lora-preview --model-base meta-llama/Llama-2-7b-chat-hf --save-model-path ../checkpoints/llava-7b-llama-2-7b-chat_2 --------------------------------------------------------------------------------