├── README.md ├── eval ├── README.md ├── calculate_score.py ├── inference.sh ├── inference │ ├── .DS_Store │ ├── CheXagent │ │ ├── model_vqa_med.py │ │ └── run_eval_batch.py │ ├── GPT-4V │ │ └── gpt4v.py │ ├── Gemini │ │ └── gemini.py │ ├── LLaVA-Med │ │ ├── model_vqa_med.py │ │ └── run_med_datasets_eval_batch.py │ ├── LLaVA │ │ ├── model_vqa.py │ │ └── run_eval_batch.py │ └── MiniGPTv2 │ │ ├── eval_minigptv2.py │ │ └── run_eval_batch.py └── model_inference.sh └── image.png /README.md: -------------------------------------------------------------------------------- 1 | # ProbMed 2 | 3 | [**🌐 Homepage**](https://jackie-2000.github.io/probmed.github.io/) | [**🤗 Dataset**](https://huggingface.co/datasets/rippleripple/ProbMed) | [**🤗 Paper**](https://arxiv.org/pdf/2405.20421) | [**📖 arXiv**](https://arxiv.org/abs/2405.20421) | [**GitHub**](https://github.com/eric-ai-lab/ProbMed/) 4 | 5 | 6 | This repo contains the evaluation code for the paper "[Worse than Random? An Embarrassingly Simple Probing Evaluation of Large Multimodal Models in Medical VQA]([https://arxiv.org/pdf/2311.16502.pdf](https://github.com/eric-ai-lab/ProbMed/))" 7 | 8 | 9 | ## Introduction 10 | We introduce the Probing Evaluation for Medical Diagnosis (ProbMed) dataset to rigorously assess LMM performance in medical imaging through probing evaluation and procedural diagnosis. Particularly, probing evaluation features pairing original questions with negation questions with hallucinated attributes, while procedural diagnosis requires reasoning across various diagnostic dimensions for each image, including modality recognition, organ identification, clinical findings, abnormalities, and positional grounding. ProbMed draws from two comprehensive biomedical datasets MedICaT and ChestX-ray14 to compile a diverse set of 6,303 images. These images span three modalities (X-ray, MRI, and CT scan) and four organs (abdomen, brain, chest, and spine). After preprocessing, we generated a diverse set of high-quality questions for each image, covering various diagnostic dimensions. This process resulted in a total of 57,132 question-answer pairs, averaging 9 pairs per image. 11 | 12 | ![Alt text](image.png) 13 | 14 | ## Dataset Creation 15 | 16 | ProbMed was created to rigorously evaluate LMMs’ readiness for real-life diagnostic tasks, particularly under adversarial conditions. Please refer to our huggingface [**🤗 Dataset**](https://huggingface.co/datasets/rippleripple/ProbMed) for more details. 17 | 18 | ## Evaluation 19 | Please refer to our [eval](eval) 20 | folder for more details. 21 | 22 | ## 🏆 Leaderboard 23 | | Model | Modality | Organ | Abnormality | Condition/Finding | Position | Overall | 24 | |-----------------|:---------:|:---------:|:-----------:|:-----------------:|:--------:|:-------:| 25 | | Random Choice | 25.00 | 25.00 | 50.00 | **35.67** | **36.48**| 32.13 | 26 | | GPT-4o | **97.42** | 69.46 | 61.79 | 29.30 | 24.06 | **55.60** | 27 | | GPT-4V | 92.51 | 71.73 | 53.30 | 35.19 | 22.40 | 55.28 | 28 | | Gemini 1.5 Pro | 96.47 | 75.69 | 62.59 | 27.93 | 17.54 | 55.08 | 29 | | Med-Flamingo | 44.15 | 61.39 | 50.00 | 26.33 | 5.65 | 35.66 | 30 | | CheXagent | 37.25 | 33.95 | **73.31** | 28.52 | 7.48 | 30.61 | 31 | | BiomedGPT | 60.25 | 46.81 | 50.31 | 14.13 | 6.11 | 33.34 | 32 | | LLaVA-Med | 5.48 | 32.96 | 38.76 | 20.38 | 5.33 | 17.90 | 33 | | MiniGPT-v2 | 3.25 | 76.26 | 50.08 | 15.23 | 7.96 | 27.67 | 34 | | LLaVA-v1.6 (7B) | 6.77 | **80.70** | 46.18 | 3.56 | 1.21 | 24.96 | 35 | | LLaVA-v1 (7B) | 25.27 | 40.53 | 50.00 | 0.34 | 0.11 | 19.30 | 36 | 37 | ## Contact 38 | - Qianqi Yan: qyan79@ucsc.edu 39 | - Xin Eric Wang: xwang366@ucsc.edu 40 | 41 | ## Citation 42 | 43 | **BibTeX:** 44 | ```bibtex 45 | @misc{yan2024worse, 46 | title={Worse than Random? An Embarrassingly Simple Probing Evaluation of Large Multimodal Models in Medical VQA}, 47 | author={Qianqi Yan and Xuehai He and Xiang Yue and Xin Eric Wang}, 48 | year={2024}, 49 | eprint={2405.20421}, 50 | archivePrefix={arXiv}, 51 | primaryClass={cs.AI} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Guidelines 2 | We provide detailed instructions for evaluation. To execute our evaluation script, please ensure that the structure of your model outputs is the same as ours. 3 | 4 | ## Model Inference 5 | 6 | Download our [dataset](https://huggingface.co/datasets/rippleripple/ProbMed) from huggingface. 7 | 8 | Clone the official repo of open-sourced models into the following folder: 9 | * LLaVAv1, v1.6 [[repo]](https://github.com/haotian-liu/LLaVA) 10 | * LLaVA-Med [[repo]](https://github.com/microsoft/LLaVA-Med) 11 | * MiniGPTv2 [[repo]](https://github.com/Vision-CAIR/MiniGPT-4) 12 | * CheXagent [[repo]](https://github.com/Stanford-AIMI/CheXagent) 13 | * BiomedGPT [[repo]](https://github.com/taokz/BiomedGPT) 14 | * Med-Flamingo [[repo]](https://github.com/snap-stanford/med-flamingo) 15 | 16 | Set up the environment for each open-sourced model as instructed by their original repo and run inference. For API-based models: GPT-4o, GPT-4V, and Gemini Pro set up your API key in the provided scripts under the /inference folder. 17 | 18 | For the open-source models, we also provide our inference scripts for your reference. To utilize those, move the inference scripts under the /inference folder to the corresponding folders you clone from the original repos by referring to the path in model_inference.sh. 19 | 20 | After setting up, run inference.sh to get model outputs on the question files. 21 | 22 | 23 | ## Get Evaluation results and scores 24 | 25 | After getting the output, run calculate_score.py to get scores for all models. 26 | 27 | Your folder structure should look like this: 28 | 29 | . 30 | project-root 31 | ├── LLaVA 32 | │ └── ... 33 | ├── LLaVA-Med 34 | │ └── ... 35 | └── ... 36 | │ 37 | ├── probmed.json 38 | ├── response_file 39 | │ └── llava_v1.json 40 | │ └── llavamed.json 41 | │ └── xxx.json 42 | ├── ablation 43 | │ └── ablation.json 44 | │ └── llava_v1.jsonl 45 | │ └── llavamed.jsonl 46 | │ └── xxx.json 47 | -------------------------------------------------------------------------------- /eval/calculate_score.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | def parse_response(models): 6 | ''' 7 | parse response data from aggregated ans file into modality-organ types 8 | ''' 9 | all_model_data = {} 10 | for model_name in models: 11 | response_data = {} 12 | with open(f"response_file/{model_name}.json", 'r') as f: 13 | response = json.load(f) 14 | for data in response: 15 | if data["image_type"] not in response_data: 16 | response_data[data["image_type"]] = [data] 17 | else: response_data[data["image_type"]].append(data) 18 | all_model_data[model_name] = response_data 19 | 20 | return all_model_data 21 | 22 | 23 | def get_score_binary(response, ans): 24 | ''' 25 | get binary score used for main results and ablation accuracy 26 | ''' 27 | response = response.strip() 28 | if ans == 'yes': 29 | if 'Yes' in response or response.lower() == 'yes' or response.lower() == 'yes.': return 1 30 | else: return 0 31 | else: 32 | if 'No' in response or response.lower() == 'no' or response.lower() == 'no.': return 1 33 | else: return 0 34 | 35 | def get_score_dict(response_data, get_score): 36 | ''' 37 | get score dict according to probmed data setting for later geting float scores 38 | ''' 39 | cur_img_id = response_data[0]['id'] 40 | score = defaultdict(list) 41 | score['id'] = [cur_img_id] 42 | 43 | modality_score = [] 44 | body_part_score = [] 45 | entity_score = [] 46 | grounding_score = [] 47 | 48 | for data in response_data: 49 | if data['id'] != cur_img_id: # next image 50 | score['id'].append(data['id']) 51 | cur_img_id = data['id'] 52 | if len(modality_score) != 2: 53 | modality_score = [] # one of questions unanswered 54 | score['modality'].append(modality_score) 55 | modality_score = [] 56 | if len(body_part_score) != 2: 57 | body_part_score = [] 58 | score['body_part'].append(body_part_score) 59 | body_part_score = [] 60 | score['entity'].append(entity_score) 61 | entity_id = -1 62 | entity_score = [] 63 | score['grounding'].append(grounding_score) 64 | grounding_id = -1 65 | grounding_score = [] 66 | if "modality" in data['qa_type']: 67 | modality_score.append(get_score(data['response'], data['gt_ans'])) 68 | elif "body_part" in data['qa_type']: 69 | body_part_score.append(get_score(data['response'], data['gt_ans'])) 70 | elif data['qa_type'] == 'abnormality': 71 | score['abnormality'].append(get_score(data['response'], data['gt_ans'])) 72 | elif "entity" in data['qa_type']: 73 | if data['qa_type'] == "entity_hallu": # abnormality 0 74 | entity_score = [get_score(data['response'], data['gt_ans'])] 75 | else: 76 | if "gt" in data['qa_type']: 77 | entity_id = data['qa_type'].split('_')[-1] 78 | entity_score_tuple = [get_score(data['response'], data['gt_ans'])] 79 | else: 80 | if data['qa_type'].split('_')[-1] != entity_id: # gt question is not answered 81 | continue 82 | entity_score_tuple.append(get_score(data['response'], data['gt_ans'])) 83 | assert len(entity_score_tuple) == 2 84 | entity_score.append(entity_score_tuple) 85 | else: 86 | if "gt" in data['qa_type']: 87 | grounding_id = data['qa_type'].split('_')[-1] 88 | grounding_score_tuple = [get_score(data['response'], data['gt_ans'])] 89 | else: 90 | if data['qa_type'].split('_')[-1] != grounding_id: # gt question is not answered 91 | continue 92 | grounding_score_tuple.append(get_score(data['response'], data['gt_ans'])) 93 | assert len(grounding_score_tuple) == 2 94 | grounding_score.append(grounding_score_tuple) 95 | score['modality'].append(modality_score) 96 | score['body_part'].append(body_part_score) 97 | score['entity'].append(entity_score) 98 | score['grounding'].append(grounding_score) 99 | return score 100 | 101 | def get_score_float(score): 102 | output_score = {} 103 | 104 | tmp = [d for d in score['abnormality'] if not np.isnan(d)] 105 | output_score['abnormality'] = { 106 | 'acc' : sum(tmp) / len(tmp)*100, 107 | 'num' : len(score['abnormality']), 108 | } 109 | 110 | tmp = [] 111 | count_nan, count_all_ones, count_first_one, count_empty = 0, 0, 0, 0 112 | for t in score['modality']: 113 | if not t: 114 | count_empty += 1 115 | continue 116 | if np.isnan(t).any(): 117 | count_nan += 1 118 | if all(elem == 1 for elem in t): 119 | assert not np.isnan(t).any() 120 | count_all_ones += 1 121 | if t[0] == 1: 122 | count_first_one += 1 123 | assert count_nan == 0 124 | output_score['modality'] = { 125 | 'acc' : count_all_ones / ((len(score['modality'])-count_nan-count_empty))*100, 126 | 'acc w. hallu': count_first_one / ((len(score['modality'])-count_nan-count_empty))*100, 127 | 'num' : len(score['modality']) - count_empty 128 | } 129 | 130 | tmp = [] 131 | count_nan, count_all_ones, count_first_one, count_empty = 0, 0, 0, 0 132 | for t in score['body_part']: 133 | if not t: 134 | count_empty += 1 135 | continue 136 | if np.isnan(t).any(): 137 | count_nan += 1 138 | if all(elem == 1 for elem in t): 139 | assert not np.isnan(t).any() 140 | count_all_ones += 1 141 | if t[0] == 1: 142 | count_first_one += 1 143 | assert count_nan == 0 144 | output_score['body_part'] = { 145 | 'acc' : count_all_ones / ((len(score['body_part'])-count_nan-count_empty))*100, 146 | 'acc w. hallu': count_first_one / ((len(score['body_part'])-count_nan-count_empty))*100, 147 | 'num' : len(score['body_part']) - count_empty 148 | } 149 | 150 | count_nan = 0 151 | filtered_list = [] 152 | for l in score['entity']: 153 | if not l: 154 | continue 155 | if isinstance(l[0], list): # Check if the first item is a list 156 | if all(np.nan in x for x in l): 157 | count_nan += 1 158 | continue 159 | filtered_list.append([x for x in l if np.nan not in x]) # remove [np,nan, 1] from l [[np,nan, 1], [0, 1]] 160 | else: 161 | if np.isnan(l[0]): 162 | count_nan += 1 # remove single [np.nan] 163 | else: 164 | filtered_list.append(l) 165 | count_first_1 = 0 166 | count_all_1 = 0 167 | for l in filtered_list: 168 | assert isinstance(l, list) 169 | if all(x[0] == 1 for x in (l if isinstance(l[0], list) else [l])): 170 | count_first_1 += 1 171 | if all(all(y == 1 for y in x) for x in (l if isinstance(l[0], list) else [l])): 172 | count_all_1 += 1 173 | output_score['entity'] = { 174 | 'acc' : count_all_1 / len(filtered_list)*100, 175 | 'acc w. hallu' : count_first_1 / len(filtered_list)*100, 176 | 'num' : len(score['entity']) 177 | } 178 | 179 | filtered_list = [] 180 | count_nan = 0 181 | count_empty = 0 182 | for l in score['grounding']: 183 | if not l: # skip empty lists 184 | count_empty += 1 185 | continue 186 | if all(np.nan in x for x in l): 187 | count_nan += 1 188 | continue 189 | filtered_list.append([x for x in l if np.nan not in x]) 190 | count_first_1 = 0 191 | count_all_1 = 0 192 | for l in filtered_list: 193 | if isinstance(l, list) and all(isinstance(x, list) for x in l): # check for list of lists 194 | if all(x[0] == 1 for x in (l if isinstance(l[0], list) else [l])): 195 | count_first_1 += 1 196 | if all(all(y == 1 for y in x) for x in l): 197 | count_all_1 += 1 198 | output_score['grounding'] = { 199 | 'acc' : count_all_1 / len(filtered_list)*100, 200 | 'acc w. hallu' : count_first_1 / len(filtered_list)*100, 201 | 'num' : len(score['grounding']) 202 | } 203 | 204 | return output_score 205 | 206 | def get_scores_probmed(all_model_data): 207 | ''' 208 | all_scores: score per modality_body_part: [KEY] acc, acc w.o. adv pair, num (Tables in Appendix) 209 | all_scores_aggr_question: aggregated score per question type: [KEY] acc, acc w.o. adv pair (Table 5 results) 210 | overall_scores_aggr_question: overall aggregated score per question type: [KEY] acc, acc w.o. adv pair (Table 5 last column) 211 | ''' 212 | all_scores = {} 213 | all_scores_aggr_question = {} 214 | overall_scores_aggr_question = {} 215 | for model_name, model_response in all_model_data.items(): 216 | for image_type, response in model_response.items(): 217 | score_dict = get_score_dict(response, get_score=get_score_binary) 218 | score_per_cat = get_score_float(score_dict) 219 | if model_name not in all_scores: 220 | all_scores[model_name] = {} 221 | all_scores[model_name][image_type] = score_per_cat 222 | aggregated = {} 223 | for modality, questions in all_scores[model_name].items(): 224 | for question, metrics in questions.items(): 225 | if question not in aggregated: 226 | aggregated[question] = { 227 | "acc": 0, 228 | "num": 0 229 | } 230 | if "acc w. hallu" in metrics: 231 | aggregated[question]["acc w. hallu"] = 0 232 | aggregated[question]["acc"] += metrics["acc"] * metrics["num"] 233 | aggregated[question]["num"] += metrics["num"] 234 | if "acc w. hallu" in metrics: 235 | aggregated[question]["acc w. hallu"] += metrics["acc w. hallu"] * metrics["num"] 236 | for question, metrics in aggregated.items(): 237 | if metrics["num"] > 0: 238 | metrics["acc"] /= metrics["num"] 239 | if "acc w. hallu" in metrics: 240 | metrics["acc w. hallu"] /= metrics["num"] 241 | all_scores_aggr_question[model_name] = aggregated 242 | 243 | for model, question_score in all_scores_aggr_question.items(): 244 | overall_scores_aggr_question[model] = { 245 | "acc": 0, 246 | "num": 0, 247 | "acc w.o. adv pair": 0, 248 | "num w.o. adv pair": 0 249 | } 250 | for question, metrics in question_score.items(): 251 | overall_scores_aggr_question[model]["acc"] += metrics["acc"] * metrics["num"] 252 | overall_scores_aggr_question[model]["num"] += metrics["num"] 253 | if "acc w. hallu" in metrics: 254 | overall_scores_aggr_question[model]["acc w.o. adv pair"] += metrics["acc w. hallu"] * metrics["num"] 255 | overall_scores_aggr_question[model]["num w.o. adv pair"] += metrics["num"] 256 | if overall_scores_aggr_question[model]["num"] > 0: 257 | overall_scores_aggr_question[model]["acc"] /= overall_scores_aggr_question[model]["num"] 258 | if overall_scores_aggr_question[model]["num w.o. adv pair"] > 0: 259 | overall_scores_aggr_question[model]["acc w.o. adv pair"] /= overall_scores_aggr_question[model]["num w.o. adv pair"] 260 | 261 | return all_scores, all_scores_aggr_question, overall_scores_aggr_question 262 | 263 | def get_model_score_vqa_rad_ablation(ans_file_name): 264 | response_data = [] 265 | if "jsonl" in ans_file_name: 266 | with open(ans_file_name, 'r') as f: 267 | for line in f: 268 | response_data.append(json.loads(line)) 269 | else: 270 | with open(ans_file_name, 'r') as f: 271 | response_data = json.load(f) 272 | score = [] 273 | for i, data in enumerate(response_data): 274 | if i % 2 == 0: 275 | assert data['gt_ans'] == 'yes' 276 | tmp = [] 277 | tmp.append(get_score_binary(data['response'], 'yes')) 278 | else: 279 | assert data['gt_ans'] == 'no' 280 | tmp.append(get_score_binary(data['response'], 'no')) 281 | score.append(tmp) 282 | score_wo_adv = [] 283 | score_w_adv = [] 284 | for tmp in score: 285 | if tmp[0] == 1: 286 | score_wo_adv.append(1) 287 | if tmp[1] == 1: 288 | score_w_adv.append(1) 289 | else: score_w_adv.append(0) 290 | else: 291 | score_w_adv.append(0) 292 | score_wo_adv.append(0) 293 | assert len(score_w_adv) == len(score_wo_adv) 294 | return sum(score_w_adv)/len(score_w_adv), sum(score_wo_adv)/len(score_wo_adv) 295 | 296 | def main(): 297 | models = ["chexagent", "gemini", "gpt4v", "llava_v1.6", "llava_v1", "llavamed", "minigptv2", "gpt4o", "med-flamingo", "biomedgpt"] 298 | all_model_data = parse_response(models) 299 | all_scores, all_scores_aggr_question, overall_scores_aggr_question = get_scores_probmed(all_model_data) 300 | 301 | # # uncomment the block to print fine-grained accuracy 302 | # print('=== Printing accuracy in Appendix Tables ===') 303 | # for model, v in all_scores.items(): 304 | # for image_type, s in v.items(): 305 | # print(model, image_type) 306 | # print(s) 307 | # print('=' * 30) 308 | 309 | print('=== Printing accuracy aggregated over modality-organ ===') 310 | for model, v in all_scores_aggr_question.items(): 311 | print(model, v) 312 | print('=' * 30) 313 | 314 | print('=== Printing overall accuracy further aggregated over question types and difference w.&w.o. adv. pairs ===') 315 | for model, overall_score in overall_scores_aggr_question.items(): 316 | print(f"{model} acc. w.o. adv. pair: {overall_score['acc w.o. adv pair']}, acc. w. adv. pair: {overall_score['acc w.o. adv pair']}, acc. diff: {overall_score['acc w.o. adv pair']}") 317 | print('=' * 30) 318 | 319 | print('=== Printing accuracy on ablation set and difference w.&w.o. adv. pairs ===') 320 | model_names = ["llava_v1.jsonl", "llava_v1.6.jsonl", "llavamed.jsonl", "minigptv2.jsonl", "chexagent.jsonl", "gpt4v.json", "gemini.jsonl", "gpt4o.json", "med-flamingo.jsonl", "biomedgpt.json"] 321 | summ = [] 322 | for model in model_names: 323 | score = get_model_score_vqa_rad_ablation(f'ablation/{model}') 324 | print(f"{model} acc. w.o. adv. pair: {score[1]*100}, acc. w. adv. pair: {score[0]*100}, acc. diff: {score[1]*100 - score[0]*100}") 325 | summ.append(score[1]*100 - score[0]*100) 326 | print(f"average drop: {sum(summ)/len(summ)}") 327 | print('=' * 30) 328 | 329 | 330 | if __name__ == "__main__": 331 | main() 332 | -------------------------------------------------------------------------------- /eval/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | ( 5 | source activate llava-med || conda activate llava-med 6 | ./model_inference.sh llavamed 7 | conda deactivate 8 | ) 9 | 10 | ( 11 | source activate llava || conda activate llava 12 | ./model_inference.sh llava_v1 13 | conda deactivate 14 | ) 15 | 16 | ( 17 | source activate llava || conda activate llava 18 | ./model_inference.sh llava_v1.6 19 | conda deactivate 20 | ) 21 | 22 | ( 23 | source activate minigptv || conda activate minigptv 24 | ./model_inference.sh minigptv2 25 | conda deactivate 26 | ) 27 | 28 | ( 29 | source activate llama || conda activate llama 30 | ./model_inference.sh chexagent 31 | conda deactivate 32 | ) 33 | 34 | ( 35 | source activate llama || conda activate llama 36 | ./model_inference.sh gpt4v 37 | conda deactivate 38 | ) 39 | 40 | ( 41 | source activate llama || conda activate llama 42 | ./model_inference.sh gemini 43 | conda deactivate 44 | ) 45 | 46 | ( 47 | source activate llama || conda activate llama 48 | ./model_inference.sh gpt4o 49 | conda deactivate 50 | ) 51 | 52 | ( 53 | source activate med-flamingo || conda activate med-flamingo 54 | ./model_inference.sh med-flamingo 55 | conda deactivate 56 | ) 57 | 58 | ( 59 | source activate biomedgpt || conda activate biomedgpt 60 | ./model_inference.sh biomedgpt 61 | conda deactivate 62 | ) 63 | -------------------------------------------------------------------------------- /eval/inference/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/ProbMed/0268b5d7e3af795ba0b30c3710c0c44e4f90158c/eval/inference/.DS_Store -------------------------------------------------------------------------------- /eval/inference/CheXagent/model_vqa_med.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import io 8 | 9 | import requests 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig 13 | 14 | from PIL import Image 15 | import random 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | def eval_model(args): 30 | # Model 31 | # step 1: Setup constant 32 | device = "cuda" 33 | dtype = torch.float16 34 | 35 | # step 2: Load Processor and Model 36 | processor = AutoProcessor.from_pretrained("path/to/CheXagent", trust_remote_code=True) 37 | generation_config = GenerationConfig.from_pretrained("path/to/CheXagent") 38 | model = AutoModelForCausalLM.from_pretrained("path/to/CheXagent", torch_dtype=dtype, trust_remote_code=True) 39 | model = model.cuda().half() 40 | 41 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 42 | # questions = get_chunk(questions, args.num_chunks, args.chunk_idx - 1) 43 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 44 | answers_file = os.path.expanduser(args.answers_file) 45 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 46 | os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True) 47 | ans_file = open(answers_file, "w") 48 | save_image_folder = os.path.join(os.path.dirname(os.path.expanduser(args.answers_file)), "images") 49 | for i, line in enumerate(tqdm(questions)): 50 | idx = line["id"] 51 | qa_type = line["qa_type"] 52 | answer = line["answer"] 53 | qs = line["question"] 54 | 55 | qs = qs.replace('', '').strip() 56 | cur_prompt = qs 57 | 58 | image_file = line["image"] 59 | image = Image.open(os.path.join(args.image_folder, image_file)) 60 | inputs = processor(images=image, text=f" USER: {cur_prompt} ASSISTANT: ", return_tensors="pt").to(device=device, dtype=dtype) 61 | 62 | output = model.generate(**inputs, generation_config=generation_config)[0] 63 | response = processor.tokenizer.decode(output, skip_special_tokens=True) 64 | 65 | ans_file.write(json.dumps({"id": idx, 66 | "qa_type": qa_type, 67 | "question": cur_prompt, 68 | "gt_ans": answer, 69 | "response": response}) + "\n") 70 | # ans_file.write(json.dumps({"id": idx, 71 | # "prompt": cur_prompt, 72 | # "text": outputs, 73 | # "answer_id": ans_id, 74 | # "model_id": model_name, 75 | # "metadata": {}}) + "\n") 76 | ans_file.flush() 77 | ans_file.close() 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 82 | parser.add_argument("--image-folder", type=str, default="") 83 | parser.add_argument("--question-file", type=str, default="tables/question.json") 84 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 85 | parser.add_argument("--mm-projector", type=str, default=None) 86 | parser.add_argument("--vision-tower", type=str, default=None) 87 | parser.add_argument("--conv-mode", type=str, default="simple") 88 | parser.add_argument("--num-chunks", type=int, default=1) 89 | parser.add_argument("--chunk-idx", type=int, default=0) 90 | parser.add_argument("--answer-prompter", action="store_true") 91 | args = parser.parse_args() 92 | 93 | eval_model(args) 94 | -------------------------------------------------------------------------------- /eval/inference/CheXagent/run_eval_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from concurrent.futures import ProcessPoolExecutor 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Parallel Chexagent evaluation script.') 8 | 9 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 10 | parser.add_argument("--image-folder", type=str, default="") 11 | parser.add_argument("--question-file", type=str, default="tables/question.json") 12 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 13 | parser.add_argument("--mm-projector", type=str, default=None) 14 | parser.add_argument("--vision-tower", type=str, default=None) 15 | parser.add_argument("--conv-mode", type=str, default="simple") 16 | parser.add_argument("--answer-prompter", action="store_true") 17 | parser.add_argument('--num-chunks', type=int, default=1, help='Number of chunks (default: 1).') 18 | parser.add_argument("--chunk-idx", type=int, default=0) 19 | args = parser.parse_args() 20 | 21 | return parser.parse_args() 22 | 23 | def run_job(chunk_idx, args): 24 | 25 | cmd = ("CUDA_VISIBLE_DEVICES={chunk_idx} python model_vqa_med.py " 26 | "--model-name {model_name} " 27 | "--question-file {question_file} " 28 | "--image-folder {image_folder} " 29 | "--answers-file {experiment_name_with_split}-chunk{chunk_idx}.jsonl " 30 | "--num-chunks {chunks} " 31 | "--chunk-idx {chunk_idx} ").format( 32 | chunk_idx=chunk_idx, 33 | chunks=args.num_chunks, 34 | model_name=args.model_name, 35 | question_file=args.question_file, 36 | image_folder=args.image_folder, 37 | experiment_name_with_split=args.experiment_name_with_split 38 | ) 39 | 40 | print(cmd) 41 | 42 | subprocess.run(cmd, shell=True, check=True) 43 | 44 | def main(): 45 | args = parse_args() 46 | args.experiment_name_with_split = args.answers_file.split(".jsonl")[0] 47 | 48 | # Create a partial function that accepts only `chunk_idx` 49 | from functools import partial 50 | run_job_with_args = partial(run_job, args=args) 51 | 52 | # Run the jobs in parallel using ProcessPoolExecutor 53 | with ProcessPoolExecutor(max_workers=args.num_chunks) as executor: 54 | list(executor.map(run_job_with_args, range(args.num_chunks))) # Use run_job_with_args instead of lambda 55 | # list(executor.map(run_job_with_args, range(1,4))) # Use run_job_with_args instead of lambda 56 | 57 | # Gather the results 58 | output_file = f"{args.experiment_name_with_split}.jsonl" 59 | with open(output_file, 'w') as outfile: 60 | for idx in range(args.num_chunks): 61 | # for idx in range(1,4): 62 | with open(f"{args.experiment_name_with_split}-chunk{idx}.jsonl") as infile: 63 | outfile.write(infile.read()) 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /eval/inference/GPT-4V/gpt4v.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | import fire 4 | import os 5 | import asyncio 6 | from openai import AsyncAzureOpenAI, AzureOpenAI 7 | from tqdm import tqdm 8 | from tqdm.asyncio import tqdm as async_tqdm 9 | from mimetypes import guess_type 10 | import base64 11 | 12 | def create_client(): 13 | api_base = "your api base" 14 | api_key= "your api key" 15 | deployment_name = 'gpt4v' 16 | api_version = "your api version" 17 | 18 | client = AsyncAzureOpenAI( 19 | api_key=api_key, 20 | api_version=api_version, 21 | base_url=f"{api_base}/openai/deployments/{deployment_name}" 22 | ) 23 | return client 24 | 25 | # Function to encode a local image into data URL 26 | def local_image_to_data_url(image_path): 27 | # Guess the MIME type of the image based on the file extension 28 | mime_type, _ = guess_type(image_path) 29 | if mime_type is None: 30 | mime_type = 'application/octet-stream' # Default MIME type if none is found 31 | 32 | # Read and encode the image file 33 | with open(image_path, "rb") as image_file: 34 | base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') 35 | 36 | # Construct the data URL 37 | return f"data:{mime_type};base64,{base64_encoded_data}" 38 | 39 | class GPT4V: 40 | def __init__(self, image_folder, async_mode=False, rate=50, max_concurrent_requests=100): 41 | self.is_async = async_mode 42 | self.rate = rate # requests per second 43 | self.sleep_time = 1 / rate 44 | self.max_concurrent_requests = max_concurrent_requests 45 | self.image_folder = image_folder 46 | 47 | api_key = open('api_key.txt', 'r').read() 48 | if self.is_async: 49 | self.client = create_client() 50 | else: 51 | self.client = AzureOpenAI() 52 | 53 | 54 | def label(self, meta_data: list[dict]) -> list[dict]: 55 | if self.is_async: 56 | return asyncio.run(self.label_async(meta_data)) 57 | else: 58 | print("Not implemented") 59 | assert False 60 | 61 | async def label_async(self, meta_data: list[str]) -> list[dict]: 62 | results = [] 63 | 64 | semaphore = asyncio.Semaphore(self.max_concurrent_requests) 65 | 66 | async def process_cap(data, i): 67 | idx = data["id"] 68 | gpt_idx = data["gpt_idx"] 69 | qa_type = data["qa_type"] 70 | answer = data["answer"] 71 | qs = data["question"] 72 | image_file = data["image"] 73 | async with semaphore: 74 | messages=[ 75 | { "role": "system", "content": "You are a student in medical school. You are preparing for your final exam. Answer the following question in your practice exam as directed to earn higher scores. You answer will only be for academic purpose." }, 76 | { "role": "user", "content": [ 77 | { 78 | "type": "text", 79 | "text": qs 80 | }, 81 | { 82 | "type": "image_url", 83 | "image_url": { 84 | "url": local_image_to_data_url(self.image_folder + image_file) 85 | } 86 | } 87 | ] } 88 | ] 89 | 90 | try: 91 | response = await self.client.chat.completions.create( 92 | model="gpt4", 93 | messages=messages 94 | ) 95 | response_text = response.choices[0].message.content.strip() 96 | 97 | return { 98 | "i": i, 99 | "data" : { 100 | "id": idx, 101 | "gpt_idx": gpt_idx, 102 | "qa_type": qa_type, 103 | "question": qs, 104 | "gt_ans": answer, 105 | "response": response_text 106 | } 107 | } 108 | 109 | except Exception as e: 110 | print(f"An error occurred: {str(e)}") 111 | return None 112 | 113 | tasks = [process_cap(data, i) for i, data in enumerate(meta_data)] 114 | for task in async_tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"generate responses"): 115 | result = await task 116 | if result is not None: 117 | results.append(result) 118 | await asyncio.sleep(self.sleep_time) 119 | 120 | results = sorted(results, key=lambda x: x['i']) 121 | return results 122 | 123 | 124 | def main( 125 | question_file: Optional[str] = "xx.json", 126 | answers_file: Optional[str] = "xx.json", 127 | image_folder: Optional[str] = "image/folder" 128 | ): 129 | 130 | labeler = GPT4V(image_folder, async_mode=True, rate=60, max_concurrent_requests=100) 131 | 132 | with open(question_file, 'r') as f: 133 | question_data = json.load(f) 134 | 135 | # question_data = question_data[:50] 136 | 137 | indices = list(range(len(question_data))) 138 | 139 | # assign global index 140 | for i, _ in enumerate(question_data): 141 | question_data[i]['gpt_idx'] = i 142 | 143 | results = labeler.label(meta_data=question_data) 144 | results = [r['data'] for r in results] 145 | 146 | # indices 147 | for data in results: 148 | indices.remove(data['gpt_idx']) 149 | 150 | no_effect = 0 151 | while (len(indices) > 0): 152 | before_count = len(indices) 153 | print(f"There are {len(indices)} left") 154 | meta_data = [question_data[i] for i in indices] 155 | tmp_results = labeler.label(meta_data=meta_data) 156 | tmp_results = [r['data'] for r in tmp_results] 157 | results.extend(tmp_results) 158 | # indices 159 | for data in tmp_results: 160 | indices.remove(data['gpt_idx']) 161 | after_count = len(indices) 162 | if after_count == before_count: 163 | no_effect += 1 164 | else: 165 | no_effect = 0 166 | if no_effect >= 3: 167 | break 168 | 169 | results = sorted(results, key=lambda x: x['gpt_idx']) 170 | 171 | with open(answers_file, 'w') as f: 172 | json.dump(results, f, indent=4) 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | -------------------------------------------------------------------------------- /eval/inference/Gemini/gemini.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | import fire 4 | import os 5 | import asyncio 6 | from tqdm import tqdm 7 | from tqdm.asyncio import tqdm as async_tqdm 8 | import google.generativeai as genai 9 | import PIL.Image 10 | 11 | class Gemini(): 12 | def __init__(self, image_folder): 13 | self.image_folder = image_folder 14 | api_key = "your api key" 15 | genai.configure(api_key=api_key) 16 | self.model = genai.GenerativeModel('gemini-pro-vision') 17 | 18 | async def process(self, data, i): 19 | idx = data["id"] 20 | gpt_idx = data["gpt_idx"] 21 | qa_type = data["qa_type"] 22 | answer = data["answer"] 23 | qs = data["question"] 24 | image_file = data["image"] 25 | try: 26 | response = await self.model.generate_content_async([qs, PIL.Image.open(self.image_folder + image_file)]) 27 | return { 28 | "i": i, 29 | "data" : { 30 | "id": idx, 31 | "gpt_idx": gpt_idx, 32 | "qa_type": qa_type, 33 | "question": qs, 34 | "gt_ans": answer, 35 | "response": response.text 36 | } 37 | } 38 | except Exception as e: 39 | print(f"An error occurred: {str(e)}") 40 | return None 41 | 42 | def label(self, meta_data: list[dict]) -> list[dict]: 43 | return asyncio.run(self.label_async(meta_data)) 44 | 45 | async def label_async(self, meta_data): 46 | results = [] 47 | 48 | tasks = [self.process(data, i) for i, data in enumerate(meta_data)] 49 | for task in async_tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"generate responses"): 50 | result = await task 51 | if result is not None: 52 | results.append(result) 53 | 54 | results = sorted(results, key=lambda x: x['i']) 55 | return results 56 | 57 | def main( 58 | question_file: Optional[str] = "xx.json", 59 | answers_file: Optional[str] = "xx.json", 60 | image_folder: Optional[str] = "image/folder" 61 | ): 62 | 63 | labeler = Gemini(image_folder) 64 | 65 | with open(question_file, 'r') as f: 66 | question_data = json.load(f) 67 | 68 | indices = list(range(len(question_data))) 69 | 70 | # assign global index 71 | for i, _ in enumerate(question_data): 72 | question_data[i]['gpt_idx'] = i 73 | 74 | results = labeler.label(meta_data=question_data) 75 | results = [r['data'] for r in results] 76 | 77 | # indices 78 | for data in results: 79 | indices.remove(data['gpt_idx']) 80 | 81 | no_effect = 0 82 | while (len(indices) > 0): 83 | before_count = len(indices) 84 | print(f"There are {len(indices)} left") 85 | meta_data = [question_data[i] for i in indices] 86 | tmp_results = labeler.label(meta_data=meta_data) 87 | tmp_results = [r['data'] for r in tmp_results] 88 | results.extend(tmp_results) 89 | # indices 90 | for data in tmp_results: 91 | indices.remove(data['gpt_idx']) 92 | after_count = len(indices) 93 | if after_count == before_count: 94 | no_effect += 1 95 | else: 96 | no_effect = 0 97 | if no_effect >= 3: 98 | break 99 | 100 | results = sorted(results, key=lambda x: x['gpt_idx']) 101 | 102 | with open(answers_file, 'w') as f: 103 | json.dump(results, f, indent=4) 104 | 105 | if __name__ == "__main__": 106 | fire.Fire(main) 107 | -------------------------------------------------------------------------------- /eval/inference/LLaVA-Med/model_vqa_med.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava import LlavaLlamaForCausalLM 10 | from llava.conversation import conv_templates 11 | from llava.utils import disable_torch_init 12 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria 13 | 14 | from PIL import Image 15 | import random 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | DEFAULT_IMAGE_TOKEN = "" 31 | DEFAULT_IMAGE_PATCH_TOKEN = "" 32 | DEFAULT_IM_START_TOKEN = "" 33 | DEFAULT_IM_END_TOKEN = "" 34 | 35 | 36 | 37 | 38 | detail_describe_instructions = [ 39 | "Describe the following image in detail.", 40 | "Provide a detailed description of the given image.", 41 | "Give an elaborate explanation of the image you see.", 42 | "Share a comprehensive rundown of the presented image.", 43 | "Offer a thorough analysis of the image.", 44 | "Explain the various aspects of the image before you.", 45 | "Clarify the contents of the displayed image with great detail.", 46 | "Characterize the image using a well-detailed description.", 47 | "Break down the elements of the image in a detailed manner.", 48 | "Walk through the important details of the image.", 49 | "Portray the image with a rich, descriptive narrative.", 50 | "Narrate the contents of the image with precision.", 51 | "Analyze the image in a comprehensive and detailed manner.", 52 | "Illustrate the image through a descriptive explanation.", 53 | "Examine the image closely and share its details.", 54 | "Write an exhaustive depiction of the given image.", 55 | ] 56 | 57 | concise_describe_instructions = [ 58 | "Describe the following image concisely.", 59 | "Provide a brief description of the given image.", 60 | "Offer a succinct explanation of the picture presented.", 61 | "Summarize the visual content of the following image.", 62 | "Give a short and clear explanation of the subsequent image.", 63 | "Share a concise interpretation of the image provided.", 64 | "Present a compact description of the photo's key features.", 65 | "Relay a brief, clear account of the picture shown.", 66 | "Render a clear and concise summary of the photo below.", 67 | "Write a terse but informative summary of the following picture.", 68 | "Create a compact narrative representing the image presented.", 69 | ] 70 | 71 | prompt_pool = detail_describe_instructions + concise_describe_instructions 72 | 73 | prompt_pool = [ "Describe the following image in detail."] 74 | 75 | 76 | def patch_config(config): 77 | patch_dict = { 78 | "use_mm_proj": True, 79 | "mm_vision_tower": "openai/clip-vit-large-patch14", 80 | "mm_hidden_size": 1024 81 | } 82 | 83 | cfg = AutoConfig.from_pretrained(config) 84 | if not hasattr(cfg, "mm_vision_tower"): 85 | print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') 86 | for k, v in patch_dict.items(): 87 | setattr(cfg, k, v) 88 | cfg.save_pretrained(config) 89 | 90 | 91 | # new stopping implementation 92 | class KeywordsStoppingCriteria(StoppingCriteria): 93 | def __init__(self, keywords, tokenizer, input_ids): 94 | self.keywords = keywords 95 | self.tokenizer = tokenizer 96 | self.start_len = None 97 | self.input_ids = input_ids 98 | 99 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 100 | if self.start_len is None: 101 | self.start_len = self.input_ids.shape[1] 102 | else: 103 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 104 | for keyword in self.keywords: 105 | if keyword in outputs: 106 | return True 107 | return False 108 | 109 | 110 | def eval_model(args): 111 | # Model 112 | disable_torch_init() 113 | model_name = os.path.expanduser(args.model_name) 114 | tokenizer = AutoTokenizer.from_pretrained(model_name) 115 | if args.mm_projector is None: 116 | patch_config(model_name) 117 | if "BiomedCLIP" in model_name or "biomed_clip" in model_name: 118 | model = LlavaLlamaForCausalLM.from_pretrained(model_name, use_cache=True).cuda() 119 | model = model.to(torch.float16) 120 | image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") 121 | 122 | openai_vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16") 123 | vision_config = openai_vision_tower.config 124 | vision_tower = model.model.vision_tower[0] 125 | vision_tower.to(device='cuda', dtype=torch.float16) 126 | setattr(vision_tower, 'config', vision_config) 127 | else: 128 | model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() 129 | image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) 130 | vision_tower = model.model.vision_tower[0] 131 | vision_tower.to(device='cuda', dtype=torch.float16) 132 | 133 | 134 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 135 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 136 | if mm_use_im_start_end: 137 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 138 | 139 | # import pdb; pdb.set_trace() 140 | vision_config = vision_tower.config 141 | vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] 142 | vision_config.use_im_start_end = mm_use_im_start_end 143 | if mm_use_im_start_end: 144 | vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 145 | image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 146 | else: 147 | # in case of using a pretrained model with only a MLP projector weights 148 | model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() 149 | 150 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 151 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 152 | if mm_use_im_start_end: 153 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 154 | 155 | vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda() 156 | 157 | if "BiomedCLIP" in model.config.mm_vision_tower: 158 | image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") 159 | else: 160 | image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) 161 | 162 | 163 | vision_config = vision_tower.config 164 | vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] 165 | vision_config.use_im_start_end = mm_use_im_start_end 166 | if mm_use_im_start_end: 167 | vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 168 | 169 | image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 170 | 171 | mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size) 172 | mm_projector_weights = torch.load(args.mm_projector, map_location='cpu') 173 | mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) 174 | 175 | model.model.mm_projector = mm_projector.cuda().half() 176 | model.model.vision_tower = [vision_tower] 177 | 178 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 179 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 180 | # questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 181 | answers_file = os.path.expanduser(args.answers_file) 182 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 183 | os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True) 184 | ans_file = open(answers_file, "w") 185 | save_image_folder = os.path.join(os.path.dirname(os.path.expanduser(args.answers_file)), "images") 186 | for i, line in enumerate(tqdm(questions)): 187 | idx = line["id"] 188 | qa_type = line["qa_type"] 189 | answer = line["answer"] 190 | # question = line['conversations'][0] 191 | # gt_ans = line["conversations"][1] 192 | 193 | # try: 194 | # question = line["conversations"][0] # ['value'].split('\n')[0] 195 | # gt_ans = line["conversations"][1] # ['value'] 196 | # except: 197 | # question = line["conversatons"][0] # ['value'].split('\n')[0] 198 | # gt_ans = line["conversatons"][1] # ['value'] 199 | 200 | # qs = question['value'] 201 | qs = line["question"] 202 | 203 | qs = qs.replace('', '').strip() 204 | cur_prompt = qs 205 | 206 | if 'image' in line: 207 | image_file = line["image"] 208 | image = Image.open(os.path.join(args.image_folder, image_file)) 209 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 210 | images = image_tensor.unsqueeze(0).half().cuda() 211 | if getattr(model.config, 'mm_use_im_start_end', False): 212 | qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN 213 | else: 214 | qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len 215 | cur_prompt = cur_prompt + '\n' + '' 216 | else: 217 | images = None 218 | 219 | if args.conv_mode == 'simple_legacy': 220 | qs += '\n\n### Response:' 221 | # assert gt_ans['from'] == 'gpt' 222 | # conv = default_conversation.copy() 223 | conv = conv_templates[args.conv_mode].copy() 224 | conv.append_message(conv.roles[0], qs) 225 | prompt = conv.get_prompt() 226 | inputs = tokenizer([prompt]) 227 | 228 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 229 | 230 | keywords = ['###'] 231 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 232 | 233 | with torch.inference_mode(): 234 | output_ids = model.generate( 235 | input_ids, 236 | images=images, 237 | do_sample=True, 238 | temperature=0.7, 239 | max_new_tokens=1024, 240 | stopping_criteria=[stopping_criteria]) 241 | 242 | # TODO: new implementation 243 | input_token_len = input_ids.shape[1] 244 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 245 | if n_diff_input_output > 0: 246 | print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') 247 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 248 | 249 | if args.conv_mode == 'simple_legacy': 250 | while True: 251 | cur_len = len(outputs) 252 | outputs = outputs.strip() 253 | for pattern in ['###', 'Assistant:', 'Response:']: 254 | if outputs.startswith(pattern): 255 | outputs = outputs[len(pattern):].strip() 256 | if len(outputs) == cur_len: 257 | break 258 | 259 | try: 260 | index = outputs.index(conv.sep) 261 | except ValueError: 262 | outputs += conv.sep 263 | index = outputs.index(conv.sep) 264 | 265 | outputs = outputs[:index].strip() 266 | 267 | # prompt for answer 268 | if args.answer_prompter: 269 | outputs_reasoning = outputs 270 | inputs = tokenizer([prompt + outputs_reasoning + ' ###\nANSWER:']) 271 | 272 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 273 | 274 | keywords = ['###'] 275 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 276 | 277 | with torch.inference_mode(): 278 | output_ids = model.generate( 279 | input_ids, 280 | images=images, 281 | do_sample=True, 282 | temperature=0.7, 283 | max_new_tokens=64, 284 | stopping_criteria=[stopping_criteria]) 285 | 286 | input_token_len = input_ids.shape[1] 287 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 288 | if n_diff_input_output > 0: 289 | print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') 290 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 291 | 292 | try: 293 | index = outputs.index(conv.sep) 294 | except ValueError: 295 | outputs += conv.sep 296 | index = outputs.index(conv.sep) 297 | 298 | outputs = outputs[:index].strip() 299 | outputs = outputs_reasoning + '\n The answer is ' + outputs 300 | 301 | # new implementation ends 302 | 303 | # original implementation 304 | # outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 305 | # try: 306 | # index = outputs.index(conv.sep, len(prompt)) 307 | # except ValueError: 308 | # outputs += conv.sep 309 | # index = outputs.index(conv.sep, len(prompt)) 310 | 311 | # outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 312 | 313 | 314 | ans_id = shortuuid.uuid() 315 | ans_file.write(json.dumps({"id": idx, 316 | "qa_type": qa_type, 317 | "question": cur_prompt, 318 | "gt_ans": answer, 319 | "response": outputs}) + "\n") 320 | # ans_file.write(json.dumps({"id": idx, 321 | # "prompt": cur_prompt, 322 | # "text": outputs, 323 | # "answer_id": ans_id, 324 | # "model_id": model_name, 325 | # "metadata": {}}) + "\n") 326 | ans_file.flush() 327 | ans_file.close() 328 | 329 | if __name__ == "__main__": 330 | parser = argparse.ArgumentParser() 331 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 332 | parser.add_argument("--image-folder", type=str, default="") 333 | parser.add_argument("--question-file", type=str, default="tables/question.json") 334 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 335 | parser.add_argument("--mm-projector", type=str, default=None) 336 | parser.add_argument("--vision-tower", type=str, default=None) 337 | parser.add_argument("--conv-mode", type=str, default="simple") 338 | parser.add_argument("--num-chunks", type=int, default=1) 339 | parser.add_argument("--chunk-idx", type=int, default=0) 340 | parser.add_argument("--answer-prompter", action="store_true") 341 | args = parser.parse_args() 342 | 343 | eval_model(args) 344 | -------------------------------------------------------------------------------- /eval/inference/LLaVA-Med/run_med_datasets_eval_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from concurrent.futures import ProcessPoolExecutor 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Parallel LLaVA evaluation script.') 8 | 9 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 10 | parser.add_argument("--image-folder", type=str, default="") 11 | parser.add_argument("--question-file", type=str, default="tables/question.json") 12 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 13 | parser.add_argument("--mm-projector", type=str, default=None) 14 | parser.add_argument("--vision-tower", type=str, default=None) 15 | parser.add_argument("--conv-mode", type=str, default="simple") 16 | parser.add_argument("--answer-prompter", action="store_true") 17 | parser.add_argument('--num-chunks', type=int, default=1, help='Number of chunks (default: 1).') 18 | parser.add_argument("--chunk-idx", type=int, default=0) 19 | args = parser.parse_args() 20 | 21 | return parser.parse_args() 22 | 23 | def run_job(chunk_idx, args): 24 | 25 | cmd = ("CUDA_VISIBLE_DEVICES={chunk_idx} python llava/eval/model_vqa_med.py " 26 | "--model-name {model_name} " 27 | "--question-file {question_file} " 28 | "--image-folder {image_folder} " 29 | "--answers-file {experiment_name_with_split}-chunk{chunk_idx}.jsonl " 30 | "--num-chunks {chunks} " 31 | "--chunk-idx {chunk_idx} ").format( 32 | chunk_idx=chunk_idx, 33 | chunks=args.num_chunks, 34 | model_name=args.model_name, 35 | question_file=args.question_file, 36 | image_folder=args.image_folder, 37 | experiment_name_with_split=args.experiment_name_with_split 38 | ) 39 | 40 | print(cmd) 41 | 42 | subprocess.run(cmd, shell=True, check=True) 43 | 44 | def main(): 45 | args = parse_args() 46 | args.experiment_name_with_split = args.answers_file.split(".jsonl")[0] 47 | 48 | # Create a partial function that accepts only `chunk_idx` 49 | from functools import partial 50 | run_job_with_args = partial(run_job, args=args) 51 | 52 | # Run the jobs in parallel using ProcessPoolExecutor 53 | with ProcessPoolExecutor(max_workers=args.num_chunks) as executor: 54 | list(executor.map(run_job_with_args, range(args.num_chunks))) # Use run_job_with_args instead of lambda 55 | # list(executor.map(run_job_with_args, range(1,4))) # Use run_job_with_args instead of lambda 56 | 57 | # Gather the results 58 | output_file = f"{args.experiment_name_with_split}.jsonl" 59 | with open(output_file, 'w') as outfile: 60 | for idx in range(args.num_chunks): 61 | # for idx in range(1,4): 62 | with open(f"{args.experiment_name_with_split}-chunk{idx}.jsonl") as infile: 63 | outfile.write(infile.read()) 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /eval/inference/LLaVA/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | # questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | # questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 39 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 40 | # questions = get_chunk(questions, args.num_chunks, args.chunk_idx-1) 41 | answers_file = os.path.expanduser(args.answers_file) 42 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 43 | ans_file = open(answers_file, "w") 44 | for line in tqdm(questions): 45 | idx = line["id"] 46 | qa_type = line["qa_type"] 47 | answer = line["answer"] 48 | qs = line["question"] 49 | image_file = line["image"] 50 | # qs = line["text"] 51 | cur_prompt = qs 52 | if model.config.mm_use_im_start_end: 53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 54 | else: 55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 56 | 57 | conv = conv_templates[args.conv_mode].copy() 58 | conv.append_message(conv.roles[0], qs) 59 | conv.append_message(conv.roles[1], None) 60 | prompt = conv.get_prompt() 61 | 62 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 63 | 64 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 65 | image_tensor = process_images([image], image_processor, model.config)[0] 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor.unsqueeze(0).half().cuda(), 71 | image_sizes=[image.size], 72 | do_sample=True if args.temperature > 0 else False, 73 | temperature=args.temperature, 74 | top_p=args.top_p, 75 | num_beams=args.num_beams, 76 | # no_repeat_ngram_size=3, 77 | max_new_tokens=1024, 78 | use_cache=True) 79 | 80 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 81 | 82 | ans_id = shortuuid.uuid() 83 | # ans_file.write(json.dumps({"question_id": idx, 84 | # "prompt": cur_prompt, 85 | # "text": outputs, 86 | # "answer_id": ans_id, 87 | # "model_id": model_name, 88 | # "metadata": {}}) + "\n") 89 | ans_file.write(json.dumps({"id": idx, 90 | "qa_type": qa_type, 91 | "question": cur_prompt, 92 | "gt_ans": answer, 93 | "response": outputs}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /eval/inference/LLaVA/run_eval_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from concurrent.futures import ProcessPoolExecutor 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Parallel LLaVA evaluation script.') 8 | 9 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 10 | parser.add_argument("--image-folder", type=str, default="") 11 | parser.add_argument("--question-file", type=str, default="tables/question.json") 12 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 13 | parser.add_argument("--mm-projector", type=str, default=None) 14 | parser.add_argument("--vision-tower", type=str, default=None) 15 | parser.add_argument("--conv-mode", type=str, default="simple") 16 | parser.add_argument("--answer-prompter", action="store_true") 17 | parser.add_argument('--num-chunks', type=int, default=1, help='Number of chunks (default: 1).') 18 | parser.add_argument("--chunk-idx", type=int, default=0) 19 | args = parser.parse_args() 20 | 21 | return parser.parse_args() 22 | 23 | def run_job(chunk_idx, args): 24 | 25 | cmd = ("CUDA_VISIBLE_DEVICES={chunk_idx} python llava/eval/model_vqa.py " 26 | "--model-path {model_name} " 27 | "--question-file {question_file} " 28 | "--image-folder {image_folder} " 29 | "--answers-file {experiment_name_with_split}-chunk{chunk_idx}.jsonl " 30 | "--num-chunks {chunks} " 31 | "--chunk-idx {chunk_idx} ").format( 32 | chunk_idx=chunk_idx, 33 | chunks=args.num_chunks, 34 | model_name=args.model_name, 35 | question_file=args.question_file, 36 | image_folder=args.image_folder, 37 | experiment_name_with_split=args.experiment_name_with_split 38 | ) 39 | 40 | print(cmd) 41 | 42 | subprocess.run(cmd, shell=True, check=True) 43 | 44 | def main(): 45 | args = parse_args() 46 | args.experiment_name_with_split = args.answers_file.split(".jsonl")[0] 47 | 48 | # Create a partial function that accepts only `chunk_idx` 49 | from functools import partial 50 | run_job_with_args = partial(run_job, args=args) 51 | 52 | # Run the jobs in parallel using ProcessPoolExecutor 53 | with ProcessPoolExecutor(max_workers=args.num_chunks) as executor: 54 | list(executor.map(run_job_with_args, range(args.num_chunks))) # Use run_job_with_args instead of lambda 55 | # list(executor.map(run_job_with_args, range(1,4))) # Use run_job_with_args instead of lambda 56 | 57 | # Gather the results 58 | output_file = f"{args.experiment_name_with_split}.jsonl" 59 | with open(output_file, 'w') as outfile: 60 | for idx in range(args.num_chunks): 61 | # for idx in range(1,4): 62 | with open(f"{args.experiment_name_with_split}-chunk{idx}.jsonl") as infile: 63 | outfile.write(infile.read()) 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /eval/inference/MiniGPTv2/eval_minigptv2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import re 5 | import json 6 | from collections import defaultdict 7 | import math 8 | 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import torch 13 | from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,IconQAEvalData,GQAEvalData,VSREvalData,HMEvalData 14 | from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA 15 | from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval 16 | 17 | from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser 18 | from minigpt4.conversation.conversation import CONV_VISION_minigptv2 19 | from minigpt4.common.config import Config 20 | 21 | 22 | def split_list(lst, n): 23 | """Split a list into n (roughly) equal-sized chunks""" 24 | chunk_size = math.ceil(len(lst) / n) # integer division 25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 26 | 27 | 28 | def get_chunk(lst, n, k): 29 | chunks = split_list(lst, n) 30 | return chunks[k] 31 | 32 | 33 | def eval_model(args): 34 | # Model 35 | model, vis_processor = init_model(args) 36 | conv_temp = CONV_VISION_minigptv2.copy() 37 | conv_temp.system = "" 38 | model.eval() 39 | 40 | # questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 41 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 42 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 43 | answers_file = os.path.expanduser(args.answers_file) 44 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 45 | ans_file = open(answers_file, "w") 46 | for line in tqdm(questions): 47 | idx = line["id"] 48 | image_file = line["image"] 49 | # add minigptv2 tag 50 | # qs = ['[caption] ' + line["question"]] 51 | qs = ['[vqa] ' + line["question"]] 52 | idx = line["id"] 53 | qa_type = line["qa_type"] 54 | answer = line["answer"] 55 | image = Image.open(args.image_folder + image_file).convert('RGB') 56 | image = vis_processor(image) 57 | texts = prepare_texts(qs, conv_temp) # warp the texts with conversation template 58 | with torch.no_grad(): 59 | answers = model.generate(torch.tensor(np.array([image])), texts, max_new_tokens=256, do_sample=False) 60 | 61 | ans_file.write(json.dumps({"id": idx, 62 | "qa_type": qa_type, 63 | "question": qs, 64 | "gt_ans": answer, 65 | "response": answers[0]}) + "\n") 66 | ans_file.flush() 67 | ans_file.close() 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 72 | parser.add_argument("--name", type=str, default='A2', help="evaluation name") 73 | parser.add_argument("--ckpt", type=str, help="path to configuration file.") 74 | parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") 75 | parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens") 76 | parser.add_argument("--batch_size", type=int, default=32) 77 | parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") 78 | parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") 79 | parser.add_argument( 80 | "--options", 81 | nargs="+", 82 | help="override some settings in the used config, the key-value pair " 83 | "in xxx=yyy format will be merged into config file (deprecate), " 84 | "change to --cfg-options instead.", 85 | ) 86 | parser.add_argument("--image-folder", type=str, default="") 87 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 88 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 89 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 90 | parser.add_argument("--num-chunks", type=int, default=1) 91 | parser.add_argument("--chunk-idx", type=int, default=0) 92 | parser.add_argument("--gpu-id", type=int, default=0) 93 | parser.add_argument("--temperature", type=float, default=0.2) 94 | parser.add_argument("--top_p", type=float, default=None) 95 | parser.add_argument("--num_beams", type=int, default=1) 96 | args = parser.parse_args() 97 | 98 | eval_model(args) 99 | -------------------------------------------------------------------------------- /eval/inference/MiniGPTv2/run_eval_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from concurrent.futures import ProcessPoolExecutor 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Parallel Minigptv2 evaluation script.') 8 | 9 | parser.add_argument("--cfg-path", type=str, default="eval_configs/minigptv2_eval.yaml") 10 | parser.add_argument("--image-folder", type=str, default="") 11 | parser.add_argument("--question-file", type=str, default="tables/question.json") 12 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 13 | parser.add_argument("--mm-projector", type=str, default=None) 14 | parser.add_argument("--vision-tower", type=str, default=None) 15 | parser.add_argument("--conv-mode", type=str, default="simple") 16 | parser.add_argument("--answer-prompter", action="store_true") 17 | parser.add_argument('--num-chunks', type=int, default=1, help='Number of chunks (default: 1).') 18 | parser.add_argument("--chunk-idx", type=int, default=0) 19 | args = parser.parse_args() 20 | 21 | return parser.parse_args() 22 | 23 | def run_job(chunk_idx, args): 24 | 25 | cmd = ("CUDA_VISIBLE_DEVICES={chunk_idx} python eval_minigptv2.py " 26 | "--cfg-path {cfg_path} " 27 | "--question-file {question_file} " 28 | "--image-folder {image_folder} " 29 | "--answers-file {experiment_name_with_split}-chunk{chunk_idx}.jsonl " 30 | "--num-chunks {chunks} " 31 | "--chunk-idx {chunk_idx} " 32 | "--gpu-id {gpu_id} ").format( 33 | cfg_path=args.cfg_path, 34 | gpu_id=chunk_idx, 35 | chunk_idx=chunk_idx, 36 | chunks=args.num_chunks, 37 | question_file=args.question_file, 38 | image_folder=args.image_folder, 39 | experiment_name_with_split=args.experiment_name_with_split 40 | ) 41 | 42 | print(cmd) 43 | 44 | subprocess.run(cmd, shell=True, check=True) 45 | 46 | def main(): 47 | args = parse_args() 48 | args.experiment_name_with_split = args.answers_file.split(".jsonl")[0] 49 | 50 | # Create a partial function that accepts only `chunk_idx` 51 | from functools import partial 52 | run_job_with_args = partial(run_job, args=args) 53 | 54 | # Run the jobs in parallel using ProcessPoolExecutor 55 | with ProcessPoolExecutor(max_workers=args.num_chunks) as executor: 56 | list(executor.map(run_job_with_args, range(args.num_chunks))) # Use run_job_with_args instead of lambda 57 | # list(executor.map(run_job_with_args, range(1,4))) # Use run_job_with_args instead of lambda 58 | 59 | # Gather the results 60 | output_file = f"{args.experiment_name_with_split}.jsonl" 61 | with open(output_file, 'w') as outfile: 62 | for idx in range(args.num_chunks): 63 | # for idx in range(1,4): 64 | with open(f"{args.experiment_name_with_split}-chunk{idx}.jsonl") as infile: 65 | outfile.write(infile.read()) 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /eval/model_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | model_name=$1 5 | 6 | if [ "${model_name}" == "llavamed" ]; then 7 | cd ../LLaVA-Med 8 | elif [ "${model_name}" == "llava_v1" ]; then 9 | cd ../LLaVA 10 | elif [ "${model_name}" == "llava_v1.6" ]; then 11 | cd ../LLaVA 12 | elif [ "${model_name}" == "minigptv2" ]; then 13 | cd ../MiniGPT-4 14 | elif [ "${model_name}" == "chexagent" ]; then 15 | cd ../CheXagent 16 | elif [ "${model_name}" == "gpt4v" ]; then 17 | cd ../gpt4V 18 | elif [ "${model_name}" == "gemini" ]; then 19 | cd ../gemini 20 | elif [ "${model_name}" == "gpt4o" ]; then 21 | cd ../gpt4V 22 | elif [ "${model_name}" == "med-flamingo" ]; then 23 | cd ../med-flamingo 24 | elif [ "${model_name}" == "biomedgpt" ]; then 25 | cd ../BiomedGPT 26 | fi 27 | 28 | echo "==========================================" 29 | 30 | # inference for probmed results 31 | question_file="path to question file" # */probmed.json 32 | answer_file="./response_file/${model_name}" 33 | answer_file_json="./response_file/${model_name}.json" 34 | 35 | # uncomment the following block if you are running inference for ablation study 36 | # question_file="/data3/qianqi/medHVL/vqa/ablation/ablation_question.json" 37 | # answer_file="/data3/qianqi/medHVL/vqa/ablation/${model_name}" 38 | # answer_file_json="/data3/qianqi/medHVL/vqa/ablation/${model_name}.json" 39 | 40 | if [ "${model_name}" == "llavamed" ]; then 41 | python llava/eval/run_med_datasets_eval_batch.py \ 42 | --num-chunks 4 \ 43 | --model-name /model_weights/llavamed/llava_med_in_text_60k \ 44 | --question-file ${question_file} \ 45 | --answers-file ${answer_file} 46 | 47 | rm ${answer_file}-* 48 | 49 | elif [ "${model_name}" == "llava_v1" ]; then 50 | python llava/eval/run_eval_batch.py \ 51 | --num-chunks 4 \ 52 | --model-name /model_weights/llava/llava_v1 \ 53 | --image-folder ${image_foler} \ 54 | --question-file ${question_file} \ 55 | --answers-file ${answer_file} 56 | 57 | rm ${answer_file}-* 58 | 59 | elif [ "${model_name}" == "llava_v1.6" ]; then 60 | python llava/eval/run_eval_batch.py \ 61 | --num-chunks 4 \ 62 | --model-name /model_weights/llava/llava-v1.6-vicuna-7b \ 63 | --image-folder ${image_foler} \ 64 | --question-file ${question_file} \ 65 | --answers-file ${answer_file} 66 | 67 | rm ${answer_file}-* 68 | 69 | elif [ "${model_name}" == "minigptv2" ]; then 70 | python run_eval_batch.py \ 71 | --num-chunks 4 \ 72 | --cfg-path eval_configs/minigptv2_eval.yaml \ 73 | --image-folder ${image_foler} \ 74 | --question-file ${question_file} \ 75 | --answers-file ${answer_file} 76 | 77 | rm ${answer_file}-* 78 | 79 | elif [ "${model_name}" == "chexagent" ]; then 80 | python run_eval_batch.py \ 81 | --num-chunks 4 \ 82 | --image-folder ${image_foler} \ 83 | --question-file ${question_file} \ 84 | --answers-file ${answer_file} 85 | 86 | rm ${answer_file}-* 87 | 88 | elif [ "${model_name}" == "gpt4v" ]; then 89 | python gpt4v.py \ 90 | --image-folder ${image_foler} \ 91 | --question-file ${question_file} \ 92 | --answers-file ${answer_file_json} 93 | 94 | elif [ "${model_name}" == "gemini" ]; then 95 | python run_eval_batch.py \ 96 | --num-chunks 4 \ 97 | --image-folder ${image_foler} \ 98 | --question-file ${question_file} \ 99 | --answers-file ${answer_file} 100 | 101 | rm ${answer_file}-* 102 | rm ${answer_file}_* 103 | 104 | elif [ "${model_name}" == "gpt4o" ]; then 105 | python gpt4v.py \ 106 | --image-folder ${image_foler} \ 107 | --question-file ${question_file} \ 108 | --answers-file ${answer_file_json} 109 | 110 | elif [ "${model_name}" == "med-flamingo" ]; then 111 | python scripts/run_eval_batch.py \ 112 | --num-chunks 3 \ 113 | --question-file ${question_file} \ 114 | --answers-file ${answer_file} 115 | 116 | rm ${answer_file}-* 117 | 118 | elif [ "${model_name}" == "biomedgpt" ]; then 119 | python evaluate.py \ 120 | ablation.tsv \ 121 | --path /model_weights/biomedgpt_base.pt \ 122 | --user-dir module \ 123 | --task vqa_gen \ 124 | --batch-size 64 \ 125 | --log-format simple --log-interval 10 \ 126 | --seed 7 \ 127 | --gen-subset ablation \ 128 | --results-path ../ablation \ 129 | --fp16 \ 130 | --beam-search-vqa-eval \ 131 | --ema-eval \ 132 | --unnormalized \ 133 | --temperature 1.0 \ 134 | --num-workers 0 \ 135 | --model-overrides "{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}" 136 | 137 | fi 138 | 139 | echo "==========================================" 140 | -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/ProbMed/0268b5d7e3af795ba0b30c3710c0c44e4f90158c/image.png --------------------------------------------------------------------------------