├── .gitattributes ├── HalDet-LLaVA ├── README.md ├── finetune_task_lora.sh ├── inference.py ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── eval │ │ ├── eval_gpt_review.py │ │ ├── eval_gpt_review_bench.py │ │ ├── eval_gpt_review_visual.py │ │ ├── eval_pope.py │ │ ├── eval_science_qa.py │ │ ├── eval_science_qa_gpt4.py │ │ ├── eval_science_qa_gpt4_requery.py │ │ ├── eval_textvqa.py │ │ ├── generate_webpage_data_from_table.py │ │ ├── m4c_evaluator.py │ │ ├── model_qa.py │ │ ├── model_vqa.py │ │ ├── model_vqa_loader.py │ │ ├── model_vqa_mmbench.py │ │ ├── model_vqa_qbench.py │ │ ├── model_vqa_science.py │ │ ├── qa_baseline_gpt35.py │ │ ├── run_llava.py │ │ ├── summarize_gpt_review.py │ │ ├── table │ │ │ ├── answer │ │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ │ ├── answer_bard.jsonl │ │ │ │ ├── answer_gpt35.jsonl │ │ │ │ ├── answer_llama-13b.jsonl │ │ │ │ └── answer_vicuna-13b.jsonl │ │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ │ ├── model.jsonl │ │ │ ├── prompt.jsonl │ │ │ ├── question.jsonl │ │ │ ├── results │ │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ │ ├── review │ │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ │ ├── reviewer.jsonl │ │ │ └── rule.json │ │ └── webpage │ │ │ ├── figures │ │ │ ├── alpaca.png │ │ │ ├── bard.jpg │ │ │ ├── chatgpt.svg │ │ │ ├── llama.jpg │ │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ │ └── vicuna.jpeg │ │ │ ├── index.html │ │ │ ├── script.js │ │ │ └── styles.css │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model │ │ │ ├── llava_llama.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ ├── attention.py │ │ │ │ ├── blocks.py │ │ │ │ ├── configuration_mpt.py │ │ │ │ ├── custom_embedding.py │ │ │ │ ├── flash_attn_triton.py │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ ├── meta_init_context.py │ │ │ │ ├── modeling_mpt.py │ │ │ │ ├── norm.py │ │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── multimodal_projector │ │ │ └── builder.py │ │ └── utils.py │ ├── serve │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── examples │ │ │ ├── extreme_ironing.jpg │ │ │ └── waterview.jpg │ │ ├── gradio_web_server.py │ │ ├── model_worker.py │ │ ├── register_worker.py │ │ └── test_message.py │ ├── train │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llama_xformers_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ ├── train_mem.py │ │ └── train_xformers.py │ └── utils.py ├── requirements.txt └── scripts │ ├── zero2.json │ ├── zero3.json │ └── zero3_offload.json ├── LICENSE ├── README.md ├── examples ├── 058214af21a03013.jpg ├── 43adec54f56ff7af.jpg ├── 508.jpg └── 63.jpg ├── figs ├── .DS_Store ├── datasetinfo.jpg ├── easydetect.jpg ├── framework.png ├── huggingface.svg ├── intro.png ├── view.png ├── 条形图.png └── 饼图.png ├── pipeline ├── claim_generate.py ├── config │ └── config.yaml ├── examples │ ├── animal.jpg │ ├── ball.jpg │ ├── football.jpg │ └── sandbeach.jpg ├── openai_wrapper.py ├── prompts │ ├── claim_generate.yaml │ ├── query_generate.yaml │ └── verify.yaml ├── query_generate.py ├── run_pipeline.py ├── tool │ ├── detect.py │ ├── google_serper.py │ └── ocr.py ├── tool_execute.py └── verify.py ├── requirements.txt └── vqa.mp4 /.gitattributes: -------------------------------------------------------------------------------- 1 | figs/动图.mp4 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /HalDet-LLaVA/README.md: -------------------------------------------------------------------------------- 1 | ## Train dataset 2 | 3 | the steps of train dataset construction: 4 | - step1: We used LLaVA to generate raw response on the training sets of [MSCOCO-2014](https://cocodataset.org/), [VQA-v2](https://visualqa.org/) and [TextVQA](https://textvqa.org/dataset/). 5 | - step2: We prompted GPT-3.5 to incorporate hallucinated text in terms of objects, attributes, scene text, and factual aspects, which were then manually reviewed. 6 | - step3: We use the pipeline [UniHD](https://arxiv.org/abs/2402.03190) to generate labels and rationales for text, which were then subjected to manual screening and modifications to obtain the training dataset. 7 | 8 | the train dataset metadata info: 9 | We have constructed 1270 instructions for fine-tuning data. The ratio of hallucination claims to non-hallucination claims is 2244:1633. We provide reference tool information and reference prompt in train set. Below is an example of a train data: 10 | 11 | ```json 12 | { 13 | "id": 6, 14 | "image_path": "/TextVQA/train_images/160dee3be9ec3cbc.jpg", 15 | "claim_list": ["The laptop brand is Toshiba","Toshiba is a multinational conglomerate with a rich history","Toshiba was founded in 1885"], 16 | "ref_tool_info": "Here is the object detection expert model's result: laptop [0.003, 0.001, 0.996, 0.996] \nHere is the scene text recognition expert model's result: ANNIVERSARY [0.065, 0.638, 0.952, 0.826] TONGFaNG [0.462, 0.523, 0.575, 0.542] \nHere is the external knowledge: 1. Toshiba Corporation (株式会社東芝, Kabushikigaisha Tōshiba, English: /təˈʃiːbə, tɒ-, toʊ-/) is a Japanese multinational electronics company headquartered in Minato, Tokyo, Japan. 2. Toshiba's early history has two strands: One is", 17 | "ref_claim_label": ["hallucination", "non-hallucination", "hallucination"], 18 | "ref_reason": [{"claim1": "hallucination","reason": "The scene text recognition expert model's result shows the text 'TONGFANG' on the laptop, not Toshiba. Therefore, there's a hallucination."},{"claim2": "non-hallucination","reason": "Based on the external knowledge provided, Toshiba is indeed a multinational conglomerate with a rich history. Therefore, there's no hallucination."},{"claim3": "hallucination","reason": "According to the external knowledge, Toshiba was founded in 1939 by the merger of Shibaura Seisakusho and Tokyo Denki, not in 1885. Therefore, there's a hallucination."}], 19 | "ref_prompt": "Given an image, a list of claims from Multimodal Large Language Models and some supplementary information by external tools, you are required to judge whether each claim in the list conflicts with the image, following these rules: \n1. You must carefully judge from four aspects, including the object, attributes, scene text and fact.\n2. You must carefully utilize supplementary information.\n3. You must carefully judge whether the visual information in the image conflicts with each claim. If there is a conflict, the result for that claim is labeled as 'hallucination'; otherwise, it is labeled as 'non-hallucination'.\n4. Finally, You MUST only respond in a dictionary format. DO NOT RESPOND WITH ANYTHING ELSE.\n" 20 | } 21 | ``` 22 | 23 | **Note**:If you want to select LLaVA as the target model for training, you need to adjust to its [custom dataset format](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md). 24 | 25 | 26 | 27 | ## HalDet-LLaVA 28 | 29 | HalDet-LLaVA is trained on the [MHaluBench training set](https://huggingface.co/datasets/openkg/MHaluBench/blob/main/MHaluBench_train.json) using LLaVA-v1.5, specific parameters can be found in the file [finetune_task_lora.sh](https://github.com/zjunlp/EasyDetect/blob/main/HalDet-LLaVA/finetune_task_lora.sh). 30 | 31 | We trained HalDet-LLaVA on 1-A800 in 1 hour. If you don"t have enough GPU resources, we will soon provide model distributed training scripts. 32 | 33 | You can inference our HalDet-LLaVA by using [inference.py](https://github.com/zjunlp/EasyDetect/blob/main/HalDet-LLaVA/inference.py) 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /HalDet-LLaVA/finetune_task_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed /llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed /scripts/zero3.json \ 6 | --model_name_or_path ./llava_1.5_7b \ 7 | --version v1 \ 8 | --data_path ./train.json \ 9 | --image_folder . \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir ./checkpoints/llava-v1.5-7b-lora \ 19 | --num_train_epochs 2 \ 20 | --per_device_train_batch_size 16 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-4 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /HalDet-LLaVA/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from llava.constants import ( 4 | IMAGE_TOKEN_INDEX, 5 | DEFAULT_IMAGE_TOKEN, 6 | DEFAULT_IM_START_TOKEN, 7 | DEFAULT_IM_END_TOKEN, 8 | IMAGE_PLACEHOLDER, 9 | ) 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import ( 14 | process_images, 15 | tokenizer_image_token, 16 | get_model_name_from_path, 17 | KeywordsStoppingCriteria, 18 | ) 19 | 20 | from PIL import Image 21 | import base64 22 | import requests 23 | from PIL import Image 24 | from io import BytesIO 25 | import re 26 | 27 | class HalDetLLaVA: 28 | def __init__(self, model_path): 29 | self.model_path = model_path 30 | self.model_name = get_model_name_from_path(self.model_path) 31 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(self.model_path, None, self.model_name, device="cuda:0") 32 | 33 | def image_parser(self, image_file, sep): 34 | out = image_file.split(sep) 35 | return out 36 | 37 | def load_image(self, image_file): 38 | if image_file.startswith("http") or image_file.startswith("https"): 39 | response = requests.get(image_file) 40 | image = Image.open(BytesIO(response.content)).convert("RGB") 41 | else: 42 | image = Image.open(image_file).convert("RGB") 43 | return image 44 | 45 | def load_images(self, image_files): 46 | out = [] 47 | for image_file in image_files: 48 | image = self.load_image(image_file) 49 | out.append(image) 50 | return out 51 | 52 | def get_response(self, query, image_file, conv_mode=None, sep=",", temperature=0.8, top_p=None, num_beams=1, max_new_tokens=2048): 53 | # Model 54 | disable_torch_init() 55 | qs = query 56 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 57 | if IMAGE_PLACEHOLDER in qs: 58 | if self.model.config.mm_use_im_start_end: 59 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 60 | else: 61 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 62 | else: 63 | if self.model.config.mm_use_im_start_end: 64 | qs = image_token_se + "\n" + qs 65 | else: 66 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 67 | conv_mode = "llava_v1" 68 | if conv_mode is not None and conv_mode != conv_mode: 69 | print( 70 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 71 | conv_mode, conv_mode, conv_mode 72 | ) 73 | ) 74 | else: 75 | conv_mode = conv_mode 76 | 77 | conv = conv_templates[conv_mode].copy() 78 | conv.append_message(conv.roles[0], qs) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | image_files = self.image_parser(image_file, sep) 83 | images = self.load_images(image_files) 84 | images_tensor = process_images( 85 | images, 86 | self.image_processor, 87 | self.model.config 88 | ).to(self.model.device, dtype=torch.float16) 89 | 90 | input_ids = ( 91 | tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 92 | .unsqueeze(0) 93 | .cuda(0) 94 | ) 95 | 96 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 97 | keywords = [stop_str] 98 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) 99 | 100 | with torch.inference_mode(): 101 | output_ids = self.model.generate( 102 | input_ids, 103 | images=images_tensor, 104 | do_sample=True if temperature > 0 else False, 105 | temperature=temperature, 106 | top_p=top_p, 107 | num_beams=num_beams, 108 | max_new_tokens=max_new_tokens, 109 | use_cache=True, 110 | stopping_criteria=[stopping_criteria], 111 | ) 112 | 113 | input_token_len = input_ids.shape[1] 114 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 115 | if n_diff_input_output > 0: 116 | print( 117 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 118 | ) 119 | outputs = self.tokenizer.batch_decode( 120 | output_ids[:, input_token_len:], skip_special_tokens=True 121 | )[0] 122 | outputs = outputs.strip() 123 | if outputs.endswith(stop_str): 124 | outputs = outputs[: -len(stop_str)] 125 | outputs = outputs.strip() 126 | return outputs 127 | 128 | if __name__ == '__main__': 129 | model = HalDetLLaVA("zjunlp/HalDet-llava-7b") 130 | query = "Given an image, a list of claims from Multimodal Large Language Models and some supplementary information by external tools, you are required to judge whether each claim in the list conflicts with the image, following these rules: \n1. You must carefully judge from four aspects, including the object, attributes, scene text and fact. \n2. You must carefully utilize supplementary information. \n3. You must carefully judge whether the visual information in the image conflicts with each claim. If there is a conflict, the result for that claim is labeled as \"hallucination\"; otherwise, it is labeled as \"non-hallucination\". \n4. Finally, You MUST only respond in a dictionary format. DO NOT RESPOND WITH ANYTHING ELSE.\nHere is the claim list: claim1: The cafe in the image is named \"Hauptbahnhof\" \nSupplementary information:\nHere is the object detection expert model's result: cafe [0.703, 0.621, 0.770, 0.650] \nHere is the scene text recognition expert model's result: Hauptbahnhof [0.571, 0.627, 0.622, 0.649] ZEITCAFE [0.707, 0.629, 0.775, 0.659] \nHere is the external knowledge: none information" 131 | image_file = "../examples/058214af21a03013.jpg" 132 | res = model.get_response(query,image_file) 133 | print(res) -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True if args.temperature > 0 else False, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_vqa_loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | # Custom dataset class 31 | class CustomDataset(Dataset): 32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): 33 | self.questions = questions 34 | self.image_folder = image_folder 35 | self.tokenizer = tokenizer 36 | self.image_processor = image_processor 37 | self.model_config = model_config 38 | 39 | def __getitem__(self, index): 40 | line = self.questions[index] 41 | image_file = line["image"] 42 | qs = line["text"] 43 | if self.model_config.mm_use_im_start_end: 44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 45 | else: 46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | conv.append_message(conv.roles[0], qs) 50 | conv.append_message(conv.roles[1], None) 51 | prompt = conv.get_prompt() 52 | 53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') 54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0] 55 | 56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') 57 | 58 | return input_ids, image_tensor 59 | 60 | def __len__(self): 61 | return len(self.questions) 62 | 63 | 64 | # DataLoader 65 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): 66 | assert batch_size == 1, "batch_size must be 1" 67 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config) 68 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) 69 | return data_loader 70 | 71 | 72 | def eval_model(args): 73 | # Model 74 | disable_torch_init() 75 | model_path = os.path.expanduser(args.model_path) 76 | model_name = get_model_name_from_path(model_path) 77 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 78 | 79 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 80 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 81 | answers_file = os.path.expanduser(args.answers_file) 82 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 83 | ans_file = open(answers_file, "w") 84 | 85 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 86 | args.conv_mode = args.conv_mode + '_mmtag' 87 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 88 | 89 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config) 90 | 91 | for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)): 92 | idx = line["question_id"] 93 | cur_prompt = line["text"] 94 | 95 | input_ids = input_ids.to(device='cuda', non_blocking=True) 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 101 | do_sample=True if args.temperature > 0 else False, 102 | temperature=args.temperature, 103 | top_p=args.top_p, 104 | num_beams=args.num_beams, 105 | max_new_tokens=args.max_new_tokens, 106 | use_cache=True) 107 | 108 | input_token_len = input_ids.shape[1] 109 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 110 | if n_diff_input_output > 0: 111 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 112 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 113 | outputs = outputs.strip() 114 | 115 | ans_id = shortuuid.uuid() 116 | ans_file.write(json.dumps({"question_id": idx, 117 | "prompt": cur_prompt, 118 | "text": outputs, 119 | "answer_id": ans_id, 120 | "model_id": model_name, 121 | "metadata": {}}) + "\n") 122 | # ans_file.flush() 123 | ans_file.close() 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 128 | parser.add_argument("--model-base", type=str, default=None) 129 | parser.add_argument("--image-folder", type=str, default="") 130 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 131 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 132 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 133 | parser.add_argument("--num-chunks", type=int, default=1) 134 | parser.add_argument("--chunk-idx", type=int, default=0) 135 | parser.add_argument("--temperature", type=float, default=0.2) 136 | parser.add_argument("--top_p", type=float, default=None) 137 | parser.add_argument("--num_beams", type=int, default=1) 138 | parser.add_argument("--max_new_tokens", type=int, default=128) 139 | args = parser.parse_args() 140 | 141 | eval_model(args) 142 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_vqa_mmbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | all_options = ['A', 'B', 'C', 'D'] 20 | 21 | 22 | def split_list(lst, n): 23 | """Split a list into n (roughly) equal-sized chunks""" 24 | chunk_size = math.ceil(len(lst) / n) # integer division 25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 26 | 27 | 28 | def get_chunk(lst, n, k): 29 | chunks = split_list(lst, n) 30 | return chunks[k] 31 | 32 | 33 | def is_none(value): 34 | if value is None: 35 | return True 36 | if type(value) is float and math.isnan(value): 37 | return True 38 | if type(value) is str and value.lower() == 'nan': 39 | return True 40 | if type(value) is str and value.lower() == 'none': 41 | return True 42 | return False 43 | 44 | def get_options(row, options): 45 | parsed_options = [] 46 | for option in options: 47 | option_value = row[option] 48 | if is_none(option_value): 49 | break 50 | parsed_options.append(option_value) 51 | return parsed_options 52 | 53 | 54 | def eval_model(args): 55 | # Model 56 | disable_torch_init() 57 | model_path = os.path.expanduser(args.model_path) 58 | model_name = get_model_name_from_path(model_path) 59 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 60 | 61 | questions = pd.read_table(os.path.expanduser(args.question_file)) 62 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 63 | answers_file = os.path.expanduser(args.answers_file) 64 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 65 | ans_file = open(answers_file, "w") 66 | 67 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 68 | args.conv_mode = args.conv_mode + '_mmtag' 69 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 70 | 71 | for index, row in tqdm(questions.iterrows(), total=len(questions)): 72 | options = get_options(row, all_options) 73 | cur_option_char = all_options[:len(options)] 74 | 75 | if args.all_rounds: 76 | num_rounds = len(options) 77 | else: 78 | num_rounds = 1 79 | 80 | for round_idx in range(num_rounds): 81 | idx = row['index'] 82 | question = row['question'] 83 | hint = row['hint'] 84 | image = load_image_from_base64(row['image']) 85 | if not is_none(hint): 86 | question = hint + '\n' + question 87 | for option_char, option in zip(all_options[:len(options)], options): 88 | question = question + '\n' + option_char + '. ' + option 89 | qs = cur_prompt = question 90 | if model.config.mm_use_im_start_end: 91 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 94 | 95 | if args.single_pred_prompt: 96 | if args.lang == 'cn': 97 | qs = qs + '\n' + "请直接回答选项字母。" 98 | else: 99 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 100 | 101 | conv = conv_templates[args.conv_mode].copy() 102 | conv.append_message(conv.roles[0], qs) 103 | conv.append_message(conv.roles[1], None) 104 | prompt = conv.get_prompt() 105 | 106 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 107 | 108 | image_tensor = process_images([image], image_processor, model.config)[0] 109 | # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 110 | 111 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 112 | 113 | with torch.inference_mode(): 114 | output_ids = model.generate( 115 | input_ids, 116 | images=image_tensor.unsqueeze(0).half().cuda(), 117 | do_sample=True if args.temperature > 0 else False, 118 | temperature=args.temperature, 119 | top_p=args.top_p, 120 | num_beams=args.num_beams, 121 | # no_repeat_ngram_size=3, 122 | max_new_tokens=1024, 123 | use_cache=True) 124 | 125 | input_token_len = input_ids.shape[1] 126 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 127 | if n_diff_input_output > 0: 128 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 129 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 130 | outputs = outputs.strip() 131 | if outputs.endswith(stop_str): 132 | outputs = outputs[:-len(stop_str)] 133 | outputs = outputs.strip() 134 | 135 | ans_id = shortuuid.uuid() 136 | ans_file.write(json.dumps({"question_id": idx, 137 | "round_id": round_idx, 138 | "prompt": cur_prompt, 139 | "text": outputs, 140 | "options": options, 141 | "option_char": cur_option_char, 142 | "answer_id": ans_id, 143 | "model_id": model_name, 144 | "metadata": {}}) + "\n") 145 | ans_file.flush() 146 | 147 | # rotate options 148 | options = options[1:] + options[:1] 149 | cur_option_char = cur_option_char[1:] + cur_option_char[:1] 150 | ans_file.close() 151 | 152 | if __name__ == "__main__": 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 155 | parser.add_argument("--model-base", type=str, default=None) 156 | parser.add_argument("--image-folder", type=str, default="") 157 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 158 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 159 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 160 | parser.add_argument("--num-chunks", type=int, default=1) 161 | parser.add_argument("--chunk-idx", type=int, default=0) 162 | parser.add_argument("--temperature", type=float, default=0.2) 163 | parser.add_argument("--top_p", type=float, default=None) 164 | parser.add_argument("--num_beams", type=int, default=1) 165 | parser.add_argument("--all-rounds", action="store_true") 166 | parser.add_argument("--single-pred-prompt", action="store_true") 167 | parser.add_argument("--lang", type=str, default="en") 168 | args = parser.parse_args() 169 | 170 | eval_model(args) 171 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_vqa_qbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import json 5 | 6 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 7 | from llava.conversation import conv_templates, SeparatorStyle 8 | from llava.model.builder import load_pretrained_model 9 | from llava.utils import disable_torch_init 10 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 11 | 12 | from PIL import Image 13 | 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | 19 | def load_image(image_file): 20 | if image_file.startswith('http') or image_file.startswith('https'): 21 | response = requests.get(image_file) 22 | image = Image.open(BytesIO(response.content)).convert('RGB') 23 | else: 24 | image = Image.open(image_file).convert('RGB') 25 | return image 26 | 27 | 28 | def eval_model(args): 29 | # Model 30 | disable_torch_init() 31 | 32 | model_name = get_model_name_from_path(args.model_path) 33 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True) 34 | 35 | 36 | 37 | 38 | with open(args.questions_file) as f: 39 | llvqa_data = json.load(f) 40 | 41 | for i, llddata in enumerate(tqdm(llvqa_data)): 42 | filename = llddata["img_path"] 43 | if args.lang == "en": 44 | message = llddata["question"] + "\nChoose between one of the options as follows:\n" 45 | elif args.lang == "zh": 46 | message = llddata["question"] + "\在下列选项中选择一个:\n" 47 | else: 48 | raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.") 49 | for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): 50 | message += f"{choice} {ans}\n" 51 | qs = message 52 | 53 | if model.config.mm_use_im_start_end: 54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 55 | else: 56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 57 | 58 | if 'llama-2' in model_name.lower(): 59 | conv_mode = "llava_llama_2" 60 | elif "v1" in model_name.lower(): 61 | conv_mode = "llava_v1" 62 | elif "mpt" in model_name.lower(): 63 | conv_mode = "mpt" 64 | else: 65 | conv_mode = "llava_v0" 66 | 67 | if args.conv_mode is not None and conv_mode != args.conv_mode: 68 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 69 | else: 70 | args.conv_mode = conv_mode 71 | 72 | conv = conv_templates[args.conv_mode].copy() 73 | conv.append_message(conv.roles[0], qs) 74 | conv.append_message(conv.roles[1], None) 75 | prompt = conv.get_prompt() 76 | 77 | image = load_image(args.image_folder + filename) 78 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 79 | 80 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 81 | 82 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 83 | keywords = [stop_str] 84 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 85 | 86 | 87 | with torch.inference_mode(): 88 | output_ids = model.generate( 89 | input_ids, 90 | images=image_tensor, 91 | num_beams=1, 92 | do_sample=False, 93 | temperature=0, 94 | max_new_tokens=1024, 95 | use_cache=True, 96 | stopping_criteria=[stopping_criteria]) 97 | 98 | input_token_len = input_ids.shape[1] 99 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 100 | if n_diff_input_output > 0: 101 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 102 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 103 | outputs = outputs.strip() 104 | if outputs.endswith(stop_str): 105 | outputs = outputs[:-len(stop_str)] 106 | outputs = outputs.strip() 107 | llddata["response"] = outputs 108 | with open(args.answers_file, "a") as wf: 109 | json.dump(llddata, wf) 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="llava-v1.5") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa") 116 | parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json") 117 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 118 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 119 | parser.add_argument("--lang", type=str, default="en") 120 | args = parser.parse_args() 121 | 122 | eval_model(args) 123 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | qs = question['value'].replace('', '').strip() 45 | cur_prompt = qs 46 | 47 | if 'image' in line: 48 | image_file = line["image"] 49 | image = Image.open(os.path.join(args.image_folder, image_file)) 50 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 51 | images = image_tensor.unsqueeze(0).half().cuda() 52 | if getattr(model.config, 'mm_use_im_start_end', False): 53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 54 | else: 55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 56 | cur_prompt = '' + '\n' + cur_prompt 57 | else: 58 | images = None 59 | 60 | if args.single_pred_prompt: 61 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 62 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." 63 | 64 | conv = conv_templates[args.conv_mode].copy() 65 | conv.append_message(conv.roles[0], qs) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 70 | 71 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 72 | keywords = [stop_str] 73 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None 74 | 75 | with torch.inference_mode(): 76 | output_ids = model.generate( 77 | input_ids, 78 | images=images, 79 | do_sample=True if args.temperature > 0 else False, 80 | temperature=args.temperature, 81 | max_new_tokens=1024, 82 | use_cache=True, 83 | stopping_criteria=stopping_criteria, 84 | ) 85 | 86 | input_token_len = input_ids.shape[1] 87 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 88 | if n_diff_input_output > 0: 89 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 90 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 91 | outputs = outputs.strip() 92 | if outputs.endswith(stop_str): 93 | outputs = outputs[:-len(stop_str)] 94 | outputs = outputs.strip() 95 | 96 | # prompt for answer 97 | if args.answer_prompter: 98 | outputs_reasoning = outputs 99 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 100 | 101 | with torch.inference_mode(): 102 | output_ids = model.generate( 103 | input_ids, 104 | images=images, 105 | do_sample=True if args.temperature > 0 else False, 106 | temperature=args.temperature, 107 | max_new_tokens=64, 108 | use_cache=True, 109 | stopping_criteria=[stopping_criteria]) 110 | 111 | input_token_len = input_ids.shape[1] 112 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 113 | if n_diff_input_output > 0: 114 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 115 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 116 | outputs = outputs.strip() 117 | if outputs.endswith(stop_str): 118 | outputs = outputs[:-len(stop_str)] 119 | outputs = outputs.strip() 120 | outputs = outputs_reasoning + '\n The answer is ' + outputs 121 | 122 | ans_id = shortuuid.uuid() 123 | ans_file.write(json.dumps({"question_id": idx, 124 | "prompt": cur_prompt, 125 | "text": outputs, 126 | "answer_id": ans_id, 127 | "model_id": model_name, 128 | "metadata": {}}) + "\n") 129 | ans_file.flush() 130 | ans_file.close() 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 135 | parser.add_argument("--model-base", type=str, default=None) 136 | parser.add_argument("--image-folder", type=str, default="") 137 | parser.add_argument("--question-file", type=str, default="tables/question.json") 138 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 139 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 140 | parser.add_argument("--num-chunks", type=int, default=1) 141 | parser.add_argument("--chunk-idx", type=int, default=0) 142 | parser.add_argument("--temperature", type=float, default=0.2) 143 | parser.add_argument("--answer-prompter", action="store_true") 144 | parser.add_argument("--single-pred-prompt", action="store_true") 145 | args = parser.parse_args() 146 | 147 | eval_model(args) 148 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | KeywordsStoppingCriteria, 19 | ) 20 | 21 | from PIL import Image 22 | 23 | import requests 24 | from PIL import Image 25 | from io import BytesIO 26 | import re 27 | 28 | 29 | def image_parser(args): 30 | out = args.image_file.split(args.sep) 31 | return out 32 | 33 | 34 | def load_image(image_file): 35 | if image_file.startswith("http") or image_file.startswith("https"): 36 | response = requests.get(image_file) 37 | image = Image.open(BytesIO(response.content)).convert("RGB") 38 | else: 39 | image = Image.open(image_file).convert("RGB") 40 | return image 41 | 42 | 43 | def load_images(image_files): 44 | out = [] 45 | for image_file in image_files: 46 | image = load_image(image_file) 47 | out.append(image) 48 | return out 49 | 50 | 51 | def eval_model(args): 52 | # Model 53 | disable_torch_init() 54 | 55 | model_name = get_model_name_from_path(args.model_path) 56 | tokenizer, model, image_processor, context_len = load_pretrained_model( 57 | args.model_path, args.model_base, model_name 58 | ) 59 | 60 | qs = args.query 61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 62 | if IMAGE_PLACEHOLDER in qs: 63 | if model.config.mm_use_im_start_end: 64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 65 | else: 66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 67 | else: 68 | if model.config.mm_use_im_start_end: 69 | qs = image_token_se + "\n" + qs 70 | else: 71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 72 | 73 | if "llama-2" in model_name.lower(): 74 | conv_mode = "llava_llama_2" 75 | elif "v1" in model_name.lower(): 76 | conv_mode = "llava_v1" 77 | elif "mpt" in model_name.lower(): 78 | conv_mode = "mpt" 79 | else: 80 | conv_mode = "llava_v0" 81 | 82 | if args.conv_mode is not None and conv_mode != args.conv_mode: 83 | print( 84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 85 | conv_mode, args.conv_mode, args.conv_mode 86 | ) 87 | ) 88 | else: 89 | args.conv_mode = conv_mode 90 | 91 | conv = conv_templates[args.conv_mode].copy() 92 | conv.append_message(conv.roles[0], qs) 93 | conv.append_message(conv.roles[1], None) 94 | prompt = conv.get_prompt() 95 | 96 | image_files = image_parser(args) 97 | images = load_images(image_files) 98 | images_tensor = process_images( 99 | images, 100 | image_processor, 101 | model.config 102 | ).to(model.device, dtype=torch.float16) 103 | 104 | input_ids = ( 105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 106 | .unsqueeze(0) 107 | .cuda(2) 108 | ) 109 | 110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 111 | keywords = [stop_str] 112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | do_sample=True if args.temperature > 0 else False, 119 | temperature=args.temperature, 120 | top_p=args.top_p, 121 | num_beams=args.num_beams, 122 | max_new_tokens=args.max_new_tokens, 123 | use_cache=True, 124 | stopping_criteria=[stopping_criteria], 125 | ) 126 | 127 | input_token_len = input_ids.shape[1] 128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 129 | if n_diff_input_output > 0: 130 | print( 131 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 132 | ) 133 | outputs = tokenizer.batch_decode( 134 | output_ids[:, input_token_len:], skip_special_tokens=True 135 | )[0] 136 | outputs = outputs.strip() 137 | if outputs.endswith(stop_str): 138 | outputs = outputs[: -len(stop_str)] 139 | outputs = outputs.strip() 140 | return outputs 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 146 | parser.add_argument("--model-base", type=str, default=None) 147 | parser.add_argument("--image-file", type=str, required=True) 148 | parser.add_argument("--query", type=str, required=True) 149 | parser.add_argument("--conv-mode", type=str, default=None) 150 | parser.add_argument("--sep", type=str, default=",") 151 | parser.add_argument("--temperature", type=float, default=0.2) 152 | parser.add_argument("--top_p", type=float, default=None) 153 | parser.add_argument("--num_beams", type=int, default=1) 154 | parser.add_argument("--max_new_tokens", type=int, default=512) 155 | args = parser.parse_args() 156 | 157 | eval_model(args) 158 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | class KeywordsStoppingCriteria(StoppingCriteria): 74 | def __init__(self, keywords, tokenizer, input_ids): 75 | self.keywords = keywords 76 | self.keyword_ids = [] 77 | self.max_keyword_len = 0 78 | for keyword in keywords: 79 | cur_keyword_ids = tokenizer(keyword).input_ids 80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 81 | cur_keyword_ids = cur_keyword_ids[1:] 82 | if len(cur_keyword_ids) > self.max_keyword_len: 83 | self.max_keyword_len = len(cur_keyword_ids) 84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 85 | self.tokenizer = tokenizer 86 | self.start_len = input_ids.shape[1] 87 | 88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 91 | for keyword_id in self.keyword_ids: 92 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 93 | return True 94 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 95 | for keyword in self.keywords: 96 | if keyword in outputs: 97 | return True 98 | return False 99 | 100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 101 | outputs = [] 102 | for i in range(output_ids.shape[0]): 103 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 104 | return all(outputs) 105 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | # from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | 28 | 29 | class LlavaConfig(LlamaConfig): 30 | model_type = "llava" 31 | 32 | 33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 34 | config_class = LlavaConfig 35 | 36 | def __init__(self, config: LlamaConfig): 37 | super(LlavaLlamaModel, self).__init__(config) 38 | 39 | 40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaConfig 42 | 43 | def __init__(self, config): 44 | super(LlamaForCausalLM, self).__init__(config) 45 | self.model = LlavaLlamaModel(config) 46 | self.pretraining_tp = config.pretraining_tp 47 | self.vocab_size = config.vocab_size 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> Union[Tuple, CausalLMOutputWithPast]: 70 | 71 | if inputs_embeds is None: 72 | ( 73 | input_ids, 74 | position_ids, 75 | attention_mask, 76 | past_key_values, 77 | inputs_embeds, 78 | labels 79 | ) = self.prepare_inputs_labels_for_multimodal( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | labels, 85 | images 86 | ) 87 | 88 | return super().forward( 89 | input_ids=input_ids, 90 | attention_mask=attention_mask, 91 | position_ids=position_ids, 92 | past_key_values=past_key_values, 93 | inputs_embeds=inputs_embeds, 94 | labels=labels, 95 | use_cache=use_cache, 96 | output_attentions=output_attentions, 97 | output_hidden_states=output_hidden_states, 98 | return_dict=return_dict 99 | ) 100 | 101 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 102 | images = kwargs.pop("images", None) 103 | _inputs = super().prepare_inputs_for_generation( 104 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 105 | ) 106 | if images is not None: 107 | _inputs['images'] = images 108 | return _inputs 109 | 110 | AutoConfig.register("llava", LlavaConfig) 111 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 112 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/serve/__init__.py -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | # Similar operation in model_worker.py 56 | image_tensor = process_images([image], image_processor, model.config) 57 | if type(image_tensor) is list: 58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 59 | else: 60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 61 | 62 | while True: 63 | try: 64 | inp = input(f"{roles[0]}: ") 65 | except EOFError: 66 | inp = "" 67 | if not inp: 68 | print("exit...") 69 | break 70 | 71 | print(f"{roles[1]}: ", end="") 72 | 73 | if image is not None: 74 | # first message 75 | if model.config.mm_use_im_start_end: 76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 77 | else: 78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 79 | conv.append_message(conv.roles[0], inp) 80 | image = None 81 | else: 82 | # later messages 83 | conv.append_message(conv.roles[0], inp) 84 | conv.append_message(conv.roles[1], None) 85 | prompt = conv.get_prompt() 86 | 87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 92 | 93 | with torch.inference_mode(): 94 | output_ids = model.generate( 95 | input_ids, 96 | images=image_tensor, 97 | do_sample=True if args.temperature > 0 else False, 98 | temperature=args.temperature, 99 | max_new_tokens=args.max_new_tokens, 100 | streamer=streamer, 101 | use_cache=True, 102 | stopping_criteria=[stopping_criteria]) 103 | 104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 105 | conv.messages[-1][-1] = outputs 106 | 107 | if args.debug: 108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-file", type=str, required=True) 116 | parser.add_argument("--device", type=str, default="cuda") 117 | parser.add_argument("--conv-mode", type=str, default=None) 118 | parser.add_argument("--temperature", type=float, default=0.2) 119 | parser.add_argument("--max-new-tokens", type=int, default=512) 120 | parser.add_argument("--load-8bit", action="store_true") 121 | parser.add_argument("--load-4bit", action="store_true") 122 | parser.add_argument("--debug", action="store_true") 123 | args = parser.parse_args() 124 | main(args) 125 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/HalDet-LLaVA/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /HalDet-LLaVA/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /HalDet-LLaVA/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.3 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | aiohttp==3.9.5 5 | aiosignal==1.3.1 6 | albumentations==1.4.4 7 | aliyun-python-sdk-core==2.15.1 8 | aliyun-python-sdk-kms==2.16.2 9 | altair==5.3.0 10 | annotated-types==0.6.0 11 | anyio==4.3.0 12 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work 13 | async-timeout==4.0.3 14 | asynctest==0.13.0 15 | attrs==23.2.0 16 | bitsandbytes==0.38.1 17 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 18 | certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi 19 | cffi==1.16.0 20 | chardet==5.2.0 21 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 22 | click==8.1.7 23 | codecov==2.1.13 24 | colorama==0.4.6 25 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work 26 | contourpy==1.2.1 27 | coverage==7.4.4 28 | crcmod==1.7 29 | cryptography==42.0.5 30 | cycler==0.12.1 31 | Cython==3.0.10 32 | debugpy @ file:///croot/debugpy_1690905042057/work 33 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 34 | defusedxml==0.7.1 35 | distro==1.9.0 36 | einops==0.7.0 37 | einops-exts==0.0.4 38 | exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work 39 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work 40 | fastapi==0.110.1 41 | ffmpy==0.3.2 42 | filelock @ file:///croot/filelock_1700591183607/work 43 | flake8==7.0.0 44 | fonttools==4.51.0 45 | frozenlist==1.4.1 46 | fsspec==2024.3.1 47 | git-lfs==1.6 48 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645438755360/work 49 | gradio==3.44.4 50 | gradio_client==0.5.1 51 | -e git+https://github.com/IDEA-Research/GroundingDINO.git@3a2b344737cb0f3c3f12d0f7f58be5dc71198289#egg=groundingdino 52 | h11==0.14.0 53 | httpcore==1.0.5 54 | httpx==0.27.0 55 | huggingface-hub==0.22.2 56 | idna @ file:///croot/idna_1666125576474/work 57 | imageio==2.34.0 58 | imgaug==0.4.0 59 | importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1710971335535/work 60 | importlib_resources==6.4.0 61 | iniconfig==2.0.0 62 | interrogate==1.7.0 63 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work 64 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1701831663892/work 65 | isort==5.13.2 66 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work 67 | Jinja2 @ file:///croot/jinja2_1706733616596/work 68 | jmespath==0.10.0 69 | joblib==1.4.0 70 | jsonschema==4.21.1 71 | jsonschema-specifications==2023.12.1 72 | jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1710255804825/work 73 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257447442/work 74 | kiwisolver==1.4.5 75 | kwarray==0.6.18 76 | lanms_neo==1.0.2 77 | lazy_loader==0.4 78 | lmdb==1.4.1 79 | Markdown==3.6 80 | markdown-it-py==3.0.0 81 | MarkupSafe @ file:///croot/markupsafe_1704205993651/work 82 | matplotlib==3.8.4 83 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work 84 | mccabe==0.7.0 85 | mdurl==0.1.2 86 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work 87 | mkl-random @ file:///croot/mkl_random_1695059800811/work 88 | mkl-service==2.4.0 89 | mmcv==2.0.0 90 | mmdet==3.0.0 91 | mmengine==0.10.3 92 | -e git+https://github.com/OpenKG-ORG/EasyDetect.git@7b857283933f4f84a31b8ae5c0df7af7a1afba7d#egg=mmocr&subdirectory=pipeline/mmocr 93 | model-index==0.1.11 94 | mpmath @ file:///croot/mpmath_1690848262763/work 95 | multidict==6.0.5 96 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work 97 | networkx @ file:///croot/networkx_1690561992265/work 98 | nltk==3.8.1 99 | numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp39-cp39-linux_x86_64.whl#sha256=6094eeedd869502faa0fd0a8c5ad3a70c5779be06ddd1feb7627e5c212fac420 100 | nvidia-cublas-cu12==12.1.3.1 101 | nvidia-cuda-cupti-cu12==12.1.105 102 | nvidia-cuda-nvrtc-cu12==12.1.105 103 | nvidia-cuda-runtime-cu12==12.1.105 104 | nvidia-cudnn-cu12==8.9.2.26 105 | nvidia-cufft-cu12==11.0.2.54 106 | nvidia-curand-cu12==10.3.2.106 107 | nvidia-cusolver-cu12==11.4.5.107 108 | nvidia-cusparse-cu12==12.1.0.106 109 | nvidia-nccl-cu12==2.19.3 110 | nvidia-nvjitlink-cu12==12.4.127 111 | nvidia-nvtx-cu12==12.1.105 112 | openai==1.23.1 113 | opencv-python==4.9.0.80 114 | opencv-python-headless==4.9.0.80 115 | opendatalab==0.0.10 116 | openmim==0.3.9 117 | openxlab==0.0.38 118 | ordered-set==4.1.0 119 | orjson==3.10.1 120 | oss2==2.17.0 121 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1710075952259/work 122 | pandas==2.2.2 123 | parameterized==0.9.0 124 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work 125 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work 126 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 127 | pillow @ file:///croot/pillow_1707233021655/work 128 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work 129 | pluggy==1.4.0 130 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work 131 | protobuf==5.26.1 132 | psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1705722404069/work 133 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 134 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 135 | py==1.11.0 136 | pyclipper==1.3.0.post5 137 | pycocotools==2.0.7 138 | pycodestyle==2.11.1 139 | pycparser==2.22 140 | pycryptodome==3.20.0 141 | pydantic==2.7.0 142 | pydantic_core==2.18.1 143 | pydub==0.25.1 144 | pyflakes==3.2.0 145 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work 146 | pyparsing==3.1.2 147 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 148 | pytest==8.1.1 149 | pytest-cov==5.0.0 150 | pytest-runner==6.0.1 151 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work 152 | python-multipart==0.0.9 153 | pytz==2023.4 154 | PyYAML==6.0.1 155 | pyzmq @ file:///croot/pyzmq_1705605076900/work 156 | rapidfuzz==3.8.1 157 | referencing==0.34.0 158 | regex==2024.4.16 159 | requests @ file:///croot/requests_1707355572290/work 160 | rich==13.4.2 161 | rpds-py==0.18.0 162 | ruff==0.4.0 163 | safetensors==0.4.3 164 | scikit-image==0.22.0 165 | scikit-learn==1.4.2 166 | scipy==1.13.0 167 | semantic-version==2.10.0 168 | sentencepiece==0.2.0 169 | shapely==2.0.4 170 | shellingham==1.5.4 171 | shortuuid==1.0.13 172 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 173 | sniffio==1.3.1 174 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 175 | starlette==0.37.2 176 | supervision==0.19.0 177 | sympy @ file:///croot/sympy_1701397643339/work 178 | tabulate==0.9.0 179 | termcolor==2.4.0 180 | terminaltables==3.1.10 181 | threadpoolctl==3.4.0 182 | tifffile==2024.4.18 183 | timm==0.9.16 184 | tokenizers==0.13.3 185 | tomli==2.0.1 186 | tomlkit==0.12.0 187 | toolz==0.12.1 188 | torch==2.0.1 189 | torchaudio==2.0.2 190 | torchvision==0.15.2 191 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1708363103305/work 192 | tqdm==4.65.2 193 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work 194 | transformers==4.31.0 195 | triton==2.0.0 196 | typer==0.12.3 197 | typing_extensions==4.11.0 198 | tzdata==2024.1 199 | ubelt==1.3.5 200 | urllib3 @ file:///croot/urllib3_1707770551213/work 201 | uvicorn==0.29.0 202 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work 203 | websockets==11.0.3 204 | xdoctest==1.1.3 205 | yapf==0.40.2 206 | yarl==1.9.4 207 | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work 208 | -------------------------------------------------------------------------------- /HalDet-LLaVA/scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /HalDet-LLaVA/scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /HalDet-LLaVA/scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /examples/058214af21a03013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/examples/058214af21a03013.jpg -------------------------------------------------------------------------------- /examples/43adec54f56ff7af.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/examples/43adec54f56ff7af.jpg -------------------------------------------------------------------------------- /examples/508.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/examples/508.jpg -------------------------------------------------------------------------------- /examples/63.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/examples/63.jpg -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/.DS_Store -------------------------------------------------------------------------------- /figs/datasetinfo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/datasetinfo.jpg -------------------------------------------------------------------------------- /figs/easydetect.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/easydetect.jpg -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/framework.png -------------------------------------------------------------------------------- /figs/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/intro.png -------------------------------------------------------------------------------- /figs/view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/view.png -------------------------------------------------------------------------------- /figs/条形图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/条形图.png -------------------------------------------------------------------------------- /figs/饼图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/figs/饼图.png -------------------------------------------------------------------------------- /pipeline/claim_generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | 4 | class ClaimGenerator: 5 | def __init__(self, config, chat): 6 | with open(config["prompts"]["claim_generate"],"r",encoding='utf-8') as file: 7 | self.prompt = yaml.load(file, yaml.FullLoader) 8 | self.chat = chat 9 | 10 | def get_response(self, text): 11 | user_prompt = self.prompt["user"].format(text=text) 12 | message = [ 13 | {"role": "system", "content": self.prompt["system"]}, 14 | {"role": "user", "content": user_prompt} 15 | ] 16 | response = self.chat.get_response(message=message) 17 | try: 18 | response = json.loads(response) 19 | except Exception as e: 20 | print(e) 21 | 22 | claim_list = [] 23 | cnt = 0 24 | for seg in response: 25 | for cla in seg["claims"]: 26 | cnt=(lambda x:x+1)(cnt) 27 | claim_list.append("claim{}: {}".format(str(cnt), cla["claim"])) 28 | claim_list = "\n".join([claim for claim in claim_list]) 29 | return response, claim_list 30 | 31 | 32 | -------------------------------------------------------------------------------- /pipeline/config/config.yaml: -------------------------------------------------------------------------------- 1 | openai: 2 | api_key: input your openai api key 3 | base_url: 4 | temperature: 0.2 5 | max_tokens: 1024 6 | tool: 7 | detect: 8 | groundingdino_config: the path of GroundingDINO_SwinT_OGC.py 9 | model_path: the path of groundingdino_swint_ogc.pth 10 | device: cuda:0 11 | BOX_TRESHOLD: 0.35 12 | TEXT_TRESHOLD: 0.25 13 | AREA_THRESHOLD: 0.001 14 | ocr: 15 | dbnetpp_config: the path of dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015.py 16 | dbnetpp_path: the path of dbnetpp.pth 17 | maerec_config: the path of maerec_b_union14m.py 18 | maerec_path: the path of maerec_b.pth 19 | device: cuda:0 20 | content: word.number 21 | cachefiles_path: the path of cache_files to save temp images 22 | BOX_TRESHOLD: 0.2 23 | TEXT_TRESHOLD: 0.25 24 | google_serper: 25 | serper_api_key: input your serper api key 26 | snippet_cnt: 10 27 | prompts: 28 | claim_generate: pipeline/prompts/claim_generate.yaml 29 | query_generate: pipeline/prompts/query_generate.yaml 30 | verify: pipeline/prompts/verify.yaml -------------------------------------------------------------------------------- /pipeline/examples/animal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/pipeline/examples/animal.jpg -------------------------------------------------------------------------------- /pipeline/examples/ball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/pipeline/examples/ball.jpg -------------------------------------------------------------------------------- /pipeline/examples/football.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/pipeline/examples/football.jpg -------------------------------------------------------------------------------- /pipeline/examples/sandbeach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/pipeline/examples/sandbeach.jpg -------------------------------------------------------------------------------- /pipeline/openai_wrapper.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from openai import OpenAI, AsyncOpenAI 3 | 4 | class SyncChat: 5 | def __init__(self, model, config): 6 | if config["base_url"] != None: 7 | self.sync_client = OpenAI(base_url=config["base_url"],api_key=config["api_key"]) 8 | else: 9 | self.sync_client = OpenAI(api_key=config["api_key"]) 10 | self.model = model 11 | 12 | def get_response(self, message, temperature=0.2, max_tokens=1024): 13 | response = self.sync_client.chat.completions.create( 14 | model=self.model, 15 | messages=message, 16 | temperature=temperature, 17 | max_tokens=max_tokens) 18 | return response.choices[0].message.content 19 | 20 | 21 | class AsyncChat: 22 | def __init__(self, model, config): 23 | if config["base_url"] != None: 24 | self.async_client = AsyncOpenAI(base_url=config["base_url"], api_key=config["api_key"]) 25 | else: 26 | self.async_client = AsyncOpenAI(api_key=config["api_key"]) 27 | self.model = model 28 | 29 | async def get_response(self, messages,temperature=0.2,max_tokens=1024): 30 | async def openai_reply(message): 31 | response = await self.async_client.chat.completions.create( 32 | model=self.model, 33 | messages=message, 34 | temperature=temperature, 35 | max_tokens=max_tokens,) 36 | return response.choices[0].message.content 37 | 38 | response_list = [openai_reply(message) for message in messages] 39 | return await asyncio.gather(*response_list) 40 | -------------------------------------------------------------------------------- /pipeline/prompts/claim_generate.yaml: -------------------------------------------------------------------------------- 1 | system: |- 2 | You are a brilliant claim generator. 3 | user: |- 4 | Given a segment of text generated by a large visual language model, the assertion is a statement that claims whether something conflicts with visual information and can be verified by humans.Your task is to first divide the text into segments, then accurately identify and extract every asserted claim within each segment. Then, resolve any coreference (pronouns or other referring expressions) in the claim for clarity. Each claim should be concise (less than 15 words) and self-contained. 5 | Your response MUST be a list of dictionaries. Each dictionary contains two keys, "segment" and "claims." The key "segment" corresponds to each segment of the given text (each segment should match the original text's segments and be arranged in the original order). Then, the value corresponding to the key "claims" is a list of assertions, extracted based on this segment. Each dictionary within this list should contain the key "claim," corresponding to the extracted claim (with all references resolved). 6 | You MUST only respond in the format as described below. DO NOT RESPOND WITH ANYTHING ELSE. ADDING ANY OTHER EXTRA NOTES THAT VIOLATE THE RESPONSE FORMAT IS BANNED. START YOUR RESPONSE WITH '['. 7 | WHEN THERE ARE DOUBLE QUOTATION MARKS " IN THE GENERATED SEGMENT AND CLAIM, YOU NEED TO ADD AN ESCAPE CHARACTER " BEFORE THEM!!! 8 | [response format]: 9 | [ 10 | {{ 11 | "segment":"Ensure that the segment match the original text's segments corresponds to each segment of the given text(each segment should match the original text's segments and be arranged in the original order)" 12 | "claims":[{{ 13 | "claim": "Ensure that the claim is fewer than 15 words and conveys a complete idea. Resolve any coreference (pronouns or other referring expressions) in the claim for clarity", 14 | }}, 15 | ...] 16 | }}, 17 | ... 18 | ] 19 | 20 | Here are two examples: 21 | [text]: This drink is Fresca. It is a lemon-lime flavored soft drink, commonly available in the United States. 22 | [response]: [{{"segment": "This drink is Fresca.","claims": [{{"claim": "This drink is Fresca"}}]}}, {{"segment": "It is a lemon-lime flavored soft drink, commonly available in the United States.","claims": [{{"claim": "Fresca is a lemon-lime flavored soft drink"}}, {{"claim": "Fresca is commonly available in the United States"}}]}}] 23 | 24 | [text]: The book with the cover featuring Dylan Thomas, \"Quite Early One Morning,\" was published in 1999, and was written by Dylan Thomas himself. 25 | [response]: [{{"segment": "The book with the cover featuring Dylan Thomas, \"Quite Early One Morning,\" was published in 1999, and was written by Dylan Thomas himself.","claims": [{{"claim": "\"Quite Early One Morning\" was published in 1999"}}, {{"claim": "The book's cover features Dylan Thomas"}}, {{"claim": "The book was written by Dylan Thomas"}}],}}] 26 | 27 | [text]:{text} 28 | [response]: -------------------------------------------------------------------------------- /pipeline/query_generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | import copy 4 | import asyncio 5 | from nltk.corpus import wordnet 6 | 7 | class QueryGenerator: 8 | def __init__(self, config, chat): 9 | with open(config["prompts"]["query_generate"],"r",encoding='utf-8') as file: 10 | self.prompt = yaml.load(file, yaml.FullLoader) 11 | self.chat = chat 12 | 13 | def objects_extract(self, claim_list): 14 | user_prompt = self.prompt[self.type]["object"]["user"].format(claims=claim_list) 15 | message = [[ 16 | {"role": "system", "content": self.prompt[self.type]["object"]["system"]}, 17 | {"role": "user", "content": user_prompt} 18 | ],] 19 | loop = asyncio.new_event_loop() 20 | asyncio.set_event_loop(loop) 21 | response = loop.run_until_complete(self.chat.get_response(messages=message)) 22 | 23 | try: 24 | response = json.loads(response[0]) 25 | except Exception as e: 26 | print(e) 27 | 28 | objects = set(()) 29 | for key in response: 30 | object_list = response[key].split(".") 31 | response[key] = object_list 32 | for object in object_list: 33 | if object != "none": 34 | objects.add(object) 35 | 36 | objects = ".".join([object for object in list(objects)]) 37 | return response, objects 38 | 39 | def get_hypernyms(self, word): 40 | synsets = wordnet.synsets(word) 41 | hypernyms = [] 42 | 43 | for synset in synsets: 44 | for hypernym in synset.hypernyms(): 45 | hypernyms.extend(hypernym.lemma_names()) 46 | 47 | hypernyms = list(set(hypernyms)) 48 | hypernyms = ".".join([hypernym for hypernym in hypernyms]) 49 | return hypernyms 50 | 51 | def remove_hypernyms(self, objects): 52 | hypernyms_dict = {} 53 | for object in objects: 54 | hypernyms = self.get_hypernyms(object) 55 | hypernyms_dict[object] = hypernyms 56 | 57 | backup = copy.deepcopy(objects) 58 | for object in objects: 59 | hypernyms_list = [] 60 | for key in hypernyms_dict: 61 | if key != object: 62 | hypernyms_list.append(hypernyms_dict[key]) 63 | hypernyms_list = ".".join([hypernym for hypernym in hypernyms_list]) 64 | if object in hypernyms_list: 65 | backup.remove(object) 66 | 67 | objects = ".".join([object for object in backup]) 68 | return objects 69 | 70 | def filter(self, res, object_list): 71 | attribute_ques_list = json.loads(res[0]) 72 | scenetext_ques_list = json.loads(res[1]) 73 | fact_ques_list = json.loads(res[2]) 74 | objects = set(()) 75 | for idx, key in enumerate(fact_ques_list): 76 | if fact_ques_list[key][0] != "none": 77 | object_list[idx] = "none" 78 | attribute_ques_list[key] = ["none"] 79 | scenetext_ques_list[key] = ["none"] 80 | else: 81 | for object in object_list[key]: 82 | if object != "none": 83 | objects.add(object) 84 | 85 | objects = self.remove_hypernyms(objects) 86 | return attribute_ques_list, scenetext_ques_list, fact_ques_list, objects 87 | 88 | def get_response(self, claim_list, type): 89 | self.type = type 90 | object_list, objects = self.objects_extract(claim_list=claim_list) 91 | self.message_list = [ 92 | [{"role": "system", "content": self.prompt[type]["attribute"]["system"]}, {"role": "user", "content": self.prompt[type]["attribute"]["user"].format(objects=objects,claims=claim_list)}], 93 | [{"role": "system", "content": self.prompt[type]["scene-text"]["system"]}, {"role": "user", "content": self.prompt[type]["scene-text"]["user"].format(claims=claim_list)}], 94 | [{"role": "system", "content": self.prompt[type]["fact"]["system"]}, {"role": "user", "content": self.prompt[type]["fact"]["user"].format(claims=claim_list)}] 95 | ] 96 | loop = asyncio.new_event_loop() 97 | asyncio.set_event_loop(loop) 98 | res = loop.run_until_complete(self.chat.get_response(messages=self.message_list)) 99 | if self.type == "image-to-text": 100 | attribute_ques_list, scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list) 101 | else: 102 | attribute_ques_list, scenetext_ques_list, fact_ques_list = json.loads(res[0]), json.loads(res[1]), json.loads(res[2]) 103 | 104 | return objects, attribute_ques_list, scenetext_ques_list, fact_ques_list -------------------------------------------------------------------------------- /pipeline/run_pipeline.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pipeline.openai_wrapper import * 3 | from pipeline.claim_generate import * 4 | from pipeline.query_generate import * 5 | from pipeline.tool_execute import * 6 | from pipeline.verify import * 7 | 8 | class Pipeline: 9 | def __init__(self): 10 | with open("pipeline/config/config.yaml", 'r', encoding='utf-8') as file: 11 | self.config = yaml.load(file, yaml.FullLoader) 12 | 13 | self.syncchat = SyncChat(model="gpt-4-1106-preview", config=self.config["openai"]) 14 | self.asyncchat = AsyncChat(model="gpt-4-1106-preview", config=self.config["openai"]) 15 | self.visionchat = SyncChat(model="gpt-4-vision-preview", config=self.config["openai"]) 16 | 17 | self.claim_generator = ClaimGenerator(config=self.config,chat=self.syncchat) 18 | self.query_generator = QueryGenerator(config=self.config,chat=self.asyncchat) 19 | self.tool = Tool(config=self.config) 20 | self.verifier = Verifier(config=self.config, chat=self.visionchat) 21 | 22 | def run(self, text, image_path, type): 23 | response, claim_list = self.claim_generator.get_response(text=text) 24 | objects, attribute_ques_list, scenetext_ques_list, fact_ques_list = self.query_generator.get_response(claim_list=claim_list, type=type) 25 | object_res, attribue_res, text_res, fact_res = self.tool.execute(image_path=image_path, 26 | objects=objects, 27 | attribute_list=attribute_ques_list, 28 | scenetext_list=scenetext_ques_list, 29 | fact_list=fact_ques_list) 30 | response = self.verifier.get_response(type, object_res, attribue_res, text_res, fact_res, claim_list, image_path) 31 | return response,claim_list 32 | 33 | 34 | -------------------------------------------------------------------------------- /pipeline/tool/detect.py: -------------------------------------------------------------------------------- 1 | # Utilizing the following code snippet based on the algorithm proposed in the paper: 2 | # Title: "Woodpecker: Hallucination correction for multimodal large language models" 3 | # Author: Yin et al. 4 | # Original code source: https://github.com/BradyFU/Woodpecker/blob/main/models/detector.py 5 | 6 | import yaml 7 | import torch 8 | import os 9 | import shortuuid 10 | import numpy as np 11 | from PIL import Image 12 | from torchvision.ops import box_convert 13 | from pipeline.tool.ocr import * 14 | from pipeline.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict 15 | 16 | class DetectModel: 17 | def __init__(self, config): 18 | self.config = config 19 | self.model = load_model(self.config["tool"]["detect"]["groundingdino_config"], 20 | self.config["tool"]["detect"]["model_path"], 21 | device=self.config["tool"]["detect"]["device"]) 22 | 23 | 24 | def execute(self, image_path, content, box_threshold, text_threshold,save_path): 25 | image_source, image = load_image(image_path) 26 | boxes, _, phrases = predict(model=self.model,image=image,caption=content,box_threshold=box_threshold,text_threshold=text_threshold,device=self.config["tool"]["detect"]["device"]) 27 | h, w, _ = image_source.shape 28 | torch_boxes = boxes * torch.Tensor([w, h, w, h]) 29 | xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() 30 | normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist() 31 | result = {"boxes":normed_xyxy, "phrases":phrases, "save_path":[]} 32 | if save_path != None: 33 | dir_name = image_path.split("/")[-1][:-4] 34 | cache_dir = save_path + dir_name 35 | os.makedirs(cache_dir, exist_ok=True) 36 | image_path_list = [] 37 | for box, norm_box in zip(xyxy, normed_xyxy): 38 | # filter out too small text 39 | if (norm_box[2]-norm_box[0]) * (norm_box[3]-norm_box[1]) < self.config["tool"]["detect"]["AREA_THRESHOLD"]: 40 | continue 41 | crop_id = shortuuid.uuid() 42 | crop_img = Image.fromarray(image_source).crop(box) 43 | crop_path = os.path.join(cache_dir, f"{crop_id}.jpg") 44 | crop_img.save(crop_path) 45 | image_path_list.append(crop_path) 46 | result["save_path"] = image_path_list 47 | 48 | return result 49 | -------------------------------------------------------------------------------- /pipeline/tool/google_serper.py: -------------------------------------------------------------------------------- 1 | # Utilizing the following code snippet based on the algorithm proposed in the paper: 2 | # Title: "FacTool: Factuality Detection in Generative AI--A Tool Augmented Framework for Multi-Task and Multi-Domain Scenarios" 3 | # Author: Chern et al. 4 | # Original code source: https://github.com/GAIR-NLP/factool/blob/main/factool/knowledge_qa/google_serper.py 5 | 6 | import asyncio 7 | import aiohttp 8 | 9 | class GoogleSerperAPIWrapper(): 10 | """Wrapper around the Serper.dev Google Search API. 11 | You can create a free API key at https://serper.dev. 12 | To use, you should have the environment variable ``SERPER_API_KEY`` 13 | set with your API key, or pass `serper_api_key` as a named parameter 14 | to the constructor. 15 | Example: 16 | .. code-block:: python 17 | from langchain import GoogleSerperAPIWrapper 18 | google_serper = GoogleSerperAPIWrapper() 19 | """ 20 | def __init__(self, config): 21 | self.config = config 22 | self.k = self.config["tool"]["google_serper"]["snippet_cnt"] 23 | self.gl = "us" 24 | self.hl = "en" 25 | self.serper_api_key = self.config["tool"]["google_serper"]["serper_api_key"] 26 | assert self.serper_api_key is not None, "Please set the SERPER_API_KEY environment variable." 27 | assert self.serper_api_key != '', "Please set the SERPER_API_KEY environment variable." 28 | 29 | async def _google_serper_search_results(self, session, search_term: str, gl: str, hl: str) -> dict: 30 | headers = { 31 | "X-API-KEY": self.serper_api_key, 32 | "Content-Type": "application/json", 33 | } 34 | params = {"q": search_term, "gl": gl, "hl": hl} 35 | async with session.post( 36 | "https://google.serper.dev/search", headers=headers, params=params, raise_for_status=True 37 | ) as response: 38 | return await response.json() 39 | 40 | def _parse_results(self, results): 41 | snippets = [] 42 | if results.get("answerBox"): 43 | answer_box = results.get("answerBox", {}) 44 | if answer_box.get("answer"): 45 | element = {"content":answer_box.get("answer"),"source":"None"} 46 | return [element] 47 | elif answer_box.get("snippet"): 48 | element = {"content":answer_box.get("snippet").replace("\n", " "),"source":"None"} 49 | return [element] 50 | elif answer_box.get("snippetHighlighted"): 51 | element = {"content":answer_box.get("snippetHighlighted"),"source":"None"} 52 | return [element] 53 | 54 | if results.get("knowledgeGraph"): 55 | kg = results.get("knowledgeGraph", {}) 56 | title = kg.get("title") 57 | entity_type = kg.get("type") 58 | if entity_type: 59 | element = {"content":f"{title}: {entity_type}","source":"None"} 60 | snippets.append(element) 61 | description = kg.get("description") 62 | if description: 63 | element = {"content":description,"source":"None"} 64 | snippets.append(element) 65 | for attribute, value in kg.get("attributes", {}).items(): 66 | element = {"content":f"{attribute}: {value}","source":"None"} 67 | snippets.append(element) 68 | 69 | for result in results["organic"][: self.k]: 70 | if "snippet" in result: 71 | if result["snippet"].find("Missing") != -1: 72 | continue 73 | element = {"content":result["snippet"],"source":result["link"]} 74 | snippets.append(element) 75 | for attribute, value in result.get("attributes", {}).items(): 76 | element = {"content":f"{attribute}: {value}","source":result["link"]} 77 | if element["content"].find("Missing") != -1: 78 | continue 79 | snippets.append(element) 80 | 81 | if len(snippets) == 0: 82 | element = {"content":"No good Google Search Result was found","source":"None"} 83 | return [element] 84 | 85 | snippets = snippets[:int(self.k / 2)] 86 | 87 | return snippets 88 | 89 | async def parallel_searches(self, search_queries, gl, hl): 90 | async with aiohttp.ClientSession() as session: 91 | tasks = [self._google_serper_search_results(session, query, gl, hl) for query in search_queries] 92 | search_results = await asyncio.gather(*tasks, return_exceptions=True) 93 | return search_results 94 | 95 | 96 | async def run(self, queries): 97 | """Run query through GoogleSearch and parse result.""" 98 | flattened_queries = [] 99 | 100 | for sublist in queries: 101 | if sublist is None: 102 | sublist = ['None', 'None'] 103 | for item in sublist: 104 | flattened_queries.append(item) 105 | 106 | results = await self.parallel_searches(flattened_queries, gl=self.gl, hl=self.hl) 107 | snippets_list = [] 108 | for i in range(len(results)): 109 | snippets_list.append(self._parse_results(results[i])) 110 | snippets_split = [snippets_list[i] + snippets_list[i+1] for i in range(0, len(snippets_list), 2)] 111 | return snippets_split 112 | 113 | 114 | def execute(self, content): 115 | query_list = [content.split(",")[0][2:-1],content.split(",")[1][2:-2]] 116 | loop = asyncio.new_event_loop() 117 | asyncio.set_event_loop(loop) 118 | search_outputs_for_claims = loop.run_until_complete(self.run([query_list])) 119 | evidences = [[output['content'] for output in search_outputs_for_claim] for search_outputs_for_claim in 120 | search_outputs_for_claims] 121 | return evidences[0] 122 | -------------------------------------------------------------------------------- /pipeline/tool/ocr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from pipeline.mmocr.mmocr.apis.inferencers import MMOCRInferencer 4 | 5 | class OCRModel: 6 | def __init__(self, config): 7 | self.config = config 8 | self.mmocr_inferencer = MMOCRInferencer( 9 | det=self.config["tool"]["ocr"]["dbnetpp_config"], 10 | det_weights=self.config["tool"]["ocr"]["dbnetpp_path"], 11 | rec=self.config["tool"]["ocr"]["maerec_config"], 12 | rec_weights=self.config["tool"]["ocr"]["maerec_path"], 13 | device=self.config["tool"]["ocr"]["device"]) 14 | 15 | def get_single_result(self, image_path): 16 | data = Image.open(image_path).convert("RGB") 17 | img = np.array(data) 18 | self.mmocr_inferencer.mode = 'rec' 19 | result = self.mmocr_inferencer(img, return_vis=True) 20 | result = result['predictions'][0] 21 | rec_text = result['rec_texts'][0] 22 | rec_score = result['rec_scores'][0] 23 | out_results = f'pred: {rec_text} \n score: {rec_score:.2f}' 24 | return out_results.split("\n")[0][6:] 25 | 26 | def execute(self, image_path_list): 27 | ocr_det_res = [] 28 | for image_path in image_path_list: 29 | res = self.get_single_result(image_path) 30 | ocr_det_res.append(res) 31 | return ocr_det_res -------------------------------------------------------------------------------- /pipeline/tool_execute.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from pipeline.openai_wrapper import * 3 | from pipeline.tool.detect import * 4 | from pipeline.tool.ocr import * 5 | from pipeline.tool.google_serper import * 6 | 7 | class Tool: 8 | def __init__(self, config): 9 | self.config = config 10 | self.detector = DetectModel(config=self.config) 11 | self.ocr = OCRModel(config=self.config) 12 | self.visionchat = SyncChat(model="gpt-4-vision-preview",config=self.config["openai"]) 13 | self.search = GoogleSerperAPIWrapper(config=self.config) 14 | 15 | def get_object_res(self, image_path, objects): 16 | object_res = self.detector.execute(image_path=image_path, 17 | content=objects, 18 | box_threshold=self.config["tool"]["detect"]["BOX_TRESHOLD"], 19 | text_threshold=self.config["tool"]["detect"]["TEXT_TRESHOLD"], 20 | save_path=None) 21 | return object_res 22 | 23 | def get_ocr_res(self, image_path, scenetext_list): 24 | use_ocr = False 25 | for key in scenetext_list: 26 | if scenetext_list[key][0] != "none": 27 | use_ocr = True 28 | ocr_res = None 29 | if use_ocr: 30 | ocr_res = self.detector.execute(image_path=image_path, 31 | content=self.config["tool"]["ocr"]["content"], 32 | box_threshold=self.config["tool"]["ocr"]["BOX_TRESHOLD"], 33 | text_threshold=self.config["tool"]["ocr"]["TEXT_TRESHOLD"], 34 | save_path=self.config["tool"]["ocr"]["cachefiles_path"]) 35 | 36 | ocr_res["phrases"] = self.ocr.execute(image_path_list=ocr_res["save_path"]) 37 | return ocr_res 38 | 39 | def get_attribute_res(self, image_path, attribute_list): 40 | def encode_image(image_path): 41 | with open(image_path, "rb") as image_file: 42 | return base64.b64encode(image_file.read()).decode('utf-8') 43 | queries = "" 44 | cnt = 1 45 | for key in attribute_list: 46 | if attribute_list[key][0] != "none": 47 | for query in attribute_list[key]: 48 | queries += str(cnt) + "." + query + "\n" 49 | cnt += 1 50 | if queries == "": 51 | attribue_res = "none information" 52 | else: 53 | img = encode_image(image_path) 54 | message = [{"role": "user","content": [{"type": "image_url","image_url": f"data:image/jpeg;base64,{img}"},{"type": "text", "text": queries}]}] 55 | attribue_res = self.visionchat.get_response(message=message) 56 | return attribue_res 57 | 58 | def get_fact_res(self, fact_list): 59 | fact_res = "" 60 | cnt = 1 61 | for key in fact_list: 62 | if fact_list[key][0] != "none": 63 | evidences = self.search.execute(content=str(fact_list[key])) 64 | for evidence in evidences: 65 | fact_res += str(cnt) + "." + evidence + "\n" 66 | cnt += 1 67 | if fact_res == "": 68 | fact_res = "none information" 69 | return fact_res 70 | 71 | def execute(self, image_path, objects, attribute_list, scenetext_list, fact_list): 72 | object_res = self.get_object_res(image_path, objects) 73 | attribue_res = self.get_attribute_res(image_path, attribute_list) 74 | ocr_res = self.get_ocr_res(image_path, scenetext_list) 75 | fact_res = self.get_fact_res(fact_list) 76 | return object_res, attribue_res, ocr_res, fact_res 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /pipeline/verify.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import base64 3 | 4 | 5 | class Verifier: 6 | def __init__(self, config, chat): 7 | with open(config["prompts"]["verify"],"r",encoding='utf-8') as file: 8 | self.prompt = yaml.load(file, yaml.FullLoader) 9 | self.chat = chat 10 | 11 | def get_response(self, type, object_res, attribue_res, text_res, fact_res, claim_list, image_path): 12 | def encode_image(image_path): 13 | with open(image_path, "rb") as image_file: 14 | return base64.b64encode(image_file.read()).decode('utf-8') 15 | input = ''' 16 | Here is the object detection expert model's result: 17 | {object} 18 | 19 | Here is the scene text recognition expert model's result: 20 | {text} 21 | 22 | Here is the attribute information: 23 | {attribute} 24 | 25 | Here is the external knowledge: 26 | {fact} 27 | 28 | Here is the claim list: 29 | {claims} 30 | 31 | Output: 32 | ''' 33 | 34 | object_det_res, text_det_res = "", "" 35 | for object_name, box in zip(object_res["phrases"], object_res["boxes"]): 36 | object_det_res += "{} {} \n".format(object_name, str(box)) 37 | 38 | if text_res != None: 39 | for text_name, box in zip(text_res["phrases"], text_res["boxes"]): 40 | text_det_res += text_name + " " + str(box) + "\n" 41 | else: 42 | text_det_res = "none information" 43 | 44 | if type == "image-to-text": 45 | img1 = encode_image("pipeline/examples/sandbeach.jpg") 46 | img2 = encode_image("pipeline/examples/football.jpg") 47 | else: 48 | img1 = encode_image("pipeline/examples/animal.jpg") 49 | img2 = encode_image("pipeline/examples/ball.jpg") 50 | base64_source_image = encode_image(image_path) 51 | content = [ 52 | {"type": "text", "text": self.prompt[type]["user"]}, 53 | {"type": "image_url","image_url": f"data:image/jpeg;base64,{img1}"}, 54 | {"type": "text", "text": self.prompt[type]["example1"]}, 55 | {"type": "image_url","image_url": f"data:image/jpeg;base64,{img2}"}, 56 | {"type": "text", "text": self.prompt[type]["example2"]}, 57 | {"type": "image_url","image_url": f"data:image/jpeg;base64,{base64_source_image}"}, 58 | {"type": "text", "text": input.format(object=object_det_res,text=text_det_res,attribute=attribue_res,fact=fact_res,claims=claim_list)} 59 | ] 60 | 61 | 62 | message = [ 63 | { 64 | 'role': 'system', 65 | 'content': self.prompt[type]["system"] 66 | }, 67 | { 68 | "role": "user", 69 | "content": content, 70 | } 71 | ] 72 | 73 | response = self.chat.get_response(message=message) 74 | return response 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.3 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | aiohttp==3.9.5 5 | aiosignal==1.3.1 6 | albumentations==1.4.4 7 | aliyun-python-sdk-core==2.15.1 8 | aliyun-python-sdk-kms==2.16.2 9 | altair==5.3.0 10 | annotated-types==0.6.0 11 | antlr4-python3-runtime==4.9.3 12 | anyio==4.3.0 13 | asttokens==2.4.1 14 | async-timeout==4.0.3 15 | asynctest==0.13.0 16 | attrs==23.2.0 17 | bitsandbytes==0.38.1 18 | braceexpand==0.1.7 19 | Brotli==1.0.9 20 | certifi==2024.2.2 21 | cffi==1.16.0 22 | chardet==5.2.0 23 | charset-normalizer==2.0.4 24 | click==8.1.7 25 | codecov==2.1.13 26 | colorama==0.4.6 27 | comm==0.2.2 28 | contourpy==1.2.1 29 | coverage==7.4.4 30 | crcmod==1.7 31 | cryptography==42.0.5 32 | cycler==0.12.1 33 | Cython==3.0.10 34 | debugpy==1.6.7 35 | decorator==5.1.1 36 | defusedxml==0.7.1 37 | distro==1.9.0 38 | einops==0.7.0 39 | einops-exts==0.0.4 40 | exceptiongroup==1.2.0 41 | executing==2.0.1 42 | fastapi==0.110.1 43 | ffmpy==0.3.2 44 | filelock==3.13.1 45 | flake8==7.0.0 46 | fonttools==4.51.0 47 | frozenlist==1.4.1 48 | fsspec==2024.3.1 49 | git-lfs==1.6 50 | gmpy2==2.1.2 51 | gradio==3.44.4 52 | gradio_client==0.5.1 53 | groundingdino==0.1.0 54 | h11==0.14.0 55 | httpcore==1.0.5 56 | httpx==0.27.0 57 | huggingface-hub==0.22.2 58 | idna==3.4 59 | imageio==2.34.0 60 | imgaug==0.4.0 61 | importlib_metadata==7.1.0 62 | importlib_resources==6.4.0 63 | iniconfig==2.0.0 64 | interrogate==1.7.0 65 | iopath==0.1.10 66 | ipykernel==6.29.3 67 | ipython==8.18.1 68 | isort==5.13.2 69 | jedi==0.19.1 70 | Jinja2==3.1.3 71 | jmespath==0.10.0 72 | joblib==1.4.0 73 | jsonschema==4.21.1 74 | jsonschema-specifications==2023.12.1 75 | jupyter_client==8.6.1 76 | jupyter_core==5.7.2 77 | kiwisolver==1.4.5 78 | kwarray==0.6.18 79 | lanms_neo==1.0.2 80 | lazy_loader==0.4 81 | lmdb==1.4.1 82 | Markdown==3.6 83 | markdown-it-py==3.0.0 84 | MarkupSafe==2.1.3 85 | matplotlib==3.8.4 86 | matplotlib-inline==0.1.7 87 | mccabe==0.7.0 88 | mdurl==0.1.2 89 | mkl-fft==1.3.8 90 | mkl-random==1.2.4 91 | mkl-service==2.4.0 92 | mmcv==2.0.0 93 | mmdet==3.0.0 94 | mmengine==0.10.3 95 | mmocr==1.0.0 96 | model-index==0.1.11 97 | mpmath==1.3.0 98 | multidict==6.0.5 99 | nest_asyncio==1.6.0 100 | networkx==3.1 101 | nltk==3.8.1 102 | numpy==1.26.4 103 | nvidia-cublas-cu12==12.1.3.1 104 | nvidia-cuda-cupti-cu12==12.1.105 105 | nvidia-cuda-nvrtc-cu12==12.1.105 106 | nvidia-cuda-runtime-cu12==12.1.105 107 | nvidia-cudnn-cu12==8.9.2.26 108 | nvidia-cufft-cu12==11.0.2.54 109 | nvidia-curand-cu12==10.3.2.106 110 | nvidia-cusolver-cu12==11.4.5.107 111 | nvidia-cusparse-cu12==12.1.0.106 112 | nvidia-nccl-cu12==2.19.3 113 | nvidia-nvjitlink-cu12==12.4.127 114 | nvidia-nvtx-cu12==12.1.105 115 | omegaconf==2.3.0 116 | openai==1.23.1 117 | opencv-python==4.9.0.80 118 | opencv-python-headless==4.9.0.80 119 | opendatalab==0.0.10 120 | openmim==0.3.9 121 | openxlab==0.0.38 122 | ordered-set==4.1.0 123 | orjson==3.10.1 124 | oss2==2.17.0 125 | packaging==24.0 126 | pandas==2.2.2 127 | parameterized==0.9.0 128 | parso==0.8.4 129 | peft==0.11.1 130 | pexpect==4.9.0 131 | pickleshare==0.7.5 132 | pillow==10.2.0 133 | pip==23.3.1 134 | platformdirs==4.2.0 135 | pluggy==1.4.0 136 | portalocker==2.8.2 137 | progressbar2==4.4.2 138 | prompt-toolkit==3.0.42 139 | protobuf==5.26.1 140 | psutil==5.9.8 141 | ptyprocess==0.7.0 142 | pure-eval==0.2.2 143 | py==1.11.0 144 | pyclipper==1.3.0.post5 145 | pycocotools==2.0.7 146 | pycodestyle==2.11.1 147 | pycparser==2.22 148 | pycryptodome==3.20.0 149 | pydantic==2.7.0 150 | pydantic_core==2.18.1 151 | pydub==0.25.1 152 | pyflakes==3.2.0 153 | Pygments==2.17.2 154 | pyparsing==3.1.2 155 | PySocks==1.7.1 156 | pytest==8.1.1 157 | pytest-cov==5.0.0 158 | pytest-runner==6.0.1 159 | python-dateutil==2.9.0 160 | python-multipart==0.0.9 161 | python-utils==3.8.2 162 | pytz==2023.4 163 | PyYAML==6.0.1 164 | pyzmq==25.1.2 165 | rapidfuzz==3.8.1 166 | referencing==0.34.0 167 | regex==2024.4.16 168 | requests==2.31.0 169 | rich==13.4.2 170 | rpds-py==0.18.0 171 | ruff==0.4.0 172 | safetensors==0.4.3 173 | scikit-image==0.22.0 174 | scikit-learn==1.4.2 175 | scipy==1.13.0 176 | semantic-version==2.10.0 177 | sentencepiece==0.2.0 178 | setuptools==60.2.0 179 | shapely==2.0.4 180 | shellingham==1.5.4 181 | shortuuid==1.0.13 182 | six==1.16.0 183 | sniffio==1.3.1 184 | stack-data==0.6.2 185 | starlette==0.37.2 186 | supervision==0.19.0 187 | sympy==1.12 188 | tabulate==0.9.0 189 | termcolor==2.4.0 190 | terminaltables==3.1.10 191 | threadpoolctl==3.4.0 192 | tifffile==2024.4.18 193 | timm==0.9.16 194 | tokenizers==0.13.3 195 | tomli==2.0.1 196 | tomlkit==0.12.0 197 | toolz==0.12.1 198 | torch==2.0.1 199 | torchaudio==2.0.2 200 | torchvision==0.15.2 201 | tornado==6.4 202 | tqdm==4.65.2 203 | traitlets==5.14.3 204 | transformers==4.31.0 205 | triton==2.0.0 206 | typer==0.12.3 207 | typing_extensions==4.11.0 208 | tzdata==2024.1 209 | ubelt==1.3.5 210 | urllib3==2.1.0 211 | uvicorn==0.29.0 212 | visual-genome==1.1.1 213 | wcwidth==0.2.13 214 | webdataset==0.2.86 215 | websockets==11.0.3 216 | wget==3.2 217 | wheel==0.41.2 218 | xdoctest==1.1.3 219 | yapf==0.40.2 220 | yarl==1.9.4 221 | zipp==3.17.0 222 | -------------------------------------------------------------------------------- /vqa.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/EasyDetect/ad519a26ce595dd412c17e05bdf28c4a3dbee696/vqa.mp4 --------------------------------------------------------------------------------