├── LICENSE.txt ├── OCRBench ├── OCRBench │ ├── FullTest.json │ └── OCRBench.json ├── README.md ├── example.py ├── images │ ├── GPT4V_Gemini.png │ └── all_data.png └── scripts │ ├── GPT4V.py │ ├── Genimi.py │ ├── LLaVA1_5.py │ ├── MiniMonkey.py │ ├── blip2.py │ ├── blip2_vicuna_instruct.py │ ├── bliva.py │ ├── interlm.py │ ├── interlm2.py │ ├── internvl2_s │ ├── intervl.py │ ├── llavar.py │ ├── mPLUG-DocOwl15.py │ ├── mPLUG-owl.py │ ├── mPLUG-owl2.py │ ├── minigpt4v2.py │ ├── monkey.py │ ├── qwenvl.py │ └── qwenvl_api.py ├── OCRBench_v2 ├── README.md ├── eval_scripts │ ├── IoUscore_metric.py │ ├── TEDS_metric.py │ ├── __pycache__ │ │ ├── IoUscore_metric.cpython-310.pyc │ │ ├── TEDS_metric.cpython-310.pyc │ │ ├── page_ocr_metric.cpython-310.pyc │ │ ├── parallel.cpython-310.pyc │ │ ├── spotting_metric.cpython-310.pyc │ │ └── vqa_metric.cpython-310.pyc │ ├── eval.py │ ├── get_score.py │ ├── page_ocr_metric.py │ ├── parallel.py │ ├── spotting_eval │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── rrc_evaluation_funcs_1_1.cpython-310.pyc │ │ │ ├── rrc_evaluation_funcs_1_1.cpython-39.pyc │ │ │ ├── script.cpython-310.pyc │ │ │ └── script.cpython-39.pyc │ │ ├── gt.zip │ │ ├── gt │ │ │ └── gt_img_0.txt │ │ ├── readme.txt │ │ ├── results.zip │ │ ├── rrc_evaluation_funcs_1_1.py │ │ ├── script.py │ │ ├── script_test_ch4_t4_e1-1577983164.zip │ │ ├── submit.zip │ │ └── submit │ │ │ └── res_img_0.txt │ ├── spotting_metric.py │ └── vqa_metric.py ├── pred_folder │ └── internvl2_5_26b.json └── requirements.txt └── README.md /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yuliang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OCRBench/README.md: -------------------------------------------------------------------------------- 1 | # OCRBench: On the Hidden Mystery of OCR in Large Multimodal Models 2 | 3 | 4 | > Large models have recently played a dominant role in natural language processing and multimodal vision-language learning. However, their effectiveness in text-related visual tasks remains relatively unexplored. In this paper, we conducted a comprehensive evaluation of Large Multimodal Models, such as GPT4V and Gemini, in various text-related visual tasks including Text Recognition, Scene Text-Centric Visual Question Answering (VQA), Document-Oriented VQA, Key Information Extraction (KIE), and Handwritten Mathematical Expression Recognition (HMER). To facilitate the assessment of Optical Character Recognition (OCR) capabilities in Large Multimodal Models, we propose OCRBench, a comprehensive evaluation benchmark. Our study encompasses 29 datasets, making it the most comprehensive OCR evaluation benchmark available. Furthermore, our study reveals both the strengths and weaknesses of these models, particularly in handling multilingual text, handwritten text, non-semantic text, and mathematical expression recognition. Most importantly, the baseline results showcased in this study could provide a foundational framework for the conception and assessment of innovative strategies targeted at enhancing zero-shot multimodal techniques. 5 | 6 | **[Project Page [This Page]](https://github.com/Yuliang-Liu/MultimodalOCR)** | **[Paper](https://arxiv.org/abs/2305.07895)** |**[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**|**[Opencompass Leaderboard](https://rank.opencompass.org.cn/leaderboard-multimodal)**| 7 | 8 | 9 | # Data 10 | To reduce false positives, we filter out questions that have answers containing fewer than 4 symbols from all datasets. 11 | | Data | Link | Description | 12 | | --- | --- | --- | 13 | | Full Test Json | [Full Test](./OCRBench/FullTest.json) | This file contains the test data used in Table 1 and Table 2 from [Paper](https://arxiv.org/abs/2305.07895). | 14 | | OCRBench Json | [OCRBench](./OCRBench/OCRBench.json) | This file contains the test data in OCRBench used in Table3 from [Paper](https://arxiv.org/abs/2305.07895). | 15 | | All Test Images |[All Images](https://drive.google.com/file/d/1U5AtLoJ7FrJe9yfcbssfeLmlKb7dTosc/view?usp=drive_link) | This file contains all the testing images used in [Paper](https://arxiv.org/abs/2305.07895), including OCRBench Images.| 16 | | OCRBench Images | [OCRBench Images](https://drive.google.com/file/d/1a3VRJx3V3SdOmPr7499Ky0Ug8AwqGUHO/view?usp=drive_link) | This file only contains the images used in OCRBench. | 17 | | Test Results | [Test Results](https://drive.google.com/drive/folders/15XlHCuNTavI1Ihqm4G7u3J34BHpkaqyE?usp=drive_link) | This file file contains the result files for the test models. | 18 | 19 | 20 | # OCRBench 21 | 22 | OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models. It comprises five components: Text Recognition, SceneText-Centric VQA, Document-Oriented VQA, Key Information Extraction, and Handwritten Mathematical Expression Recognition. The benchmark includes 1000 question-answer pairs, and all the answers undergo manual verification and correction to ensure a more precise evaluation. 23 | 24 | You can find the results of Large Multimodal Models in **[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**, if you would like to include your model in the OCRBench leaderboard, please follow the evaluation instructions provided below and feel free to contact us via email at zhangli123@hust.edu.cn. We will update the leaderboard in time. 25 | 26 | 27 | 28 | # Evaluation 29 | The test code for evaluating models in the paper can be found in [scripts](./scripts). Before conducting the evaluation, you need to configure the model weights and environment based on the official code link provided in the scripts. If you want to evaluate other models, please edit the "TODO" things in [example](./example.py). 30 | 31 | You can also use [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for evaluation. 32 | 33 | Example evaluation scripts: 34 | ```python 35 | 36 | python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/OCRBench.json --save_name Monkey_OCRBench --num_workers GPU_Nums # Test on OCRBench 37 | python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/FullTest.json --save_name Monkey_FullTest --num_workers GPU_Nums # Full Test 38 | 39 | ``` 40 | 41 | # Citation 42 | If you wish to refer to the baseline results published here, please use the following BibTeX entries: 43 | ```BibTeX 44 | @article{Liu_2024, 45 | title={OCRBench: on the hidden mystery of OCR in large multimodal models}, 46 | volume={67}, 47 | ISSN={1869-1919}, 48 | url={http://dx.doi.org/10.1007/s11432-024-4235-6}, 49 | DOI={10.1007/s11432-024-4235-6}, 50 | number={12}, 51 | journal={Science China Information Sciences}, 52 | publisher={Springer Science and Business Media LLC}, 53 | author={Liu, Yuliang and Li, Zhang and Huang, Mingxin and Yang, Biao and Yu, Wenwen and Li, Chunyuan and Yin, Xu-Cheng and Liu, Cheng-Lin and Jin, Lianwen and Bai, Xiang}, 54 | year={2024}, 55 | month=dec } 56 | ``` 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /OCRBench/example.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | # TODO model packages import 13 | # from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | def split_list(lst, n): 16 | length = len(lst) 17 | avg = length // n # 每份的大小 18 | result = [] # 存储分割后的子列表 19 | for i in range(n - 1): 20 | result.append(lst[i*avg:(i+1)*avg]) 21 | result.append(lst[(n-1)*avg:]) 22 | return result 23 | 24 | def save_json(json_list,save_path): 25 | with open(save_path, 'w') as file: 26 | json.dump(json_list, file,indent=4) 27 | 28 | def _get_args(): 29 | parser = ArgumentParser() 30 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 31 | parser.add_argument("--output_folder", type=str, default="./results") 32 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 33 | parser.add_argument("--model_path", type=str, default="")#TODO Set the address of your model's weights 34 | parser.add_argument("--save_name", type=str, default="") #TODO Set the name of the JSON file you save in the output_folder. 35 | parser.add_argument("--num_workers", type=int, default=8) 36 | args = parser.parse_args() 37 | return args 38 | 39 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 40 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 41 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 42 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 43 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 44 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 45 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 46 | 47 | def eval_worker(args, data, eval_id, output_queue): 48 | print(f"Process {eval_id} start.") 49 | checkpoint = args.model_path 50 | 51 | # TODO model init 52 | 53 | # model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval() 54 | # tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 55 | # tokenizer.padding_side = 'left' 56 | # tokenizer.pad_token_id = tokenizer.eod_id 57 | 58 | for i in tqdm(range(len(data))): 59 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 60 | qs = data[i]['question'] 61 | 62 | # TODO Generation process 63 | # query = f'{img_path} {qs} Answer: ' 64 | 65 | # input_ids = tokenizer(query, return_tensors='pt', padding='longest') 66 | # attention_mask = input_ids.attention_mask 67 | # input_ids = input_ids.input_ids 68 | 69 | # pred = model.generate( 70 | # input_ids=input_ids.to(f'cuda:{eval_id}'), 71 | # attention_mask=attention_mask.to(f'cuda:{eval_id}'), 72 | # do_sample=False, 73 | # num_beams=1, 74 | # max_new_tokens=100, 75 | # min_new_tokens=1, 76 | # length_penalty=1, 77 | # num_return_sequences=1, 78 | # output_hidden_states=True, 79 | # use_cache=True, 80 | # pad_token_id=tokenizer.eod_id, 81 | # eos_token_id=tokenizer.eod_id, 82 | # ) 83 | # response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() 84 | data[i]['predict'] = response 85 | output_queue.put({eval_id: data}) 86 | print(f"Process {eval_id} has completed.") 87 | 88 | if __name__=="__main__": 89 | multiprocessing.set_start_method('spawn') 90 | args = _get_args() 91 | 92 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 93 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 94 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 95 | else: 96 | data_path = args.OCRBench_file 97 | 98 | with open(data_path, "r") as f: 99 | data = json.load(f) 100 | 101 | data_list = split_list(data, args.num_workers) 102 | 103 | output_queue = Manager().Queue() 104 | 105 | pool = Pool(processes=args.num_workers) 106 | for i in range(len(data_list)): 107 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 108 | pool.close() 109 | pool.join() 110 | 111 | 112 | results = {} 113 | while not output_queue.empty(): 114 | result = output_queue.get() 115 | results.update(result) 116 | data = [] 117 | for i in range(len(data_list)): 118 | data.extend(results[i]) 119 | 120 | for i in range(len(data)): 121 | data_type = data[i]["type"] 122 | dataset_name = data[i]["dataset_name"] 123 | answers = data[i]["answers"] 124 | if data[i].get('predict',0)==0: 125 | continue 126 | predict = data[i]['predict'] 127 | data[i]['result'] = 0 128 | if dataset_name == "HME100k": 129 | if type(answers)==list: 130 | for j in range(len(answers)): 131 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 132 | predict = predict.strip().replace("\n"," ").replace(" ","") 133 | if answer in predict: 134 | data[i]['result'] = 1 135 | else: 136 | answers = answers.strip().replace("\n"," ").replace(" ","") 137 | predict = predict.strip().replace("\n"," ").replace(" ","") 138 | if answers in predict: 139 | data[i]['result'] = 1 140 | else: 141 | if type(answers)==list: 142 | for j in range(len(answers)): 143 | answer = answers[j].lower().strip().replace("\n"," ") 144 | predict = predict.lower().strip().replace("\n"," ") 145 | if answer in predict: 146 | data[i]['result'] = 1 147 | else: 148 | answers = answers.lower().strip().replace("\n"," ") 149 | predict = predict.lower().strip().replace("\n"," ") 150 | if answers in predict: 151 | data[i]['result'] = 1 152 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 153 | if len(data)==1000: 154 | for i in range(len(data)): 155 | if data[i].get("result",100)==100: 156 | continue 157 | OCRBench_score[data[i]['type']] += data[i]['result'] 158 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 159 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 160 | print("###########################OCRBench##############################") 161 | print(f"Text Recognition(Total 300):{recognition_score}") 162 | print("------------------Details of Recognition Score-------------------") 163 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 164 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 165 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 166 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 167 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 168 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 169 | print("----------------------------------------------------------------") 170 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 171 | print("----------------------------------------------------------------") 172 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 173 | print("----------------------------------------------------------------") 174 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 175 | print("----------------------------------------------------------------") 176 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 177 | print("----------------------Final Score-------------------------------") 178 | print(f"Final Score(Total 1000): {Final_score}") 179 | else: 180 | for i in range(len(data)): 181 | num_all[data[i]['dataset_name']] += 1 182 | if data[i].get("result",100)==100: 183 | continue 184 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 185 | for key in AllDataset_score.keys(): 186 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 187 | -------------------------------------------------------------------------------- /OCRBench/images/GPT4V_Gemini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench/images/GPT4V_Gemini.png -------------------------------------------------------------------------------- /OCRBench/images/all_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench/images/all_data.png -------------------------------------------------------------------------------- /OCRBench/scripts/GPT4V.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import requests 3 | from tqdm import tqdm 4 | import json 5 | from PIL import Image 6 | import random 7 | import time 8 | import pathlib 9 | import textwrap 10 | from argparse import ArgumentParser 11 | import google.generativeai as genai 12 | import json 13 | from PIL import Image 14 | from IPython.display import display 15 | from IPython.display import Markdown 16 | from tqdm import tqdm 17 | import os 18 | def encode_image(image_path): 19 | with open(image_path, "rb") as image_file: 20 | return base64.b64encode(image_file.read()).decode('utf-8') 21 | def save_json(json_list,save_path): 22 | with open(save_path, 'w') as file: 23 | json.dump(json_list, file,indent=4) 24 | def _get_args(): 25 | parser = ArgumentParser() 26 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 27 | parser.add_argument("--output_path", type=str, default="./results") 28 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 29 | parser.add_argument("--OPENAI_API_KEY", type=str, default="") 30 | parser.add_argument("--API_BASE", type=str, default="") 31 | parser.add_argument("--model", type=str, default="gpt-4-vision-preview") 32 | args = parser.parse_args() 33 | return args 34 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 35 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 36 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 37 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 38 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 39 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 40 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 41 | 42 | if __name__ == "__main__": 43 | args = _get_args() 44 | 45 | if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")): 46 | data_path = os.path.join(args.output_path,f"{args.model}.json") 47 | else: 48 | data_path = args.OCRBench_file 49 | 50 | with open(data_path,"r") as f: 51 | data = json.load(f) 52 | for i in tqdm(range(len(data))): 53 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 54 | question = data[i]['question'] 55 | if data[i].get("predict", 0)!=0: 56 | print(f"{img_path} predict exist, continue.") 57 | continue 58 | base64_image = encode_image(img_path) 59 | headers = { 60 | "Content-Type": "application/json", 61 | "Authorization": f"Bearer {args.OPENAI_API_KEY}" 62 | } 63 | payload = { 64 | "model": args.model, 65 | "messages": [ 66 | { 67 | "role": "user", 68 | "content": [ 69 | { 70 | "type": "text", 71 | "text": f"{question}" 72 | }, 73 | { 74 | "type": "image_url", 75 | "image_url": { 76 | "url": f"data:image/jpeg;base64,{base64_image}" 77 | } 78 | } 79 | ] 80 | } 81 | ], 82 | "max_tokens": 500 83 | } 84 | try: 85 | response = requests.post(args.API_BASE, headers=headers, json=payload) 86 | print(response.json()) 87 | answer = response.json()['choices'][0]['message']['content'] 88 | data[i]['predict'] = answer 89 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 90 | except: 91 | time.sleep(100) 92 | print(f"{img_path} error") 93 | for i in range(len(data)): 94 | data_type = data[i]["type"] 95 | dataset_name = data[i]["dataset_name"] 96 | answers = data[i]["answers"] 97 | if data[i].get('predict',0)==0: 98 | continue 99 | predict = data[i]['predict'] 100 | data[i]['result'] = 0 101 | if dataset_name == "HME100k": 102 | if type(answers)==list: 103 | for j in range(len(answers)): 104 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 105 | predict = predict.strip().replace("\n"," ").replace(" ","") 106 | if answer in predict: 107 | data[i]['result'] = 1 108 | else: 109 | answers = answers.strip().replace("\n"," ").replace(" ","") 110 | predict = predict.strip().replace("\n"," ").replace(" ","") 111 | if answers in predict: 112 | data[i]['result'] = 1 113 | else: 114 | if type(answers)==list: 115 | for j in range(len(answers)): 116 | answer = answers[j].lower().strip().replace("\n"," ") 117 | predict = predict.lower().strip().replace("\n"," ") 118 | if answer in predict: 119 | data[i]['result'] = 1 120 | else: 121 | answers = answers.lower().strip().replace("\n"," ") 122 | predict = predict.lower().strip().replace("\n"," ") 123 | if answers in predict: 124 | data[i]['result'] = 1 125 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 126 | for i in range(len(data)): 127 | if data[i].get("result",100)==100: 128 | continue 129 | OCRBench_score[data[i]['type']] += data[i]['result'] 130 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 131 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 132 | print("###########################OCRBench##############################") 133 | print(f"Text Recognition(Total 300):{recognition_score}") 134 | print("------------------Details of Recognition Score-------------------") 135 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 136 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 137 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 138 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 139 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 140 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 141 | print("----------------------------------------------------------------") 142 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 143 | print("----------------------------------------------------------------") 144 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 145 | print("----------------------------------------------------------------") 146 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 147 | print("----------------------------------------------------------------") 148 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 149 | print("----------------------Final Score-------------------------------") 150 | print(f"Final Score(Total 1000): {Final_score}") -------------------------------------------------------------------------------- /OCRBench/scripts/Genimi.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import textwrap 3 | from argparse import ArgumentParser 4 | import google.generativeai as genai 5 | import json 6 | from PIL import Image 7 | from IPython.display import display 8 | from IPython.display import Markdown 9 | from tqdm import tqdm 10 | import os 11 | import sys 12 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 13 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 14 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 15 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 16 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 17 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 18 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 19 | def save_json(json_list,save_path): 20 | with open(save_path, 'w') as file: 21 | json.dump(json_list, file,indent=4) 22 | def _get_args(): 23 | parser = ArgumentParser() 24 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 25 | parser.add_argument("--output_path", type=str, default="./results") 26 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 27 | parser.add_argument("--GOOGLE_API_KEY", type=str, default="") 28 | parser.add_argument("--model", type=str, default="gemini-pro-vision") 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | if __name__ == "__main__": 34 | args = _get_args() 35 | genai.configure(api_key=args.GOOGLE_API_KEY) 36 | model = genai.GenerativeModel(args.model) 37 | 38 | if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")): 39 | data_path = os.path.join(args.output_path,f"{args.model}.json") 40 | else: 41 | data_path = args.OCRBench_file 42 | 43 | with open(data_path, "r") as f: 44 | data = json.load(f) 45 | for i in tqdm(range(len(data))): 46 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 47 | question = data[i]['question'] 48 | if data[i].get("predict", 0)!=0: 49 | print(f"{img_path} predict exist, continue.") 50 | continue 51 | try: 52 | img = Image.open(img_path).convert("RGB") 53 | response = model.generate_content([question, img]) 54 | data[i]['predict'] = response.text 55 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 56 | except: 57 | print(f"{img_path}: API call failed.") 58 | for i in range(len(data)): 59 | data_type = data[i]["type"] 60 | dataset_name = data[i]["dataset_name"] 61 | answers = data[i]["answers"] 62 | if data[i].get('predict',0)==0: 63 | continue 64 | predict = data[i]['predict'] 65 | data[i]['result'] = 0 66 | if dataset_name == "HME100k": 67 | if type(answers)==list: 68 | for j in range(len(answers)): 69 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 70 | predict = predict.strip().replace("\n"," ").replace(" ","") 71 | if answer in predict: 72 | data[i]['result'] = 1 73 | else: 74 | answers = answers.strip().replace("\n"," ").replace(" ","") 75 | predict = predict.strip().replace("\n"," ").replace(" ","") 76 | if answers in predict: 77 | data[i]['result'] = 1 78 | else: 79 | if type(answers)==list: 80 | for j in range(len(answers)): 81 | answer = answers[j].lower().strip().replace("\n"," ") 82 | predict = predict.lower().strip().replace("\n"," ") 83 | if answer in predict: 84 | data[i]['result'] = 1 85 | else: 86 | answers = answers.lower().strip().replace("\n"," ") 87 | predict = predict.lower().strip().replace("\n"," ") 88 | if answers in predict: 89 | data[i]['result'] = 1 90 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 91 | for i in range(len(data)): 92 | if data[i].get("result",100)==100: 93 | continue 94 | OCRBench_score[data[i]['type']] += data[i]['result'] 95 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 96 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 97 | print("###########################OCRBench##############################") 98 | print(f"Text Recognition(Total 300):{recognition_score}") 99 | print("------------------Details of Recognition Score-------------------") 100 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 101 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 102 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 103 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 104 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 105 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 106 | print("----------------------------------------------------------------") 107 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 108 | print("----------------------------------------------------------------") 109 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 110 | print("----------------------------------------------------------------") 111 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 112 | print("----------------------------------------------------------------") 113 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 114 | print("----------------------Final Score-------------------------------") 115 | print(f"Final Score(Total 1000): {Final_score}") -------------------------------------------------------------------------------- /OCRBench/scripts/LLaVA1_5.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 13 | from llava.conversation import conv_templates, SeparatorStyle 14 | from llava.model.builder import load_pretrained_model 15 | from llava.utils import disable_torch_init 16 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 17 | 18 | # https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa_loader.py 19 | 20 | def split_list(lst, n): 21 | length = len(lst) 22 | avg = length // n # 每份的大小 23 | result = [] # 存储分割后的子列表 24 | for i in range(n - 1): 25 | result.append(lst[i*avg:(i+1)*avg]) 26 | result.append(lst[(n-1)*avg:]) 27 | return result 28 | 29 | def save_json(json_list,save_path): 30 | with open(save_path, 'w') as file: 31 | json.dump(json_list, file,indent=4) 32 | 33 | def _get_args(): 34 | parser = ArgumentParser() 35 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 36 | parser.add_argument("--output_folder", type=str, default="./results") 37 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 38 | parser.add_argument("--model_path", type=str, default="liuhaotian/llava-v1.5-7b") 39 | parser.add_argument("--model_base", type=str, default=None) 40 | parser.add_argument("--save_name", type=str, default="llava1_5_7b") 41 | parser.add_argument("--conv_mode", type=str, default="vicuna_v1") 42 | parser.add_argument("--num_workers", type=int, default=8) 43 | parser.add_argument("--temperature", type=float, default=0.0) 44 | parser.add_argument("--top_p", type=float, default=None) 45 | parser.add_argument("--num_beams", type=int, default=1) 46 | args = parser.parse_args() 47 | return args 48 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 49 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 50 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 51 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 52 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 53 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 54 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 55 | 56 | def eval_worker(args, data, eval_id, output_queue): 57 | print(f"Process {eval_id} start.") 58 | device = f"cuda:{eval_id}" 59 | disable_torch_init() 60 | model_path = os.path.expanduser(args.model_path) 61 | model_name = get_model_name_from_path(model_path) 62 | tokenizer, model, image_processor, context_len = load_pretrained_model( model_path = model_path, model_base = args.model_base, model_name = model_name,device = device) 63 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 64 | args.conv_mode = args.conv_mode + '_mmtag' 65 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 66 | for i in tqdm(range(len(data))): 67 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 68 | qs = data[i]['question'] 69 | qs = qs+"\nAnswer the question using a single word or phrase." 70 | if model.config.mm_use_im_start_end: 71 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 72 | else: 73 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 74 | conv = conv_templates[args.conv_mode].copy() 75 | conv.append_message(conv.roles[0], qs) 76 | conv.append_message(conv.roles[1], None) 77 | prompt = conv.get_prompt() 78 | 79 | image = Image.open(img_path).convert('RGB') 80 | image_tensor = process_images([image], image_processor, model.config) 81 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) 82 | if data[i].get("predict", 0)!=0: 83 | print(f"{img_path} predict exist, continue.") 84 | continue 85 | 86 | stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2 87 | input_ids = input_ids.to(device=device, non_blocking=True) 88 | with torch.inference_mode(): 89 | output_ids = model.generate( 90 | input_ids, 91 | images=image_tensor.to(dtype=torch.float16, device=device, non_blocking=True), 92 | do_sample=True if args.temperature > 0 else False, 93 | temperature=args.temperature, 94 | top_p=args.top_p, 95 | num_beams=args.num_beams, 96 | max_new_tokens=128, 97 | use_cache=True) 98 | 99 | input_token_len = input_ids.shape[1] 100 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 101 | if n_diff_input_output > 0: 102 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 103 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 104 | outputs = outputs.strip() 105 | if outputs.endswith(stop_str): 106 | outputs = outputs[:-len(stop_str)] 107 | outputs = outputs.strip() 108 | 109 | data[i]['predict'] = outputs 110 | output_queue.put({eval_id: data}) 111 | print(f"Process {eval_id} has completed.") 112 | 113 | if __name__=="__main__": 114 | multiprocessing.set_start_method('spawn') 115 | args = _get_args() 116 | 117 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 118 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 119 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 120 | else: 121 | data_path = args.OCRBench_file 122 | 123 | with open(data_path, "r") as f: 124 | data = json.load(f) 125 | 126 | data_list = split_list(data, args.num_workers) 127 | output_queue = Manager().Queue() 128 | 129 | pool = Pool(processes=args.num_workers) 130 | for i in range(len(data_list)): 131 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 132 | pool.close() 133 | pool.join() 134 | 135 | results = {} 136 | while not output_queue.empty(): 137 | result = output_queue.get() 138 | results.update(result) 139 | data = [] 140 | for i in range(len(data_list)): 141 | data.extend(results[i]) 142 | 143 | 144 | for i in range(len(data)): 145 | data_type = data[i]["type"] 146 | dataset_name = data[i]["dataset_name"] 147 | answers = data[i]["answers"] 148 | if data[i].get('predict',0)==0: 149 | continue 150 | predict = data[i]['predict'] 151 | data[i]['result'] = 0 152 | if dataset_name == "HME100k": 153 | if type(answers)==list: 154 | for j in range(len(answers)): 155 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 156 | predict = predict.strip().replace("\n"," ").replace(" ","") 157 | if answer in predict: 158 | data[i]['result'] = 1 159 | else: 160 | answers = answers.strip().replace("\n"," ").replace(" ","") 161 | predict = predict.strip().replace("\n"," ").replace(" ","") 162 | if answers in predict: 163 | data[i]['result'] = 1 164 | else: 165 | if type(answers)==list: 166 | for j in range(len(answers)): 167 | answer = answers[j].lower().strip().replace("\n"," ") 168 | predict = predict.lower().strip().replace("\n"," ") 169 | if answer in predict: 170 | data[i]['result'] = 1 171 | else: 172 | answers = answers.lower().strip().replace("\n"," ") 173 | predict = predict.lower().strip().replace("\n"," ") 174 | if answers in predict: 175 | data[i]['result'] = 1 176 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 177 | if len(data)==1000: 178 | for i in range(len(data)): 179 | if data[i].get("result",100)==100: 180 | continue 181 | OCRBench_score[data[i]['type']] += data[i]['result'] 182 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 183 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 184 | print("###########################OCRBench##############################") 185 | print(f"Text Recognition(Total 300):{recognition_score}") 186 | print("------------------Details of Recognition Score-------------------") 187 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 188 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 189 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 190 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 191 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 192 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 193 | print("----------------------------------------------------------------") 194 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 195 | print("----------------------------------------------------------------") 196 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 197 | print("----------------------------------------------------------------") 198 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 199 | print("----------------------------------------------------------------") 200 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 201 | print("----------------------Final Score-------------------------------") 202 | print(f"Final Score(Total 1000): {Final_score}") 203 | else: 204 | for i in range(len(data)): 205 | num_all[data[i]['dataset_name']] += 1 206 | if data[i].get("result",100)==100: 207 | continue 208 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 209 | for key in AllDataset_score.keys(): 210 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 211 | -------------------------------------------------------------------------------- /OCRBench/scripts/blip2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 13 | import torch 14 | # https://huggingface.co/Salesforce/blip2-opt-6.7b 15 | 16 | def split_list(lst, n): 17 | length = len(lst) 18 | avg = length // n # 每份的大小 19 | result = [] # 存储分割后的子列表 20 | for i in range(n - 1): 21 | result.append(lst[i*avg:(i+1)*avg]) 22 | result.append(lst[(n-1)*avg:]) 23 | return result 24 | 25 | def save_json(json_list,save_path): 26 | with open(save_path, 'w') as file: 27 | json.dump(json_list, file,indent=4) 28 | 29 | def _get_args(): 30 | parser = ArgumentParser() 31 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 32 | parser.add_argument("--output_folder", type=str, default="./results") 33 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 34 | parser.add_argument("--model_path", type=str, default="./model_weights/blip2-opt-6.7b") 35 | parser.add_argument("--save_name", type=str, default="blip2_opt_6_7b") 36 | parser.add_argument("--num_workers", type=int, default=8) 37 | args = parser.parse_args() 38 | return args 39 | 40 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 41 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 42 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 43 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 44 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 45 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 46 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 47 | 48 | def eval_worker(args, data, eval_id, output_queue): 49 | print(f"Process {eval_id} start.") 50 | processor = Blip2Processor.from_pretrained(args.model_path) 51 | model = Blip2ForConditionalGeneration.from_pretrained(args.model_path, load_in_8bit=False, device_map={"": eval_id}, torch_dtype=torch.float16) 52 | for i in tqdm(range(len(data))): 53 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 54 | qs = data[i]['question'] 55 | if data[i].get("predict", 0)!=0: 56 | print(f"{img_path} predict exist, continue.") 57 | continue 58 | image = Image.open(img_path).convert("RGB") 59 | prompt = f"Question: {qs} Answer:" 60 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=f"cuda:{eval_id}", dtype=torch.float16) 61 | generated_ids = model.generate(**inputs) 62 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 63 | data[i]['predict'] = generated_text 64 | output_queue.put({eval_id: data}) 65 | print(f"Process {eval_id} has completed.") 66 | 67 | if __name__=="__main__": 68 | multiprocessing.set_start_method('spawn') 69 | args = _get_args() 70 | 71 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 72 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 73 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 74 | else: 75 | data_path = args.OCRBench_file 76 | 77 | with open(data_path, "r") as f: 78 | data = json.load(f) 79 | 80 | data_list = split_list(data, args.num_workers) 81 | output_queue = Manager().Queue() 82 | 83 | pool = Pool(processes=args.num_workers) 84 | for i in range(len(data_list)): 85 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 86 | pool.close() 87 | pool.join() 88 | 89 | results = {} 90 | while not output_queue.empty(): 91 | result = output_queue.get() 92 | results.update(result) 93 | data = [] 94 | for i in range(len(data_list)): 95 | data.extend(results[i]) 96 | 97 | for i in range(len(data)): 98 | data_type = data[i]["type"] 99 | dataset_name = data[i]["dataset_name"] 100 | answers = data[i]["answers"] 101 | if data[i].get('predict',0)==0: 102 | continue 103 | predict = data[i]['predict'] 104 | data[i]['result'] = 0 105 | if dataset_name == "HME100k": 106 | if type(answers)==list: 107 | for j in range(len(answers)): 108 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 109 | predict = predict.strip().replace("\n"," ").replace(" ","") 110 | if answer in predict: 111 | data[i]['result'] = 1 112 | else: 113 | answers = answers.strip().replace("\n"," ").replace(" ","") 114 | predict = predict.strip().replace("\n"," ").replace(" ","") 115 | if answers in predict: 116 | data[i]['result'] = 1 117 | else: 118 | if type(answers)==list: 119 | for j in range(len(answers)): 120 | answer = answers[j].lower().strip().replace("\n"," ") 121 | predict = predict.lower().strip().replace("\n"," ") 122 | if answer in predict: 123 | data[i]['result'] = 1 124 | else: 125 | answers = answers.lower().strip().replace("\n"," ") 126 | predict = predict.lower().strip().replace("\n"," ") 127 | if answers in predict: 128 | data[i]['result'] = 1 129 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 130 | if len(data)==1000: 131 | for i in range(len(data)): 132 | if data[i].get("result",100)==100: 133 | continue 134 | OCRBench_score[data[i]['type']] += data[i]['result'] 135 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 136 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 137 | print("###########################OCRBench##############################") 138 | print(f"Text Recognition(Total 300):{recognition_score}") 139 | print("------------------Details of Recognition Score-------------------") 140 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 141 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 142 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 143 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 144 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 145 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 146 | print("----------------------------------------------------------------") 147 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 148 | print("----------------------------------------------------------------") 149 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 150 | print("----------------------------------------------------------------") 151 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 152 | print("----------------------------------------------------------------") 153 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 154 | print("----------------------Final Score-------------------------------") 155 | print(f"Final Score(Total 1000): {Final_score}") 156 | else: 157 | for i in range(len(data)): 158 | num_all[data[i]['dataset_name']] += 1 159 | if data[i].get("result",100)==100: 160 | continue 161 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 162 | for key in AllDataset_score.keys(): 163 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") -------------------------------------------------------------------------------- /OCRBench/scripts/blip2_vicuna_instruct.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 13 | 14 | # https://huggingface.co/Salesforce/instructblip-vicuna-7b 15 | def split_list(lst, n): 16 | length = len(lst) 17 | avg = length // n # 每份的大小 18 | result = [] # 存储分割后的子列表 19 | for i in range(n - 1): 20 | result.append(lst[i*avg:(i+1)*avg]) 21 | result.append(lst[(n-1)*avg:]) 22 | return result 23 | 24 | def save_json(json_list,save_path): 25 | with open(save_path, 'w') as file: 26 | json.dump(json_list, file,indent=4) 27 | 28 | def _get_args(): 29 | parser = ArgumentParser() 30 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 31 | parser.add_argument("--output_folder", type=str, default="./results") 32 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 33 | parser.add_argument("--model_path", type=str, default="./model_weights/instructblip-vicuna-7b") 34 | parser.add_argument("--save_name", type=str, default="instructblip_vicuna_7b") 35 | parser.add_argument("--num_workers", type=int, default=8) 36 | parser.add_argument("--temperature", type=float, default=0.0) 37 | args = parser.parse_args() 38 | return args 39 | 40 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 41 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 42 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 43 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 44 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 45 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 46 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 47 | 48 | def eval_worker(args, data, eval_id, output_queue): 49 | print(f"Process {eval_id} start.") 50 | device = f"cuda:{eval_id}" 51 | model = InstructBlipForConditionalGeneration.from_pretrained(args.model_path) 52 | processor = InstructBlipProcessor.from_pretrained(args.model_path) 53 | model.to(device) 54 | for i in tqdm(range(len(data))): 55 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 56 | qs = data[i]['question'] 57 | if data[i].get("predict", 0)!=0: 58 | print(f"{img_path} predict exist, continue.") 59 | continue 60 | image = Image.open(img_path).convert('RGB') 61 | inputs = processor(images=image, text=qs, return_tensors="pt").to(device) 62 | outputs = model.generate( 63 | **inputs, 64 | do_sample=False, 65 | num_beams=5, 66 | max_length=100, 67 | min_length=1, 68 | top_p=0.9, 69 | repetition_penalty=1.5, 70 | length_penalty=1.0, 71 | temperature=0, 72 | ) 73 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() 74 | data[i]['predict'] = generated_text 75 | output_queue.put({eval_id: data}) 76 | print(f"Process {eval_id} has completed.") 77 | 78 | if __name__=="__main__": 79 | multiprocessing.set_start_method('spawn') 80 | args = _get_args() 81 | 82 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 83 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 84 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 85 | else: 86 | data_path = args.OCRBench_file 87 | 88 | with open(data_path, "r") as f: 89 | data = json.load(f) 90 | 91 | data_list = split_list(data, args.num_workers) 92 | output_queue = Manager().Queue() 93 | 94 | pool = Pool(processes=args.num_workers) 95 | for i in range(len(data_list)): 96 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 97 | pool.close() 98 | pool.join() 99 | 100 | 101 | results = {} 102 | while not output_queue.empty(): 103 | result = output_queue.get() 104 | results.update(result) 105 | data = [] 106 | for i in range(len(data_list)): 107 | data.extend(results[i]) 108 | 109 | for i in range(len(data)): 110 | data_type = data[i]["type"] 111 | dataset_name = data[i]["dataset_name"] 112 | answers = data[i]["answers"] 113 | if data[i].get('predict',0)==0: 114 | continue 115 | predict = data[i]['predict'] 116 | data[i]['result'] = 0 117 | if dataset_name == "HME100k": 118 | if type(answers)==list: 119 | for j in range(len(answers)): 120 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 121 | predict = predict.strip().replace("\n"," ").replace(" ","") 122 | if answer in predict: 123 | data[i]['result'] = 1 124 | else: 125 | answers = answers.strip().replace("\n"," ").replace(" ","") 126 | predict = predict.strip().replace("\n"," ").replace(" ","") 127 | if answers in predict: 128 | data[i]['result'] = 1 129 | else: 130 | if type(answers)==list: 131 | for j in range(len(answers)): 132 | answer = answers[j].lower().strip().replace("\n"," ") 133 | predict = predict.lower().strip().replace("\n"," ") 134 | if answer in predict: 135 | data[i]['result'] = 1 136 | else: 137 | answers = answers.lower().strip().replace("\n"," ") 138 | predict = predict.lower().strip().replace("\n"," ") 139 | if answers in predict: 140 | data[i]['result'] = 1 141 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 142 | if len(data)==1000: 143 | for i in range(len(data)): 144 | if data[i].get("result",100)==100: 145 | continue 146 | OCRBench_score[data[i]['type']] += data[i]['result'] 147 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 148 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 149 | print("###########################OCRBench##############################") 150 | print(f"Text Recognition(Total 300):{recognition_score}") 151 | print("------------------Details of Recognition Score-------------------") 152 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 153 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 154 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 155 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 156 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 157 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 158 | print("----------------------------------------------------------------") 159 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 160 | print("----------------------------------------------------------------") 161 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 162 | print("----------------------------------------------------------------") 163 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 164 | print("----------------------------------------------------------------") 165 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 166 | print("----------------------Final Score-------------------------------") 167 | print(f"Final Score(Total 1000): {Final_score}") 168 | else: 169 | for i in range(len(data)): 170 | num_all[data[i]['dataset_name']] += 1 171 | if data[i].get("result",100)==100: 172 | continue 173 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 174 | for key in AllDataset_score.keys(): 175 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 176 | -------------------------------------------------------------------------------- /OCRBench/scripts/bliva.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | from bliva.models import load_model_and_preprocess 12 | import numpy as np 13 | 14 | # https://github.com/mlpc-ucsd/BLIVA/blob/main/evaluate.py 15 | 16 | def disable_torch_init(): 17 | """ 18 | Disable the redundant torch default initialization to accelerate model creation. 19 | """ 20 | import torch 21 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 22 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 23 | 24 | def split_list(lst, n): 25 | length = len(lst) 26 | avg = length // n # 每份的大小 27 | result = [] # 存储分割后的子列表 28 | for i in range(n - 1): 29 | result.append(lst[i*avg:(i+1)*avg]) 30 | result.append(lst[(n-1)*avg:]) 31 | return result 32 | 33 | def save_json(json_list,save_path): 34 | with open(save_path, 'w') as file: 35 | json.dump(json_list, file,indent=4) 36 | 37 | def _get_args(): 38 | parser = ArgumentParser() 39 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 40 | parser.add_argument("--output_folder", type=str, default="./results") 41 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 42 | parser.add_argument("--model_path", type=str, default="bliva_vicuna") 43 | parser.add_argument("--save_name", type=str, default="bliva") 44 | parser.add_argument("--num_workers", type=int, default=8) 45 | args = parser.parse_args() 46 | return args 47 | 48 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 49 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 50 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 51 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 52 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 53 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 54 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 55 | 56 | def eval_worker(args, data, eval_id, output_queue): 57 | print(f"Process {eval_id} start.") 58 | device = f"cuda:{eval_id}" 59 | np.random.seed(0) 60 | disable_torch_init() 61 | if "vicuna" in args.model_path.lower(): 62 | print("load bliva-vicuna") 63 | model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="vicuna7b", is_eval=True, device=device) 64 | if "flant5xxl" in args.model_path.lower(): 65 | print("load bliva-flant5xxl") 66 | model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="flant5xxl", is_eval=True, device=device) 67 | vis_processor = vis_processors["eval"] 68 | for i in tqdm(range(len(data))): 69 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 70 | qs = data[i]['question'] 71 | if data[i].get("predict", 0)!=0: 72 | print(f"{img_path} predict exist, continue.") 73 | continue 74 | image = Image.open(img_path).convert('RGB') 75 | question = [qs] 76 | image = vis_processor(image).unsqueeze(0).to(device) 77 | outputs = model.generate({"image": image, "prompt": qs}, max_length=150) 78 | data[i]['predict'] = outputs[0].split('### Assistant:')[0] 79 | output_queue.put({eval_id: data}) 80 | print(f"Process {eval_id} has completed.") 81 | 82 | if __name__=="__main__": 83 | multiprocessing.set_start_method('spawn') 84 | args = _get_args() 85 | 86 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 87 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 88 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 89 | else: 90 | data_path = args.OCRBench_file 91 | 92 | with open(data_path, "r") as f: 93 | data = json.load(f) 94 | 95 | data_list = split_list(data, args.num_workers) 96 | output_queue = Manager().Queue() 97 | 98 | pool = Pool(processes=args.num_workers) 99 | for i in range(len(data_list)): 100 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 101 | pool.close() 102 | pool.join() 103 | 104 | 105 | results = {} 106 | while not output_queue.empty(): 107 | result = output_queue.get() 108 | results.update(result) 109 | data = [] 110 | for i in range(len(data_list)): 111 | data.extend(results[i]) 112 | 113 | for i in range(len(data)): 114 | data_type = data[i]["type"] 115 | dataset_name = data[i]["dataset_name"] 116 | answers = data[i]["answers"] 117 | if data[i].get('predict',0)==0: 118 | continue 119 | predict = data[i]['predict'] 120 | data[i]['result'] = 0 121 | if dataset_name == "HME100k": 122 | if type(answers)==list: 123 | for j in range(len(answers)): 124 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 125 | predict = predict.strip().replace("\n"," ").replace(" ","") 126 | if answer in predict: 127 | data[i]['result'] = 1 128 | else: 129 | answers = answers.strip().replace("\n"," ").replace(" ","") 130 | predict = predict.strip().replace("\n"," ").replace(" ","") 131 | if answers in predict: 132 | data[i]['result'] = 1 133 | else: 134 | if type(answers)==list: 135 | for j in range(len(answers)): 136 | answer = answers[j].lower().strip().replace("\n"," ") 137 | predict = predict.lower().strip().replace("\n"," ") 138 | if answer in predict: 139 | data[i]['result'] = 1 140 | else: 141 | answers = answers.lower().strip().replace("\n"," ") 142 | predict = predict.lower().strip().replace("\n"," ") 143 | if answers in predict: 144 | data[i]['result'] = 1 145 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 146 | if len(data)==1000: 147 | for i in range(len(data)): 148 | if data[i].get("result",100)==100: 149 | continue 150 | OCRBench_score[data[i]['type']] += data[i]['result'] 151 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 152 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 153 | print("###########################OCRBench##############################") 154 | print(f"Text Recognition(Total 300):{recognition_score}") 155 | print("------------------Details of Recognition Score-------------------") 156 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 157 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 158 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 159 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 160 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 161 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 162 | print("----------------------------------------------------------------") 163 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 164 | print("----------------------------------------------------------------") 165 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 166 | print("----------------------------------------------------------------") 167 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 168 | print("----------------------------------------------------------------") 169 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 170 | print("----------------------Final Score-------------------------------") 171 | print(f"Final Score(Total 1000): {Final_score}") 172 | else: 173 | for i in range(len(data)): 174 | num_all[data[i]['dataset_name']] += 1 175 | if data[i].get("result",100)==100: 176 | continue 177 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 178 | for key in AllDataset_score.keys(): 179 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") -------------------------------------------------------------------------------- /OCRBench/scripts/interlm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | from transformers import AutoModel, AutoTokenizer 12 | # https://github.com/InternLM/InternLM-XComposer/tree/main/InternLM-XComposer-1.0 13 | def split_list(lst, n): 14 | length = len(lst) 15 | avg = length // n # 每份的大小 16 | result = [] # 存储分割后的子列表 17 | for i in range(n - 1): 18 | result.append(lst[i*avg:(i+1)*avg]) 19 | result.append(lst[(n-1)*avg:]) 20 | return result 21 | 22 | def save_json(json_list,save_path): 23 | with open(save_path, 'w') as file: 24 | json.dump(json_list, file,indent=4) 25 | 26 | def _get_args(): 27 | parser = ArgumentParser() 28 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 29 | parser.add_argument("--output_folder", type=str, default="./results") 30 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 31 | parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer-7b')#TODO Set the address of your model's weights 32 | parser.add_argument("--save_name", type=str, default="internlm-xcomposer-7b") #TODO Set the name of the JSON file you save in the output_folder. 33 | parser.add_argument("--num_workers", type=int, default=1) 34 | args = parser.parse_args() 35 | return args 36 | 37 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 38 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 39 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 40 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 41 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 42 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 43 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 44 | 45 | def eval_worker(args, data, eval_id, output_queue): 46 | print(f"Process {eval_id} start.") 47 | checkpoint = args.model_path 48 | 49 | torch.set_grad_enabled(False) 50 | 51 | # init model and tokenizer 52 | model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval() 53 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 54 | model.tokenizer = tokenizer 55 | 56 | for i in tqdm(range(len(data))): 57 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 58 | qs = data[i]['question'] 59 | response = model.generate(qs, img_path) 60 | data[i]['predict'] = response 61 | output_queue.put({eval_id: data}) 62 | print(f"Process {eval_id} has completed.") 63 | 64 | if __name__=="__main__": 65 | multiprocessing.set_start_method('spawn') 66 | args = _get_args() 67 | 68 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 69 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 70 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 71 | else: 72 | data_path = args.OCRBench_file 73 | 74 | with open(data_path, "r") as f: 75 | data = json.load(f) 76 | 77 | data_list = split_list(data, args.num_workers) 78 | 79 | output_queue = Manager().Queue() 80 | 81 | pool = Pool(processes=args.num_workers) 82 | for i in range(len(data_list)): 83 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 84 | pool.close() 85 | pool.join() 86 | 87 | results = {} 88 | while not output_queue.empty(): 89 | result = output_queue.get() 90 | results.update(result) 91 | data = [] 92 | for i in range(len(data_list)): 93 | data.extend(results[i]) 94 | 95 | for i in range(len(data)): 96 | data_type = data[i]["type"] 97 | dataset_name = data[i]["dataset_name"] 98 | answers = data[i]["answers"] 99 | if data[i].get('predict',0)==0: 100 | continue 101 | predict = data[i]['predict'] 102 | data[i]['result'] = 0 103 | if dataset_name == "HME100k": 104 | if type(answers)==list: 105 | for j in range(len(answers)): 106 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 107 | predict = predict.strip().replace("\n"," ").replace(" ","") 108 | if answer in predict: 109 | data[i]['result'] = 1 110 | else: 111 | answers = answers.strip().replace("\n"," ").replace(" ","") 112 | predict = predict.strip().replace("\n"," ").replace(" ","") 113 | if answers in predict: 114 | data[i]['result'] = 1 115 | else: 116 | if type(answers)==list: 117 | for j in range(len(answers)): 118 | answer = answers[j].lower().strip().replace("\n"," ") 119 | predict = predict.lower().strip().replace("\n"," ") 120 | if answer in predict: 121 | data[i]['result'] = 1 122 | else: 123 | answers = answers.lower().strip().replace("\n"," ") 124 | predict = predict.lower().strip().replace("\n"," ") 125 | if answers in predict: 126 | data[i]['result'] = 1 127 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 128 | if len(data)==1000: 129 | for i in range(len(data)): 130 | if data[i].get("result",100)==100: 131 | continue 132 | OCRBench_score[data[i]['type']] += data[i]['result'] 133 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 134 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 135 | print("###########################OCRBench##############################") 136 | print(f"Text Recognition(Total 300):{recognition_score}") 137 | print("------------------Details of Recognition Score-------------------") 138 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 139 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 140 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 141 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 142 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 143 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 144 | print("----------------------------------------------------------------") 145 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 146 | print("----------------------------------------------------------------") 147 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 148 | print("----------------------------------------------------------------") 149 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 150 | print("----------------------------------------------------------------") 151 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 152 | print("----------------------Final Score-------------------------------") 153 | print(f"Final Score(Total 1000): {Final_score}") 154 | else: 155 | for i in range(len(data)): 156 | num_all[data[i]['dataset_name']] += 1 157 | if data[i].get("result",100)==100: 158 | continue 159 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 160 | for key in AllDataset_score.keys(): 161 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 162 | -------------------------------------------------------------------------------- /OCRBench/scripts/interlm2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | from transformers import AutoModel, AutoTokenizer 12 | #https://github.com/InternLM/InternLM-XComposer/tree/main 13 | 14 | def split_list(lst, n): 15 | length = len(lst) 16 | avg = length // n # 每份的大小 17 | result = [] # 存储分割后的子列表 18 | for i in range(n - 1): 19 | result.append(lst[i*avg:(i+1)*avg]) 20 | result.append(lst[(n-1)*avg:]) 21 | return result 22 | 23 | def save_json(json_list,save_path): 24 | with open(save_path, 'w') as file: 25 | json.dump(json_list, file,indent=4) 26 | 27 | def _get_args(): 28 | parser = ArgumentParser() 29 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 30 | parser.add_argument("--output_folder", type=str, default="./results") 31 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 32 | parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer2-vl-7b')#TODO Set the address of your model's weights 33 | parser.add_argument("--save_name", type=str, default="internlm-xcomposer2-vl-7b") #TODO Set the name of the JSON file you save in the output_folder. 34 | parser.add_argument("--num_workers", type=int, default=1) 35 | args = parser.parse_args() 36 | return args 37 | 38 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 39 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 40 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 41 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 42 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 43 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 44 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 45 | 46 | def eval_worker(args, data, eval_id, output_queue): 47 | print(f"Process {eval_id} start.") 48 | checkpoint = args.model_path 49 | torch.set_grad_enabled(False) 50 | 51 | # init model and tokenizer 52 | model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval() 53 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 54 | 55 | for i in tqdm(range(len(data))): 56 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 57 | qs = data[i]['question'] 58 | text = f'{qs}' 59 | with torch.cuda.amp.autocast(): 60 | response, _ = model.chat(tokenizer, query=text, image=img_path, history=[], do_sample=False) 61 | data[i]['predict'] = response 62 | output_queue.put({eval_id: data}) 63 | print(f"Process {eval_id} has completed.") 64 | 65 | if __name__=="__main__": 66 | multiprocessing.set_start_method('spawn') 67 | args = _get_args() 68 | 69 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 70 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 71 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 72 | else: 73 | data_path = args.OCRBench_file 74 | 75 | with open(data_path, "r") as f: 76 | data = json.load(f) 77 | 78 | data_list = split_list(data, args.num_workers) 79 | 80 | output_queue = Manager().Queue() 81 | 82 | pool = Pool(processes=args.num_workers) 83 | for i in range(len(data_list)): 84 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 85 | pool.close() 86 | pool.join() 87 | 88 | results = {} 89 | while not output_queue.empty(): 90 | result = output_queue.get() 91 | results.update(result) 92 | data = [] 93 | for i in range(len(data_list)): 94 | data.extend(results[i]) 95 | 96 | for i in range(len(data)): 97 | data_type = data[i]["type"] 98 | dataset_name = data[i]["dataset_name"] 99 | answers = data[i]["answers"] 100 | if data[i].get('predict',0)==0: 101 | continue 102 | predict = data[i]['predict'] 103 | data[i]['result'] = 0 104 | if dataset_name == "HME100k": 105 | if type(answers)==list: 106 | for j in range(len(answers)): 107 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 108 | predict = predict.strip().replace("\n"," ").replace(" ","") 109 | if answer in predict: 110 | data[i]['result'] = 1 111 | else: 112 | answers = answers.strip().replace("\n"," ").replace(" ","") 113 | predict = predict.strip().replace("\n"," ").replace(" ","") 114 | if answers in predict: 115 | data[i]['result'] = 1 116 | else: 117 | if type(answers)==list: 118 | for j in range(len(answers)): 119 | answer = answers[j].lower().strip().replace("\n"," ") 120 | predict = predict.lower().strip().replace("\n"," ") 121 | if answer in predict: 122 | data[i]['result'] = 1 123 | else: 124 | answers = answers.lower().strip().replace("\n"," ") 125 | predict = predict.lower().strip().replace("\n"," ") 126 | if answers in predict: 127 | data[i]['result'] = 1 128 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 129 | if len(data)==1000: 130 | for i in range(len(data)): 131 | if data[i].get("result",100)==100: 132 | continue 133 | OCRBench_score[data[i]['type']] += data[i]['result'] 134 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 135 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 136 | print("###########################OCRBench##############################") 137 | print(f"Text Recognition(Total 300):{recognition_score}") 138 | print("------------------Details of Recognition Score-------------------") 139 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 140 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 141 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 142 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 143 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 144 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 145 | print("----------------------------------------------------------------") 146 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 147 | print("----------------------------------------------------------------") 148 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 149 | print("----------------------------------------------------------------") 150 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 151 | print("----------------------------------------------------------------") 152 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 153 | print("----------------------Final Score-------------------------------") 154 | print(f"Final Score(Total 1000): {Final_score}") 155 | else: 156 | for i in range(len(data)): 157 | num_all[data[i]['dataset_name']] += 1 158 | if data[i].get("result",100)==100: 159 | continue 160 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 161 | for key in AllDataset_score.keys(): 162 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 163 | -------------------------------------------------------------------------------- /OCRBench/scripts/intervl.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | from PIL import Image 12 | from transformers import AutoModel, CLIPImageProcessor 13 | from transformers import AutoTokenizer 14 | 15 | #https://github.com/OpenGVLab/InternVL 16 | 17 | def split_list(lst, n): 18 | length = len(lst) 19 | avg = length // n # 每份的大小 20 | result = [] # 存储分割后的子列表 21 | for i in range(n - 1): 22 | result.append(lst[i*avg:(i+1)*avg]) 23 | result.append(lst[(n-1)*avg:]) 24 | return result 25 | 26 | def save_json(json_list,save_path): 27 | with open(save_path, 'w') as file: 28 | json.dump(json_list, file,indent=4) 29 | 30 | def _get_args(): 31 | parser = ArgumentParser() 32 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 33 | parser.add_argument("--output_folder", type=str, default="./results") 34 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 35 | parser.add_argument("--model_path", type=str, default='OpenGVLab/InternVL-Chat-Chinese-V1-1')#TODO Set the address of your model's weights 36 | parser.add_argument("--save_name", type=str, default="InternVL-Chat-Chinese-V1-1") #TODO Set the name of the JSON file you save in the output_folder. 37 | parser.add_argument("--num_workers", type=int, default=1) 38 | args = parser.parse_args() 39 | return args 40 | 41 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 42 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 43 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 44 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 45 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 46 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 47 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 48 | 49 | def eval_worker(args, data, eval_id, output_queue): 50 | print(f"Process {eval_id} start.") 51 | checkpoint = args.model_path 52 | model = AutoModel.from_pretrained( 53 | checkpoint, 54 | torch_dtype=torch.bfloat16, 55 | low_cpu_mem_usage=True, 56 | trust_remote_code=True, 57 | device_map='cuda').eval() 58 | 59 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 60 | 61 | for i in tqdm(range(len(data))): 62 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 63 | qs = data[i]['question'] 64 | image = Image.open(img_path).convert('RGB') 65 | image = image.resize((448, 448)) 66 | image_processor = CLIPImageProcessor.from_pretrained(checkpoint) 67 | 68 | pixel_values = image_processor(images=image, return_tensors='pt').pixel_values 69 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 70 | 71 | generation_config = dict( 72 | num_beams=1, 73 | max_new_tokens=512, 74 | do_sample=False, 75 | ) 76 | response = model.chat(tokenizer, pixel_values, qs, generation_config) 77 | data[i]['predict'] = response 78 | output_queue.put({eval_id: data}) 79 | print(f"Process {eval_id} has completed.") 80 | 81 | if __name__=="__main__": 82 | multiprocessing.set_start_method('spawn') 83 | args = _get_args() 84 | 85 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 86 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 87 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 88 | else: 89 | data_path = args.OCRBench_file 90 | 91 | with open(data_path, "r") as f: 92 | data = json.load(f) 93 | 94 | data_list = split_list(data, args.num_workers) 95 | 96 | output_queue = Manager().Queue() 97 | 98 | pool = Pool(processes=args.num_workers) 99 | for i in range(len(data_list)): 100 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 101 | pool.close() 102 | pool.join() 103 | 104 | results = {} 105 | while not output_queue.empty(): 106 | result = output_queue.get() 107 | results.update(result) 108 | data = [] 109 | for i in range(len(data_list)): 110 | data.extend(results[i]) 111 | 112 | for i in range(len(data)): 113 | data_type = data[i]["type"] 114 | dataset_name = data[i]["dataset_name"] 115 | answers = data[i]["answers"] 116 | if data[i].get('predict',0)==0: 117 | continue 118 | predict = data[i]['predict'] 119 | data[i]['result'] = 0 120 | if dataset_name == "HME100k": 121 | if type(answers)==list: 122 | for j in range(len(answers)): 123 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 124 | predict = predict.strip().replace("\n"," ").replace(" ","") 125 | if answer in predict: 126 | data[i]['result'] = 1 127 | else: 128 | answers = answers.strip().replace("\n"," ").replace(" ","") 129 | predict = predict.strip().replace("\n"," ").replace(" ","") 130 | if answers in predict: 131 | data[i]['result'] = 1 132 | else: 133 | if type(answers)==list: 134 | for j in range(len(answers)): 135 | answer = answers[j].lower().strip().replace("\n"," ") 136 | predict = predict.lower().strip().replace("\n"," ") 137 | if answer in predict: 138 | data[i]['result'] = 1 139 | else: 140 | answers = answers.lower().strip().replace("\n"," ") 141 | predict = predict.lower().strip().replace("\n"," ") 142 | if answers in predict: 143 | data[i]['result'] = 1 144 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 145 | if len(data)==1000: 146 | for i in range(len(data)): 147 | if data[i].get("result",100)==100: 148 | continue 149 | OCRBench_score[data[i]['type']] += data[i]['result'] 150 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 151 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 152 | print("###########################OCRBench##############################") 153 | print(f"Text Recognition(Total 300):{recognition_score}") 154 | print("------------------Details of Recognition Score-------------------") 155 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 156 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 157 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 158 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 159 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 160 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 161 | print("----------------------------------------------------------------") 162 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 163 | print("----------------------------------------------------------------") 164 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 165 | print("----------------------------------------------------------------") 166 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 167 | print("----------------------------------------------------------------") 168 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 169 | print("----------------------Final Score-------------------------------") 170 | print(f"Final Score(Total 1000): {Final_score}") 171 | else: 172 | for i in range(len(data)): 173 | num_all[data[i]['dataset_name']] += 1 174 | if data[i].get("result",100)==100: 175 | continue 176 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 177 | for key in AllDataset_score.keys(): 178 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 179 | -------------------------------------------------------------------------------- /OCRBench/scripts/mPLUG-DocOwl15.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | from argparse import ArgumentParser 5 | from multiprocessing import Manager, Pool, Queue 6 | 7 | import torch 8 | from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX 9 | from mplug_docowl.conversation import conv_templates 10 | from mplug_docowl.mm_utils import ( 11 | KeywordsStoppingCriteria, 12 | get_model_name_from_path, 13 | process_images, 14 | tokenizer_image_token, 15 | ) 16 | from mplug_docowl.model.builder import load_pretrained_model 17 | from mplug_docowl.processor import DocProcessor 18 | from tqdm import tqdm 19 | from transformers import TextStreamer 20 | 21 | 22 | # https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py 23 | def split_list(lst, n): 24 | length = len(lst) 25 | avg = length // n # 每份的大小 26 | result = [] # 存储分割后的子列表 27 | for i in range(n - 1): 28 | result.append(lst[i * avg : (i + 1) * avg]) 29 | result.append(lst[(n - 1) * avg :]) 30 | return result 31 | 32 | 33 | def save_json(json_list, save_path): 34 | with open(save_path, "w", encoding="utf-8") as file: 35 | json.dump(json_list, file, indent=4) 36 | 37 | 38 | def _get_args(): 39 | parser = ArgumentParser() 40 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 41 | parser.add_argument("--output_folder", type=str, default="./results") 42 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 43 | parser.add_argument("--model_path", type=str, default="mPLUG/DocOwl1.5") 44 | parser.add_argument("--save_name", type=str, default="mplug-DocOwl1.5") 45 | parser.add_argument("--conv_mode", type=str, default="mplug_owl2") 46 | parser.add_argument("--num_workers", type=int, default=8) 47 | parser.add_argument("--temperature", type=float, default=0.0) 48 | args = parser.parse_args() 49 | return args 50 | 51 | 52 | OCRBench_score = { 53 | "Regular Text Recognition": 0, 54 | "Irregular Text Recognition": 0, 55 | "Artistic Text Recognition": 0, 56 | "Handwriting Recognition": 0, 57 | "Digit String Recognition": 0, 58 | "Non-Semantic Text Recognition": 0, 59 | "Scene Text-centric VQA": 0, 60 | "Doc-oriented VQA": 0, 61 | "Key Information Extraction": 0, 62 | "Handwritten Mathematical Expression Recognition": 0, 63 | } 64 | AllDataset_score = { 65 | "IIIT5K": 0, 66 | "svt": 0, 67 | "IC13_857": 0, 68 | "IC15_1811": 0, 69 | "svtp": 0, 70 | "ct80": 0, 71 | "cocotext": 0, 72 | "ctw": 0, 73 | "totaltext": 0, 74 | "HOST": 0, 75 | "WOST": 0, 76 | "WordArt": 0, 77 | "IAM": 0, 78 | "ReCTS": 0, 79 | "ORAND": 0, 80 | "NonSemanticText": 0, 81 | "SemanticText": 0, 82 | "STVQA": 0, 83 | "textVQA": 0, 84 | "ocrVQA": 0, 85 | "ESTVQA": 0, 86 | "ESTVQA_cn": 0, 87 | "docVQA": 0, 88 | "infographicVQA": 0, 89 | "ChartQA": 0, 90 | "ChartQA_Human": 0, 91 | "FUNSD": 0, 92 | "SROIE": 0, 93 | "POIE": 0, 94 | "HME100k": 0, 95 | } 96 | num_all = { 97 | "IIIT5K": 0, 98 | "svt": 0, 99 | "IC13_857": 0, 100 | "IC15_1811": 0, 101 | "svtp": 0, 102 | "ct80": 0, 103 | "cocotext": 0, 104 | "ctw": 0, 105 | "totaltext": 0, 106 | "HOST": 0, 107 | "WOST": 0, 108 | "WordArt": 0, 109 | "IAM": 0, 110 | "ReCTS": 0, 111 | "ORAND": 0, 112 | "NonSemanticText": 0, 113 | "SemanticText": 0, 114 | "STVQA": 0, 115 | "textVQA": 0, 116 | "ocrVQA": 0, 117 | "ESTVQA": 0, 118 | "ESTVQA_cn": 0, 119 | "docVQA": 0, 120 | "infographicVQA": 0, 121 | "ChartQA": 0, 122 | "ChartQA_Human": 0, 123 | "FUNSD": 0, 124 | "SROIE": 0, 125 | "POIE": 0, 126 | "HME100k": 0, 127 | } 128 | 129 | 130 | def eval_worker(args, data, eval_id, output_queue): 131 | print(f"Process {eval_id} start.") 132 | model_name = get_model_name_from_path(args.model_path) 133 | tokenizer, model, _, _ = load_pretrained_model( 134 | args.model_path, 135 | None, 136 | model_name, 137 | load_8bit=False, 138 | load_4bit=False, 139 | device=f"cuda:{eval_id}", 140 | ) 141 | 142 | doc_image_processor = DocProcessor( 143 | image_size=448, 144 | anchors="grid_9", 145 | add_global_img=True, 146 | add_textual_crop_indicator=True, 147 | ) 148 | 149 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 150 | 151 | for i in tqdm(range(len(data))): 152 | img_path = os.path.join(args.image_folder, data[i]["image_path"]) 153 | qs = data[i]["question"] 154 | if data[i].get("predict", 0) != 0: 155 | print(f"{img_path} predict exist, continue.") 156 | continue 157 | 158 | image_tensor, patch_positions, text = doc_image_processor( 159 | images=img_path, query="<|image|>" + qs 160 | ) 161 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 162 | patch_positions = patch_positions.to(model.device) 163 | 164 | conv = conv_templates["mplug_owl2"].copy() 165 | conv.append_message(conv.roles[0], text) 166 | conv.append_message(conv.roles[1], None) 167 | prompt = conv.get_prompt() 168 | 169 | input_ids = ( 170 | tokenizer_image_token( 171 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" 172 | ) 173 | .unsqueeze(0) 174 | .to(model.device) 175 | ) 176 | 177 | stop_str = conv.sep2 178 | keywords = [stop_str] 179 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 180 | with torch.inference_mode(): 181 | output_ids = model.generate( 182 | input_ids, 183 | images=image_tensor, 184 | patch_positions=patch_positions, 185 | do_sample=False, 186 | temperature=1.0, 187 | max_new_tokens=512, 188 | streamer=streamer, 189 | use_cache=True, 190 | stopping_criteria=[stopping_criteria], 191 | ) 192 | 193 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() 194 | data[i]["predict"] = outputs 195 | output_queue.put({eval_id: data}) 196 | print(f"Process {eval_id} has completed.") 197 | 198 | 199 | if __name__ == "__main__": 200 | multiprocessing.set_start_method("spawn") 201 | args = _get_args() 202 | 203 | if os.path.exists(os.path.join(args.output_folder, f"{args.save_name}.json")): 204 | data_path = os.path.join(args.output_folder, f"{args.save_name}.json") 205 | print( 206 | f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}." 207 | ) 208 | else: 209 | data_path = args.OCRBench_file 210 | 211 | with open(data_path, "r", encoding="utf-8") as f: 212 | data = json.load(f) 213 | 214 | data_list = split_list(data, args.num_workers) 215 | output_queue = Manager().Queue() 216 | 217 | pool = Pool(processes=args.num_workers) 218 | for i in range(len(data_list)): 219 | # pool.apply(eval_worker, args=(args, data_list[i], i, output_queue)) 220 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 221 | pool.close() 222 | pool.join() 223 | 224 | results = {} 225 | while not output_queue.empty(): 226 | result = output_queue.get() 227 | results.update(result) 228 | data = [] 229 | for i in range(len(data_list)): 230 | data.extend(results[i]) 231 | 232 | for i in range(len(data)): 233 | data_type = data[i]["type"] 234 | dataset_name = data[i]["dataset_name"] 235 | answers = data[i]["answers"] 236 | if data[i].get("predict", 0) == 0: 237 | continue 238 | predict = data[i]["predict"] 239 | data[i]["result"] = 0 240 | if dataset_name == "HME100k": 241 | if type(answers) == list: 242 | for j in range(len(answers)): 243 | answer = answers[j].strip().replace("\n", " ").replace(" ", "") 244 | predict = predict.strip().replace("\n", " ").replace(" ", "") 245 | if answer in predict: 246 | data[i]["result"] = 1 247 | else: 248 | answers = answers.strip().replace("\n", " ").replace(" ", "") 249 | predict = predict.strip().replace("\n", " ").replace(" ", "") 250 | if answers in predict: 251 | data[i]["result"] = 1 252 | else: 253 | if type(answers) == list: 254 | for j in range(len(answers)): 255 | answer = answers[j].lower().strip().replace("\n", " ") 256 | predict = predict.lower().strip().replace("\n", " ") 257 | if answer in predict: 258 | data[i]["result"] = 1 259 | else: 260 | answers = answers.lower().strip().replace("\n", " ") 261 | predict = predict.lower().strip().replace("\n", " ") 262 | if answers in predict: 263 | data[i]["result"] = 1 264 | save_json(data, os.path.join(args.output_folder, f"{args.save_name}.json")) 265 | if len(data) == 1000: 266 | for i in range(len(data)): 267 | if data[i].get("result", 100) == 100: 268 | continue 269 | OCRBench_score[data[i]["type"]] += data[i]["result"] 270 | recognition_score = ( 271 | OCRBench_score["Regular Text Recognition"] 272 | + OCRBench_score["Irregular Text Recognition"] 273 | + OCRBench_score["Artistic Text Recognition"] 274 | + OCRBench_score["Handwriting Recognition"] 275 | + OCRBench_score["Digit String Recognition"] 276 | + OCRBench_score["Non-Semantic Text Recognition"] 277 | ) 278 | Final_score = ( 279 | recognition_score 280 | + OCRBench_score["Scene Text-centric VQA"] 281 | + OCRBench_score["Doc-oriented VQA"] 282 | + OCRBench_score["Key Information Extraction"] 283 | + OCRBench_score["Handwritten Mathematical Expression Recognition"] 284 | ) 285 | print("###########################OCRBench##############################") 286 | print(f"Text Recognition(Total 300):{recognition_score}") 287 | print("------------------Details of Recognition Score-------------------") 288 | print( 289 | f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}" 290 | ) 291 | print( 292 | f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}" 293 | ) 294 | print( 295 | f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}" 296 | ) 297 | print( 298 | f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}" 299 | ) 300 | print( 301 | f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}" 302 | ) 303 | print( 304 | f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}" 305 | ) 306 | print("----------------------------------------------------------------") 307 | print( 308 | f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}" 309 | ) 310 | print("----------------------------------------------------------------") 311 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 312 | print("----------------------------------------------------------------") 313 | print( 314 | f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}" 315 | ) 316 | print("----------------------------------------------------------------") 317 | print( 318 | f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}" 319 | ) 320 | print("----------------------Final Score-------------------------------") 321 | print(f"Final Score(Total 1000): {Final_score}") 322 | else: 323 | for i in range(len(data)): 324 | num_all[data[i]["dataset_name"]] += 1 325 | if data[i].get("result", 100) == 100: 326 | continue 327 | AllDataset_score[data[i]["dataset_name"]] += data[i]["result"] 328 | for key in AllDataset_score.keys(): 329 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 330 | -------------------------------------------------------------------------------- /OCRBench/scripts/mPLUG-owl.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | import sys 13 | sys.path.append("./scripts/mPLUG-Owl/mPLUG-Owl/") 14 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration 15 | from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer 16 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor 17 | 18 | # https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl 19 | def split_list(lst, n): 20 | length = len(lst) 21 | avg = length // n # 每份的大小 22 | result = [] # 存储分割后的子列表 23 | for i in range(n - 1): 24 | result.append(lst[i*avg:(i+1)*avg]) 25 | result.append(lst[(n-1)*avg:]) 26 | return result 27 | 28 | def save_json(json_list,save_path): 29 | with open(save_path, 'w') as file: 30 | json.dump(json_list, file,indent=4) 31 | 32 | def _get_args(): 33 | parser = ArgumentParser() 34 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 35 | parser.add_argument("--output_folder", type=str, default="./results") 36 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 37 | parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl") 38 | parser.add_argument("--save_name", type=str, default="mplug-owl") 39 | parser.add_argument("--num_workers", type=int, default=8) 40 | args = parser.parse_args() 41 | return args 42 | 43 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 44 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 45 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 46 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 47 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 48 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 49 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 50 | 51 | def eval_worker(args, data, eval_id, output_queue): 52 | print(f"Process {eval_id} start.") 53 | pretrained_ckpt = args.model_path 54 | model = MplugOwlForConditionalGeneration.from_pretrained( 55 | pretrained_ckpt, 56 | torch_dtype=torch.bfloat16, 57 | ) 58 | model.to(f"cuda:{eval_id}") 59 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) 60 | tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt) 61 | processor = MplugOwlProcessor(image_processor, tokenizer) 62 | for i in tqdm(range(len(data))): 63 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 64 | qs = data[i]['question'] 65 | prompts = [ 66 | f'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 67 | Human: 68 | Human: {qs} 69 | AI: '''] 70 | if data[i].get("predict", 0)!=0: 71 | print(f"{img_path} predict exist, continue.") 72 | continue 73 | generate_kwargs = { 74 | 'do_sample': False, 75 | 'top_k': 1, 76 | 'max_length': 100 77 | } 78 | images = [Image.open(img_path)] 79 | inputs = processor(text=prompts, images=images, return_tensors='pt') 80 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} 81 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 82 | with torch.no_grad(): 83 | res = model.generate(**inputs, **generate_kwargs) 84 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) 85 | data[i]['predict'] = sentence 86 | output_queue.put({eval_id: data}) 87 | print(f"Process {eval_id} has completed.") 88 | 89 | if __name__=="__main__": 90 | multiprocessing.set_start_method('spawn') 91 | args = _get_args() 92 | 93 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 94 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 95 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 96 | else: 97 | data_path = args.OCRBench_file 98 | 99 | with open(data_path, "r") as f: 100 | data = json.load(f) 101 | 102 | data_list = split_list(data, args.num_workers) 103 | output_queue = Manager().Queue() 104 | 105 | pool = Pool(processes=args.num_workers) 106 | for i in range(len(data_list)): 107 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 108 | pool.close() 109 | pool.join() 110 | 111 | results = {} 112 | while not output_queue.empty(): 113 | result = output_queue.get() 114 | results.update(result) 115 | data = [] 116 | for i in range(len(data_list)): 117 | data.extend(results[i]) 118 | 119 | for i in range(len(data)): 120 | data_type = data[i]["type"] 121 | dataset_name = data[i]["dataset_name"] 122 | answers = data[i]["answers"] 123 | if data[i].get('predict',0)==0: 124 | continue 125 | predict = data[i]['predict'] 126 | data[i]['result'] = 0 127 | if dataset_name == "HME100k": 128 | if type(answers)==list: 129 | for j in range(len(answers)): 130 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 131 | predict = predict.strip().replace("\n"," ").replace(" ","") 132 | if answer in predict: 133 | data[i]['result'] = 1 134 | else: 135 | answers = answers.strip().replace("\n"," ").replace(" ","") 136 | predict = predict.strip().replace("\n"," ").replace(" ","") 137 | if answers in predict: 138 | data[i]['result'] = 1 139 | else: 140 | if type(answers)==list: 141 | for j in range(len(answers)): 142 | answer = answers[j].lower().strip().replace("\n"," ") 143 | predict = predict.lower().strip().replace("\n"," ") 144 | if answer in predict: 145 | data[i]['result'] = 1 146 | else: 147 | answers = answers.lower().strip().replace("\n"," ") 148 | predict = predict.lower().strip().replace("\n"," ") 149 | if answers in predict: 150 | data[i]['result'] = 1 151 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 152 | if len(data)==1000: 153 | for i in range(len(data)): 154 | if data[i].get("result",100)==100: 155 | continue 156 | OCRBench_score[data[i]['type']] += data[i]['result'] 157 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 158 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 159 | print("###########################OCRBench##############################") 160 | print(f"Text Recognition(Total 300):{recognition_score}") 161 | print("------------------Details of Recognition Score-------------------") 162 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 163 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 164 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 165 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 166 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 167 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 168 | print("----------------------------------------------------------------") 169 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 170 | print("----------------------------------------------------------------") 171 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 172 | print("----------------------------------------------------------------") 173 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 174 | print("----------------------------------------------------------------") 175 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 176 | print("----------------------Final Score-------------------------------") 177 | print(f"Final Score(Total 1000): {Final_score}") 178 | else: 179 | for i in range(len(data)): 180 | num_all[data[i]['dataset_name']] += 1 181 | if data[i].get("result",100)==100: 182 | continue 183 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 184 | for key in AllDataset_score.keys(): 185 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 186 | -------------------------------------------------------------------------------- /OCRBench/scripts/mPLUG-owl2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | from transformers import TextStreamer 13 | from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 14 | from mplug_owl2.conversation import conv_templates, SeparatorStyle 15 | from mplug_owl2.model.builder import load_pretrained_model 16 | from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 17 | 18 | # https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2 19 | def split_list(lst, n): 20 | length = len(lst) 21 | avg = length // n # 每份的大小 22 | result = [] # 存储分割后的子列表 23 | for i in range(n - 1): 24 | result.append(lst[i*avg:(i+1)*avg]) 25 | result.append(lst[(n-1)*avg:]) 26 | return result 27 | 28 | def save_json(json_list,save_path): 29 | with open(save_path, 'w') as file: 30 | json.dump(json_list, file,indent=4) 31 | 32 | def _get_args(): 33 | parser = ArgumentParser() 34 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 35 | parser.add_argument("--output_folder", type=str, default="./results") 36 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 37 | parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl2") 38 | parser.add_argument("--save_name", type=str, default="mplug-owl2") 39 | parser.add_argument("--conv_mode", type=str, default="mplug_owl2") 40 | parser.add_argument("--num_workers", type=int, default=8) 41 | parser.add_argument("--temperature", type=float, default=0.0) 42 | args = parser.parse_args() 43 | return args 44 | 45 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 46 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 47 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 48 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 49 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 50 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 51 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 52 | 53 | def eval_worker(args, data, eval_id, output_queue): 54 | print(f"Process {eval_id} start.") 55 | model_name = get_model_name_from_path(args.model_path) 56 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name, load_8bit=False, load_4bit=False, device=f"cuda:{eval_id}") 57 | for i in tqdm(range(len(data))): 58 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 59 | qs = data[i]['question'] 60 | if data[i].get("predict", 0)!=0: 61 | print(f"{img_path} predict exist, continue.") 62 | continue 63 | conv = conv_templates[args.conv_mode].copy() 64 | roles = conv.roles 65 | image = Image.open(img_path).convert('RGB') 66 | max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance. 67 | image = image.resize((max_edge, max_edge)) 68 | image_tensor = process_images([image], image_processor) 69 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 70 | 71 | inp = DEFAULT_IMAGE_TOKEN + qs 72 | conv.append_message(conv.roles[0], inp) 73 | conv.append_message(conv.roles[1], None) 74 | prompt = conv.get_prompt() 75 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 76 | stop_str = conv.sep2 77 | keywords = [stop_str] 78 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 79 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 80 | with torch.inference_mode(): 81 | output_ids = model.generate( 82 | input_ids, 83 | images=image_tensor, 84 | do_sample=False, 85 | temperature=args.temperature, 86 | max_new_tokens=100, 87 | streamer=streamer, 88 | use_cache=True, 89 | stopping_criteria=[stopping_criteria]) 90 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 91 | data[i]['predict'] = outputs 92 | output_queue.put({eval_id: data}) 93 | print(f"Process {eval_id} has completed.") 94 | 95 | if __name__=="__main__": 96 | multiprocessing.set_start_method('spawn') 97 | args = _get_args() 98 | 99 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 100 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 101 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 102 | else: 103 | data_path = args.OCRBench_file 104 | 105 | with open(data_path, "r") as f: 106 | data = json.load(f) 107 | 108 | data_list = split_list(data, args.num_workers) 109 | output_queue = Manager().Queue() 110 | 111 | pool = Pool(processes=args.num_workers) 112 | for i in range(len(data_list)): 113 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 114 | pool.close() 115 | pool.join() 116 | 117 | results = {} 118 | while not output_queue.empty(): 119 | result = output_queue.get() 120 | results.update(result) 121 | data = [] 122 | for i in range(len(data_list)): 123 | data.extend(results[i]) 124 | 125 | for i in range(len(data)): 126 | data_type = data[i]["type"] 127 | dataset_name = data[i]["dataset_name"] 128 | answers = data[i]["answers"] 129 | if data[i].get('predict',0)==0: 130 | continue 131 | predict = data[i]['predict'] 132 | data[i]['result'] = 0 133 | if dataset_name == "HME100k": 134 | if type(answers)==list: 135 | for j in range(len(answers)): 136 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 137 | predict = predict.strip().replace("\n"," ").replace(" ","") 138 | if answer in predict: 139 | data[i]['result'] = 1 140 | else: 141 | answers = answers.strip().replace("\n"," ").replace(" ","") 142 | predict = predict.strip().replace("\n"," ").replace(" ","") 143 | if answers in predict: 144 | data[i]['result'] = 1 145 | else: 146 | if type(answers)==list: 147 | for j in range(len(answers)): 148 | answer = answers[j].lower().strip().replace("\n"," ") 149 | predict = predict.lower().strip().replace("\n"," ") 150 | if answer in predict: 151 | data[i]['result'] = 1 152 | else: 153 | answers = answers.lower().strip().replace("\n"," ") 154 | predict = predict.lower().strip().replace("\n"," ") 155 | if answers in predict: 156 | data[i]['result'] = 1 157 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 158 | if len(data)==1000: 159 | for i in range(len(data)): 160 | if data[i].get("result",100)==100: 161 | continue 162 | OCRBench_score[data[i]['type']] += data[i]['result'] 163 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 164 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 165 | print("###########################OCRBench##############################") 166 | print(f"Text Recognition(Total 300):{recognition_score}") 167 | print("------------------Details of Recognition Score-------------------") 168 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 169 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 170 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 171 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 172 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 173 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 174 | print("----------------------------------------------------------------") 175 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 176 | print("----------------------------------------------------------------") 177 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 178 | print("----------------------------------------------------------------") 179 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 180 | print("----------------------------------------------------------------") 181 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 182 | print("----------------------Final Score-------------------------------") 183 | print(f"Final Score(Total 1000): {Final_score}") 184 | else: 185 | for i in range(len(data)): 186 | num_all[data[i]['dataset_name']] += 1 187 | if data[i].get("result",100)==100: 188 | continue 189 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 190 | for key in AllDataset_score.keys(): 191 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 192 | -------------------------------------------------------------------------------- /OCRBench/scripts/minigpt4v2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | import sys 13 | sys.path.append("./scripts/MiniGPT-4/") 14 | from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser 15 | from minigpt4.conversation.conversation import CONV_VISION_minigptv2 16 | from minigpt4.common.config import Config 17 | import random 18 | # https://github.com/Vision-CAIR/MiniGPT-4/blob/main/eval_scripts/eval_vqa.py 19 | 20 | 21 | def split_list(lst, n): 22 | length = len(lst) 23 | avg = length // n # 每份的大小 24 | result = [] # 存储分割后的子列表 25 | for i in range(n - 1): 26 | result.append(lst[i*avg:(i+1)*avg]) 27 | result.append(lst[(n-1)*avg:]) 28 | return result 29 | 30 | def save_json(json_list,save_path): 31 | with open(save_path, 'w') as file: 32 | json.dump(json_list, file,indent=4) 33 | 34 | def _get_args(): 35 | parser = ArgumentParser() 36 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 37 | parser.add_argument("--output_folder", type=str, default="./results") 38 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 39 | parser.add_argument("--cfg-path", default='./scripts/MiniGPT-4/eval_configs/minigptv2_eval.yaml') 40 | parser.add_argument("--save_name", type=str, default="minigptv2") 41 | parser.add_argument("--num_workers", type=int, default=1) 42 | parser.add_argument("--temperature", type=float, default=0.0) 43 | parser.add_argument( 44 | "--options", 45 | nargs="+", 46 | help="override some settings in the used config, the key-value pair " 47 | "in xxx=yyy format will be merged into config file (deprecate), " 48 | "change to --cfg-options instead.", 49 | ) 50 | args = parser.parse_args() 51 | return args 52 | 53 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 54 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, 55 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 56 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 57 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 58 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 59 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 60 | 61 | def eval_worker(args, data, eval_id, output_queue): 62 | 63 | print(f"Process {eval_id} start.") 64 | device = f'cuda:{eval_id}' 65 | cfg = Config(args) 66 | model, vis_processor = init_model(args, device) 67 | conv_temp = CONV_VISION_minigptv2.copy() 68 | conv_temp.system = "" 69 | model.eval() 70 | instruction_pool = [ 71 | "[vqa] {}" 72 | ] 73 | for i in tqdm(range(len(data))): 74 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 75 | qs = data[i]['question'] 76 | if data[i].get("predict", 0)!=0: 77 | print(f"{img_path} predict exist, continue.") 78 | continue 79 | image = Image.open(img_path).convert("RGB") 80 | image = vis_processor(image) 81 | image = image.unsqueeze(0).to(device) 82 | # question = self.text_processor(qs) 83 | instruction = random.choice(instruction_pool).format(qs) 84 | instruction = " {} ".format(instruction) 85 | texts = prepare_texts(instruction, conv_temp) # warp the texts with conversation template 86 | answers = model.generate(image, texts, max_new_tokens=100, do_sample=False) 87 | data[i]['predict'] = answers[0] 88 | output_queue.put({eval_id: data}) 89 | print(f"Process {eval_id} has completed.") 90 | 91 | if __name__=="__main__": 92 | multiprocessing.set_start_method('spawn') 93 | args = _get_args() 94 | 95 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 96 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 97 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 98 | else: 99 | data_path = args.OCRBench_file 100 | 101 | with open(data_path, "r") as f: 102 | data = json.load(f) 103 | 104 | data_list = split_list(data, args.num_workers) 105 | output_queue = Manager().Queue() 106 | 107 | pool = Pool(processes=args.num_workers) 108 | for i in range(len(data_list)): 109 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 110 | pool.close() 111 | pool.join() 112 | 113 | results = {} 114 | while not output_queue.empty(): 115 | result = output_queue.get() 116 | results.update(result) 117 | data = [] 118 | for i in range(len(data_list)): 119 | data.extend(results[i]) 120 | 121 | for i in range(len(data)): 122 | data_type = data[i]["type"] 123 | dataset_name = data[i]["dataset_name"] 124 | answers = data[i]["answers"] 125 | if data[i].get('predict',0)==0: 126 | continue 127 | predict = data[i]['predict'] 128 | data[i]['result'] = 0 129 | if dataset_name == "HME100k": 130 | if type(answers)==list: 131 | for j in range(len(answers)): 132 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 133 | predict = predict.strip().replace("\n"," ").replace(" ","") 134 | if answer in predict: 135 | data[i]['result'] = 1 136 | else: 137 | answers = answers.strip().replace("\n"," ").replace(" ","") 138 | predict = predict.strip().replace("\n"," ").replace(" ","") 139 | if answers in predict: 140 | data[i]['result'] = 1 141 | else: 142 | if type(answers)==list: 143 | for j in range(len(answers)): 144 | answer = answers[j].lower().strip().replace("\n"," ") 145 | predict = predict.lower().strip().replace("\n"," ") 146 | if answer in predict: 147 | data[i]['result'] = 1 148 | else: 149 | answers = answers.lower().strip().replace("\n"," ") 150 | predict = predict.lower().strip().replace("\n"," ") 151 | if answers in predict: 152 | data[i]['result'] = 1 153 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 154 | if len(data)==1000: 155 | for i in range(len(data)): 156 | if data[i].get("result",100)==100: 157 | continue 158 | OCRBench_score[data[i]['type']] += data[i]['result'] 159 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 160 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 161 | print("###########################OCRBench##############################") 162 | print(f"Text Recognition(Total 300):{recognition_score}") 163 | print("------------------Details of Recognition Score-------------------") 164 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 165 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 166 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 167 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 168 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 169 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 170 | print("----------------------------------------------------------------") 171 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 172 | print("----------------------------------------------------------------") 173 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 174 | print("----------------------------------------------------------------") 175 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 176 | print("----------------------------------------------------------------") 177 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 178 | print("----------------------Final Score-------------------------------") 179 | print(f"Final Score(Total 1000): {Final_score}") 180 | else: 181 | for i in range(len(data)): 182 | num_all[data[i]['dataset_name']] += 1 183 | if data[i].get("result",100)==100: 184 | continue 185 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 186 | for key in AllDataset_score.keys(): 187 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 188 | -------------------------------------------------------------------------------- /OCRBench/scripts/monkey.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | 12 | from transformers import AutoModelForCausalLM, AutoTokenizer 13 | 14 | # https://github.com/Yuliang-Liu/Monkey 15 | 16 | def split_list(lst, n): 17 | length = len(lst) 18 | avg = length // n # 每份的大小 19 | result = [] # 存储分割后的子列表 20 | for i in range(n - 1): 21 | result.append(lst[i*avg:(i+1)*avg]) 22 | result.append(lst[(n-1)*avg:]) 23 | return result 24 | 25 | def save_json(json_list,save_path): 26 | with open(save_path, 'w') as file: 27 | json.dump(json_list, file,indent=4) 28 | 29 | def _get_args(): 30 | parser = ArgumentParser() 31 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 32 | parser.add_argument("--output_folder", type=str, default="./results") 33 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 34 | parser.add_argument("--model_path", type=str, default="echo840/Monkey") 35 | parser.add_argument("--save_name", type=str, default="monkey") 36 | parser.add_argument("--num_workers", type=int, default=8) 37 | args = parser.parse_args() 38 | return args 39 | 40 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 41 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0, 42 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 43 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 44 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 45 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 46 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 47 | 48 | def eval_worker(args, data, eval_id, output_queue): 49 | print(f"Process {eval_id} start.") 50 | checkpoint = args.model_path 51 | model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval() 52 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 53 | tokenizer.padding_side = 'left' 54 | tokenizer.pad_token_id = tokenizer.eod_id 55 | 56 | for i in tqdm(range(len(data))): 57 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 58 | qs = data[i]['question'] 59 | query = f'{img_path} {qs} Answer: ' 60 | 61 | input_ids = tokenizer(query, return_tensors='pt', padding='longest') 62 | attention_mask = input_ids.attention_mask 63 | input_ids = input_ids.input_ids 64 | 65 | pred = model.generate( 66 | input_ids=input_ids.to(f'cuda:{eval_id}'), 67 | attention_mask=attention_mask.to(f'cuda:{eval_id}'), 68 | do_sample=False, 69 | num_beams=1, 70 | max_new_tokens=100, 71 | min_new_tokens=1, 72 | length_penalty=1, 73 | num_return_sequences=1, 74 | output_hidden_states=True, 75 | use_cache=True, 76 | pad_token_id=tokenizer.eod_id, 77 | eos_token_id=tokenizer.eod_id, 78 | ) 79 | response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() 80 | data[i]['predict'] = response 81 | output_queue.put({eval_id: data}) 82 | print(f"Process {eval_id} has completed.") 83 | 84 | if __name__=="__main__": 85 | multiprocessing.set_start_method('spawn') 86 | args = _get_args() 87 | 88 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 89 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 90 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 91 | else: 92 | data_path = args.OCRBench_file 93 | 94 | with open(data_path, "r") as f: 95 | data = json.load(f) 96 | 97 | data_list = split_list(data, args.num_workers) 98 | 99 | output_queue = Manager().Queue() 100 | 101 | pool = Pool(processes=args.num_workers) 102 | for i in range(len(data_list)): 103 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 104 | pool.close() 105 | pool.join() 106 | 107 | 108 | results = {} 109 | while not output_queue.empty(): 110 | result = output_queue.get() 111 | results.update(result) 112 | data = [] 113 | for i in range(len(data_list)): 114 | data.extend(results[i]) 115 | 116 | for i in range(len(data)): 117 | data_type = data[i]["type"] 118 | dataset_name = data[i]["dataset_name"] 119 | answers = data[i]["answers"] 120 | if data[i].get('predict',0)==0: 121 | continue 122 | predict = data[i]['predict'] 123 | data[i]['result'] = 0 124 | if dataset_name == "HME100k": 125 | if type(answers)==list: 126 | for j in range(len(answers)): 127 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 128 | predict = predict.strip().replace("\n"," ").replace(" ","") 129 | if answer in predict: 130 | data[i]['result'] = 1 131 | else: 132 | answers = answers.strip().replace("\n"," ").replace(" ","") 133 | predict = predict.strip().replace("\n"," ").replace(" ","") 134 | if answers in predict: 135 | data[i]['result'] = 1 136 | else: 137 | if type(answers)==list: 138 | for j in range(len(answers)): 139 | answer = answers[j].lower().strip().replace("\n"," ") 140 | predict = predict.lower().strip().replace("\n"," ") 141 | if answer in predict: 142 | data[i]['result'] = 1 143 | else: 144 | answers = answers.lower().strip().replace("\n"," ") 145 | predict = predict.lower().strip().replace("\n"," ") 146 | if answers in predict: 147 | data[i]['result'] = 1 148 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 149 | if len(data)==1000: 150 | for i in range(len(data)): 151 | if data[i].get("result",100)==100: 152 | continue 153 | OCRBench_score[data[i]['type']] += data[i]['result'] 154 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 155 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 156 | print("###########################OCRBench##############################") 157 | print(f"Text Recognition(Total 300):{recognition_score}") 158 | print("------------------Details of Recognition Score-------------------") 159 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 160 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 161 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 162 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 163 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 164 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 165 | print("----------------------------------------------------------------") 166 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 167 | print("----------------------------------------------------------------") 168 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 169 | print("----------------------------------------------------------------") 170 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 171 | print("----------------------------------------------------------------") 172 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 173 | print("----------------------Final Score-------------------------------") 174 | print(f"Final Score(Total 1000): {Final_score}") 175 | else: 176 | for i in range(len(data)): 177 | num_all[data[i]['dataset_name']] += 1 178 | if data[i].get("result",100)==100: 179 | continue 180 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 181 | for key in AllDataset_score.keys(): 182 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 183 | -------------------------------------------------------------------------------- /OCRBench/scripts/qwenvl.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import math 9 | import multiprocessing 10 | from multiprocessing import Pool, Queue, Manager 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | 13 | # https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/evaluate_vqa.py 14 | def split_list(lst, n): 15 | length = len(lst) 16 | avg = length // n # 每份的大小 17 | result = [] # 存储分割后的子列表 18 | for i in range(n - 1): 19 | result.append(lst[i*avg:(i+1)*avg]) 20 | result.append(lst[(n-1)*avg:]) 21 | return result 22 | 23 | def save_json(json_list,save_path): 24 | with open(save_path, 'w') as file: 25 | json.dump(json_list, file,indent=4) 26 | 27 | def _get_args(): 28 | parser = ArgumentParser() 29 | parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") 30 | parser.add_argument("--output_folder", type=str, default="./results") 31 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 32 | parser.add_argument("--model_path", type=str, default="Qwen/Qwen-VL") 33 | parser.add_argument("--save_name", type=str, default="qwenvl") 34 | parser.add_argument("--num_workers", type=int, default=1) 35 | args = parser.parse_args() 36 | return args 37 | 38 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 39 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0, 40 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 41 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 42 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 43 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 44 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 45 | 46 | def eval_worker(args, data, eval_id, output_queue): 47 | print(f"Process {eval_id} start.") 48 | checkpoint = args.model_path 49 | model = AutoModelForCausalLM.from_pretrained( 50 | checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval() 51 | 52 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, 53 | trust_remote_code=True) 54 | tokenizer.padding_side = 'left' 55 | tokenizer.pad_token_id = tokenizer.eod_id 56 | 57 | for i in tqdm(range(len(data))): 58 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 59 | qs = data[i]['question'] 60 | # query = f'{img_path} {qs} Answer: ' 61 | query = f'{img_path}{qs} Answer:' 62 | input_ids = tokenizer(query, return_tensors='pt', padding='longest') 63 | attention_mask = input_ids.attention_mask 64 | input_ids = input_ids.input_ids 65 | 66 | pred = model.generate( 67 | input_ids=input_ids.to(f'cuda:{eval_id}'), 68 | attention_mask=attention_mask.to(f'cuda:{eval_id}'), 69 | do_sample=False, 70 | num_beams=1, 71 | max_new_tokens=100, 72 | min_new_tokens=1, 73 | length_penalty=1, 74 | num_return_sequences=1, 75 | output_hidden_states=True, 76 | use_cache=True, 77 | pad_token_id=tokenizer.eod_id, 78 | eos_token_id=tokenizer.eod_id, 79 | ) 80 | response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() 81 | data[i]['predict'] = response 82 | output_queue.put({eval_id: data}) 83 | print(f"Process {eval_id} has completed.") 84 | 85 | if __name__=="__main__": 86 | multiprocessing.set_start_method('spawn') 87 | args = _get_args() 88 | if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): 89 | data_path = os.path.join(args.output_folder,f"{args.save_name}.json") 90 | print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") 91 | else: 92 | data_path = args.OCRBench_file 93 | 94 | with open(data_path, "r") as f: 95 | data = json.load(f) 96 | 97 | data_list = split_list(data, args.num_workers) 98 | 99 | output_queue = Manager().Queue() 100 | 101 | pool = Pool(processes=args.num_workers) 102 | for i in range(len(data_list)): 103 | pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) 104 | pool.close() 105 | pool.join() 106 | 107 | results = {} 108 | while not output_queue.empty(): 109 | result = output_queue.get() 110 | results.update(result) 111 | data = [] 112 | for i in range(len(data_list)): 113 | data.extend(results[i]) 114 | 115 | for i in range(len(data)): 116 | data_type = data[i]["type"] 117 | dataset_name = data[i]["dataset_name"] 118 | answers = data[i]["answers"] 119 | if data[i].get('predict',0)==0: 120 | continue 121 | predict = data[i]['predict'] 122 | data[i]['result'] = 0 123 | if dataset_name == "HME100k": 124 | if type(answers)==list: 125 | for j in range(len(answers)): 126 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 127 | predict = predict.strip().replace("\n"," ").replace(" ","") 128 | if answer in predict: 129 | data[i]['result'] = 1 130 | else: 131 | answers = answers.strip().replace("\n"," ").replace(" ","") 132 | predict = predict.strip().replace("\n"," ").replace(" ","") 133 | if answers in predict: 134 | data[i]['result'] = 1 135 | else: 136 | if type(answers)==list: 137 | for j in range(len(answers)): 138 | answer = answers[j].lower().strip().replace("\n"," ") 139 | predict = predict.lower().strip().replace("\n"," ") 140 | if answer in predict: 141 | data[i]['result'] = 1 142 | else: 143 | answers = answers.lower().strip().replace("\n"," ") 144 | predict = predict.lower().strip().replace("\n"," ") 145 | if answers in predict: 146 | data[i]['result'] = 1 147 | save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) 148 | if len(data)==1000: 149 | for i in range(len(data)): 150 | if data[i].get("result",100)==100: 151 | continue 152 | OCRBench_score[data[i]['type']] += data[i]['result'] 153 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 154 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 155 | print("###########################OCRBench##############################") 156 | print(f"Text Recognition(Total 300):{recognition_score}") 157 | print("------------------Details of Recognition Score-------------------") 158 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 159 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 160 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 161 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 162 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 163 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 164 | print("----------------------------------------------------------------") 165 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 166 | print("----------------------------------------------------------------") 167 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 168 | print("----------------------------------------------------------------") 169 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 170 | print("----------------------------------------------------------------") 171 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 172 | print("----------------------Final Score-------------------------------") 173 | print(f"Final Score(Total 1000): {Final_score}") 174 | else: 175 | for i in range(len(data)): 176 | num_all[data[i]['dataset_name']] += 1 177 | if data[i].get("result",100)==100: 178 | continue 179 | AllDataset_score[data[i]['dataset_name']] += data[i]['result'] 180 | for key in AllDataset_score.keys(): 181 | print(f"{key}: {AllDataset_score[key]/float(num_all[key])}") 182 | -------------------------------------------------------------------------------- /OCRBench/scripts/qwenvl_api.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from argparse import ArgumentParser 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | import sys 7 | from http import HTTPStatus 8 | from dashscope import MultiModalConversation 9 | import time 10 | # You should follow the instructions here befor strat: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start 11 | OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, 12 | "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0, 13 | "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} 14 | AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 15 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 16 | num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, 17 | "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} 18 | def save_json(json_list,save_path): 19 | with open(save_path, 'w') as file: 20 | json.dump(json_list, file,indent=4) 21 | 22 | def call_with_local_file(img_path, question, model_name): 23 | """Sample of use local file. 24 | linux&mac file schema: file:///home/images/test.png 25 | windows file schema: file://D:/images/abc.png 26 | """ 27 | local_file_path1 = f'file://{img_path}' 28 | messages = [{ 29 | 'role': 'system', 30 | 'content': [{ 31 | 'text': 'You are a helpful assistant.' 32 | }] 33 | }, { 34 | 'role': 35 | 'user', 36 | 'content': [ 37 | { 38 | 'image': local_file_path1 39 | }, 40 | { 41 | 'text': question 42 | }, 43 | ] 44 | }] 45 | response = MultiModalConversation.call(model=model_name, messages=messages) 46 | # time.sleep(2) #For qwenvl-max you may need to add this line to avoid the limits. 47 | print(response) 48 | return response['output']['choices'][0]["message"]['content'][0]['text'] 49 | 50 | 51 | def _get_args(): 52 | parser = ArgumentParser() 53 | parser.add_argument("--image_folder", type=str, default="./data") 54 | parser.add_argument("--output_path", type=str, default="./results") 55 | parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") 56 | parser.add_argument("--model", type=str, default="qwen-vl-max") 57 | args = parser.parse_args() 58 | return args 59 | 60 | 61 | if __name__ == "__main__": 62 | args = _get_args() 63 | if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")): 64 | data_path = os.path.join(args.output_path,f"{args.model}.json") 65 | else: 66 | data_path = args.OCRBench_file 67 | with open(data_path, "r") as f: 68 | data = json.load(f) 69 | for i in tqdm(range(len(data))): 70 | img_path = os.path.join(args.image_folder, data[i]['image_path']) 71 | question = data[i]['question'] 72 | if data[i].get("predict", 0)!=0: 73 | print(f"{img_path} predict exist, continue.") 74 | continue 75 | try: 76 | response = call_with_local_file(img_path, question, args.model) 77 | data[i]['predict'] = response 78 | except: 79 | print("QwenVL api failed") 80 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 81 | for i in range(len(data)): 82 | data_type = data[i]["type"] 83 | dataset_name = data[i]["dataset_name"] 84 | answers = data[i]["answers"] 85 | if data[i].get('predict',0)==0: 86 | continue 87 | predict = data[i]['predict'] 88 | data[i]['result'] = 0 89 | if dataset_name == "HME100k": 90 | if type(answers)==list: 91 | for j in range(len(answers)): 92 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 93 | predict = predict.strip().replace("\n"," ").replace(" ","") 94 | if answer in predict: 95 | data[i]['result'] = 1 96 | else: 97 | answers = answers.strip().replace("\n"," ").replace(" ","") 98 | predict = predict.strip().replace("\n"," ").replace(" ","") 99 | if answers in predict: 100 | data[i]['result'] = 1 101 | else: 102 | if type(answers)==list: 103 | for j in range(len(answers)): 104 | answer = answers[j].lower().strip().replace("\n"," ") 105 | predict = predict.lower().strip().replace("\n"," ") 106 | if answer in predict: 107 | data[i]['result'] = 1 108 | else: 109 | answers = answers.lower().strip().replace("\n"," ") 110 | predict = predict.lower().strip().replace("\n"," ") 111 | if answers in predict: 112 | data[i]['result'] = 1 113 | save_json(data, os.path.join(args.output_path,f"{args.model}.json")) 114 | for i in range(len(data)): 115 | if data[i].get("result",100)==100: 116 | continue 117 | OCRBench_score[data[i]['type']] += data[i]['result'] 118 | recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] 119 | Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] 120 | print("###########################OCRBench##############################") 121 | print(f"Text Recognition(Total 300):{recognition_score}") 122 | print("------------------Details of Recognition Score-------------------") 123 | print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") 124 | print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") 125 | print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") 126 | print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") 127 | print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") 128 | print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") 129 | print("----------------------------------------------------------------") 130 | print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") 131 | print("----------------------------------------------------------------") 132 | print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") 133 | print("----------------------------------------------------------------") 134 | print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") 135 | print("----------------------------------------------------------------") 136 | print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") 137 | print("----------------------Final Score-------------------------------") 138 | print(f"Final Score(Total 1000): {Final_score}") 139 | -------------------------------------------------------------------------------- /OCRBench_v2/README.md: -------------------------------------------------------------------------------- 1 | # OCRBench v2: An Improved Benchmark for Evaluating Large Multimodal Models on Visual Text Localization and Reasoning 2 | 3 | > Scoring the Optical Character Recognition (OCR) capabilities of Large Multimodal Models (LMMs) has witnessed growing interest recently. Existing benchmarks have highlighted the impressive performance of LMMs in text recognition; however, their abilities in certain challenging tasks, such as text localization, handwritten content extraction, and logical reasoning, remain underexplored. To bridge this gap, we introduce OCRBench v2, a large-scale bilingual text-centric benchmark with currently the most comprehensive set of tasks (4X more tasks than the previous multi-scene benchmark OCRBench), the widest coverage of scenarios (31 diverse scenarios including street scene, receipt, formula, diagram, and so on), and thorough evaluation metrics, with a total of 10,000 human-verified question-answering pairs and a high proportion of difficult samples. After carefully benchmarking state-of-the-art LMMs on OCRBench v2, we find that 36 out of 38 LMMs score below 50 (100 in total) and suffer from five-type limitations, including less frequently encountered text recognition, fine-grained perception, layout perception, complex element parsing, and logical reasoning. 4 | 5 | **[Project Page](https://github.com/Yuliang-Liu/MultimodalOCR)** | **[Paper](https://arxiv.org/abs/2501.00321)** | **[OCRBench v2 Leaderboard](https://huggingface.co/spaces/ling99/OCRBench-v2-leaderboard)** 6 | 7 |

8 | 9 |

10 | 11 | # Data 12 | You can download OCRBench v2 from [Google Drive](https://drive.google.com/file/d/1Hk1TMu--7nr5vJ7iaNwMQZ_Iw9W_KI3C/view?usp=sharing) 13 | After downloading and extracting the dataset, the directory structure is as follows: 14 | ``` 15 | OCRBench_v2/ 16 | ├── EN_part/ 17 | ├── CN_part/ 18 | ├── OCRBench_v2.json 19 | ``` 20 | # Evaluation 21 | 22 | ## Environment 23 | All Python dependencies required for the evaluation process are specified in the **requirements.txt**. 24 | To set up the environment, simply run the following commands in the project directory: 25 | ```python 26 | conda create -n ocrbench_v2 python==3.10 -y 27 | conda activate ocrbench_v2 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Inference 32 | To evaluate the model's performance on OCRBench v2, please save the model's inference results in the JSON file within the `predict` field. 33 |
34 | Example structure of the JSON file: 35 | 36 | ```json 37 | { 38 | [ 39 | "dataset_name": "xx", 40 | "type": "xx", 41 | "id": 0, 42 | "image_path": "xx", 43 | "question": "xx", 44 | "answers": [ 45 | "xx" 46 | ], 47 | "predict": "xx" 48 | ] 49 | ... 50 | } 51 | ``` 52 | 53 | ## Evaluation Scripts 54 | After obtaining the inference results from the model, you can use the following scripts to calculate the final score for OCRBench v2. For example, `./pred_folder/internvl2_5_26b.json` contains sample inference results generated by InternVL2.5-26B using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). To compute the score for each sample, you can use the script `./eval_scripts/eval.py`. The results will be saved in the `./res_folder`. 55 | 56 | ```python 57 | python ./eval_scripts/eval.py --input_path ./pred_folder/internvl2_5_26b.json --output_path ./res_folder/internvl2_5_26b.json 58 | ``` 59 | 60 | Once the scores for all samples have been calculated, you can use the script `./eval_scripts/get_score.py` to compute the overall metrics for OCRBench v2. 61 | 62 | ```python 63 | python ./eval_scripts/get_score.py --json_file ./res_folder/internvl2_5_26b.json 64 | ``` 65 | 66 | # Leaderboard 67 | 68 | ## Performance of LMMs on English subsets 69 | 70 |

71 | 72 |

73 | 74 | ## Performance of LMMs on Chinese subsets 75 | 76 |

77 | 78 |

79 | 80 | # Copyright Statement 81 | The data are collected from public datasets and community user contributions. This dataset is for research purposes only and not for commercial use. If you have any copyright concerns, please contact ling_fu@hust.edu.cn. 82 | 83 | # Citation 84 | ```BibTeX 85 | @misc{fu2024ocrbenchv2improvedbenchmark, 86 | title={OCRBench v2: An Improved Benchmark for Evaluating Large Multimodal Models on Visual Text Localization and Reasoning}, 87 | author={Ling Fu and Biao Yang and Zhebin Kuang and Jiajun Song and Yuzhe Li and Linghao Zhu and Qidi Luo and Xinyu Wang and Hao Lu and Mingxin Huang and Zhang Li and Guozhi Tang and Bin Shan and Chunhui Lin and Qi Liu and Binghong Wu and Hao Feng and Hao Liu and Can Huang and Jingqun Tang and Wei Chen and Lianwen Jin and Yuliang Liu and Xiang Bai}, 88 | year={2024}, 89 | eprint={2501.00321}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.CV}, 92 | url={https://arxiv.org/abs/2501.00321}, 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/IoUscore_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import ast 4 | import ipdb 5 | from vqa_metric import vqa_evaluation 6 | 7 | 8 | def calculate_iou(box1, box2): 9 | 10 | try: 11 | box1 = [int(coordinate) for coordinate in box1] 12 | box2 = [int(coordinate) for coordinate in box2] 13 | except: 14 | return 0 15 | 16 | x1_inter = max(box1[0], box2[0]) 17 | y1_inter = max(box1[1], box2[1]) 18 | x2_inter = min(box1[2], box2[2]) 19 | y2_inter = min(box1[3], box2[3]) 20 | 21 | inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter) 22 | 23 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 24 | box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) 25 | 26 | union_area = box1_area + box2_area - inter_area 27 | 28 | iou = inter_area / union_area if union_area != 0 else 0 29 | 30 | return iou 31 | 32 | 33 | def vqa_with_position_evaluation(predict, img_metas): 34 | 35 | score_content, score_bbox = .0, .0 36 | if "answer" in predict.keys(): 37 | score_content = vqa_evaluation(predict["answer"], img_metas["answers"]) 38 | if "bbox" in predict.keys(): 39 | gt_bbox = img_metas["bbox"] 40 | try: 41 | predict_bbox_list = ast.literal_eval(predict["bbox"]) 42 | score_bbox = calculate_iou(predict_bbox_list, gt_bbox) 43 | except: 44 | score_bbox = 0 45 | return 0.5 * score_content + 0.5 * score_bbox 46 | 47 | 48 | def extract_coordinates(text): 49 | # Regex pattern to match coordinates in either (x1, y1, x2, y2) or [x1, y1, x2, y2] format 50 | 51 | pattern = r'[\(\[]\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*[\)\]]' 52 | 53 | matches = list(re.finditer(pattern, text)) 54 | coords_list = [] 55 | coords_set = set() 56 | for match in matches: 57 | 58 | x1, y1, x2, y2 = map(int, match.groups()) 59 | 60 | if all(0 <= n <= 1000 for n in [x1, y1, x2, y2]): 61 | coords = (x1, y1, x2, y2) 62 | 63 | if coords in coords_set: 64 | coords_list = [c for c in coords_list if c != coords] 65 | 66 | coords_list.append(coords) 67 | coords_set.add(coords) 68 | if coords_list: 69 | last_coords = coords_list[-1] 70 | return list(last_coords) 71 | else: 72 | return None 73 | 74 | 75 | if __name__ == "__main__": 76 | 77 | print("Example for Text Grounding task.") 78 | box1 = [50, 50, 150, 150] 79 | box2 = [60, 60, 140, 140] 80 | iou_score = calculate_iou(box1, box2) 81 | print(f"IoU score: {iou_score}") 82 | 83 | print("Example for VQA with position task.") 84 | pred = {"content": "The content is Hello Buddies", "bbox": box1} 85 | gt = {"content": "Hello Buddies", "bbox": box2} 86 | 87 | vqa_score = vqa_evaluation(pred["content"], gt["content"]) 88 | iou_score = calculate_iou(pred["bbox"], gt["bbox"]) 89 | 90 | print(f"VQA score: {vqa_score}") 91 | print(f"IoU score: {iou_score}") 92 | -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/IoUscore_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/IoUscore_metric.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/TEDS_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/TEDS_metric.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/page_ocr_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/page_ocr_metric.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/parallel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/parallel.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/spotting_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/spotting_metric.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/__pycache__/vqa_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/__pycache__/vqa_metric.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/get_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import ipdb 4 | import argparse 5 | 6 | 7 | def calculate_average(scores_dict): 8 | averages = {key: sum(values) / len(values) for key, values in scores_dict.items() if len(values) > 0} 9 | return averages 10 | 11 | 12 | def main(): 13 | # Set up argument parser 14 | parser = argparse.ArgumentParser(description="Process a JSON file to calculate scores.") 15 | parser.add_argument("--json_file", type=str, required=True, help="Path to the JSON file containing inference data.") 16 | args = parser.parse_args() 17 | 18 | # Load data from JSON file 19 | inference_file = args.json_file 20 | if not os.path.exists(inference_file): 21 | print(f"Error: File '{inference_file}' does not exist.") 22 | return 23 | 24 | with open(inference_file, "r") as f: 25 | data_list = json.load(f) 26 | 27 | en_text_recognition_list, en_text_detection_list, en_text_spotting_list, en_relationship_extraction_list = [], [], [], [] 28 | en_element_parsing_list, en_mathematical_calculation_list, en_visual_text_understanding_list = [], [], [] 29 | en_knowledge_reasoning_list = [] 30 | 31 | cn_text_recognition_list, cn_relationship_extraction_list = [], [] 32 | cn_element_parsing_list, cn_visual_text_understanding_list = [], [] 33 | cn_knowledge_reasoning_list = [] 34 | 35 | res_list = [] 36 | for item in data_list: 37 | if "ignore" in item.keys(): 38 | assert item["ignore"] == "True" 39 | 40 | elif item["type"] == "text recognition en" or item["type"] == "fine-grained text recognition en" or item["type"] == "full-page OCR en": 41 | en_text_recognition_list.append(item["score"]) 42 | 43 | elif item["type"] == "text grounding en" or item["type"] == "VQA with position en": 44 | en_text_detection_list.append(item["score"]) 45 | 46 | elif item["type"] == "text spotting en": 47 | en_text_spotting_list.append(item["score"]) 48 | 49 | elif item["type"] == "key information extraction en" or item["type"] == "key information mapping en": 50 | en_relationship_extraction_list.append(item["score"]) 51 | 52 | elif item["type"] == "document parsing en" or item["type"] == "chart parsing en" \ 53 | or item["type"] == "table parsing en" or item["type"] == "formula recognition en": 54 | en_element_parsing_list.append(item["score"]) 55 | 56 | elif item["type"] == "math QA en" or item["type"] == "text counting en": 57 | en_mathematical_calculation_list.append(item["score"]) 58 | 59 | elif item["type"] == "document classification en" \ 60 | or item["type"] == "cognition VQA en" or item["type"] == "diagram QA en": 61 | en_visual_text_understanding_list.append(item["score"]) 62 | 63 | elif item["type"] == "reasoning VQA en" or item["type"] == "science QA en" \ 64 | or item["type"] == "APP agent en" or item["type"] == "ASCII art classification en": 65 | en_knowledge_reasoning_list.append(item["score"]) 66 | 67 | elif item["type"] == "full-page OCR cn": 68 | cn_text_recognition_list.append(item["score"]) 69 | 70 | elif item["type"] == "key information extraction cn" or item["type"] == "handwritten answer extraction cn": 71 | cn_relationship_extraction_list.append(item["score"]) 72 | 73 | elif item["type"] == "document parsing cn" or item["type"] == "table parsing cn" or item["type"] == "formula recognition cn": 74 | cn_element_parsing_list.append(item["score"]) 75 | 76 | elif item["type"] == "cognition VQA cn": 77 | cn_visual_text_understanding_list.append(item["score"]) 78 | 79 | elif item["type"] == "reasoning VQA cn" or item["type"] == "text translation cn": 80 | cn_knowledge_reasoning_list.append(item["score"]) 81 | 82 | else: 83 | raise ValueError("Unknown task type!") 84 | 85 | en_scores = { 86 | "text_recognition": en_text_recognition_list, 87 | "text_detection": en_text_detection_list, 88 | "text_spotting": en_text_spotting_list, 89 | "relationship_extraction": en_relationship_extraction_list, 90 | "element_parsing": en_element_parsing_list, 91 | "mathematical_calculation": en_mathematical_calculation_list, 92 | "visual_text_understanding": en_visual_text_understanding_list, 93 | "knowledge_reasoning": en_knowledge_reasoning_list 94 | } 95 | 96 | cn_scores = { 97 | "text_recognition": cn_text_recognition_list, 98 | "relationship_extraction": cn_relationship_extraction_list, 99 | "element_parsing": cn_element_parsing_list, 100 | "visual_text_understanding": cn_visual_text_understanding_list, 101 | "knowledge_reasoning": cn_knowledge_reasoning_list 102 | } 103 | 104 | en_averages = calculate_average(en_scores) 105 | cn_averages = calculate_average(cn_scores) 106 | 107 | print("English Scores:") 108 | for key, score in en_averages.items(): 109 | print(f"{key}: {score:.3f} (Count: {len(en_scores[key])})") 110 | 111 | print("\nChinese Scores:") 112 | for key, score in cn_averages.items(): 113 | print(f"{key}: {score:.3f} (Count: {len(cn_scores[key])})") 114 | 115 | score_en_overall = sum(en_averages.values()) / len(en_averages) 116 | score_cn_overall = sum(cn_averages.values()) / len(cn_averages) 117 | 118 | print("\nOverall Scores:") 119 | print(f"English Overall Score: {score_en_overall:.3f}") 120 | print(f"Chinese Overall Score: {score_cn_overall:.3f}") 121 | 122 | print("End of Code!") 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/page_ocr_metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import nltk 4 | from nltk.metrics import precision, recall, f_measure 5 | import numpy as np 6 | import jieba 7 | import re 8 | from nltk.translate import meteor_score 9 | 10 | 11 | def contain_chinese_string(text): 12 | chinese_pattern = re.compile(r'[\u4e00-\u9fa5]') 13 | return bool(chinese_pattern.search(text)) 14 | 15 | def cal_per_metrics(pred, gt): 16 | metrics = {} 17 | 18 | if contain_chinese_string(gt) or contain_chinese_string(pred): 19 | reference = jieba.lcut(gt) 20 | hypothesis = jieba.lcut(pred) 21 | else: 22 | reference = gt.split() 23 | hypothesis = pred.split() 24 | 25 | metrics["bleu"] = nltk.translate.bleu([reference], hypothesis) 26 | metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis) 27 | 28 | reference = set(reference) 29 | hypothesis = set(hypothesis) 30 | metrics["f_measure"] = f_measure(reference, hypothesis) 31 | 32 | metrics["precision"] = precision(reference, hypothesis) 33 | metrics["recall"] = recall(reference, hypothesis) 34 | metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt)) 35 | return metrics 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | # Examples for region text recognition and read all text tasks 41 | predict_text = "metrics['edit_dist'] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))" 42 | true_text = "metrics = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))" 43 | 44 | scores = cal_per_metrics(predict_text, true_text) 45 | 46 | predict_text = "metrics['edit_dist'] len(gt))" 47 | true_text = "metrics = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))" 48 | 49 | scores = cal_per_metrics(predict_text, true_text) 50 | print(scores) 51 | -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/parallel.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from concurrent.futures import ProcessPoolExecutor, as_completed 3 | 4 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): 5 | """ 6 | A parallel version of the map function with a progress bar. 7 | 8 | Args: 9 | array (array-like): An array to iterate over. 10 | function (function): A python function to apply to the elements of array 11 | n_jobs (int, default=16): The number of cores to use 12 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 13 | keyword arguments to function 14 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 15 | Useful for catching bugs 16 | Returns: 17 | [function(array[0]), function(array[1]), ...] 18 | """ 19 | # We run the first few iterations serially to catch bugs 20 | if front_num > 0: 21 | front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] 22 | else: 23 | front = [] 24 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. 25 | if n_jobs == 1: 26 | return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] 27 | # Assemble the workers 28 | with ProcessPoolExecutor(max_workers=n_jobs) as pool: 29 | # Pass the elements of array into function 30 | if use_kwargs: 31 | futures = [pool.submit(function, **a) for a in array[front_num:]] 32 | else: 33 | futures = [pool.submit(function, a) for a in array[front_num:]] 34 | kwargs = { 35 | 'total': len(futures), 36 | 'unit': 'it', 37 | 'unit_scale': True, 38 | 'leave': True 39 | } 40 | # Print out the progress as tasks complete 41 | for f in tqdm(as_completed(futures), **kwargs): 42 | pass 43 | out = [] 44 | # Get the results from the futures. 45 | for i, future in tqdm(enumerate(futures)): 46 | try: 47 | out.append(future.result()) 48 | except Exception as e: 49 | out.append(e) 50 | return front + out -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__init__.py -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/rrc_evaluation_funcs_1_1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/rrc_evaluation_funcs_1_1.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/rrc_evaluation_funcs_1_1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/rrc_evaluation_funcs_1_1.cpython-39.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/script.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/script.cpython-310.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/__pycache__/script.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/__pycache__/script.cpython-39.pyc -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/gt.zip -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/gt/gt_img_0.txt: -------------------------------------------------------------------------------- 1 | 442,380,507,380,507,399,442,399,CHEROKEE 2 | 506,380,547,380,547,397,506,397,STREET 3 | 481,399,536,399,536,417,481,417,BIKES 4 | 443,425,469,425,469,438,443,438,### 5 | 471,425,505,425,505,438,471,438,### 6 | 513,425,543,425,543,439,513,439,### -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/results.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/results.zip -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/script_test_ch4_t4_e1-1577983164.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/script_test_ch4_t4_e1-1577983164.zip -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/submit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuliang-Liu/MultimodalOCR/b5ecad3e3408dd924497d9329ff4b0b8295dfe15/OCRBench_v2/eval_scripts/spotting_eval/submit.zip -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_eval/submit/res_img_0.txt: -------------------------------------------------------------------------------- 1 | 0,0,1000,0,1000,1000,0,1000,CHEROKEE STREET BIKES -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/spotting_metric.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import ast 4 | import ipdb 5 | import shutil 6 | import zipfile 7 | import subprocess 8 | import spotting_eval.rrc_evaluation_funcs_1_1 as rrc_evaluation_funcs 9 | from spotting_eval.script import default_evaluation_params,validate_data,evaluate_method 10 | 11 | 12 | def extract_bounding_boxes_robust(predict_str): 13 | """ 14 | Extract coordinates and text content from the given prediction string, 15 | handling potential format issues. 16 | 17 | Args: 18 | predict_str (str): Model prediction output as a string. 19 | 20 | Returns: 21 | list: Extracted data in the format [[x1, y1, x2, y2, text_content], ...]. 22 | Returns None if no valid data is extracted. 23 | """ 24 | results = [] 25 | seen = set() 26 | 27 | # try parsing with ast.literal_eval 28 | try: 29 | data = ast.literal_eval(predict_str) 30 | except Exception: 31 | data = None 32 | 33 | if data is not None: 34 | if isinstance(data, (list, tuple)): 35 | for item in data: 36 | if isinstance(item, (list, tuple)) and len(item) >= 5: 37 | x1_str, y1_str, x2_str, y2_str = item[:4] 38 | text_content = item[4] 39 | 40 | x1_str = str(x1_str).strip() 41 | y1_str = str(y1_str).strip() 42 | x2_str = str(x2_str).strip() 43 | y2_str = str(y2_str).strip() 44 | text_content = str(text_content).replace("\n", "").strip().strip('"').strip("'") 45 | 46 | try: 47 | x1 = int(x1_str) 48 | y1 = int(y1_str) 49 | x2 = int(x2_str) 50 | y2 = int(y2_str) 51 | 52 | if not (0 <= x1 <= 1000 and 0 <= y1 <= 1000 and 0 <= x2 <= 1000 and 0 <= y2 <= 1000): 53 | continue 54 | 55 | key = (x1, y1, x2, y2, text_content) 56 | if key in seen: 57 | continue 58 | 59 | seen.add(key) 60 | results.append([x1, y1, x2, y2, text_content]) 61 | except ValueError: 62 | continue 63 | else: 64 | # try parsing with regular expression 65 | 66 | list_content = predict_str 67 | items = re.findall(r'[\[\(]\s*([^\[\]\(\)]*?)\s*[\]\)]', list_content) 68 | 69 | if not items: 70 | return None 71 | 72 | for item in items: 73 | parts = item.split(',', 4) 74 | if len(parts) < 5: 75 | continue 76 | 77 | x1_str, y1_str, x2_str, y2_str, text_content = parts 78 | 79 | x1_str = x1_str.strip() 80 | y1_str = y1_str.strip() 81 | x2_str = x2_str.strip() 82 | y2_str = y2_str.strip() 83 | text_content = text_content.replace("\n", "").strip().strip('"').strip("'") 84 | 85 | try: 86 | x1 = int(x1_str) 87 | y1 = int(y1_str) 88 | x2 = int(x2_str) 89 | y2 = int(y2_str) 90 | 91 | if not (0 <= x1 <= 1000 and 0 <= y1 <= 1000 and 0 <= x2 <= 1000 and 0 <= y2 <= 1000): 92 | continue 93 | 94 | key = (x1, y1, x2, y2, text_content) 95 | if key in seen: 96 | continue 97 | 98 | seen.add(key) 99 | results.append([x1, y1, x2, y2, text_content]) 100 | except ValueError: 101 | continue 102 | 103 | if not results: 104 | return None 105 | 106 | return results 107 | 108 | 109 | def zip_folder(source_folder, destination_zip): 110 | abs_source = os.path.abspath(source_folder) 111 | abs_destination = os.path.abspath(destination_zip) 112 | 113 | with zipfile.ZipFile(abs_destination, 'w', zipfile.ZIP_DEFLATED) as zf: 114 | for root, _, files in os.walk(abs_source): 115 | for file in files: 116 | abs_file_path = os.path.join(root, file) 117 | 118 | relative_path = os.path.relpath(abs_file_path, abs_source) 119 | zf.write(abs_file_path, relative_path) 120 | 121 | 122 | def spotting_evaluation(prediction_list, img_metas): 123 | score = 0 124 | 125 | submit_path = "./eval_scripts/spotting_eval/submit" 126 | gt_path = "./eval_scripts/spotting_eval/gt" 127 | submit_zip_path = "./eval_scripts/spotting_eval/submit.zip" 128 | gt_zip_path = "./eval_scripts/spotting_eval/gt.zip" 129 | for file_path in [submit_path, gt_path, submit_zip_path, gt_zip_path]: 130 | if "zip" in file_path: 131 | if os.path.exists(file_path): 132 | os.remove(file_path) 133 | else: 134 | if os.path.exists(file_path): 135 | shutil.rmtree(file_path) 136 | os.makedirs(file_path) 137 | 138 | res_submit_list = [] 139 | for item in prediction_list: 140 | if len(item) != 5: 141 | ipdb.set_trace() 142 | x1, y1, x2, y2, rec = item 143 | if x1 >= x2 or y1 >= y2: 144 | continue 145 | 146 | res_submit_list.append(",".join([str(x1),str(y1),str(x2),str(y1),str(x2),str(y2),str(x1),str(y2),rec])) 147 | 148 | res_gt_list = [] 149 | for bbox, rec in zip(img_metas["bbox"], img_metas["content"]): 150 | x_coords = bbox[0::2] 151 | y_coords = bbox[1::2] 152 | 153 | x1, y1 = min(x_coords), min(y_coords) 154 | x2, y2 = max(x_coords), max(y_coords) 155 | 156 | res_gt_list.append(",".join([str(x1),str(y1),str(x2),str(y1),str(x2),str(y2),str(x1),str(y2),rec])) 157 | 158 | if len(res_submit_list) == 0 or len(res_gt_list) == 0: 159 | return 0 160 | 161 | with open(os.path.join(submit_path,"res_img_0.txt"), "w") as f: 162 | for item in res_submit_list[:-1]: 163 | f.write(item + "\n") 164 | f.write(res_submit_list[-1]) 165 | 166 | with open(os.path.join(gt_path,"gt_img_0.txt"), "w") as f: 167 | for item in res_gt_list[:-1]: 168 | f.write(item + "\n") 169 | f.write(res_gt_list[-1]) 170 | 171 | zip_folder(submit_path, submit_zip_path) 172 | zip_folder(gt_path, gt_zip_path) 173 | 174 | command = { 175 | 'g': gt_zip_path, 176 | 's': submit_zip_path, 177 | 'o': './', 178 | 'p': '{"IOU_CONSTRAINT":0.5}' 179 | } 180 | 181 | # run rrc_evaluation_funcs 182 | result = rrc_evaluation_funcs.main_evaluation(command,default_evaluation_params,validate_data,evaluate_method) 183 | score = result["method"]["hmean"] 184 | return score 185 | -------------------------------------------------------------------------------- /OCRBench_v2/eval_scripts/vqa_metric.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import ipdb 5 | import math 6 | import numpy as np 7 | 8 | 9 | def levenshtein_distance(s1, s2): 10 | if len(s1) > len(s2): 11 | s1, s2 = s2, s1 12 | 13 | distances = range(len(s1) + 1) 14 | for i2, c2 in enumerate(s2): 15 | distances_ = [i2+1] 16 | for i1, c1 in enumerate(s1): 17 | if c1 == c2: 18 | distances_.append(distances[i1]) 19 | else: 20 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 21 | distances = distances_ 22 | return distances[-1] 23 | 24 | 25 | def vqa_evaluation(predict, answers): 26 | score = 0 27 | if type(answers)==list: 28 | for j in range(len(answers)): 29 | if isinstance(answers[j], (int, float)): 30 | answers[j] = str(answers[j]) 31 | try: 32 | answer = answers[j].lower().strip().replace("\n"," ") 33 | except: 34 | ipdb.set_trace() 35 | if isinstance(predict, (int, float)): 36 | predict = str(predict) 37 | predict = predict.lower().strip().replace("\n"," ") 38 | if len(answer.split()) < 5: 39 | if answer in predict: 40 | score = 1 41 | else: 42 | dist = levenshtein_distance(predict, answer) 43 | length = max(len(predict), len(answer)) 44 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 45 | ANLS_value = 1 - ANLS_value 46 | 47 | if ANLS_value >= 0.5 and ANLS_value > score: 48 | score = ANLS_value 49 | 50 | else: 51 | answers = answers.lower().strip().replace("\n"," ") 52 | predict = predict.lower().strip().replace("\n"," ") 53 | if len(answers.split()) < 5: 54 | if answers in predict: 55 | score = 1 56 | else: 57 | dist = levenshtein_distance(predict, answers) 58 | length = max(len(predict), len(answers)) 59 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 60 | ANLS_value = 1 - ANLS_value 61 | 62 | if ANLS_value >= 0.5 and ANLS_value > score: 63 | score = ANLS_value 64 | 65 | return score 66 | 67 | 68 | def cn_vqa_evaluation(predict, answers): 69 | score = 0 70 | if type(answers)==list: 71 | for j in range(len(answers)): 72 | if isinstance(answers[j], (int, float)): 73 | answers[j] = str(answers[j]) 74 | try: 75 | answer = answers[j].lower().strip().replace("\n"," ").replace(" ", "") 76 | except: 77 | ipdb.set_trace() 78 | if isinstance(predict, (int, float)): 79 | predict = str(predict) 80 | predict = predict.lower().strip().replace("\n"," ").replace(" ", "") 81 | if len(answer.split(",")) < 4: 82 | if answer in predict: 83 | score = 1 84 | else: 85 | dist = levenshtein_distance(predict, answer) 86 | length = max(len(predict), len(answer)) 87 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 88 | ANLS_value = 1 - ANLS_value 89 | 90 | if ANLS_value >= 0.5 and ANLS_value > score: 91 | score = ANLS_value 92 | 93 | else: 94 | answers = answers.lower().strip().replace("\n"," ").replace(" ", "") 95 | predict = predict.lower().strip().replace("\n"," ").replace(" ", "") 96 | if len(answer.split(",")) < 4: 97 | if answers in predict: 98 | score = 1 99 | else: 100 | dist = levenshtein_distance(predict, answers) 101 | length = max(len(predict), len(answers)) 102 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 103 | ANLS_value = 1 - ANLS_value 104 | 105 | if ANLS_value >= 0.5 and ANLS_value > score: 106 | score = ANLS_value 107 | 108 | return score 109 | 110 | 111 | def vqa_evaluation_case_sensitive(predict, answers): 112 | score = 0 113 | if type(answers)==list: 114 | for j in range(len(answers)): 115 | if isinstance(answers[j], (int, float)): 116 | answers[j] = str(answers[j]) 117 | try: 118 | answer = answers[j].strip().replace("\n"," ") 119 | except: 120 | ipdb.set_trace() 121 | predict = predict.strip().replace("\n"," ") 122 | if len(answer.split()) < 5: 123 | if answer in predict: 124 | score = 1 125 | else: 126 | dist = levenshtein_distance(predict, answer) 127 | length = max(len(predict), len(answer)) 128 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 129 | ANLS_value = 1 - ANLS_value 130 | 131 | if ANLS_value >= 0.5 and ANLS_value > score: 132 | score = ANLS_value 133 | 134 | else: 135 | answers = answers.strip().replace("\n"," ") 136 | predict = predict.strip().replace("\n"," ") 137 | if len(answers.split()) < 5: 138 | if answers in predict: 139 | score = 1 140 | else: 141 | dist = levenshtein_distance(predict, answers) 142 | length = max(len(predict), len(answers)) 143 | ANLS_value = 0.0 if length == 0 else float(dist) / float(length) 144 | ANLS_value = 1 - ANLS_value 145 | 146 | if ANLS_value >= 0.5 and ANLS_value > score: 147 | score = ANLS_value 148 | 149 | return score 150 | 151 | 152 | def extract_first_number(string): 153 | match = re.search(r'\d+', string) 154 | if match: 155 | return int(match.group()) 156 | return None 157 | 158 | 159 | def counting_evaluation(predict, answers, eval_method): 160 | score = 0 161 | 162 | if isinstance(predict, str): 163 | predict_processed = predict.lower().strip().replace("\n", " ") 164 | elif math.isnan(predict): 165 | return 0 166 | else: 167 | predict_processed = int(predict) 168 | if type(answers)==list: 169 | temp_score = 0 170 | for j in range(len(answers)): 171 | if isinstance(answers[j], (int, float)): 172 | answers[j] = str(answers[j]) 173 | answer = answers[j].lower().strip().replace("\n"," ") 174 | if eval_method == "exact match": 175 | if answer in predict: 176 | score = 1 177 | else: 178 | score = 0 179 | elif eval_method == "regression": 180 | predict_number = extract_first_number(predict_processed) 181 | if predict_number: 182 | 183 | answer = int(answer) 184 | 185 | if predict_number <= 0 or predict_number >= 2 * answer: 186 | score = 0 187 | else: 188 | iou = 1 - abs(predict_number - answer) / answer 189 | if iou > 0.5: 190 | score = iou 191 | else: 192 | score = 0 193 | else: 194 | score = 0 195 | if score > temp_score: 196 | temp_score = score 197 | score = temp_score 198 | 199 | else: 200 | answers = answers.lower().strip().replace("\n"," ") 201 | predict = predict.lower().strip().replace("\n"," ") 202 | if eval_method == "exact match": 203 | if answer in predict: 204 | score = 1 205 | else: 206 | score = 0 207 | elif eval_method == "regression": 208 | predict = extract_first_number(predict) 209 | if predict: 210 | answer = int(answer) 211 | if predict <= 0 or predict >= 2 * answer: 212 | score = 0 213 | else: 214 | iou = 1 - abs(predict - answer) / answer 215 | 216 | if iou > 0.5: 217 | score = iou 218 | else: 219 | score = 0 220 | else: 221 | score = 0 222 | return score 223 | 224 | 225 | def math_expression_evaluation(predict, answers): 226 | score = 0 227 | if type(answers)==list: 228 | for j in range(len(answers)): 229 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 230 | predict = predict.strip().replace("\n"," ").replace(" ","") 231 | if answer in predict: 232 | score = 1 233 | else: 234 | answers = answers.strip().replace("\n"," ").replace(" ","") 235 | predict = predict.strip().replace("\n"," ").replace(" ","") 236 | if answers in predict: 237 | score = 1 238 | return score 239 | 240 | 241 | def remove_text_tags(latex_str): 242 | """ 243 | Removes LaTeX \text{...} tags while keeping their content. 244 | 245 | :param latex_str: A string containing LaTeX expressions 246 | :return: The processed string with \text{...} tags removed 247 | """ 248 | 249 | pattern = r'\\text\{([^{}]*)\}' 250 | 251 | processed_str = re.sub(pattern, r'\1', latex_str) 252 | 253 | return processed_str 254 | 255 | 256 | def cn_math_expression_evaluation(predict, answers): 257 | score = 0 258 | 259 | assert len(answers) == 1 260 | answers = [remove_text_tags(answers[0])] 261 | predict = remove_text_tags(predict) 262 | 263 | if type(answers)==list: 264 | for j in range(len(answers)): 265 | answer = answers[j].strip().replace("\n"," ").replace(" ","") 266 | predict = predict.strip().replace("\n"," ").replace(" ","") 267 | if answer in predict: 268 | score = 1 269 | else: 270 | answers = answers.strip().replace("\n"," ").replace(" ","") 271 | predict = predict.strip().replace("\n"," ").replace(" ","") 272 | if answers in predict: 273 | score = 1 274 | return score 275 | 276 | 277 | if __name__ == "__main__": 278 | test_predict = "apple pie and banana" 279 | test_answers = ["apple", "banana pie", "apple pie and orange"] 280 | 281 | vqa_score = vqa_evaluation(test_predict, test_answers) 282 | print(f"VQA evaluation score for predict '{test_predict}' and answers {test_answers}: {vqa_score}") 283 | -------------------------------------------------------------------------------- /OCRBench_v2/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | distance 3 | apted 4 | lxml 5 | zss 6 | Levenshtein 7 | editdistance 8 | nltk 9 | jieba 10 | Polygon3 11 | tqdm 12 | ipdb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OCRBench & OCRBench v2 2 | 3 | **This is the repository of the [OCRBench](./OCRBench/README.md) & [OCRBench v2](./OCRBench_v2/README.md).** 4 | 5 | **OCRBench** is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models. It comprises five components: Text Recognition, SceneText-Centric VQA, Document-Oriented VQA, Key Information Extraction, and Handwritten Mathematical Expression Recognition. The benchmark includes 1000 question-answer pairs, and all the answers undergo manual verification and correction to ensure a more precise evaluation. More details can be found in [OCRBench README](./OCRBench/README.md). 6 | 7 |

8 | 9 |

10 | 11 | **OCRBench v2** is a large-scale bilingual text-centric benchmark with currently the most comprehensive set of tasks (4× more tasks than the previous multi-scene benchmark OCRBench), the widest coverage of scenarios (31 diverse scenarios including street scene, receipt, formula, diagram, and so on), and thorough evaluation metrics, with a total of 10, 000 human-verified question-answering pairs and a high proportion of difficult samples. More details can be found in [OCRBench v2 README](./OCRBench_v2/README.md). 12 | 13 |

14 | 15 |

16 | 17 | # News 18 | * ```2024.12.31``` 🚀 [OCRBench v2](./OCRBench_v2/README.md) is released. 19 | * ```2024.12.11``` 🚀 OCRBench has been accepted by [Science China Information Sciences](https://link.springer.com/article/10.1007/s11432-024-4235-6). 20 | * ```2024.5.19 ``` 🚀 We realese [DTVQA](https://github.com/ShuoZhang2003/DT-VQA), to explore the Capabilities of Large Multimodal Models on Dense Text. 21 | * ```2024.5.01 ``` 🚀 Thanks to [SWHL](https://github.com/Yuliang-Liu/MultimodalOCR/issues/29) for releasing [ChineseOCRBench](https://huggingface.co/datasets/SWHL/ChineseOCRBench). 22 | * ```2024.3.26 ``` 🚀 OCRBench is now supported in [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). 23 | * ```2024.3.12 ``` 🚀 We plan to construct OCRBench v2 to include more ocr tasks and data. Any contribution will be appreciated. 24 | * ```2024.2.25 ``` 🚀 OCRBench is now supported in [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). 25 | 26 | 27 | # Other Related Multilingual Datasets 28 | | Data | Link | Description | 29 | | --- | --- | --- | 30 | | EST-VQA Dataset (CVPR 2020, English and Chinese) | [Link](https://github.com/xinke-wang/EST-VQA) | On the General Value of Evidence, and Bilingual Scene-Text Visual Question Answering. | 31 | | Swahili Dataset (ICDAR 2024) | [Link](https://arxiv.org/abs/2405.11437) | The First Swahili Language Scene Text Detection and Recognition Dataset. | 32 | | Urdu Dataset (ICDAR 2024) | [Link](https://arxiv.org/abs/2405.12533) | Dataset and Benchmark for Urdu Natural Scenes Text Detection, Recognition and Visual Question Answering. | 33 | | MTVQA (9 languages) | [Link](https://arxiv.org/abs/2405.11985) | MTVQA: Benchmarking Multilingual Text-Centric Visual Question Answering. | 34 | | EVOBC (Oracle Bone Script Evolution Dataset) | [Link](https://arxiv.org/abs/2401.12467) | We systematically collected ancient characters from authoritative texts and websites spanning six historical stages. | 35 | | HUST-OBC (Oracle Bone Script Character Dataset) | [Link](https://arxiv.org/abs/2401.15365) | For deciphering oracle bone script characters. | 36 | 37 | # Citation 38 | If you wish to refer to the baseline results published here, please use the following BibTeX entries: 39 | ```BibTeX 40 | @article{Liu_2024, 41 | title={OCRBench: on the hidden mystery of OCR in large multimodal models}, 42 | volume={67}, 43 | ISSN={1869-1919}, 44 | url={http://dx.doi.org/10.1007/s11432-024-4235-6}, 45 | DOI={10.1007/s11432-024-4235-6}, 46 | number={12}, 47 | journal={Science China Information Sciences}, 48 | publisher={Springer Science and Business Media LLC}, 49 | author={Liu, Yuliang and Li, Zhang and Huang, Mingxin and Yang, Biao and Yu, Wenwen and Li, Chunyuan and Yin, Xu-Cheng and Liu, Cheng-Lin and Jin, Lianwen and Bai, Xiang}, 50 | year={2024}, 51 | month=dec } 52 | 53 | @misc{fu2024ocrbenchv2improvedbenchmark, 54 | title={OCRBench v2: An Improved Benchmark for Evaluating Large Multimodal Models on Visual Text Localization and Reasoning}, 55 | author={Ling Fu and Biao Yang and Zhebin Kuang and Jiajun Song and Yuzhe Li and Linghao Zhu and Qidi Luo and Xinyu Wang and Hao Lu and Mingxin Huang and Zhang Li and Guozhi Tang and Bin Shan and Chunhui Lin and Qi Liu and Binghong Wu and Hao Feng and Hao Liu and Can Huang and Jingqun Tang and Wei Chen and Lianwen Jin and Yuliang Liu and Xiang Bai}, 56 | year={2024}, 57 | eprint={2501.00321}, 58 | archivePrefix={arXiv}, 59 | primaryClass={cs.CV}, 60 | url={https://arxiv.org/abs/2501.00321}, 61 | } 62 | ``` 63 | 64 | 65 | 66 | --------------------------------------------------------------------------------