├── .env.example ├── .eslintrc.json ├── .gitignore ├── .prettierrc ├── .streamlit └── config.toml ├── LICENSE ├── README.md ├── assets └── dashboard-gif.gif ├── dashboard ├── .gitignore ├── Home.py ├── README.md ├── pages │ ├── 1_Performance_Metrics.py │ └── 2_Test_Result.py └── utils │ ├── data_loader.py │ └── style.py ├── data └── receipt.json ├── jest.config.ts ├── package-lock.json ├── package.json ├── prisma └── schema.prisma ├── requirements.txt ├── src ├── evaluation │ ├── index.ts │ ├── json.ts │ └── text.ts ├── index.ts ├── models.example.yaml ├── models │ ├── awsTextract.ts │ ├── azure.ts │ ├── base.ts │ ├── dashscope.ts │ ├── gemini.ts │ ├── googleDocumentAI.ts │ ├── index.ts │ ├── llm.ts │ ├── mistral.ts │ ├── omniAI.ts │ ├── openai.ts │ ├── openrouter.ts │ ├── registry.ts │ ├── shared │ │ ├── index.ts │ │ ├── prompt.ts │ │ └── tokenCost.ts │ ├── togetherai.ts │ ├── unstructured.ts │ └── zerox.ts ├── types │ ├── data.ts │ ├── index.ts │ └── model.ts └── utils │ ├── dataLoader.ts │ ├── db.ts │ ├── file.ts │ ├── htmlToMarkdown.ts │ ├── index.ts │ ├── logs.ts │ └── zod.ts ├── tests └── evaluation │ └── json.test.ts └── tsconfig.json /.env.example: -------------------------------------------------------------------------------- 1 | # OmniAI 2 | OMNIAI_API_URL= 3 | OMNIAI_API_KEY= 4 | 5 | # OpenAI 6 | OPENAI_API_KEY= 7 | 8 | # Anthropic 9 | ANTHROPIC_API_KEY= 10 | 11 | # Gemini 12 | GOOGLE_GENERATIVE_AI_API_KEY= 13 | 14 | # DeepSeek 15 | DEEPSEEK_API_KEY= 16 | 17 | # Mistral 18 | MISTRAL_API_KEY= 19 | 20 | # Unstructured 21 | UNSTRUCTURED_API_KEY= 22 | 23 | # AWS Textract 24 | AWS_REGION= 25 | AWS_ACCESS_KEY_ID= 26 | AWS_SECRET_ACCESS_KEY= 27 | 28 | # Azure 29 | AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT= 30 | AZURE_DOCUMENT_INTELLIGENCE_KEY= 31 | 32 | # Google Document AI 33 | GOOGLE_LOCATION= 34 | GOOGLE_PROJECT_ID= 35 | GOOGLE_PROCESSOR_ID= 36 | GOOGLE_APPLICATION_CREDENTIALS_PATH= 37 | 38 | # Database (load data from database & save results to database) 39 | DATABASE_URL= 40 | -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "eslint:recommended", 4 | "plugin:@typescript-eslint/recommended", 5 | "prettier" 6 | ], 7 | "parser": "@typescript-eslint/parser", 8 | "plugins": ["@typescript-eslint"], 9 | "root": true 10 | } 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Dependencies and build artifacts 2 | /node_modules 3 | /dist 4 | venv 5 | 6 | # Configuration and credentials 7 | .env 8 | google_credentials.json 9 | models.yaml 10 | 11 | # Data and model files 12 | data 13 | /results 14 | eng.traineddata 15 | scripts 16 | 17 | # System files 18 | .DS_Store -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "printWidth": 90, 3 | "semi": true, 4 | "singleQuote": true, 5 | "tabWidth": 2, 6 | "trailingComma": "all" 7 | } 8 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [server] 2 | runOnSave = true 3 | 4 | [theme] 5 | base="light" 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Omni OCR Benchmark](https://omniai-images.s3.us-east-1.amazonaws.com/omni-ocr-benchmark.png)](https://getomni.ai/ocr-benchmark) 2 | 3 | # Omni OCR Benchmark 4 | 5 | A benchmarking tool that compares OCR and data extraction capabilities of different large multimodal models such as gpt-4o, evaluating both text and json extraction accuracy. The goal of this benchmark is to publish a comprehensive benchmark of OCRaccuracy across traditional OCR providers and multimodal Language Models. The evaluation dataset and methodologies are all Open Source, and we encourage expanding this benchmark to encompass any additional providers. 6 | 7 | [**Open Source LLM Benchmark Results (Mar 2025)**](https://getomni.ai/blog/benchmarking-open-source-models-for-ocr) | [**Dataset**](https://huggingface.co/datasets/getomni-ai/ocr-benchmark) 8 | 9 | [**Benchmark Results (Feb 2025)**](https://getomni.ai/ocr-benchmark) | [**Dataset**](https://huggingface.co/datasets/getomni-ai/ocr-benchmark) 10 | 11 | ![image](https://github.com/user-attachments/assets/2be179ad-0abd-4f0e-b73a-7d5a70390367) 12 | 13 | 14 | ## Methodology 15 | 16 | The primary goal is to evaluate JSON extraction from documents. To evaluate this, the Omni benchmark runs Document ⇒ OCR ⇒ Extraction. Measuring how well a model can OCR a page, and return that content in a format that an LLM can parse. 17 | 18 | ![methodology](https://omniai-images.s3.us-east-1.amazonaws.com/methodology-diagram.png) 19 | 20 | ## Evaluation Metrics 21 | 22 | ### JSON accuracy 23 | 24 | We use a modified [json-diff](https://github.com/zgrossbart/jdd) to identify differences between predicted and ground truth JSON objects. You can review the [evaluation/json.ts](./src/evaluation/json.ts) file to see the exact implementation. Accuracy is calculated as: 25 | 26 | ```math 27 | \text{Accuracy} = 1 - \frac{\text{number of difference fields}}{\text{total fields}} 28 | ``` 29 | 30 | ![json-diff](https://omniai-images.s3.us-east-1.amazonaws.com/json_accuracy.png) 31 | 32 | ### Text similarity 33 | 34 | While the primary benchmark metric is JSON accuracy, we have included [levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance) as a measurement of text similarity between extracted and ground truth text. 35 | Lower distance indicates higher similarity. Note this scoring method heavily penalizes accurate text that does not conform to the exact layout of the ground truth data. 36 | 37 | In the example below, an LLM could decode both blocks of text without any issue. All the information is 100% accurate, but slight rearrangements of the header text (address, phone number, etc.) result in a large difference on edit distance scoring. 38 | 39 | ![text-similarity](https://omniai-images.s3.us-east-1.amazonaws.com/edit_distance.png) 40 | 41 | ## Running the benchmark 42 | 43 | 1. Clone the repo and install dependencies: `npm install` 44 | 2. Prepare your test data 45 | 1. For local data, add individual files to the `data` folder. 46 | 2. To pull from a DB, add `DATABASE_URL` in your `.env` 47 | 3. Copy the `models.example.yaml` file to `models.yaml`. Set up API keys in `.env` for the models you want to test. Check out the [supported models](#supported-models) here. 48 | 4. Run the benchmark: `npm run benchmark` 49 | 5. Results will be saved in the `results//results.json` file. 50 | 51 | ## Supported models 52 | 53 | To enable specific models, create a `models.yaml` file in the `src` directory. Check out the [models.example.yaml](./src/models.example.yaml) file for the required variables. 54 | 55 | ```yaml 56 | models: 57 | - ocr: gemini-2.0-flash-001 # The model to use for OCR 58 | extraction: gpt-4o # The model to use for JSON extraction 59 | 60 | - ocr: gpt-4o 61 | extraction: gpt-4o 62 | directImageExtraction: true # Whether to use the model's native image extraction capabilities 63 | ``` 64 | 65 | You can view configuration for each model in the [src/models/](./src/models/) folder. 66 | 67 | ### Closed-source LLMs 68 | 69 | | Model Provider | Models | OCR | JSON Extraction | Required ENV Variables | 70 | | -------------- | ------------------------------------------------------------ | --- | --------------- | ---------------------------------------------------------------------------------------------------- | 71 | | Anthropic | `claude-3-5-sonnet-20241022` | ✅ | ✅ | `ANTHROPIC_API_KEY` | 72 | | OpenAI | `gpt-4o` | ✅ | ✅ | `OPENAI_API_KEY` | 73 | | Gemini | `gemini-2.0-flash-001`, `gemini-1.5-pro`, `gemini-1.5-flash` | ✅ | ✅ | `GOOGLE_GENERATIVE_AI_API_KEY` | 74 | | Mistral | `mistral-ocr` | ✅ | ❌ | `MISTRAL_API_KEY` | 75 | | OmniAI | `omniai` | ✅ | ✅ | `OMNIAI_API_KEY`, `OMNIAI_API_URL` | 76 | 77 | ### Open-source LLMs 78 | 79 | | Model Provider | Models | OCR | JSON Extraction | Required ENV Variables | 80 | | -------------- | ------------------------------------------------------------ | --- | --------------- | ---------------------- | 81 | | Gemma 3 | `google/gemma-3-27b-it` | ✅ | ❌ | | 82 | | Qwen 2.5 | `qwen2.5-vl-32b-instruct`, `qwen2.5-vl-72b-instruct` | ✅ | ❌ | | 83 | | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo`, `meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo` | ✅ | ❌ | | 84 | | ZeroX | `zerox` | ✅ | ✅ | `OPENAI_API_KEY` | 85 | 86 | ### Cloud OCR Providers 87 | 88 | | Model Provider | Models | OCR | JSON Extraction | Required ENV Variables | 89 | | -------------- | ------------------------------------------------------------ | --- | --------------- | ---------------------------------------------------------------------------------------------------- | 90 | | AWS | `aws-text-extract` | ✅ | ❌ | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` | 91 | | Azure | `azure-document-intelligence` | ✅ | ❌ | `AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT`, `AZURE_DOCUMENT_INTELLIGENCE_KEY` | 92 | | Google | `google-document-ai` | ✅ | ❌ | `GOOGLE_LOCATION`, `GOOGLE_PROJECT_ID`, `GOOGLE_PROCESSOR_ID`, `GOOGLE_APPLICATION_CREDENTIALS_PATH` | 93 | | Unstructured | `unstructured` | ✅ | ❌ | `UNSTRUCTURED_API_KEY` | 94 | 95 | - LLMS are instructed to use the following [system prompts](./src/models/shared/prompt.ts) for OCR and JSON extraction. 96 | - For Google Document AI, you need to include `google_credentials.json` in the `data` folder. 97 | 98 | ## Benchmark Dashboard 99 | 100 | ![dashboard](./assets/dashboard-gif.gif) 101 | 102 | You can use benchmark dashboard to easily view the results of each test run. Check out the [dashboard documentation](dashboard/README.md) for more details. 103 | 104 | ## License 105 | 106 | This project is licensed under the MIT License - see the LICENSE file for details. 107 | -------------------------------------------------------------------------------- /assets/dashboard-gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/getomni-ai/benchmark/f54e6b595131ca7ef320a9058a4f0f6481e6c17d/assets/dashboard-gif.gif -------------------------------------------------------------------------------- /dashboard/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /dashboard/Home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from utils.style import SIDEBAR_STYLE 4 | 5 | st.set_page_config( 6 | page_title="OCR Benchmark Dashboard", 7 | ) 8 | st.markdown(SIDEBAR_STYLE, unsafe_allow_html=True) 9 | 10 | st.title("OCR Benchmark Dashboard") 11 | st.markdown( 12 | """ 13 | Welcome to the OCR Benchmark Dashboard! This tool helps analyze and visualize OCR and extraction model performance. 14 | 15 | ### Available Pages: 16 | - **Performance Metrics**: View detailed performance metrics, costs, and latency analysis 17 | - **Test Results**: View detailed test results 18 | 19 | Choose a page from the sidebar to get started. 20 | """ 21 | ) 22 | -------------------------------------------------------------------------------- /dashboard/README.md: -------------------------------------------------------------------------------- 1 | # OCR Benchmark Dashboard 2 | 3 | ![dashboard](../assets/dashboard-gif.gif) 4 | 5 | This dashboard is used to view the results of the OCR Benchmark. 6 | 7 | ## Getting started 8 | 9 | 1. Create a virtual environment 10 | 11 | ```bash 12 | python3 -m venv venv 13 | source venv/bin/activate 14 | ``` 15 | 16 | 2. Install python dependencies 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | 3. Run the dashboard: 23 | 24 | ```bash 25 | streamlit run dashboard/Home.py 26 | ``` 27 | 28 | 4. The dashboard will open in your browser and show: 29 | - Model comparison charts for JSON accuracy, and text similarity 30 | - Cost and latency charts for each model 31 | - Detailed performance statistics for each model combination 32 | - Test results table with individual test cases 33 | 34 | The dashboard automatically loads results from your `results` folder and lets you switch between different test runs . 35 | -------------------------------------------------------------------------------- /dashboard/pages/1_Performance_Metrics.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from datetime import datetime 3 | import plotly.express as px 4 | import pandas as pd 5 | 6 | from utils.data_loader import load_run_list, load_results_for_run 7 | from utils.style import SIDEBAR_STYLE 8 | 9 | st.set_page_config(page_title="Performance Metrics") 10 | st.markdown(SIDEBAR_STYLE, unsafe_allow_html=True) 11 | 12 | 13 | def create_results_table(results): 14 | """Create a DataFrame from test results""" 15 | rows = [] 16 | 17 | for test in results: # Results is a list of test cases 18 | row = { 19 | "Image": test["fileUrl"], 20 | "OCR Model": test["ocrModel"], 21 | "Extraction Model": test["extractionModel"], 22 | "Levenshtein Score": test.get("levenshteinDistance", 0), 23 | "JSON Accuracy": test.get("jsonAccuracy", 0), 24 | "Total Cost": test.get("usage", {}).get("totalCost", 0), 25 | "Duration (ms)": test.get("usage", {}).get("duration", 0), 26 | "Metadata": test.get("metadata", {}), 27 | } 28 | rows.append(row) 29 | 30 | return pd.DataFrame(rows) 31 | 32 | 33 | def create_model_comparison_table(results): 34 | """Create a DataFrame comparing different model combinations""" 35 | model_stats = {} 36 | 37 | for test in results: 38 | if "error" in test and test["error"]: 39 | continue 40 | 41 | model_key = ( 42 | f"{test['extractionModel']} (IMG2JSON)" 43 | if test.get("directImageExtraction", False) 44 | else f"{test['ocrModel']} → {test['extractionModel']}" 45 | ) 46 | if model_key not in model_stats: 47 | model_stats[model_key] = { 48 | "count": 0, 49 | "json_accuracy": 0, 50 | "text_accuracy": 0, 51 | "total_cost": 0.0, 52 | "ocr_cost": 0.0, 53 | "extraction_cost": 0.0, 54 | "ocr_latency": 0, 55 | "extraction_latency": 0, 56 | "extraction_count": 0, 57 | "ocr_input_tokens": 0, 58 | "ocr_output_tokens": 0, 59 | "extraction_input_tokens": 0, 60 | "extraction_output_tokens": 0, 61 | } 62 | 63 | stats = model_stats[model_key] 64 | stats["count"] += 1 65 | stats["text_accuracy"] += test.get("levenshteinDistance", 0) or 0 66 | # Ensure None values are converted to 0.0 67 | totalCost = test.get("usage", {}).get("totalCost") 68 | stats["total_cost"] += 0.0 if totalCost is None else totalCost 69 | usage = test.get("usage", {}) 70 | 71 | # Handle possible None values in ocr cost 72 | ocrCost = usage.get("ocr", {}).get("totalCost") 73 | stats["ocr_cost"] += 0.0 if ocrCost is None else ocrCost 74 | 75 | stats["ocr_latency"] += usage.get("ocr", {}).get("duration", 0) / 1000 76 | stats["ocr_input_tokens"] += usage.get("ocr", {}).get("inputTokens", 0) 77 | stats["ocr_output_tokens"] += usage.get("ocr", {}).get("outputTokens", 0) 78 | 79 | # Add token counting 80 | if usage.get("extraction"): 81 | stats["extraction_input_tokens"] += usage.get("extraction", {}).get( 82 | "inputTokens", 0 83 | ) 84 | stats["extraction_output_tokens"] += usage.get("extraction", {}).get( 85 | "outputTokens", 0 86 | ) 87 | 88 | # Only add JSON accuracy and extraction stats if extraction was performed 89 | if "jsonAccuracy" in test and usage.get("extraction"): 90 | stats["extraction_count"] += 1 91 | stats["json_accuracy"] += test["jsonAccuracy"] 92 | # Handle possible None values in extraction cost 93 | extractionCost = usage.get("extraction", {}).get("totalCost") 94 | stats["extraction_cost"] += ( 95 | 0.0 if extractionCost is None else extractionCost 96 | ) 97 | stats["extraction_latency"] += ( 98 | usage.get("extraction", {}).get("duration", 0) / 1000 99 | ) 100 | 101 | # Calculate averages 102 | for stats in model_stats.values(): 103 | stats["text_accuracy"] /= stats["count"] 104 | stats["ocr_latency"] /= stats["count"] 105 | stats["ocr_cost"] /= stats["count"] 106 | stats["total_cost"] /= stats["count"] 107 | stats["ocr_input_tokens"] /= stats["count"] 108 | stats["ocr_output_tokens"] /= stats["count"] 109 | 110 | # Calculate extraction-related averages only if there were extractions 111 | if stats["extraction_count"] > 0: 112 | stats["json_accuracy"] /= stats["extraction_count"] 113 | stats["extraction_latency"] /= stats["extraction_count"] 114 | stats["extraction_cost"] /= stats["extraction_count"] 115 | stats["extraction_input_tokens"] /= stats["extraction_count"] 116 | stats["extraction_output_tokens"] /= stats["extraction_count"] 117 | 118 | # Convert to DataFrame 119 | df = pd.DataFrame.from_dict(model_stats, orient="index") 120 | df.index.name = "Model Combination" 121 | return df 122 | 123 | 124 | def create_accuracy_comparison_charts(results): 125 | """Create separate DataFrames for JSON and Text accuracy comparisons""" 126 | model_accuracies = {} 127 | 128 | for test in results: 129 | if "error" in test and test["error"]: 130 | continue 131 | 132 | model_key = ( 133 | f"{test['extractionModel']} (IMG2JSON)" 134 | if test.get("directImageExtraction", False) 135 | else f"{test['ocrModel']} → {test['extractionModel']}" 136 | ) 137 | if model_key not in model_accuracies: 138 | model_accuracies[model_key] = { 139 | "count": 0, 140 | "json_accuracy": 0, 141 | "text_similarity": 0, 142 | "total_matched_items": 0, 143 | "total_items": 0, 144 | "extraction_count": 0, 145 | } 146 | 147 | stats = model_accuracies[model_key] 148 | stats["count"] += 1 149 | stats["text_similarity"] += test.get("levenshteinDistance", 0) or 0 150 | 151 | # Handle JSON accuracy if present 152 | if "jsonAccuracy" in test: 153 | stats["extraction_count"] += 1 154 | stats["json_accuracy"] += test["jsonAccuracy"] or 0 155 | 156 | # Calculate final averages 157 | for stats in model_accuracies.values(): 158 | stats["text_similarity"] /= stats["count"] 159 | 160 | # Calculate JSON accuracy only if there were extractions 161 | if stats["extraction_count"] > 0: 162 | stats["json_accuracy"] /= stats["extraction_count"] 163 | else: 164 | stats["json_accuracy"] = 0 165 | 166 | # Create DataFrames 167 | json_df = pd.DataFrame( 168 | { 169 | "Model": model_accuracies.keys(), 170 | "JSON Accuracy": [ 171 | stats["json_accuracy"] for stats in model_accuracies.values() 172 | ], 173 | } 174 | ).set_index("Model") 175 | 176 | text_df = pd.DataFrame( 177 | { 178 | "Model": model_accuracies.keys(), 179 | "Text Similarity": [ 180 | stats["text_similarity"] for stats in model_accuracies.values() 181 | ], 182 | } 183 | ).set_index("Model") 184 | 185 | return json_df, text_df 186 | 187 | 188 | def main(): 189 | st.title("Performance Metrics") 190 | 191 | # Load only the list of runs initially 192 | runs = load_run_list() 193 | 194 | if not runs: 195 | st.warning("No benchmark runs found.") 196 | return 197 | 198 | # Create columns for the header section 199 | col1, col2 = st.columns([2, 3]) 200 | 201 | with col1: 202 | # Create a dropdown to select the test run 203 | selected_timestamp = st.selectbox( 204 | "Select Test Run", 205 | [run["timestamp"] for run in runs], 206 | format_func=lambda x: datetime.strptime(x, "%Y-%m-%d-%H-%M-%S").strftime( 207 | "%Y-%m-%d %H:%M:%S" 208 | ), 209 | ) 210 | 211 | # Load the detailed results only when a run is selected 212 | run_data = load_results_for_run(selected_timestamp) 213 | 214 | with col2: 215 | st.markdown('
', unsafe_allow_html=True) 216 | with st.expander("Run Details", expanded=True): 217 | if run_data.get("run_by"): 218 | st.markdown(f"**Run By:** {run_data['run_by']}") 219 | if run_data.get("description"): 220 | st.markdown(f"**Description:** {run_data['description']}") 221 | st.markdown(f"**Total # of documents:** {run_data['total_documents']}") 222 | st.markdown(f"**Status:** {run_data['status'].title()}") 223 | st.markdown(f"**Created:** {run_data['created_at']}") 224 | if run_data.get("completed_at"): 225 | st.markdown(f"**Completed:** {run_data['completed_at']}") 226 | 227 | results = run_data["results"] 228 | 229 | st.header("Evaluation Metrics by Model") 230 | json_df, text_df = create_accuracy_comparison_charts(results) 231 | fig1 = px.bar( 232 | json_df.reset_index().sort_values("JSON Accuracy", ascending=False), 233 | x="Model", 234 | y="JSON Accuracy", 235 | title="JSON Accuracy by Model", 236 | height=600, 237 | color_discrete_sequence=["#636EFA"], 238 | ) 239 | fig1.update_layout(showlegend=False) 240 | fig1.update_traces(texttemplate="%{y:.1%}", textposition="outside") 241 | st.plotly_chart(fig1) 242 | 243 | fig2 = px.bar( 244 | text_df.reset_index().sort_values("Text Similarity", ascending=False), 245 | x="Model", 246 | y="Text Similarity", 247 | title="Text Similarity by Model", 248 | height=600, 249 | color_discrete_sequence=["#636EFA"], 250 | ) 251 | fig2.update_layout(showlegend=False) 252 | fig2.update_traces(texttemplate="%{y:.1%}", textposition="outside") 253 | st.plotly_chart(fig2) 254 | 255 | # Model Statistics Table 256 | st.header("Model Performance Statistics") 257 | model_stats = create_model_comparison_table(results) 258 | st.dataframe( 259 | model_stats.style.format( 260 | { 261 | "json_accuracy": "{:.2%}", 262 | "text_accuracy": "{:.2%}", 263 | "avg_latency": "{:.2f} s", 264 | "total_cost": "${:.4f}", 265 | "count": "{:.0f}", 266 | } 267 | ) 268 | ) 269 | 270 | # Cost and Latency Charts 271 | st.header("Cost and Latency Analysis") 272 | 273 | # Cost per document chart 274 | cost_df = pd.DataFrame(model_stats["total_cost"] * 1000).reset_index() 275 | cost_df.columns = ["Model", "Cost per 1,000 Pages"] 276 | fig4 = px.bar( 277 | cost_df.sort_values("Cost per 1,000 Pages", ascending=True), 278 | x="Model", 279 | y="Cost per 1,000 Pages", 280 | title="Cost per 1,000 Pages by Model Combination", 281 | height=600, 282 | color_discrete_sequence=["#EE553B"], 283 | ) 284 | fig4.update_layout(showlegend=False) 285 | fig4.update_traces(texttemplate="$%{y:.2f}", textposition="outside") 286 | st.plotly_chart(fig4) 287 | 288 | # Create stacked bar chart for cost breakdown per document 289 | cost_breakdown_df = pd.DataFrame( 290 | { 291 | "Model": model_stats.index, 292 | "OCR": model_stats["ocr_cost"] * 1000, 293 | "Extraction": model_stats["extraction_cost"] * 1000, 294 | } 295 | ) 296 | 297 | # Calculate cost per 1k documents for sorting 298 | cost_breakdown_df["Total"] = ( 299 | cost_breakdown_df["OCR"] + cost_breakdown_df["Extraction"] 300 | ) 301 | fig_cost = px.bar( 302 | cost_breakdown_df.sort_values("Total", ascending=True), 303 | x="Model", 304 | y=["OCR", "Extraction"], 305 | title="Cost per 1,000 Pages Breakdown by Model Combination (OCR + Extraction)", 306 | height=600, 307 | color_discrete_sequence=["#636EFA", "#EF553B"], 308 | ) 309 | fig_cost.update_layout( 310 | barmode="stack", 311 | showlegend=True, 312 | legend_title="Phase", 313 | yaxis=dict( 314 | title="Cost per 1,000 Pages (USD)", 315 | range=[ 316 | 0, 317 | cost_breakdown_df["Total"].max() * 1.2, 318 | ], 319 | ), 320 | ) 321 | fig_cost.update_traces(texttemplate="$%{y:.2f}", textposition="inside") 322 | st.plotly_chart(fig_cost) 323 | 324 | # Create stacked bar chart for latency 325 | latency_df = pd.DataFrame( 326 | { 327 | "Model": model_stats.index, 328 | "OCR": model_stats["ocr_latency"], 329 | "Extraction": model_stats["extraction_latency"], 330 | } 331 | ) 332 | 333 | # Calculate total latency for labels 334 | latency_df["Total"] = latency_df.get("OCR", 0) + latency_df.get("Extraction", 0) 335 | fig5 = px.bar( 336 | latency_df.sort_values("Total", ascending=True), 337 | x="Model", 338 | y=["OCR", "Extraction"], 339 | title="Latency by Model Combination (OCR + Extraction)", 340 | height=600, 341 | color_discrete_sequence=["#636EFA", "#EF553B"], 342 | ) 343 | fig5.update_layout( 344 | barmode="stack", 345 | showlegend=True, 346 | legend_title="Phase", 347 | yaxis=dict( 348 | range=[ 349 | 0, 350 | latency_df["Total"].max() * 1.2, 351 | ] # Set y-axis range to 120% of max value 352 | ), 353 | ) 354 | fig5.update_traces(texttemplate="%{y:.2f}s", textposition="inside") 355 | st.plotly_chart(fig5) 356 | 357 | # Total latency chart 358 | total_latency_df = pd.DataFrame( 359 | { 360 | "Model": model_stats.index, 361 | "Total Latency": model_stats["ocr_latency"] 362 | + model_stats["extraction_latency"], 363 | } 364 | ) 365 | fig6 = px.bar( 366 | total_latency_df.sort_values("Total Latency", ascending=True), 367 | x="Model", 368 | y="Total Latency", 369 | title="Total Latency by Model Combination", 370 | height=600, 371 | color_discrete_sequence=["#636EFA"], 372 | ) 373 | fig6.update_layout(showlegend=False) 374 | fig6.update_traces(texttemplate="%{y:.2f}s", textposition="outside") 375 | st.plotly_chart(fig6) 376 | 377 | # Add new token usage chart at the bottom 378 | st.header("Token Usage Analysis") 379 | token_df = pd.DataFrame( 380 | { 381 | "Model": model_stats.index, 382 | "Input Tokens": model_stats["ocr_input_tokens"], 383 | "Output Tokens": model_stats["ocr_output_tokens"], 384 | "Extraction Input Tokens": model_stats["extraction_input_tokens"], 385 | "Extraction Output Tokens": model_stats["extraction_output_tokens"], 386 | } 387 | ) 388 | 389 | # Calculate total tokens for sorting 390 | token_df["Total"] = ( 391 | token_df["Input Tokens"] 392 | + token_df["Output Tokens"] 393 | + token_df["Extraction Input Tokens"] 394 | + token_df["Extraction Output Tokens"] 395 | ) 396 | 397 | fig_tokens = px.bar( 398 | token_df.sort_values("Total", ascending=True), 399 | x="Model", 400 | y=[ 401 | "Input Tokens", 402 | "Output Tokens", 403 | "Extraction Input Tokens", 404 | "Extraction Output Tokens", 405 | ], 406 | title="Average Token Usage per Page by Model Combination", 407 | height=600, 408 | color_discrete_sequence=["#636EFA", "#EF553B", "#7B83FB", "#F76D57"], 409 | ) 410 | 411 | fig_tokens.update_layout( 412 | barmode="stack", 413 | showlegend=True, 414 | legend_title="Token Type", 415 | yaxis=dict( 416 | title="Number of Tokens", 417 | range=[0, token_df["Total"].max() * 1.2], 418 | ), 419 | ) 420 | fig_tokens.update_traces(texttemplate="%{y:.0f}", textposition="inside") 421 | st.plotly_chart(fig_tokens) 422 | 423 | # Detailed Results Table 424 | st.header("Test Results") 425 | df = create_results_table(results) 426 | st.dataframe(df) 427 | 428 | 429 | if __name__ == "__main__": 430 | main() 431 | -------------------------------------------------------------------------------- /dashboard/pages/2_Test_Result.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import requests 3 | import streamlit as st 4 | from difflib import HtmlDiff 5 | from utils.data_loader import ( 6 | load_run_list, 7 | load_results_for_run, 8 | format_timestamp, 9 | load_one_result, 10 | ) 11 | from utils.style import SIDEBAR_STYLE 12 | 13 | 14 | st.set_page_config(page_title="Test Results", layout="wide") 15 | st.markdown(SIDEBAR_STYLE, unsafe_allow_html=True) 16 | 17 | 18 | def display_json_diff(test_case, container): 19 | """Display JSON differences in a readable format""" 20 | # First check for errors 21 | if "error" in test_case and test_case["error"]: 22 | container.subheader("Error Message") 23 | container.error(test_case["error"]) 24 | return 25 | 26 | # If no errors, display JSON diff as before 27 | if "jsonDiff" in test_case or "fullJsonDiff" in test_case: 28 | container.subheader("JSON Differences") 29 | # Display diff stats 30 | stats = test_case["jsonDiffStats"] 31 | cols = container.columns(4) 32 | cols[0].metric("Additions", stats["additions"]) 33 | cols[1].metric("Missing", stats["deletions"]) 34 | cols[2].metric("Modifications", stats["modifications"]) 35 | cols[3].metric("Total Changes", stats["total"]) 36 | 37 | cols = container.columns(2) 38 | total_fields = test_case.get("jsonAccuracyResult", {}).get("totalFields", 0) 39 | cols[0].metric("Total Fields", total_fields) 40 | cols[1].metric("Accuracy", test_case.get("jsonAccuracy", 0)) 41 | 42 | # Create tabs for different diff views 43 | tab_summary, tab_full, tab_ground_truth, tab_predicted, tab_schema = ( 44 | container.tabs( 45 | ["Summary Diff", "Full Diff", "Ground Truth", "Predicted", "Schema"] 46 | ) 47 | ) 48 | 49 | with tab_summary: 50 | if "jsonDiff" in test_case: 51 | tab_summary.json(test_case["jsonDiff"]) 52 | else: 53 | tab_summary.warning("Summary diff not available") 54 | 55 | with tab_full: 56 | if "fullJsonDiff" in test_case: 57 | tab_full.json(test_case["fullJsonDiff"]) 58 | else: 59 | tab_full.warning("Full diff not available") 60 | 61 | with tab_ground_truth: 62 | tab_ground_truth.json(test_case["trueJson"]) 63 | 64 | with tab_predicted: 65 | tab_predicted.json(test_case["predictedJson"]) 66 | 67 | with tab_schema: 68 | tab_schema.json(test_case.get("jsonSchema", {})) 69 | 70 | 71 | def display_file_preview(test_case, container): 72 | """Display the original file preview and model information""" 73 | container.subheader("File Preview") 74 | 75 | # Display model information 76 | model_info = container.expander("Model Information", expanded=True) 77 | with model_info: 78 | cols = model_info.columns(2) 79 | 80 | # Display OCR model if available 81 | ocr_model = test_case.get("ocrModel", "Not specified") 82 | cols[0].markdown(f"**OCR Model:** {ocr_model}") 83 | 84 | # Display extraction model if available 85 | extraction_model = test_case.get("extractionModel", "Not specified") 86 | cols[1].markdown(f"**Extraction Model:** {extraction_model}") 87 | 88 | # Display direct image extraction flag if available 89 | direct_image = test_case.get("directImageExtraction", False) 90 | model_info.markdown( 91 | f"**Direct Image Extraction:** {'Yes' if direct_image else 'No'}" 92 | ) 93 | 94 | def show_pdf(url): 95 | 96 | try: 97 | response = requests.get(url) 98 | response.raise_for_status() # Raise an exception for bad status codes 99 | base64_pdf = base64.b64encode(response.content).decode("utf-8") 100 | pdf_display = f'' 101 | st.markdown(pdf_display, unsafe_allow_html=True) 102 | except Exception as e: 103 | st.error(f"Failed to load PDF: {str(e)}") 104 | st.markdown(f"You can [view the PDF directly]({url}) in a new tab.") 105 | 106 | # Display file preview 107 | if "fileUrl" in test_case: 108 | file_url = test_case["fileUrl"] 109 | if file_url.lower().endswith(".pdf"): 110 | with container: 111 | show_pdf(file_url) 112 | else: 113 | container.image(file_url, width=700) 114 | else: 115 | container.warning("No file preview available") 116 | 117 | 118 | def display_markdown_diff(test_case): 119 | """Display markdown differences in a side-by-side view""" 120 | if "trueMarkdown" in test_case and "predictedMarkdown" in test_case: 121 | st.subheader("Markdown Differences") 122 | 123 | # Create HTML diff 124 | differ = HtmlDiff() 125 | diff_html = differ.make_file( 126 | test_case["trueMarkdown"].splitlines(), 127 | test_case["predictedMarkdown"].splitlines(), 128 | fromdesc="True Markdown", 129 | todesc="Predicted Markdown", 130 | ) 131 | 132 | # Display side-by-side view 133 | st.markdown("### Side by Side Comparison") 134 | cols = st.columns(2) 135 | with cols[0]: 136 | st.markdown("**True Markdown**") 137 | st.text_area("", test_case["trueMarkdown"], height=400, key="true_markdown") 138 | with cols[1]: 139 | st.markdown("**Predicted Markdown**") 140 | st.text_area( 141 | "", test_case["predictedMarkdown"], height=400, key="predicted_markdown" 142 | ) 143 | 144 | # Display HTML diff (optional, behind expander) 145 | with st.expander("View HTML Diff"): 146 | st.components.v1.html(diff_html, height=600, scrolling=True) 147 | 148 | 149 | def main(): 150 | st.title("Test Results") 151 | 152 | # Load only the list of runs initially 153 | runs = load_run_list() 154 | 155 | if not runs: 156 | st.warning("No results found.") 157 | return 158 | 159 | # 1. Select which test run (timestamp) 160 | col1, col2 = st.columns(2) 161 | with col1: 162 | selected_timestamp = st.selectbox( 163 | "Select Test Run", 164 | [run["timestamp"] for run in runs], 165 | format_func=format_timestamp, 166 | ) 167 | 168 | # Load minimal results for the selected run (just for the dropdown) 169 | run_data = load_results_for_run(selected_timestamp, include_metrics_only=True) 170 | 171 | # Get all test cases 172 | all_test_cases = run_data.get("results", []) 173 | 174 | # Filter out None values and then filter for diffs 175 | all_test_cases = [test for test in all_test_cases if test is not None] 176 | 177 | # 2. Filter test cases for ones that have a non-empty JSON diff 178 | results_with_diffs = [ 179 | test 180 | for test in all_test_cases 181 | if isinstance(test, dict) 182 | and isinstance(test.get("jsonDiffStats"), dict) 183 | and test["jsonDiffStats"].get("total", 0) > 0 184 | ] 185 | 186 | if not results_with_diffs: 187 | st.warning("No test cases have JSON differences for this run.") 188 | return 189 | 190 | # 3. Build the dropdown items from only those filtered test cases 191 | test_case_labels = { 192 | f"{test['id']}": idx for idx, test in enumerate(results_with_diffs) 193 | } 194 | 195 | # Initialize session state for selected index if it doesn't exist 196 | if "selected_test_idx" not in st.session_state: 197 | st.session_state.selected_test_idx = 0 198 | 199 | with col2: 200 | # Create a row for the dropdown and navigation buttons 201 | dropdown_col, nav_col_1, nav_col_2 = st.columns([4, 0.5, 0.5]) 202 | 203 | with dropdown_col: 204 | # Update session state when dropdown changes 205 | selected_test_id = st.selectbox( 206 | "Select Test Case (Only Cases with Differences)", 207 | options=list(test_case_labels.keys()), 208 | format_func=lambda x: f"{x}", 209 | index=st.session_state.selected_test_idx, 210 | ) 211 | st.session_state.selected_test_idx = test_case_labels[selected_test_id] 212 | 213 | with nav_col_1: 214 | st.write( 215 | '
', 216 | unsafe_allow_html=True, 217 | ) 218 | # Add navigation buttons 219 | if st.button("←"): 220 | if st.session_state.selected_test_idx > 0: 221 | st.session_state.selected_test_idx -= 1 222 | st.rerun() 223 | 224 | with nav_col_2: 225 | st.write( 226 | '
', 227 | unsafe_allow_html=True, 228 | ) 229 | if st.button("→"): 230 | if st.session_state.selected_test_idx < len(results_with_diffs) - 1: 231 | st.session_state.selected_test_idx += 1 232 | st.rerun() 233 | 234 | # 4. Load only the selected test case 235 | selected_result_id = results_with_diffs[st.session_state.selected_test_idx]["id"] 236 | detailed_data = load_one_result(selected_timestamp, selected_result_id) 237 | test_case = detailed_data["result"] 238 | 239 | # Display run metadata if available 240 | if detailed_data.get("description") or detailed_data.get("run_by"): 241 | with st.expander("Run Details", expanded=False): 242 | cols = st.columns(3) 243 | with cols[0]: 244 | st.markdown(f"**Status:** {detailed_data['status'].title()}") 245 | if detailed_data.get("run_by"): 246 | st.markdown(f"**Run By:** {detailed_data['run_by']}") 247 | with cols[1]: 248 | st.markdown(f"**Created:** {detailed_data['created_at']}") 249 | if detailed_data.get("completed_at"): 250 | st.markdown(f"**Completed:** {detailed_data['completed_at']}") 251 | with cols[2]: 252 | if detailed_data.get("description"): 253 | st.markdown(f"**Description:** {detailed_data['description']}") 254 | 255 | # Display file URL 256 | st.markdown(f"**File URL:** [{test_case['fileUrl']}]({test_case['fileUrl']})") 257 | 258 | # Create two columns for file preview and JSON diff 259 | left_col, right_col = st.columns(2) 260 | 261 | # Display file preview on the left 262 | with left_col: 263 | display_file_preview(test_case, left_col) 264 | 265 | # Display JSON diff on the right 266 | with right_col: 267 | display_json_diff(test_case, right_col) 268 | 269 | # Display markdown diff at the bottom 270 | st.markdown("---") # Add a separator 271 | display_markdown_diff(test_case) 272 | 273 | 274 | if __name__ == "__main__": 275 | main() 276 | -------------------------------------------------------------------------------- /dashboard/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | from datetime import datetime 5 | from dotenv import load_dotenv 6 | from sqlalchemy.sql import text 7 | from sqlalchemy.orm import sessionmaker 8 | from sqlalchemy import create_engine 9 | from typing import Dict, Any, List, TypedDict, Optional 10 | 11 | load_dotenv() 12 | 13 | 14 | class BenchmarkRunMetadata(TypedDict): 15 | timestamp: str 16 | status: str 17 | run_by: Optional[str] 18 | description: Optional[str] 19 | total_documents: Optional[int] 20 | created_at: Optional[str] 21 | completed_at: Optional[str] 22 | 23 | 24 | def load_run_list_from_folder( 25 | results_dir: str = "results", 26 | ) -> List[BenchmarkRunMetadata]: 27 | """Load list of benchmark runs from the results directory""" 28 | results_path = Path(results_dir) 29 | result_dirs = [d for d in results_path.iterdir() if d.is_dir()] 30 | runs = [] 31 | 32 | for dir_path in result_dirs: 33 | timestamp = dir_path.name 34 | json_path = dir_path / "results.json" 35 | if json_path.exists(): 36 | runs.append( 37 | { 38 | "timestamp": timestamp, 39 | "status": "completed", # Assuming completed if file exists 40 | "run_by": None, 41 | "description": None, 42 | "total_documents": None, 43 | "created_at": format_timestamp(timestamp), 44 | "completed_at": format_timestamp(timestamp), 45 | } 46 | ) 47 | 48 | return sorted(runs, key=lambda x: x["timestamp"], reverse=True) 49 | 50 | 51 | def load_run_list_from_db() -> List[BenchmarkRunMetadata]: 52 | """Load list of benchmark runs from database""" 53 | database_url = os.getenv("DATABASE_URL") 54 | engine = create_engine(database_url) 55 | Session = sessionmaker(bind=engine) 56 | session = Session() 57 | 58 | query = text( 59 | """ 60 | SELECT 61 | timestamp, 62 | status, 63 | run_by, 64 | description, 65 | total_documents, 66 | created_at, 67 | completed_at 68 | FROM benchmark_runs 69 | ORDER BY created_at DESC 70 | """ 71 | ) 72 | 73 | rows = session.execute(query) 74 | runs = [] 75 | 76 | for row in rows: 77 | runs.append( 78 | { 79 | "timestamp": row.timestamp, 80 | "status": row.status, 81 | "run_by": row.run_by, 82 | "description": row.description, 83 | "total_documents": row.total_documents, 84 | "created_at": ( 85 | row.created_at.strftime("%Y-%m-%d %H:%M:%S") 86 | if row.created_at 87 | else None 88 | ), 89 | "completed_at": ( 90 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S") 91 | if row.completed_at 92 | else None 93 | ), 94 | } 95 | ) 96 | 97 | session.close() 98 | return runs 99 | 100 | 101 | def load_results_for_run_from_folder( 102 | timestamp: str, results_dir: str = "results" 103 | ) -> Dict[str, Any]: 104 | """Load results for a specific run from folder""" 105 | results_path = Path(results_dir) / timestamp / "results.json" 106 | if results_path.exists(): 107 | with open(results_path) as f: 108 | results = json.load(f) 109 | # Assign id to each result if not already present 110 | for idx, result in enumerate(results): 111 | if "id" not in result: 112 | result["id"] = idx 113 | total_documents = len(results) 114 | return { 115 | "results": results, 116 | "status": "completed", 117 | "run_by": None, 118 | "description": None, 119 | "total_documents": total_documents, 120 | "created_at": format_timestamp(timestamp), 121 | "completed_at": format_timestamp(timestamp), 122 | } 123 | return {} 124 | 125 | 126 | def load_results_for_run_from_db( 127 | timestamp: str, include_metrics_only: bool = True 128 | ) -> Dict[str, Any]: 129 | """Load results for a specific run from database""" 130 | database_url = os.getenv("DATABASE_URL") 131 | engine = create_engine(database_url) 132 | Session = sessionmaker(bind=engine) 133 | session = Session() 134 | 135 | if not include_metrics_only: 136 | output_string = """ 137 | 'trueMarkdown', bres.true_markdown, 138 | 'predictedMarkdown', bres.predicted_markdown, 139 | 'trueJson', bres.true_json, 140 | 'predictedJson', bres.predicted_json, 141 | 'jsonDiff', bres.json_diff, 142 | 'fullJsonDiff', bres.full_json_diff, 143 | """ 144 | else: 145 | output_string = "" 146 | 147 | query = text( 148 | f""" 149 | WITH filtered_run AS ( 150 | SELECT id, timestamp, status, run_by, description, total_documents, created_at, completed_at 151 | FROM benchmark_runs 152 | WHERE timestamp = :timestamp 153 | ) 154 | SELECT 155 | fr.timestamp, 156 | fr.status, 157 | fr.run_by, 158 | fr.description, 159 | fr.total_documents, 160 | fr.created_at, 161 | fr.completed_at, 162 | json_agg( 163 | json_build_object( 164 | 'id', bres.id, 165 | 'fileUrl', bres.file_url, 166 | 'ocrModel', bres.ocr_model, 167 | 'extractionModel', bres.extraction_model, 168 | 'directImageExtraction', bres.direct_image_extraction, 169 | {output_string} 170 | 'levenshteinDistance', bres.levenshtein_distance, 171 | 'jsonAccuracy', bres.json_accuracy, 172 | 'jsonAccuracyResult', bres.json_accuracy_result, 173 | 'jsonDiffStats', bres.json_diff_stats, 174 | 'metadata', bres.metadata, 175 | 'usage', bres.usage, 176 | 'error', bres.error 177 | ) 178 | ) as results 179 | FROM filtered_run fr 180 | LEFT JOIN benchmark_results bres ON fr.id = bres.benchmark_run_id 181 | GROUP BY fr.id, fr.timestamp, fr.status, fr.run_by, fr.description, fr.total_documents, fr.created_at, fr.completed_at 182 | """ 183 | ) 184 | 185 | row = session.execute(query, {"timestamp": timestamp}).first() 186 | session.close() 187 | 188 | if row: 189 | return { 190 | "results": row.results, 191 | "status": row.status, 192 | "total_documents": row.total_documents, 193 | "run_by": row.run_by, 194 | "description": row.description, 195 | "created_at": ( 196 | row.created_at.strftime("%Y-%m-%d %H:%M:%S") if row.created_at else None 197 | ), 198 | "completed_at": ( 199 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S") 200 | if row.completed_at 201 | else None 202 | ), 203 | } 204 | return {} 205 | 206 | 207 | def load_one_result_from_db(timestamp: str, id: str) -> Dict[str, Any]: 208 | """Load one test case result from database for a specific run and file""" 209 | database_url = os.getenv("DATABASE_URL") 210 | engine = create_engine(database_url) 211 | Session = sessionmaker(bind=engine) 212 | session = Session() 213 | 214 | query = text( 215 | """ 216 | WITH filtered_results AS ( 217 | SELECT * 218 | FROM benchmark_results 219 | WHERE id = :id 220 | ) 221 | SELECT 222 | br.timestamp, 223 | br.status, 224 | br.run_by, 225 | br.description, 226 | br.total_documents, 227 | br.created_at, 228 | br.completed_at, 229 | json_build_object( 230 | 'id', fr.id, 231 | 'fileUrl', fr.file_url, 232 | 'ocrModel', fr.ocr_model, 233 | 'extractionModel', fr.extraction_model, 234 | 'directImageExtraction', fr.direct_image_extraction, 235 | 'trueMarkdown', fr.true_markdown, 236 | 'predictedMarkdown', fr.predicted_markdown, 237 | 'trueJson', fr.true_json, 238 | 'predictedJson', fr.predicted_json, 239 | 'jsonDiff', fr.json_diff, 240 | 'fullJsonDiff', fr.full_json_diff, 241 | 'jsonDiffStats', fr.json_diff_stats, 242 | 'levenshteinDistance', fr.levenshtein_distance, 243 | 'jsonAccuracy', fr.json_accuracy, 244 | 'jsonAccuracyResult', fr.json_accuracy_result, 245 | 'jsonSchema', fr.json_schema, 246 | 'metadata', fr.metadata, 247 | 'usage', fr.usage, 248 | 'error', fr.error 249 | ) as result 250 | FROM benchmark_runs br 251 | INNER JOIN filtered_results fr ON br.id = fr.benchmark_run_id 252 | WHERE br.timestamp = :timestamp 253 | LIMIT 1 254 | """ 255 | ) 256 | 257 | row = session.execute(query, {"timestamp": timestamp, "id": id}).first() 258 | session.close() 259 | 260 | if row: 261 | return { 262 | "result": row.result, 263 | "status": row.status, 264 | "run_by": row.run_by, 265 | "description": row.description, 266 | "created_at": ( 267 | row.created_at.strftime("%Y-%m-%d %H:%M:%S") if row.created_at else None 268 | ), 269 | "completed_at": ( 270 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S") 271 | if row.completed_at 272 | else None 273 | ), 274 | } 275 | return {} 276 | 277 | 278 | def load_one_result_from_folder( 279 | timestamp: str, id: str, results_dir: str = "results" 280 | ) -> Dict[str, Any]: 281 | """Load one test case result from folder for a specific run and file""" 282 | results_path = Path(results_dir) / timestamp / "results.json" 283 | if results_path.exists(): 284 | with open(results_path) as f: 285 | results = json.load(f) 286 | for idx, result in enumerate(results): 287 | if idx == id: 288 | return { 289 | "result": result, 290 | "status": "completed", 291 | "run_by": None, 292 | "description": None, 293 | "created_at": format_timestamp(timestamp), 294 | "completed_at": format_timestamp(timestamp), 295 | } 296 | return {} 297 | 298 | 299 | def load_run_list() -> List[BenchmarkRunMetadata]: 300 | """Load list of benchmark runs from either database or local files""" 301 | if os.getenv("DATABASE_URL"): 302 | return load_run_list_from_db() 303 | return load_run_list_from_folder() 304 | 305 | 306 | def load_results_for_run( 307 | timestamp: str, include_metrics_only: bool = True 308 | ) -> Dict[str, Any]: 309 | """Load results for a specific run from either database or local files""" 310 | if os.getenv("DATABASE_URL"): 311 | return load_results_for_run_from_db(timestamp, include_metrics_only) 312 | return load_results_for_run_from_folder(timestamp) 313 | 314 | 315 | def load_one_result(timestamp: str, id: str) -> Dict[str, Any]: 316 | """Load one test case result from either database or local files""" 317 | if os.getenv("DATABASE_URL"): 318 | return load_one_result_from_db(timestamp, id) 319 | return load_one_result_from_folder(timestamp, id) 320 | 321 | 322 | def format_timestamp(timestamp: str) -> str: 323 | """Convert timestamp string to readable format""" 324 | return datetime.strptime(timestamp, "%Y-%m-%d-%H-%M-%S").strftime( 325 | "%Y-%m-%d %H:%M:%S" 326 | ) 327 | -------------------------------------------------------------------------------- /dashboard/utils/style.py: -------------------------------------------------------------------------------- 1 | SIDEBAR_STYLE = """ 2 | 7 | """ 8 | -------------------------------------------------------------------------------- /data/receipt.json: -------------------------------------------------------------------------------- 1 | { 2 | "imageUrl": "https://omni-demo-data.s3.us-east-1.amazonaws.com/templates/receipt.png", 3 | "metadata": { 4 | "orientation": 0, 5 | "documentQuality": "clean", 6 | "resolution": [612, 792], 7 | "language": "EN" 8 | }, 9 | "jsonSchema": { 10 | "type": "object", 11 | "required": ["merchant", "receipt_details", "totals"], 12 | "properties": { 13 | "totals": { 14 | "type": "object", 15 | "required": ["total"], 16 | "properties": { 17 | "tax": { 18 | "type": "number", 19 | "description": "Tax amount" 20 | }, 21 | "total": { 22 | "type": "number", 23 | "description": "Final total amount" 24 | }, 25 | "subtotal": { 26 | "type": "number", 27 | "description": "Subtotal before tax and fees" 28 | } 29 | }, 30 | "description": "Payment totals" 31 | }, 32 | "merchant": { 33 | "type": "object", 34 | "required": ["name"], 35 | "properties": { 36 | "name": { 37 | "type": "string", 38 | "description": "Business name" 39 | }, 40 | "phone": { 41 | "type": "string", 42 | "description": "Contact phone number" 43 | }, 44 | "address": { 45 | "type": "string", 46 | "description": "Store location address" 47 | } 48 | }, 49 | "description": "Basic merchant information" 50 | }, 51 | "line_items": { 52 | "type": "array", 53 | "items": { 54 | "type": "object", 55 | "required": ["description", "amount"], 56 | "properties": { 57 | "amount": { 58 | "type": "number", 59 | "description": "Price of the item" 60 | }, 61 | "description": { 62 | "type": "string", 63 | "description": "Item name or description" 64 | } 65 | } 66 | }, 67 | "description": "List of purchased items" 68 | }, 69 | "receipt_details": { 70 | "type": "object", 71 | "required": ["date"], 72 | "properties": { 73 | "date": { 74 | "type": "string", 75 | "description": "Transaction date" 76 | }, 77 | "time": { 78 | "type": "string", 79 | "description": "Transaction time" 80 | }, 81 | "receipt_number": { 82 | "type": "string", 83 | "description": "Receipt or ticket number" 84 | } 85 | }, 86 | "description": "Transaction details" 87 | }, 88 | "payment": { 89 | "type": "object", 90 | "properties": { 91 | "payment_method": { 92 | "type": "string", 93 | "description": "" 94 | }, 95 | "card_last_four_digits": { 96 | "type": "string", 97 | "description": "" 98 | } 99 | } 100 | } 101 | } 102 | }, 103 | "trueJsonOutput": { 104 | "totals": { 105 | "tax": 6.18, 106 | "total": 48.43, 107 | "subtotal": 42.25 108 | }, 109 | "merchant": { 110 | "name": "Nick the Greek Souvlaki & Gyro House", 111 | "phone": "(415) 757-0426", 112 | "address": "121 Spear Street, Suite B08, San Francisco, CA 94105" 113 | }, 114 | "line_items": [ 115 | { 116 | "amount": 12.5, 117 | "description": "Beef/Lamb Gyro Pita" 118 | }, 119 | { 120 | "amount": 13.25, 121 | "description": "Gyro Bowl" 122 | }, 123 | { 124 | "amount": 16.5, 125 | "description": "Pork Gyro Pita" 126 | } 127 | ], 128 | "receipt_details": { 129 | "date": "November 8, 2024", 130 | "time": "2:16 PM", 131 | "receipt_number": "NKZ1" 132 | }, 133 | "payment": { 134 | "payment_method": "Mastercard", 135 | "card_last_four_digits": "0920" 136 | } 137 | }, 138 | "trueMarkdownOutput": "**NICK THE GREEK**\n\nSOUVLAKI & GYRO HOUSE\n\n**San Francisco**\n\n121 spear streeet \nSuite B08 \nsan francisco, CA \n94105 \n(415) 757-0426 \nwww.nickthegreeksj.com\n\nNovember 8, 2024 \n2:16 PM \nSamantha\n\nTicket: 17 \nReceipt: NKZ1 \nAuthorization: CF2D4F\n\nMastercard \nAID A0 00 00 00 04 10 10\n\n**TO GO**\n\nBeef/Lamb Gyro Pita $12.50 \nGyro Bowl $13.25 \nBeef/Lamb Gyro \nPork Gyro Pita $16.50 \nFries & Drink ($4.00)\n\nSubtotal $42.25 \nSF Mandate (6%) $2.54 \n8.625% (8.625%) $3.64\n\n**Total** $48.43 \nMastercard 0920 (Contactless) $48.43" 139 | } 140 | -------------------------------------------------------------------------------- /jest.config.ts: -------------------------------------------------------------------------------- 1 | // jest.config.js 2 | module.exports = { 3 | preset: 'ts-jest', 4 | testEnvironment: 'node', 5 | moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json'], 6 | testMatch: ['**/tests/**/*.test.ts'], 7 | transform: { 8 | '^.+\\.(ts|tsx)$': 'ts-jest', 9 | }, 10 | }; 11 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "benchmark", 3 | "version": "0.0.1", 4 | "description": "OCR Benchmark", 5 | "main": "index.js", 6 | "scripts": { 7 | "build": "tsc", 8 | "test": "jest", 9 | "benchmark": "ts-node src/index.ts" 10 | }, 11 | "dependencies": { 12 | "@ai-sdk/anthropic": "^1.2.12", 13 | "@ai-sdk/azure": "^1.1.9", 14 | "@ai-sdk/deepseek": "^0.1.6", 15 | "@ai-sdk/google": "^1.1.10", 16 | "@ai-sdk/openai": "^1.3.14", 17 | "@aws-sdk/client-textract": "^3.716.0", 18 | "@azure-rest/ai-document-intelligence": "^1.0.0", 19 | "@azure/core-auth": "^1.9.0", 20 | "@google-cloud/documentai": "^8.12.0", 21 | "@google/generative-ai": "^0.21.0", 22 | "@huggingface/hub": "^1.0.1", 23 | "@mistralai/mistralai": "^1.5.1", 24 | "@prisma/client": "^6.3.1", 25 | "ai": "^4.3.16", 26 | "axios": "^1.7.9", 27 | "canvas": "^3.1.0", 28 | "cli-progress": "^3.12.0", 29 | "dotenv": "^16.4.7", 30 | "fastest-levenshtein": "^1.0.16", 31 | "form-data": "^4.0.2", 32 | "jimp": "^1.6.0", 33 | "json-diff": "^1.0.6", 34 | "lodash": "^4.17.21", 35 | "moment": "^2.30.1", 36 | "openai": "^4.94.0", 37 | "p-limit": "^3.1.0", 38 | "pdfkit": "^0.17.1", 39 | "pg": "^8.13.1", 40 | "sharp": "^0.33.5", 41 | "together-ai": "^0.13.0", 42 | "turndown": "^7.2.0", 43 | "zerox": "^1.0.43" 44 | }, 45 | "devDependencies": { 46 | "@eslint/js": "^9.17.0", 47 | "@types/jest": "^29.5.14", 48 | "eslint": "^9.17.0", 49 | "jest": "^29.7.0", 50 | "prettier": "^3.4.2", 51 | "prisma": "^6.3.1", 52 | "ts-jest": "^29.2.5", 53 | "ts-node": "^10.9.2", 54 | "typescript": "^5.7.2", 55 | "typescript-eslint": "^8.18.1" 56 | }, 57 | "keywords": [ 58 | "OCR", 59 | "Benchmark", 60 | "LLM" 61 | ], 62 | "author": "[@annapo23, @tylermaran, @kailingding, @zeeshan]", 63 | "license": "ISC" 64 | } 65 | -------------------------------------------------------------------------------- /prisma/schema.prisma: -------------------------------------------------------------------------------- 1 | generator client { 2 | provider = "prisma-client-js" 3 | } 4 | 5 | datasource db { 6 | provider = "postgresql" 7 | url = env("DATABASE_URL") 8 | } 9 | 10 | model BenchmarkRun { 11 | id String @id @default(uuid()) @db.Uuid 12 | completedAt DateTime? @map("completed_at") 13 | createdAt DateTime @default(now()) @map("created_at") 14 | description String? @map("description") 15 | error String? 16 | modelsConfig Json @map("models_config") // The models.yaml configuration 17 | results BenchmarkResult[] 18 | runBy String? @map("run_by") 19 | status String // 'running', 'completed', 'failed' 20 | timestamp String // timestamp format: YYYY-MM-DD-HH-mm-ss 21 | totalDocuments Int @map("total_documents") 22 | 23 | @@map("benchmark_runs") 24 | } 25 | 26 | model BenchmarkResult { 27 | id String @id @default(uuid()) @db.Uuid 28 | benchmarkRun BenchmarkRun @relation(fields: [benchmarkRunId], references: [id]) 29 | benchmarkRunId String @map("benchmark_run_id") 30 | createdAt DateTime @default(now()) @map("created_at") 31 | directImageExtraction Boolean @default(false) @map("direct_image_extraction") 32 | error String? 33 | extractionModel String? @map("extraction_model") 34 | fileUrl String @map("file_url") 35 | fullJsonDiff Json? @map("full_json_diff") 36 | jsonAccuracy Float? @map("json_accuracy") 37 | jsonAccuracyResult Json? @map("json_accuracy_result") 38 | jsonDiff Json? @map("json_diff") 39 | jsonDiffStats Json? @map("json_diff_stats") 40 | jsonSchema Json @map("json_schema") 41 | levenshteinDistance Float? @map("levenshtein_distance") 42 | metadata Json @map("metadata") 43 | ocrModel String @map("ocr_model") 44 | predictedJson Json? @map("predicted_json") 45 | predictedMarkdown String? @map("predicted_markdown") 46 | trueJson Json @map("true_json") 47 | trueMarkdown String @map("true_markdown") 48 | usage Json? 49 | 50 | @@map("benchmark_results") 51 | } 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.41.1 2 | pandas==2.2.3 3 | datetime==5.5 4 | plotly==5.24.1 5 | sqlalchemy==2.0.38 6 | psycopg2-binary==2.9.10 7 | python-dotenv==1.0.1 -------------------------------------------------------------------------------- /src/evaluation/index.ts: -------------------------------------------------------------------------------- 1 | export * from './text'; 2 | export * from './json'; 3 | -------------------------------------------------------------------------------- /src/evaluation/json.ts: -------------------------------------------------------------------------------- 1 | import { diff } from 'json-diff'; 2 | 3 | interface DiffStats { 4 | additions: number; 5 | deletions: number; 6 | modifications: number; 7 | total: number; 8 | } 9 | 10 | export interface AccuracyResult { 11 | score: number; 12 | fullJsonDiff: Record; 13 | jsonDiff: Record; 14 | jsonDiffStats?: DiffStats; 15 | totalFields: number; 16 | } 17 | 18 | /** 19 | * Calculates accuracy for JSON structure and primitive values only 20 | * 21 | * The accuracy is calculated as: 22 | * 1 - (number of differences / total fields in actual) 23 | * 24 | * Differences include: 25 | * - Additions: Fields present in predicted but not in actual 26 | * - Deletions: Fields present in actual but not in predicted 27 | * - Modifications: Fields present in both but with different values 28 | * 29 | * A score of 1.0 means the JSONs are identical 30 | * A score of 0.0 means completely different 31 | */ 32 | export const calculateJsonAccuracy = ( 33 | actual: Record, 34 | predicted: Record, 35 | ignoreCases: boolean = false, 36 | ): AccuracyResult => { 37 | // Convert strings to uppercase if ignoreCases is true 38 | const processedActual = ignoreCases ? convertStringsToUppercase(actual) : actual; 39 | const processedPredicted = ignoreCases 40 | ? convertStringsToUppercase(predicted) 41 | : predicted; 42 | 43 | // Get the diff result 44 | const fullDiffResult = diff(processedActual, processedPredicted, { 45 | full: true, 46 | sort: true, 47 | }); 48 | const diffResult = diff(processedActual, processedPredicted, { sort: true }); 49 | const totalFields = countTotalFields(processedActual); 50 | 51 | if (!diffResult) { 52 | // If there's no diff, the JSONs are identical 53 | return { 54 | score: 1, 55 | jsonDiff: {}, 56 | fullJsonDiff: {}, 57 | jsonDiffStats: { 58 | additions: 0, 59 | deletions: 0, 60 | modifications: 0, 61 | total: 0, 62 | }, 63 | totalFields, 64 | }; 65 | } 66 | 67 | const changes = countChanges(diffResult); 68 | const score = Math.max( 69 | 0, 70 | 1 - (changes.additions + changes.deletions + changes.modifications) / totalFields, 71 | ); 72 | 73 | return { 74 | score: Number(score.toFixed(4)), 75 | jsonDiff: diffResult, 76 | fullJsonDiff: fullDiffResult, 77 | jsonDiffStats: changes, 78 | totalFields, 79 | }; 80 | }; 81 | 82 | /** 83 | * Recursively converts all string values in an object to uppercase 84 | */ 85 | const convertStringsToUppercase = (obj: any): any => { 86 | if (obj === null || typeof obj !== 'object') { 87 | return obj; 88 | } 89 | 90 | if (Array.isArray(obj)) { 91 | return obj.map((item) => convertStringsToUppercase(item)); 92 | } 93 | 94 | const result: Record = {}; 95 | for (const key in obj) { 96 | const value = obj[key]; 97 | if (typeof value === 'string') { 98 | result[key] = value.toUpperCase(); 99 | } else if (typeof value === 'object' && value !== null) { 100 | result[key] = convertStringsToUppercase(value); 101 | } else { 102 | result[key] = value; 103 | } 104 | } 105 | return result; 106 | }; 107 | 108 | export const countChanges = (diffResult: any): DiffStats => { 109 | const changes: DiffStats = { 110 | additions: 0, 111 | deletions: 0, 112 | modifications: 0, 113 | total: 0, 114 | }; 115 | 116 | const traverse = (obj: any) => { 117 | if (!obj || typeof obj !== 'object') { 118 | return; 119 | } 120 | 121 | for (const key in obj) { 122 | const value = obj[key]; 123 | 124 | if (Array.isArray(value)) { 125 | // Handle array diffs 126 | value.forEach((item) => { 127 | // Check if item is in the expected [operation, element] format 128 | if (!Array.isArray(item) || item.length !== 2) { 129 | return; 130 | } 131 | 132 | const [operation, element] = item; 133 | if (element === null || typeof element !== 'object') { 134 | // Handle primitive value changes in arrays 135 | switch (operation) { 136 | case '+': 137 | changes.additions++; 138 | break; 139 | case '-': 140 | changes.deletions++; 141 | break; 142 | } 143 | } else { 144 | switch (operation) { 145 | // Handle array element additions and deletions 146 | case '+': 147 | changes.additions += countTotalFields(element); 148 | break; 149 | case '-': 150 | changes.deletions += countTotalFields(element); 151 | break; 152 | case '~': 153 | // Handle array element modifications 154 | traverse(element); 155 | break; 156 | } 157 | } 158 | }); 159 | } else { 160 | if (key.endsWith('__deleted')) { 161 | if (value === null || typeof value !== 'object') { 162 | changes.deletions++; 163 | } else { 164 | changes.deletions += countTotalFields(value); 165 | } 166 | } else if (key.endsWith('__added')) { 167 | if (value === null || typeof value !== 'object') { 168 | changes.additions++; 169 | } else { 170 | changes.additions += countTotalFields(value); 171 | } 172 | } else if (typeof value === 'object' && value !== null) { 173 | if (value.__old !== undefined && value.__new !== undefined) { 174 | if (value.__old === null && value.__new !== null) { 175 | changes.modifications += countTotalFields(value.__new) || 1; 176 | } else { 177 | changes.modifications += countTotalFields(value.__old) || 1; 178 | } 179 | } else { 180 | traverse(value); 181 | } 182 | } 183 | } 184 | } 185 | }; 186 | 187 | traverse(diffResult); 188 | 189 | changes.total = changes.additions + changes.deletions + changes.modifications; 190 | return changes; 191 | }; 192 | 193 | export function countTotalFields(obj: any): number { 194 | let count = 0; 195 | 196 | const traverse = (current: any) => { 197 | if (!current || typeof current !== 'object') { 198 | return; 199 | } 200 | 201 | if (Array.isArray(current)) { 202 | // Traverse into array elements if they're objects 203 | current.forEach((item) => { 204 | if (typeof item === 'object' && item !== null) { 205 | traverse(item); 206 | } else { 207 | count++; 208 | } 209 | }); 210 | } else { 211 | for (const key in current) { 212 | // Skip diff metadata keys 213 | if (key.includes('__')) { 214 | continue; 215 | } 216 | 217 | // Only count primitive value fields 218 | if ( 219 | current[key] === null || 220 | typeof current[key] === 'string' || 221 | typeof current[key] === 'number' || 222 | typeof current[key] === 'boolean' 223 | ) { 224 | count++; 225 | } 226 | // Recurse into nested objects and arrays 227 | else if (typeof current[key] === 'object') { 228 | traverse(current[key]); 229 | } 230 | } 231 | } 232 | }; 233 | 234 | traverse(obj); 235 | return count; 236 | } 237 | -------------------------------------------------------------------------------- /src/evaluation/text.ts: -------------------------------------------------------------------------------- 1 | import { distance } from 'fastest-levenshtein'; 2 | 3 | /** 4 | * Calculates text similarity between original and OCR text using Levenshtein distance 5 | * Returns a score between 0 and 1, where: 6 | * 1.0 = texts are identical 7 | * 0.0 = texts are completely different 8 | */ 9 | export const calculateTextSimilarity = (original: string, predicted: string): number => { 10 | if (original === predicted) return 1; 11 | if (!original.length || !predicted.length) return 0; 12 | 13 | // Normalize strings 14 | const normalizedOriginal = original.trim().toLowerCase(); 15 | const normalizedPredicted = predicted.trim().toLowerCase(); 16 | 17 | // Calculate Levenshtein distance 18 | const levenshteinDistance = distance(normalizedOriginal, normalizedPredicted); 19 | 20 | // Normalize score between 0 and 1 21 | const maxLength = Math.max(normalizedOriginal.length, normalizedPredicted.length); 22 | const similarity = 1 - levenshteinDistance / maxLength; 23 | 24 | return Number(similarity.toFixed(4)); 25 | }; 26 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import dotenv from 'dotenv'; 2 | import path from 'path'; 3 | import moment from 'moment'; 4 | import cliProgress from 'cli-progress'; 5 | import { isEmpty } from 'lodash'; 6 | import pLimit from 'p-limit'; 7 | import yaml from 'js-yaml'; 8 | import fs from 'fs'; 9 | 10 | import { BenchmarkRun } from '@prisma/client'; 11 | import { calculateJsonAccuracy, calculateTextSimilarity } from './evaluation'; 12 | import { getModelProvider } from './models'; 13 | import { Result } from './types'; 14 | import { 15 | createResultFolder, 16 | loadLocalData, 17 | writeToFile, 18 | loadFromDb, 19 | createBenchmarkRun, 20 | saveResult, 21 | completeBenchmarkRun, 22 | } from './utils'; 23 | 24 | dotenv.config(); 25 | 26 | /* -------------------------------------------------------------------------- */ 27 | /* Benchmark Config */ 28 | /* -------------------------------------------------------------------------- */ 29 | 30 | const MODEL_CONCURRENCY = { 31 | 'aws-textract': 50, 32 | 'azure-document-intelligence': 50, 33 | 'claude-3-5-sonnet-20241022': 10, 34 | 'gemini-2.0-flash-001': 30, 35 | 'mistral-ocr': 5, 36 | 'gpt-4o': 50, 37 | 'qwen2.5-vl-32b-instruct': 10, 38 | 'qwen2.5-vl-72b-instruct': 10, 39 | 'google/gemma-3-27b-it': 10, 40 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo': 10, 41 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo': 10, 42 | omniai: 30, 43 | zerox: 50, 44 | }; 45 | 46 | interface ModelConfig { 47 | ocr: string; 48 | extraction?: string; 49 | directImageExtraction?: boolean; 50 | } 51 | 52 | // Load models config 53 | const loadModelsConfig = () => { 54 | try { 55 | const configPath = path.join(__dirname, 'models.yaml'); 56 | const fileContents = fs.readFileSync(configPath, 'utf8'); 57 | const config = yaml.load(fileContents) as { models: ModelConfig[] }; 58 | return config.models; 59 | } catch (error) { 60 | console.error('Error loading models config:', error); 61 | return [] as ModelConfig[]; 62 | } 63 | }; 64 | 65 | const MODELS = loadModelsConfig(); 66 | 67 | const DATA_FOLDER = path.join(__dirname, '../data'); 68 | 69 | const DATABASE_URL = process.env.DATABASE_URL; 70 | 71 | const TIMEOUT_MS = 10 * 60 * 1000; // 10 minutes in milliseconds 72 | 73 | const withTimeout = async (promise: Promise, operation: string) => { 74 | let timeoutId: NodeJS.Timeout; 75 | 76 | const timeoutPromise = new Promise((_, reject) => { 77 | timeoutId = setTimeout(() => { 78 | reject(new Error(`${operation} operation timed out after ${TIMEOUT_MS}ms`)); 79 | }, TIMEOUT_MS); 80 | }); 81 | 82 | try { 83 | const result = await Promise.race([promise, timeoutPromise]); 84 | clearTimeout(timeoutId); 85 | return result; 86 | } catch (error) { 87 | clearTimeout(timeoutId); 88 | console.error(`Timeout error in ${operation}:`, error); 89 | throw error; 90 | } 91 | }; 92 | 93 | /* -------------------------------------------------------------------------- */ 94 | /* Run Benchmark */ 95 | /* -------------------------------------------------------------------------- */ 96 | 97 | const timestamp = moment(new Date()).format('YYYY-MM-DD-HH-mm-ss'); 98 | const resultFolder = createResultFolder(timestamp); 99 | 100 | const runBenchmark = async () => { 101 | const data = DATABASE_URL ? await loadFromDb() : loadLocalData(DATA_FOLDER); 102 | const results: Result[] = []; 103 | 104 | // Create benchmark run 105 | let benchmarkRun: BenchmarkRun; 106 | if (DATABASE_URL) { 107 | benchmarkRun = await createBenchmarkRun(timestamp, MODELS, data.length); 108 | } 109 | 110 | // Create multiple progress bars 111 | const multibar = new cliProgress.MultiBar({ 112 | format: '{model} |{bar}| {percentage}% | {value}/{total}', 113 | barCompleteChar: '\u2588', 114 | barIncompleteChar: '\u2591', 115 | clearOnComplete: false, 116 | hideCursor: true, 117 | }); 118 | 119 | // Create progress bars for each model 120 | const progressBars = MODELS.reduce( 121 | (acc, model) => ({ 122 | ...acc, 123 | [`${model.directImageExtraction ? `${model.extraction} (IMG2JSON)` : `${model.ocr}-${model.extraction}`}`]: 124 | multibar.create(data.length, 0, { 125 | model: `${model.directImageExtraction ? `${model.extraction} (IMG2JSON)` : `${model.ocr} -> ${model.extraction}`}`, 126 | }), 127 | }), 128 | {}, 129 | ); 130 | 131 | const modelPromises = MODELS.map( 132 | async ({ ocr: ocrModel, extraction: extractionModel, directImageExtraction }) => { 133 | // Calculate concurrent requests based on rate limit 134 | const concurrency = Math.min( 135 | MODEL_CONCURRENCY[ocrModel as keyof typeof MODEL_CONCURRENCY] ?? 20, 136 | MODEL_CONCURRENCY[extractionModel as keyof typeof MODEL_CONCURRENCY] ?? 20, 137 | ); 138 | const limit = pLimit(concurrency); 139 | 140 | const promises = data.map((item) => 141 | limit(async () => { 142 | const ocrModelProvider = getModelProvider(ocrModel); 143 | const extractionModelProvider = extractionModel 144 | ? getModelProvider(extractionModel) 145 | : undefined; 146 | 147 | const result: Result = { 148 | fileUrl: item.imageUrl, 149 | metadata: item.metadata, 150 | jsonSchema: item.jsonSchema, 151 | ocrModel, 152 | extractionModel, 153 | directImageExtraction, 154 | trueMarkdown: item.trueMarkdownOutput, 155 | trueJson: item.trueJsonOutput, 156 | predictedMarkdown: undefined, 157 | predictedJson: undefined, 158 | levenshteinDistance: undefined, 159 | jsonAccuracy: undefined, 160 | jsonDiff: undefined, 161 | fullJsonDiff: undefined, 162 | jsonDiffStats: undefined, 163 | jsonAccuracyResult: undefined, 164 | usage: undefined, 165 | }; 166 | 167 | try { 168 | if (directImageExtraction) { 169 | const extractionResult = await withTimeout( 170 | extractionModelProvider.extractFromImage(item.imageUrl, item.jsonSchema), 171 | `JSON extraction: ${extractionModel}`, 172 | ); 173 | result.predictedJson = extractionResult.json; 174 | result.usage = { 175 | ...extractionResult.usage, 176 | ocr: undefined, 177 | extraction: extractionResult.usage, 178 | }; 179 | } else { 180 | let ocrResult; 181 | if (ocrModel === 'ground-truth') { 182 | result.predictedMarkdown = item.trueMarkdownOutput; 183 | } else { 184 | if (ocrModelProvider) { 185 | ocrResult = await withTimeout( 186 | ocrModelProvider.ocr(item.imageUrl), 187 | `OCR: ${ocrModel}`, 188 | ); 189 | result.predictedMarkdown = ocrResult.text; 190 | result.usage = { 191 | ...ocrResult.usage, 192 | ocr: ocrResult.usage, 193 | extraction: undefined, 194 | }; 195 | } 196 | } 197 | 198 | let extractionResult; 199 | if (extractionModelProvider) { 200 | extractionResult = await withTimeout( 201 | extractionModelProvider.extractFromText( 202 | result.predictedMarkdown, 203 | item.jsonSchema, 204 | ocrResult?.imageBase64s, 205 | ), 206 | `JSON extraction: ${extractionModel}`, 207 | ); 208 | result.predictedJson = extractionResult.json; 209 | 210 | const mergeUsage = (base: any, additional: any) => ({ 211 | duration: (base?.duration ?? 0) + (additional?.duration ?? 0), 212 | inputTokens: (base?.inputTokens ?? 0) + (additional?.inputTokens ?? 0), 213 | outputTokens: 214 | (base?.outputTokens ?? 0) + (additional?.outputTokens ?? 0), 215 | totalTokens: (base?.totalTokens ?? 0) + (additional?.totalTokens ?? 0), 216 | inputCost: (base?.inputCost ?? 0) + (additional?.inputCost ?? 0), 217 | outputCost: (base?.outputCost ?? 0) + (additional?.outputCost ?? 0), 218 | totalCost: (base?.totalCost ?? 0) + (additional?.totalCost ?? 0), 219 | }); 220 | 221 | result.usage = { 222 | ocr: result.usage?.ocr ?? {}, 223 | extraction: extractionResult.usage, 224 | ...mergeUsage(result.usage, extractionResult.usage), 225 | }; 226 | } 227 | } 228 | 229 | if (result.predictedMarkdown) { 230 | result.levenshteinDistance = calculateTextSimilarity( 231 | item.trueMarkdownOutput, 232 | result.predictedMarkdown, 233 | ); 234 | } 235 | 236 | if (!isEmpty(result.predictedJson)) { 237 | const jsonAccuracyResult = calculateJsonAccuracy( 238 | item.trueJsonOutput, 239 | result.predictedJson, 240 | ); 241 | result.jsonAccuracy = jsonAccuracyResult.score; 242 | result.jsonDiff = jsonAccuracyResult.jsonDiff; 243 | result.fullJsonDiff = jsonAccuracyResult.fullJsonDiff; 244 | result.jsonDiffStats = jsonAccuracyResult.jsonDiffStats; 245 | result.jsonAccuracyResult = jsonAccuracyResult; 246 | } 247 | } catch (error) { 248 | result.error = error; 249 | console.error( 250 | `Error processing ${item.imageUrl} with ${ocrModel} and ${extractionModel}:\n`, 251 | error, 252 | ); 253 | } 254 | 255 | if (benchmarkRun) { 256 | await saveResult(benchmarkRun.id, result); 257 | } 258 | 259 | // Update progress bar for this model 260 | progressBars[ 261 | `${directImageExtraction ? `${extractionModel} (IMG2JSON)` : `${ocrModel}-${extractionModel}`}` 262 | ].increment(); 263 | return result; 264 | }), 265 | ); 266 | 267 | // Process items concurrently for this model 268 | const modelResults = await Promise.all(promises); 269 | 270 | results.push(...modelResults); 271 | }, 272 | ); 273 | 274 | // Process each model with its own concurrency limit 275 | await Promise.all(modelPromises); 276 | 277 | // Stop all progress bars 278 | multibar.stop(); 279 | 280 | // Complete benchmark run successfully 281 | if (benchmarkRun) { 282 | await completeBenchmarkRun(benchmarkRun.id); 283 | } 284 | 285 | writeToFile(path.join(resultFolder, 'results.json'), results); 286 | }; 287 | 288 | runBenchmark(); 289 | -------------------------------------------------------------------------------- /src/models.example.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | - ocr: ground-truth 3 | extraction: gpt-4o 4 | 5 | - ocr: gpt-4o 6 | extraction: gpt-4o 7 | 8 | - ocr: gpt-4o 9 | extraction: gpt-4o 10 | directImageExtraction: true 11 | 12 | # - ocr: gemini-2.0-flash-001 13 | # extraction: gpt-4o 14 | 15 | # - ocr: azure-gpt-4o 16 | # extraction: azure-gpt-4o 17 | 18 | # - ocr: claude-3-5-sonnet-20241022 19 | # extraction: claude-3-5-sonnet-20241022 20 | 21 | # - ocr: claude-3-5-sonnet-20241022 22 | # extraction: claude-3-5-sonnet-20241022 23 | # directImageExtraction: true 24 | 25 | # - ocr: zerox 26 | # extraction: gpt-4o 27 | 28 | # - ocr: omniai 29 | # extraction: gpt-4o 30 | 31 | # - ocr: aws-textract 32 | # extraction: gpt-4o 33 | 34 | # - ocr: google-document-ai 35 | # extraction: gpt-4o 36 | 37 | # - ocr: azure-document-intelligence 38 | # extraction: gpt-4o 39 | 40 | # - ocr: unstructured 41 | # extraction: gpt-4o 42 | 43 | # - ocr: gpt-4o 44 | # extraction: deepseek-chat 45 | -------------------------------------------------------------------------------- /src/models/awsTextract.ts: -------------------------------------------------------------------------------- 1 | import { TextractClient, AnalyzeDocumentCommand } from '@aws-sdk/client-textract'; 2 | import { ModelProvider } from './base'; 3 | 4 | // https://aws.amazon.com/textract/pricing/ 5 | // $4 per 1000 pages for the first 1M pages, Layout model 6 | const COST_PER_PAGE = 4 / 1000; 7 | 8 | export class AWSTextractProvider extends ModelProvider { 9 | private client: TextractClient; 10 | 11 | constructor() { 12 | super('aws-textract'); 13 | this.client = new TextractClient({ 14 | region: process.env.AWS_REGION, 15 | credentials: { 16 | accessKeyId: process.env.AWS_ACCESS_KEY_ID!, 17 | secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY!, 18 | }, 19 | }); 20 | } 21 | 22 | async ocr(imagePath: string) { 23 | try { 24 | // Convert image URL to base64 25 | const response = await fetch(imagePath); 26 | const arrayBuffer = await response.arrayBuffer(); 27 | const buffer = Buffer.from(arrayBuffer); 28 | 29 | const start = performance.now(); 30 | const command = new AnalyzeDocumentCommand({ 31 | Document: { 32 | Bytes: buffer, 33 | }, 34 | FeatureTypes: ['LAYOUT'], 35 | }); 36 | 37 | const result = await this.client.send(command); 38 | const end = performance.now(); 39 | 40 | // Extract text from blocks 41 | const text = 42 | result.Blocks?.filter((block) => block.Text) 43 | .map((block) => block.Text) 44 | .join('\n') || ''; 45 | 46 | return { 47 | text, 48 | usage: { 49 | duration: end - start, 50 | totalCost: COST_PER_PAGE, // the input is always 1 page. 51 | }, 52 | }; 53 | } catch (error) { 54 | console.error('AWS Textract Error:', error); 55 | throw error; 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/models/azure.ts: -------------------------------------------------------------------------------- 1 | import { AzureKeyCredential } from '@azure/core-auth'; 2 | import DocumentIntelligence, { 3 | DocumentIntelligenceClient, 4 | getLongRunningPoller, 5 | isUnexpected, 6 | AnalyzeOperationOutput, 7 | } from '@azure-rest/ai-document-intelligence'; 8 | 9 | import { ModelProvider } from './base'; 10 | 11 | // https://azure.microsoft.com/en-us/pricing/details/ai-document-intelligence/ 12 | // $10 per 1000 pages for the first 1M pages, Prebuilt-Layout model 13 | const COST_PER_PAGE = 10 / 1000; 14 | 15 | export class AzureDocumentIntelligenceProvider extends ModelProvider { 16 | private client: DocumentIntelligenceClient; 17 | 18 | constructor() { 19 | super('azure-document-intelligence'); 20 | 21 | const endpoint = process.env.AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT; 22 | const apiKey = process.env.AZURE_DOCUMENT_INTELLIGENCE_KEY; 23 | 24 | if (!endpoint || !apiKey) { 25 | throw new Error('Missing required Azure Document Intelligence configuration'); 26 | } 27 | 28 | this.client = DocumentIntelligence(endpoint, new AzureKeyCredential(apiKey)); 29 | } 30 | 31 | async ocr(imagePath: string) { 32 | try { 33 | const start = performance.now(); 34 | 35 | const initialResponse = await this.client 36 | .path('/documentModels/{modelId}:analyze', 'prebuilt-layout') 37 | .post({ 38 | contentType: 'application/json', 39 | body: { 40 | urlSource: imagePath, 41 | }, 42 | queryParameters: { outputContentFormat: 'markdown' }, 43 | }); 44 | 45 | if (isUnexpected(initialResponse)) { 46 | throw initialResponse.body.error; 47 | } 48 | 49 | const poller = getLongRunningPoller(this.client, initialResponse); 50 | const result = (await poller.pollUntilDone()).body as AnalyzeOperationOutput; 51 | const analyzeResult = result.analyzeResult; 52 | const text = analyzeResult?.content; 53 | 54 | const end = performance.now(); 55 | 56 | return { 57 | text, 58 | usage: { 59 | duration: end - start, 60 | totalCost: COST_PER_PAGE, // the input is always 1 page. 61 | }, 62 | }; 63 | } catch (error) { 64 | console.error('Azure Document Intelligence Error:', error); 65 | throw error; 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/models/base.ts: -------------------------------------------------------------------------------- 1 | import { JsonSchema, Usage } from '../types'; 2 | 3 | export class ModelProvider { 4 | model: string; 5 | outputDir?: string; 6 | 7 | constructor(model: string, outputDir?: string) { 8 | this.model = model; 9 | this.outputDir = outputDir; 10 | } 11 | 12 | async ocr(imagePath: string): Promise<{ 13 | text: string; 14 | imageBase64s?: string[]; 15 | usage: Usage; 16 | }> { 17 | throw new Error('Not implemented'); 18 | } 19 | 20 | async extractFromText?( 21 | text: string, 22 | schema: JsonSchema, 23 | imageBase64s?: string[], 24 | ): Promise<{ 25 | json: Record; 26 | usage: Usage; 27 | }> { 28 | throw new Error('Not implemented'); 29 | } 30 | 31 | async extractFromImage?( 32 | imagePath: string, 33 | schema: JsonSchema, 34 | ): Promise<{ 35 | json: Record; 36 | usage: Usage; 37 | }> { 38 | throw new Error('Not implemented'); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/models/dashscope.ts: -------------------------------------------------------------------------------- 1 | import OpenAI from 'openai'; 2 | import sharp from 'sharp'; 3 | 4 | import { ModelProvider } from './base'; 5 | import { Usage } from '../types'; 6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared'; 7 | 8 | export class DashscopeProvider extends ModelProvider { 9 | private client: OpenAI; 10 | 11 | constructor(model: string) { 12 | super(model); 13 | 14 | const apiKey = process.env.DASHSCOPE_API_KEY; 15 | if (!apiKey) { 16 | throw new Error('Missing required HuggingFace API key'); 17 | } 18 | 19 | this.client = new OpenAI({ 20 | baseURL: 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1', 21 | apiKey, 22 | }); 23 | } 24 | 25 | async ocr(imagePath: string): Promise<{ 26 | text: string; 27 | imageBase64s?: string[]; 28 | usage: Usage; 29 | }> { 30 | const start = performance.now(); 31 | 32 | // Fetch the image 33 | const imageResponse = await fetch(imagePath); 34 | const imageBuffer = await imageResponse.arrayBuffer(); 35 | 36 | // compress the image 37 | const resizedImageBuffer = await sharp(Buffer.from(imageBuffer)) 38 | .jpeg({ quality: 90 }) 39 | .toBuffer(); 40 | 41 | // Convert to base64 42 | const resizedImageBase64 = `data:image/jpeg;base64,${resizedImageBuffer.toString('base64')}`; 43 | 44 | const response = await this.client.chat.completions.create({ 45 | model: 'qwen2.5-vl-32b-instruct', 46 | messages: [ 47 | { 48 | role: 'user', 49 | content: [ 50 | { type: 'text', text: OCR_SYSTEM_PROMPT }, 51 | { 52 | type: 'image_url', 53 | image_url: { 54 | url: resizedImageBase64, 55 | }, 56 | }, 57 | ], 58 | }, 59 | ], 60 | }); 61 | 62 | const end = performance.now(); 63 | 64 | const inputTokens = response.usage?.prompt_tokens || 0; 65 | const outputTokens = response.usage?.completion_tokens || 0; 66 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 67 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 68 | 69 | return { 70 | text: response.choices[0].message.content || '', 71 | usage: { 72 | duration: end - start, 73 | inputTokens, 74 | outputTokens, 75 | totalTokens: inputTokens + outputTokens, 76 | inputCost, 77 | outputCost, 78 | totalCost: inputCost + outputCost, 79 | }, 80 | }; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/models/gemini.ts: -------------------------------------------------------------------------------- 1 | import { GoogleGenerativeAI } from '@google/generative-ai'; 2 | 3 | import { ModelProvider } from './base'; 4 | import { 5 | IMAGE_EXTRACTION_SYSTEM_PROMPT, 6 | JSON_EXTRACTION_SYSTEM_PROMPT, 7 | OCR_SYSTEM_PROMPT, 8 | } from './shared/prompt'; 9 | import { calculateTokenCost } from './shared/tokenCost'; 10 | import { getMimeType } from '../utils'; 11 | import { JsonSchema } from '../types'; 12 | 13 | export class GeminiProvider extends ModelProvider { 14 | private client: GoogleGenerativeAI; 15 | 16 | constructor(model: string) { 17 | super(model); 18 | 19 | const apiKey = process.env.GOOGLE_GENERATIVE_AI_API_KEY; 20 | 21 | if (!apiKey) { 22 | throw new Error('Missing required Google Generative AI configuration'); 23 | } 24 | 25 | this.client = new GoogleGenerativeAI(apiKey); 26 | } 27 | 28 | async ocr(imagePath: string) { 29 | try { 30 | const start = performance.now(); 31 | 32 | const model = this.client.getGenerativeModel({ 33 | model: this.model, 34 | generationConfig: { temperature: 0 }, 35 | }); 36 | 37 | // read image and convert to base64 38 | const response = await fetch(imagePath); 39 | const imageBuffer = await response.arrayBuffer(); 40 | const base64Image = Buffer.from(imageBuffer).toString('base64'); 41 | 42 | const imagePart = { 43 | inlineData: { 44 | data: base64Image, 45 | mimeType: getMimeType(imagePath), 46 | }, 47 | }; 48 | 49 | const ocrResult = await model.generateContent([OCR_SYSTEM_PROMPT, imagePart]); 50 | const text = ocrResult.response.text(); 51 | 52 | const end = performance.now(); 53 | 54 | const ocrInputTokens = ocrResult.response.usageMetadata.promptTokenCount; 55 | const ocrOutputTokens = ocrResult.response.usageMetadata.candidatesTokenCount; 56 | const inputCost = calculateTokenCost(this.model, 'input', ocrInputTokens); 57 | const outputCost = calculateTokenCost(this.model, 'output', ocrOutputTokens); 58 | 59 | return { 60 | text, 61 | usage: { 62 | duration: end - start, 63 | inputTokens: ocrInputTokens, 64 | outputTokens: ocrOutputTokens, 65 | totalTokens: ocrInputTokens + ocrOutputTokens, 66 | inputCost, 67 | outputCost, 68 | totalCost: inputCost + outputCost, 69 | }, 70 | }; 71 | } catch (error) { 72 | console.error('Google Generative AI OCR Error:', error); 73 | throw error; 74 | } 75 | } 76 | 77 | // FIXME: JSON output might not be 100% correct yet, because Gemini uses a subset of OpenAPI 3.0 schema 78 | // https://sdk.vercel.ai/providers/ai-sdk-providers/google-generative-ai#schema-limitations 79 | async extractFromText(text: string, schema: JsonSchema) { 80 | const filteredSchema = this.convertSchemaForGemini(schema); 81 | 82 | const start = performance.now(); 83 | const model = this.client.getGenerativeModel({ 84 | model: this.model, 85 | generationConfig: { 86 | temperature: 0, 87 | responseMimeType: 'application/json', 88 | responseSchema: filteredSchema, 89 | }, 90 | }); 91 | 92 | const result = await model.generateContent([JSON_EXTRACTION_SYSTEM_PROMPT, text]); 93 | 94 | const json = JSON.parse(result.response.text()); 95 | 96 | const end = performance.now(); 97 | 98 | const inputTokens = result.response.usageMetadata.promptTokenCount; 99 | const outputTokens = result.response.usageMetadata.candidatesTokenCount; 100 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 101 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 102 | 103 | return { 104 | json, 105 | usage: { 106 | duration: end - start, 107 | inputTokens, 108 | outputTokens, 109 | totalTokens: inputTokens + outputTokens, 110 | inputCost, 111 | outputCost, 112 | totalCost: inputCost + outputCost, 113 | }, 114 | }; 115 | } 116 | 117 | // FIXME: JSON output might not be 100% correct yet, because Gemini uses a subset of OpenAPI 3.0 schema 118 | // https://sdk.vercel.ai/providers/ai-sdk-providers/google-generative-ai#schema-limitations 119 | async extractFromImage(imagePath: string, schema: JsonSchema) { 120 | const filteredSchema = this.convertSchemaForGemini(schema); 121 | 122 | // read image and convert to base64 123 | const response = await fetch(imagePath); 124 | const imageBuffer = await response.arrayBuffer(); 125 | const base64Image = Buffer.from(imageBuffer).toString('base64'); 126 | 127 | const start = performance.now(); 128 | 129 | const model = this.client.getGenerativeModel({ 130 | model: this.model, 131 | generationConfig: { 132 | temperature: 0, 133 | responseMimeType: 'application/json', 134 | responseSchema: filteredSchema, 135 | }, 136 | }); 137 | 138 | const imagePart = { 139 | inlineData: { 140 | data: base64Image, 141 | mimeType: getMimeType(imagePath), 142 | }, 143 | }; 144 | 145 | const result = await model.generateContent([ 146 | IMAGE_EXTRACTION_SYSTEM_PROMPT, 147 | imagePart, 148 | ]); 149 | 150 | const json = JSON.parse(result.response.text()); 151 | 152 | const end = performance.now(); 153 | 154 | const inputTokens = result.response.usageMetadata.promptTokenCount; 155 | const outputTokens = result.response.usageMetadata.candidatesTokenCount; 156 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 157 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 158 | 159 | return { 160 | json, 161 | usage: { 162 | duration: end - start, 163 | inputTokens, 164 | outputTokens, 165 | totalTokens: inputTokens + outputTokens, 166 | inputCost, 167 | outputCost, 168 | totalCost: inputCost + outputCost, 169 | }, 170 | }; 171 | } 172 | 173 | convertSchemaForGemini(schema) { 174 | // Deep clone the schema to avoid modifying the original 175 | const newSchema = JSON.parse(JSON.stringify(schema)); 176 | 177 | function processSchemaNode(node) { 178 | if (!node || typeof node !== 'object') return node; 179 | 180 | // Fix enum type definition 181 | if (node.type === 'enum' && node.enum) { 182 | node.type = 'string'; 183 | } 184 | // Handle case where enum array exists but type isn't specified 185 | if (node.enum && !node.type) { 186 | node.type = 'string'; 187 | } 188 | 189 | // Remove additionalProperties constraints 190 | if ('additionalProperties' in node) { 191 | delete node.additionalProperties; 192 | } 193 | 194 | // Handle 'not' validation keyword 195 | if (node.not) { 196 | if (node.not.type === 'null') { 197 | delete node.not; 198 | node.nullable = false; 199 | } else { 200 | processSchemaNode(node.not); 201 | } 202 | } 203 | 204 | // Handle arrays 205 | if (node.type === 'array' && node.items) { 206 | // Move required fields to items level 207 | if (node.required) { 208 | if (!node.items.required) { 209 | node.items.required = node.required; 210 | } else { 211 | node.items.required = [ 212 | ...new Set([...node.items.required, ...node.required]), 213 | ]; 214 | } 215 | delete node.required; 216 | } 217 | 218 | processSchemaNode(node.items); 219 | } 220 | 221 | // Handle objects with properties 222 | if (node.properties) { 223 | Object.entries(node.properties).forEach(([key, prop]) => { 224 | node.properties[key] = processSchemaNode(prop); 225 | }); 226 | } 227 | 228 | return node; 229 | } 230 | 231 | return processSchemaNode(newSchema); 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /src/models/googleDocumentAI.ts: -------------------------------------------------------------------------------- 1 | import fs from 'fs'; 2 | import { DocumentProcessorServiceClient } from '@google-cloud/documentai'; 3 | import { ModelProvider } from './base'; 4 | 5 | // https://cloud.google.com/document-ai/pricing 6 | // $1.5 per 1000 pages for the first 5M pages 7 | const COST_PER_PAGE = 1.5 / 1000; 8 | 9 | export class GoogleDocumentAIProvider extends ModelProvider { 10 | private client: DocumentProcessorServiceClient; 11 | private processorPath: string; 12 | 13 | constructor() { 14 | super('google-document-ai'); 15 | 16 | const projectId = process.env.GOOGLE_PROJECT_ID; 17 | const location = process.env.GOOGLE_LOCATION || 'us'; // default to 'us' 18 | const processorId = process.env.GOOGLE_PROCESSOR_ID; 19 | 20 | if (!projectId || !processorId) { 21 | throw new Error('Missing required Google Document AI configuration'); 22 | } 23 | 24 | const credentials = JSON.parse( 25 | fs.readFileSync(process.env.GOOGLE_APPLICATION_CREDENTIALS_PATH || '', 'utf8'), 26 | ); 27 | this.client = new DocumentProcessorServiceClient({ 28 | credentials, 29 | }); 30 | 31 | this.processorPath = `projects/${projectId}/locations/${location}/processors/${processorId}`; 32 | } 33 | 34 | async ocr(imagePath: string) { 35 | try { 36 | // Download the image 37 | const response = await fetch(imagePath); 38 | const arrayBuffer = await response.arrayBuffer(); 39 | const imageContent = Buffer.from(arrayBuffer).toString('base64'); 40 | 41 | // Determine MIME type from URL 42 | const mimeType = this.getMimeType(imagePath); 43 | 44 | const request = { 45 | name: this.processorPath, 46 | rawDocument: { 47 | content: imageContent, 48 | mimeType: mimeType, 49 | }, 50 | }; 51 | 52 | const start = performance.now(); 53 | const [result] = await this.client.processDocument(request); 54 | const { document } = result; 55 | const end = performance.now(); 56 | 57 | // Extract text from the document 58 | const text = document?.text || ''; 59 | 60 | return { 61 | text, 62 | usage: { 63 | duration: end - start, 64 | totalCost: COST_PER_PAGE, // the input is always 1 page. 65 | }, 66 | }; 67 | } catch (error) { 68 | console.error('Google Document AI Error:', error); 69 | throw error; 70 | } 71 | } 72 | 73 | private getMimeType(url: string): string { 74 | const extension = url.split('.').pop()?.toLowerCase(); 75 | switch (extension) { 76 | case 'pdf': 77 | return 'application/pdf'; 78 | case 'png': 79 | return 'image/png'; 80 | case 'jpg': 81 | case 'jpeg': 82 | return 'image/jpeg'; 83 | case 'tiff': 84 | case 'tif': 85 | return 'image/tiff'; 86 | case 'gif': 87 | return 'image/gif'; 88 | case 'bmp': 89 | return 'image/bmp'; 90 | default: 91 | return 'image/png'; // default to PNG 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/models/index.ts: -------------------------------------------------------------------------------- 1 | export * from './registry'; 2 | -------------------------------------------------------------------------------- /src/models/llm.ts: -------------------------------------------------------------------------------- 1 | import { 2 | generateText, 3 | generateObject, 4 | CoreMessage, 5 | CoreUserMessage, 6 | NoObjectGeneratedError, 7 | } from 'ai'; 8 | import { createOpenAI } from '@ai-sdk/openai'; 9 | import { createAnthropic } from '@ai-sdk/anthropic'; 10 | import { createGoogleGenerativeAI } from '@ai-sdk/google'; 11 | import { createDeepSeek } from '@ai-sdk/deepseek'; 12 | import { createAzure } from '@ai-sdk/azure'; 13 | 14 | import { ExtractionResult, JsonSchema } from '../types'; 15 | import { generateZodSchema, writeResultToFile } from '../utils'; 16 | import { calculateTokenCost } from './shared'; 17 | import { ModelProvider } from './base'; 18 | import { 19 | OCR_SYSTEM_PROMPT, 20 | JSON_EXTRACTION_SYSTEM_PROMPT, 21 | IMAGE_EXTRACTION_SYSTEM_PROMPT, 22 | } from './shared'; 23 | import { 24 | OPENAI_MODELS, 25 | ANTHROPIC_MODELS, 26 | GOOGLE_GENERATIVE_AI_MODELS, 27 | FINETUNED_MODELS, 28 | DEEPSEEK_MODELS, 29 | AZURE_OPENAI_MODELS, 30 | } from './registry'; 31 | 32 | export const createModelProvider = (model: string) => { 33 | if (OPENAI_MODELS.includes(model)) { 34 | return createOpenAI({ 35 | apiKey: process.env.OPENAI_API_KEY, 36 | baseURL: process.env.OPENAI_ENDPOINT || 'https://api.openai.com/v1', 37 | }); 38 | } 39 | if (AZURE_OPENAI_MODELS.includes(model)) { 40 | return createAzure({ 41 | apiKey: process.env.AZURE_OPENAI_API_KEY, 42 | resourceName: process.env.AZURE_OPENAI_RESOURCE_NAME, 43 | }); 44 | } 45 | if (FINETUNED_MODELS.includes(model)) { 46 | return createOpenAI({ apiKey: process.env.OPENAI_API_KEY }); 47 | } 48 | if (ANTHROPIC_MODELS.includes(model)) { 49 | return createAnthropic({ apiKey: process.env.ANTHROPIC_API_KEY }); 50 | } 51 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(model)) { 52 | return createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY }); 53 | } 54 | if (DEEPSEEK_MODELS.includes(model)) { 55 | return createDeepSeek({ apiKey: process.env.DEEPSEEK_API_KEY }); 56 | } 57 | throw new Error(`Model '${model}' does not support image inputs`); 58 | }; 59 | 60 | export class LLMProvider extends ModelProvider { 61 | constructor(model: string) { 62 | if (AZURE_OPENAI_MODELS.includes(model)) { 63 | const openaiModel = model.replace('azure-', ''); 64 | super(openaiModel); 65 | } else { 66 | super(model); 67 | } 68 | } 69 | 70 | async ocr(imagePath: string) { 71 | const modelProvider = createModelProvider(this.model); 72 | 73 | let imageMessage: CoreUserMessage = { 74 | role: 'user', 75 | content: [ 76 | { 77 | type: 'image', 78 | image: imagePath, 79 | }, 80 | ], 81 | }; 82 | 83 | if (ANTHROPIC_MODELS.includes(this.model)) { 84 | // read image and convert to base64 85 | const response = await fetch(imagePath); 86 | const imageBuffer = await response.arrayBuffer(); 87 | const base64Image = Buffer.from(imageBuffer).toString('base64'); 88 | imageMessage.content = [ 89 | { 90 | type: 'image', 91 | image: base64Image, 92 | }, 93 | ]; 94 | } 95 | 96 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(this.model)) { 97 | // gemini requires a text message in user messages 98 | imageMessage.content = [ 99 | { 100 | type: 'text', 101 | text: ' ', 102 | }, 103 | { 104 | type: 'file', 105 | data: imagePath, 106 | mimeType: 'image/png', 107 | }, 108 | ]; 109 | } 110 | 111 | const messages: CoreMessage[] = [ 112 | { role: 'system', content: OCR_SYSTEM_PROMPT }, 113 | imageMessage, 114 | ]; 115 | 116 | const start = performance.now(); 117 | const { text, usage: ocrUsage } = await generateText({ 118 | model: modelProvider(this.model), 119 | messages, 120 | }); 121 | const end = performance.now(); 122 | 123 | const inputCost = calculateTokenCost(this.model, 'input', ocrUsage.promptTokens); 124 | const outputCost = calculateTokenCost( 125 | this.model, 126 | 'output', 127 | ocrUsage.completionTokens, 128 | ); 129 | 130 | const usage = { 131 | duration: end - start, 132 | inputTokens: ocrUsage.promptTokens, 133 | outputTokens: ocrUsage.completionTokens, 134 | totalTokens: ocrUsage.totalTokens, 135 | inputCost, 136 | outputCost, 137 | totalCost: inputCost + outputCost, 138 | }; 139 | 140 | return { 141 | text, 142 | usage, 143 | }; 144 | } 145 | 146 | async extractFromText(text: string, schema: JsonSchema, imageBase64s?: string[]) { 147 | const modelProvider = createModelProvider(this.model); 148 | 149 | let imageMessages: CoreMessage[] = []; 150 | if (imageBase64s && imageBase64s.length > 0) { 151 | imageMessages = [ 152 | { 153 | role: 'user', 154 | content: imageBase64s.map((base64) => ({ 155 | type: 'image', 156 | image: base64, 157 | })), 158 | }, 159 | ]; 160 | } 161 | const messages: CoreMessage[] = [ 162 | { role: 'system', content: JSON_EXTRACTION_SYSTEM_PROMPT }, 163 | ...imageMessages, 164 | { role: 'user', content: text }, 165 | ]; 166 | 167 | const zodSchema = generateZodSchema(schema); 168 | 169 | const start = performance.now(); 170 | 171 | let json, extractionUsage; 172 | try { 173 | const { object, usage } = await generateObject({ 174 | model: modelProvider(this.model), 175 | messages, 176 | schema: zodSchema, 177 | temperature: 0, 178 | }); 179 | json = object; 180 | extractionUsage = usage; 181 | } catch (error) { 182 | // if cause is AI_TypeValidationError, then still parse the json 183 | if (error instanceof NoObjectGeneratedError) { 184 | const errorText = error.text; 185 | json = JSON.parse(errorText); 186 | extractionUsage = error.usage; 187 | } else { 188 | throw error; 189 | } 190 | } 191 | 192 | const end = performance.now(); 193 | const inputCost = calculateTokenCost( 194 | this.model, 195 | 'input', 196 | extractionUsage.promptTokens, 197 | ); 198 | const outputCost = calculateTokenCost( 199 | this.model, 200 | 'output', 201 | extractionUsage.completionTokens, 202 | ); 203 | 204 | const usage = { 205 | duration: end - start, 206 | inputTokens: extractionUsage.promptTokens, 207 | outputTokens: extractionUsage.completionTokens, 208 | totalTokens: extractionUsage.totalTokens, 209 | inputCost, 210 | outputCost, 211 | totalCost: inputCost + outputCost, 212 | }; 213 | 214 | return { 215 | json, 216 | usage, 217 | }; 218 | } 219 | 220 | async extractFromImage(imagePath: string, schema: JsonSchema) { 221 | const modelProvider = createModelProvider(this.model); 222 | 223 | let imageMessage: CoreUserMessage = { 224 | role: 'user', 225 | content: [ 226 | { 227 | type: 'image', 228 | image: imagePath, 229 | }, 230 | ], 231 | }; 232 | 233 | if (ANTHROPIC_MODELS.includes(this.model)) { 234 | // read image and convert to base64 235 | const response = await fetch(imagePath); 236 | const imageBuffer = await response.arrayBuffer(); 237 | const base64Image = Buffer.from(imageBuffer).toString('base64'); 238 | imageMessage.content = [ 239 | { 240 | type: 'image', 241 | image: base64Image, 242 | }, 243 | ]; 244 | } 245 | 246 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(this.model)) { 247 | // gemini requires a text message in user messages 248 | imageMessage.content = [ 249 | { 250 | type: 'text', 251 | text: ' ', 252 | }, 253 | { 254 | type: 'file', 255 | data: imagePath, 256 | mimeType: 'image/png', 257 | }, 258 | ]; 259 | } 260 | 261 | const messages: CoreMessage[] = [ 262 | { role: 'system', content: IMAGE_EXTRACTION_SYSTEM_PROMPT }, 263 | imageMessage, 264 | ]; 265 | 266 | const zodSchema = generateZodSchema(schema); 267 | 268 | const start = performance.now(); 269 | const { object: json, usage: extractionUsage } = await generateObject({ 270 | model: modelProvider(this.model), 271 | messages, 272 | schema: zodSchema, 273 | temperature: 0, 274 | }); 275 | const end = performance.now(); 276 | 277 | const inputCost = calculateTokenCost( 278 | this.model, 279 | 'input', 280 | extractionUsage.promptTokens, 281 | ); 282 | const outputCost = calculateTokenCost( 283 | this.model, 284 | 'output', 285 | extractionUsage.completionTokens, 286 | ); 287 | 288 | const usage = { 289 | duration: end - start, 290 | inputTokens: extractionUsage.promptTokens, 291 | outputTokens: extractionUsage.completionTokens, 292 | totalTokens: extractionUsage.totalTokens, 293 | inputCost, 294 | outputCost, 295 | totalCost: inputCost + outputCost, 296 | }; 297 | 298 | return { 299 | json, 300 | usage, 301 | }; 302 | } 303 | } 304 | -------------------------------------------------------------------------------- /src/models/mistral.ts: -------------------------------------------------------------------------------- 1 | import { Mistral } from '@mistralai/mistralai'; 2 | 3 | import { ModelProvider } from './base'; 4 | 5 | // $1.00 per 1000 images 6 | const COST_PER_IMAGE = 0.001; 7 | 8 | export class MistralProvider extends ModelProvider { 9 | private client: Mistral; 10 | 11 | constructor() { 12 | super('mistral-ocr'); 13 | 14 | const apiKey = process.env.MISTRAL_API_KEY; 15 | 16 | if (!apiKey) { 17 | throw new Error('Missing required Mistral API key'); 18 | } 19 | 20 | this.client = new Mistral({ 21 | apiKey, 22 | }); 23 | } 24 | 25 | async ocr(imagePath: string) { 26 | try { 27 | const start = performance.now(); 28 | 29 | const response = await this.client.ocr.process({ 30 | model: 'mistral-ocr-latest', 31 | document: { 32 | imageUrl: imagePath, 33 | }, 34 | includeImageBase64: true, 35 | }); 36 | 37 | const text = response.pages.map((page) => page.markdown).join('\n'); 38 | const end = performance.now(); 39 | 40 | const imageBase64s = response.pages.flatMap((page) => 41 | page.images.map((image) => image.imageBase64).filter((base64) => base64), 42 | ); 43 | 44 | return { 45 | text, 46 | imageBase64s, 47 | usage: { 48 | duration: end - start, 49 | totalCost: COST_PER_IMAGE, 50 | }, 51 | }; 52 | } catch (error) { 53 | console.error('Mistral OCR Error:', error); 54 | throw error; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/models/omniAI.ts: -------------------------------------------------------------------------------- 1 | import axios from 'axios'; 2 | import FormData from 'form-data'; 3 | 4 | import { JsonSchema } from '../types'; 5 | import { ModelProvider } from './base'; 6 | 7 | // https://getomni.ai/pricing 8 | // 1 cent per page 9 | const COST_PER_PAGE = 0.01; 10 | 11 | interface ExtractResponse { 12 | ocr: { 13 | pages: Array<{ 14 | page: number; 15 | content: string; 16 | }>; 17 | inputTokens: number; 18 | outputTokens: number; 19 | }; 20 | extracted?: Record; // Only present when schema is provided 21 | } 22 | 23 | export const sendExtractRequest = async ( 24 | imageUrl: string, 25 | schema?: JsonSchema, 26 | ): Promise => { 27 | const apiKey = process.env.OMNIAI_API_KEY; 28 | if (!apiKey) { 29 | throw new Error('Missing OMNIAI_API_KEY in .env'); 30 | } 31 | 32 | const formData = new FormData(); 33 | formData.append('url', imageUrl); 34 | 35 | // Add optional parameters if provided 36 | if (schema) { 37 | formData.append('schema', JSON.stringify(schema)); 38 | } 39 | 40 | try { 41 | const response = await axios.post( 42 | `${process.env.OMNIAI_API_URL}/extract/sync`, 43 | formData, 44 | { 45 | headers: { 46 | 'x-api-key': apiKey, 47 | ...formData.getHeaders(), 48 | }, 49 | }, 50 | ); 51 | 52 | return response.data.result; 53 | } catch (error) { 54 | if (axios.isAxiosError(error)) { 55 | throw new Error( 56 | `Failed to extract from image: ${JSON.stringify(error.response?.data) || JSON.stringify(error.message)}`, 57 | ); 58 | } 59 | throw error; 60 | } 61 | }; 62 | 63 | export class OmniAIProvider extends ModelProvider { 64 | constructor(model: string) { 65 | super(model); 66 | } 67 | 68 | async ocr(imagePath: string) { 69 | const start = performance.now(); 70 | const response = await sendExtractRequest(imagePath); 71 | const end = performance.now(); 72 | 73 | const text = response.ocr.pages.map((page) => page.content).join('\n'); 74 | const inputTokens = response.ocr.inputTokens; 75 | const outputTokens = response.ocr.outputTokens; 76 | 77 | return { 78 | text, 79 | usage: { 80 | duration: end - start, 81 | inputTokens, 82 | outputTokens, 83 | totalTokens: inputTokens + outputTokens, 84 | totalCost: COST_PER_PAGE, 85 | }, 86 | }; 87 | } 88 | 89 | async extractFromImage(imagePath: string, schema?: JsonSchema) { 90 | const start = performance.now(); 91 | const response = await sendExtractRequest(imagePath, schema); 92 | const end = performance.now(); 93 | 94 | const inputToken = response.ocr.inputTokens; 95 | const outputToken = response.ocr.outputTokens; 96 | 97 | return { 98 | json: response.extracted || {}, 99 | usage: { 100 | duration: end - start, 101 | inputTokens: inputToken, 102 | outputTokens: outputToken, 103 | totalTokens: inputToken + outputToken, 104 | totalCost: 0, // TODO: extraction cost is included in the OCR cost, 1 cent per page 105 | }, 106 | }; 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/models/openai.ts: -------------------------------------------------------------------------------- 1 | import OpenAI from 'openai'; 2 | import sharp from 'sharp'; 3 | 4 | import { ModelProvider } from './base'; 5 | import { Usage } from '../types'; 6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared'; 7 | 8 | export class OpenAIProvider extends ModelProvider { 9 | private client: OpenAI; 10 | 11 | constructor(model: string) { 12 | super(model); 13 | 14 | const apiKey = process.env.COMPATIBLE_OPENAI_API_KEY; 15 | const baseURL = process.env.COMPATIBLE_OPENAI_BASE_URL; 16 | if (!apiKey) { 17 | throw new Error('Missing required API key'); 18 | } 19 | 20 | this.client = new OpenAI({ 21 | baseURL, 22 | apiKey, 23 | }); 24 | } 25 | 26 | async ocr(imagePath: string): Promise<{ 27 | text: string; 28 | imageBase64s?: string[]; 29 | usage: Usage; 30 | }> { 31 | const start = performance.now(); 32 | 33 | const response = await this.client.chat.completions.create({ 34 | model: this.model, 35 | messages: [ 36 | { 37 | role: 'user', 38 | content: [ 39 | { type: 'text', text: OCR_SYSTEM_PROMPT }, 40 | { 41 | type: 'image_url', 42 | image_url: { 43 | url: imagePath, 44 | }, 45 | }, 46 | ], 47 | }, 48 | ], 49 | }); 50 | 51 | const end = performance.now(); 52 | 53 | const inputTokens = response.usage?.prompt_tokens || 0; 54 | const outputTokens = response.usage?.completion_tokens || 0; 55 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 56 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 57 | 58 | return { 59 | text: response.choices[0].message.content || '', 60 | usage: { 61 | duration: end - start, 62 | inputTokens, 63 | outputTokens, 64 | totalTokens: inputTokens + outputTokens, 65 | inputCost, 66 | outputCost, 67 | totalCost: inputCost + outputCost, 68 | }, 69 | }; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/models/openrouter.ts: -------------------------------------------------------------------------------- 1 | import OpenAI from 'openai'; 2 | import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; 3 | 4 | import { ModelProvider } from './base'; 5 | import { Usage } from '../types'; 6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared'; 7 | 8 | export class OpenRouterProvider extends ModelProvider { 9 | private client: OpenAI; 10 | 11 | constructor(model: string) { 12 | super(model); 13 | 14 | const apiKey = process.env.OPENROUTER_API_KEY; 15 | if (!apiKey) { 16 | throw new Error('Missing required OpenRouter API key'); 17 | } 18 | 19 | this.client = new OpenAI({ 20 | baseURL: 'https://openrouter.ai/api/v1', 21 | apiKey, 22 | defaultHeaders: { 23 | 'HTTP-Referer': process.env.SITE_URL || 'https://github.com/omni-ai/benchmark', 24 | 'X-Title': 'OmniAI OCR Benchmark', 25 | }, 26 | }); 27 | } 28 | 29 | async ocr(imagePath: string): Promise<{ 30 | text: string; 31 | imageBase64s?: string[]; 32 | usage: Usage; 33 | }> { 34 | const start = performance.now(); 35 | 36 | const messages: ChatCompletionMessageParam[] = [ 37 | { 38 | role: 'user', 39 | content: [ 40 | { type: 'text', text: OCR_SYSTEM_PROMPT }, 41 | { type: 'image_url', image_url: { url: imagePath } }, 42 | ], 43 | }, 44 | ]; 45 | 46 | const response = await this.client.chat.completions.create({ 47 | model: this.model, 48 | messages, 49 | }); 50 | 51 | const end = performance.now(); 52 | 53 | const inputTokens = response.usage?.prompt_tokens || 0; 54 | const outputTokens = response.usage?.completion_tokens || 0; 55 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 56 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 57 | 58 | return { 59 | text: response.choices[0].message.content || '', 60 | usage: { 61 | duration: end - start, 62 | inputTokens, 63 | outputTokens, 64 | totalTokens: inputTokens + outputTokens, 65 | inputCost, 66 | outputCost, 67 | totalCost: inputCost + outputCost, 68 | }, 69 | }; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/models/registry.ts: -------------------------------------------------------------------------------- 1 | import { AzureDocumentIntelligenceProvider } from './azure'; 2 | import { AWSTextractProvider } from './awsTextract'; 3 | import { DashscopeProvider } from './dashscope'; 4 | import { GeminiProvider } from './gemini'; 5 | import { GoogleDocumentAIProvider } from './googleDocumentAI'; 6 | import { LLMProvider } from './llm'; 7 | import { MistralProvider } from './mistral'; 8 | import { OmniAIProvider } from './omniAI'; 9 | import { OpenAIProvider } from './openai'; 10 | import { OpenRouterProvider } from './openrouter'; 11 | import { TogetherProvider } from './togetherai'; 12 | import { UnstructuredProvider } from './unstructured'; 13 | import { ZeroxProvider } from './zerox'; 14 | 15 | export const OPENAI_MODELS = [ 16 | 'chatgpt-4o-latest', 17 | 'gpt-4o-mini', 18 | 'gpt-4o', 19 | 'o1', 20 | 'o1-mini', 21 | 'o3-mini', 22 | 'o4-mini', 23 | 'gpt-4o-2024-11-20', 24 | 'gpt-4.1', 25 | 'gpt-4.1-mini', 26 | 'gpt-4.1-nano', 27 | ]; 28 | export const AZURE_OPENAI_MODELS = [ 29 | 'azure-gpt-4o-mini', 30 | 'azure-gpt-4o', 31 | 'azure-o1', 32 | 'azure-o1-mini', 33 | 'azure-o3-mini', 34 | 'azure-gpt-4.1', 35 | 'azure-gpt-4.1-mini', 36 | 'azure-gpt-4.1-nano', 37 | ]; 38 | export const ANTHROPIC_MODELS = [ 39 | 'claude-3-5-sonnet-20241022', 40 | 'claude-3-7-sonnet-20250219', 41 | 'claude-sonnet-4-20250514', 42 | 'claude-opus-4-20250514', 43 | ]; 44 | export const DEEPSEEK_MODELS = ['deepseek-chat']; 45 | export const GOOGLE_GENERATIVE_AI_MODELS = [ 46 | 'gemini-1.5-pro', 47 | 'gemini-1.5-flash', 48 | 'gemini-2.0-flash-001', 49 | 'gemini-2.5-pro-exp-03-25', 50 | 'gemini-2.5-pro-preview-03-25', 51 | 'gemini-2.5-flash-preview-05-20', 52 | ]; 53 | export const OPENROUTER_MODELS = [ 54 | 'qwen/qwen2.5-vl-32b-instruct:free', 55 | 'qwen/qwen-2.5-vl-72b-instruct', 56 | // 'google/gemma-3-27b-it', 57 | 'deepseek/deepseek-chat-v3-0324', 58 | 'meta-llama/llama-3.2-11b-vision-instruct', 59 | 'meta-llama/llama-3.2-90b-vision-instruct', 60 | ]; 61 | export const TOGETHER_MODELS = [ 62 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo', 63 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo', 64 | 'meta-llama/Llama-4-Scout-17B-16E-Instruct', 65 | 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8', 66 | ]; 67 | export const FINETUNED_MODELS = []; 68 | 69 | export const MODEL_PROVIDERS = { 70 | anthropic: { 71 | models: ANTHROPIC_MODELS, 72 | provider: LLMProvider, 73 | }, 74 | aws: { 75 | models: ['aws-textract'], 76 | provider: AWSTextractProvider, 77 | }, 78 | azureOpenai: { 79 | models: AZURE_OPENAI_MODELS, 80 | provider: LLMProvider, 81 | }, 82 | gemini: { 83 | models: GOOGLE_GENERATIVE_AI_MODELS, 84 | provider: GeminiProvider, 85 | }, 86 | google: { 87 | models: ['google-document-ai'], 88 | provider: GoogleDocumentAIProvider, 89 | }, 90 | deepseek: { 91 | models: DEEPSEEK_MODELS, 92 | provider: LLMProvider, 93 | }, 94 | azure: { 95 | models: ['azure-document-intelligence'], 96 | provider: AzureDocumentIntelligenceProvider, 97 | }, 98 | mistral: { 99 | models: ['mistral-ocr'], 100 | provider: MistralProvider, 101 | }, 102 | omniai: { 103 | models: ['omniai'], 104 | provider: OmniAIProvider, 105 | }, 106 | openai: { 107 | models: OPENAI_MODELS, 108 | provider: LLMProvider, 109 | }, 110 | openaiBase: { 111 | models: ['google/gemma-3-27b-it'], 112 | provider: OpenAIProvider, 113 | }, 114 | openrouter: { 115 | models: OPENROUTER_MODELS, 116 | provider: OpenRouterProvider, 117 | }, 118 | together: { 119 | models: TOGETHER_MODELS, 120 | provider: TogetherProvider, 121 | }, 122 | dashscope: { 123 | models: ['qwen2.5-vl-32b-instruct', 'qwen2.5-vl-72b-instruct'], 124 | provider: DashscopeProvider, 125 | }, 126 | unstructured: { 127 | models: ['unstructured'], 128 | provider: UnstructuredProvider, 129 | }, 130 | zerox: { 131 | models: ['zerox'], 132 | provider: ZeroxProvider, 133 | }, 134 | groundTruth: { 135 | models: ['ground-truth'], 136 | provider: undefined, 137 | }, 138 | }; 139 | 140 | export const getModelProvider = (model: string) => { 141 | // Include Openai FT models 142 | MODEL_PROVIDERS['openaiFt'] = { 143 | models: FINETUNED_MODELS, 144 | provider: LLMProvider, 145 | }; 146 | const foundProvider = Object.values(MODEL_PROVIDERS).find( 147 | (group) => group.models && group.models.includes(model), 148 | ); 149 | 150 | if (foundProvider) { 151 | if (model === 'ground-truth') { 152 | return undefined; 153 | } 154 | const provider = new foundProvider.provider(model); 155 | return provider; 156 | } 157 | 158 | throw new Error(`Model '${model}' is not supported.`); 159 | }; 160 | -------------------------------------------------------------------------------- /src/models/shared/index.ts: -------------------------------------------------------------------------------- 1 | export * from './prompt'; 2 | export * from './tokenCost'; 3 | -------------------------------------------------------------------------------- /src/models/shared/prompt.ts: -------------------------------------------------------------------------------- 1 | export const OCR_SYSTEM_PROMPT = ` 2 | Convert the following document to markdown. 3 | Return only the markdown with no explanation text. Do not include delimiters like \`\`\`markdown or \`\`\`html. 4 | 5 | RULES: 6 | - You must include all information on the page. Do not exclude headers, footers, charts, infographics, or subtext. 7 | - Return tables in an HTML format. 8 | - Logos should be wrapped in brackets. Ex: Coca-Cola 9 | - Watermarks should be wrapped in brackets. Ex: OFFICIAL COPY 10 | - Page numbers should be wrapped in brackets. Ex: 14 or 9/22 11 | - Prefer using ☐ and ☑ for check boxes. 12 | `; 13 | 14 | export const JSON_EXTRACTION_SYSTEM_PROMPT = ` 15 | Extract data from the following document based on the JSON schema. 16 | Return null if the document does not contain information relevant to schema. 17 | Return only the JSON with no explanation text. 18 | `; 19 | 20 | export const IMAGE_EXTRACTION_SYSTEM_PROMPT = ` 21 | Extract the following JSON schema from the image. 22 | Return only the JSON with no explanation text. 23 | `; 24 | -------------------------------------------------------------------------------- /src/models/shared/tokenCost.ts: -------------------------------------------------------------------------------- 1 | import { FINETUNED_MODELS } from '../registry'; 2 | 3 | export const TOKEN_COST = { 4 | 'azure-gpt-4o': { 5 | input: 2.5, 6 | output: 10, 7 | }, 8 | 'azure-gpt-4o-mini': { 9 | input: 0.15, 10 | output: 0.6, 11 | }, 12 | 'azure-gpt-4.1': { 13 | input: 2, 14 | output: 8, 15 | }, 16 | 'azure-gpt-4.1-mini': { 17 | input: 0.4, 18 | output: 1.6, 19 | }, 20 | 'azure-gpt-4.1-nano': { 21 | input: 0.1, 22 | output: 0.4, 23 | }, 24 | 'azure-o1': { 25 | input: 15, 26 | output: 60, 27 | }, 28 | 'azure-o1-mini': { 29 | input: 1.1, 30 | output: 4.4, 31 | }, 32 | 'azure-o3-mini': { 33 | input: 1.1, 34 | output: 4.4, 35 | }, 36 | 'claude-3-5-sonnet-20241022': { 37 | input: 3, 38 | output: 15, 39 | }, 40 | 'claude-3-7-sonnet-20250219': { 41 | input: 3, 42 | output: 15, 43 | }, 44 | 'claude-sonnet-4-20250514': { 45 | input: 3, 46 | output: 15, 47 | }, 48 | 'claude-opus-4-20250514': { 49 | input: 15, 50 | output: 75, 51 | }, 52 | 'deepseek-chat': { 53 | input: 0.14, 54 | output: 0.28, 55 | }, 56 | 'gemini-1.5-pro': { 57 | input: 1.25, 58 | output: 5, 59 | }, 60 | 'gemini-1.5-flash': { 61 | input: 0.075, 62 | output: 0.3, 63 | }, 64 | 'gemini-2.0-flash-001': { 65 | input: 0.1, 66 | output: 0.4, 67 | }, 68 | 'gemini-2.5-pro-exp-03-25': { 69 | input: 1.25, 70 | output: 10, 71 | }, 72 | 'gemini-2.5-pro-preview-03-25': { 73 | input: 1.25, 74 | output: 10, 75 | }, 76 | 'gemini-2.5-flash-preview-05-20': { 77 | input: 0.15, 78 | output: 0.6, 79 | }, 80 | 'gpt-4o': { 81 | input: 2.5, 82 | output: 10, 83 | }, 84 | 'gpt-4o-2024-11-20': { 85 | input: 2.5, 86 | output: 10, 87 | }, 88 | 'gpt-4o-mini': { 89 | input: 0.15, 90 | output: 0.6, 91 | }, 92 | 'gpt-4.1': { 93 | input: 2, 94 | output: 8, 95 | }, 96 | 'gpt-4.1-mini': { 97 | input: 0.4, 98 | output: 1.6, 99 | }, 100 | 'gpt-4.1-nano': { 101 | input: 0.1, 102 | output: 0.4, 103 | }, 104 | 105 | o1: { 106 | input: 15, 107 | output: 60, 108 | }, 109 | 'o1-mini': { 110 | input: 1.1, 111 | output: 4.4, 112 | }, 113 | 'o3-mini': { 114 | input: 1.1, 115 | output: 4.4, 116 | }, 117 | 'o4-mini': { 118 | input: 1.1, 119 | output: 4.4, 120 | }, 121 | 'chatgpt-4o-latest': { 122 | input: 2.5, 123 | output: 10, 124 | }, 125 | zerox: { 126 | input: 2.5, 127 | output: 10, 128 | }, 129 | 'qwen2.5-vl-32b-instruct': { 130 | input: 0, // TODO: Add cost 131 | output: 0, // TODO: Add cost 132 | }, 133 | 'qwen2.5-vl-72b-instruct': { 134 | input: 0, // TODO: Add cost 135 | output: 0, // TODO: Add cost 136 | }, 137 | 'google/gemma-3-27b-it': { 138 | input: 0.1, 139 | output: 0.2, 140 | }, 141 | 'deepseek/deepseek-chat-v3-0324': { 142 | input: 0.27, 143 | output: 1.1, 144 | }, 145 | 'meta-llama/llama-3.2-11b-vision-instruct': { 146 | input: 0.055, 147 | output: 0.055, 148 | }, 149 | 'meta-llama/llama-3.2-90b-vision-instruct': { 150 | input: 0.8, 151 | output: 1.6, 152 | }, 153 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo': { 154 | input: 0.18, 155 | output: 0.18, 156 | }, 157 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo': { 158 | input: 1.2, 159 | output: 1.2, 160 | }, 161 | 'meta-llama/Llama-4-Scout-17B-16E-Instruct': { 162 | input: 0.18, 163 | output: 0.59, 164 | }, 165 | 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8': { 166 | input: 0.27, 167 | output: 0.85, 168 | }, 169 | }; 170 | 171 | export const calculateTokenCost = ( 172 | model: string, 173 | type: 'input' | 'output', 174 | tokens: number, 175 | ): number => { 176 | const fineTuneCost = Object.fromEntries( 177 | FINETUNED_MODELS.map((el) => [el, { input: 3.75, output: 15.0 }]), 178 | ); 179 | const combinedCost = { ...TOKEN_COST, ...fineTuneCost }; 180 | const modelInfo = combinedCost[model]; 181 | if (!modelInfo) throw new Error(`Model '${model}' is not supported.`); 182 | return (modelInfo[type] * tokens) / 1_000_000; 183 | }; 184 | -------------------------------------------------------------------------------- /src/models/togetherai.ts: -------------------------------------------------------------------------------- 1 | import Together from 'together-ai'; 2 | 3 | import { ModelProvider } from './base'; 4 | import { Usage } from '../types'; 5 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared'; 6 | 7 | export class TogetherProvider extends ModelProvider { 8 | private client: Together; 9 | 10 | constructor(model: string) { 11 | super(model); 12 | 13 | const apiKey = process.env.TOGETHER_API_KEY; 14 | if (!apiKey) { 15 | throw new Error('Missing required Together API key'); 16 | } 17 | 18 | this.client = new Together(); 19 | } 20 | 21 | async ocr(imagePath: string): Promise<{ 22 | text: string; 23 | imageBase64s?: string[]; 24 | usage: Usage; 25 | }> { 26 | const start = performance.now(); 27 | 28 | const response = await this.client.chat.completions.create({ 29 | model: this.model, 30 | messages: [ 31 | { 32 | role: 'user', 33 | content: [ 34 | { type: 'text', text: OCR_SYSTEM_PROMPT }, 35 | { type: 'image_url', image_url: { url: imagePath } }, 36 | ], 37 | }, 38 | ], 39 | }); 40 | 41 | const end = performance.now(); 42 | 43 | const inputTokens = response.usage?.prompt_tokens || 0; 44 | const outputTokens = response.usage?.completion_tokens || 0; 45 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens); 46 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens); 47 | 48 | return { 49 | text: response.choices[0].message.content || '', 50 | usage: { 51 | duration: end - start, 52 | inputTokens, 53 | outputTokens, 54 | totalTokens: inputTokens + outputTokens, 55 | inputCost, 56 | outputCost, 57 | totalCost: inputCost + outputCost, 58 | }, 59 | }; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/models/unstructured.ts: -------------------------------------------------------------------------------- 1 | import axios from 'axios'; 2 | import { ModelProvider } from './base'; 3 | import { htmlToMarkdown } from '../utils'; 4 | 5 | // Fast Pipeline: $1 per 1,000 pages 6 | const COST_PER_PAGE = 20 / 1000; 7 | 8 | enum UnstructuredTypes { 9 | Title = 'Title', 10 | Header = 'Header', 11 | NarrativeText = 'NarrativeText', 12 | } 13 | 14 | interface UnstructuredElement { 15 | text: string; 16 | type: UnstructuredTypes; 17 | metadata: { 18 | filename: string; 19 | filetype: string; 20 | languages: string[]; 21 | page_number: number; 22 | parent_id?: string; 23 | text_as_html?: string; 24 | }; 25 | element_id: string; 26 | } 27 | 28 | export class UnstructuredProvider extends ModelProvider { 29 | constructor() { 30 | super('unstructured'); 31 | } 32 | 33 | async ocr(imagePath: string) { 34 | try { 35 | const start = performance.now(); 36 | 37 | const fileName = imagePath.split('/').pop()[0]; 38 | const formData = new FormData(); 39 | const response = await axios.get(imagePath, { responseType: 'arraybuffer' }); 40 | const fileData = Buffer.from(response.data); 41 | 42 | formData.append('files', new Blob([fileData]), fileName); 43 | 44 | const apiResponse = await axios.post( 45 | 'https://api.unstructuredapp.io/general/v0/general', 46 | formData, 47 | { 48 | headers: { 49 | accept: 'application/json', 50 | 'unstructured-api-key': process.env.UNSTRUCTURED_API_KEY, 51 | }, 52 | }, 53 | ); 54 | 55 | const unstructuredResult = apiResponse.data as UnstructuredElement[]; 56 | 57 | // Format the result 58 | let markdown = ''; 59 | if (Array.isArray(unstructuredResult)) { 60 | markdown = unstructuredResult.reduce((acc, el) => { 61 | if (el.type === UnstructuredTypes.Title) { 62 | acc += `\n### ${el.text}\n`; 63 | } else if (el.type === UnstructuredTypes.NarrativeText) { 64 | acc += `\n${el.text}\n`; 65 | } else if (el.metadata?.text_as_html) { 66 | acc += htmlToMarkdown(el.metadata.text_as_html) + '\n'; 67 | } else if (el.text) { 68 | acc += el.text + '\n'; 69 | } 70 | return acc; 71 | }, ''); 72 | } else { 73 | markdown = JSON.stringify(unstructuredResult); 74 | } 75 | 76 | const end = performance.now(); 77 | 78 | return { 79 | text: markdown, 80 | usage: { 81 | duration: end - start, 82 | totalCost: COST_PER_PAGE, // the input is always 1 page. 83 | }, 84 | }; 85 | } catch (error) { 86 | console.error('Unstructured Error:', error); 87 | throw error; 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/models/zerox.ts: -------------------------------------------------------------------------------- 1 | import { zerox } from 'zerox'; 2 | 3 | import { ModelProvider } from './base'; 4 | import { calculateTokenCost } from './shared'; 5 | 6 | export class ZeroxProvider extends ModelProvider { 7 | constructor() { 8 | super('zerox'); 9 | } 10 | 11 | async ocr(imagePath: string) { 12 | const startTime = performance.now(); 13 | 14 | const result = await zerox({ 15 | filePath: imagePath, 16 | openaiAPIKey: process.env.OPENAI_API_KEY, 17 | }); 18 | 19 | const endTime = performance.now(); 20 | 21 | const text = result.pages.map((page) => page.content).join('\n'); 22 | 23 | const inputCost = calculateTokenCost(this.model, 'input', result.inputTokens); 24 | const outputCost = calculateTokenCost(this.model, 'output', result.outputTokens); 25 | 26 | const usage = { 27 | duration: endTime - startTime, 28 | inputTokens: result.inputTokens, 29 | outputTokens: result.outputTokens, 30 | totalTokens: result.inputTokens + result.outputTokens, 31 | inputCost, 32 | outputCost, 33 | totalCost: inputCost + outputCost, 34 | }; 35 | 36 | return { 37 | text, 38 | usage, 39 | }; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/types/data.ts: -------------------------------------------------------------------------------- 1 | import { Usage } from './model'; 2 | import { AccuracyResult } from '../evaluation'; 3 | 4 | export interface Input { 5 | imageUrl: string; 6 | metadata: Metadata; 7 | jsonSchema: JsonSchema; 8 | trueJsonOutput: Record; 9 | trueMarkdownOutput: string; 10 | } 11 | 12 | export interface Metadata { 13 | orientation?: number; 14 | documentQuality?: string; 15 | resolution?: number[]; 16 | language?: string; 17 | } 18 | 19 | export interface JsonSchema { 20 | type: string; 21 | description?: string; 22 | properties?: Record; 23 | items?: JsonSchema; 24 | required?: string[]; 25 | } 26 | 27 | export interface Result { 28 | fileUrl: string; 29 | metadata: Metadata; 30 | ocrModel: string; 31 | extractionModel: string; 32 | jsonSchema: JsonSchema; 33 | directImageExtraction?: boolean; 34 | trueMarkdown: string; 35 | trueJson: Record; 36 | predictedMarkdown?: string; 37 | predictedJson?: Record; 38 | levenshteinDistance?: number; 39 | jsonAccuracy?: number; 40 | jsonDiff?: Record; 41 | fullJsonDiff?: Record; 42 | jsonDiffStats?: Record; 43 | jsonAccuracyResult?: AccuracyResult; 44 | usage?: Usage; 45 | error?: any; 46 | } 47 | -------------------------------------------------------------------------------- /src/types/index.ts: -------------------------------------------------------------------------------- 1 | export * from './data'; 2 | export * from './model'; 3 | -------------------------------------------------------------------------------- /src/types/model.ts: -------------------------------------------------------------------------------- 1 | export interface ExtractionResult { 2 | json?: Record; 3 | text?: string; 4 | usage: Usage; 5 | } 6 | 7 | export interface Usage { 8 | duration?: number; 9 | inputTokens?: number; 10 | outputTokens?: number; 11 | totalTokens?: number; 12 | inputCost?: number; 13 | outputCost?: number; 14 | totalCost?: number; 15 | ocr?: Usage; 16 | extraction?: Usage; 17 | } 18 | -------------------------------------------------------------------------------- /src/utils/dataLoader.ts: -------------------------------------------------------------------------------- 1 | import { Input } from '../types'; 2 | import { Pool } from 'pg'; 3 | import fs from 'fs'; 4 | import path from 'path'; 5 | 6 | // Pull JSON files from local folder 7 | export const loadLocalData = (folder: string): Input[] => { 8 | const files = fs.readdirSync(folder).filter((file) => file.endsWith('.json')); 9 | const data = files.map((file) => { 10 | const filePath = path.join(folder, file); 11 | const fileContent = fs.readFileSync(filePath, 'utf8'); 12 | return JSON.parse(fileContent); 13 | }); 14 | 15 | return data; 16 | }; 17 | 18 | // Query results from the documents table. 19 | export const loadFromDb = async (): Promise => { 20 | const pool = new Pool({ 21 | connectionString: process.env.DATABASE_URL, 22 | ssl: { rejectUnauthorized: false }, 23 | }); 24 | 25 | try { 26 | const result = await pool.query(` 27 | SELECT 28 | url AS "imageUrl", 29 | config AS "metadata", 30 | schema AS "jsonSchema", 31 | extracted_json AS "trueJsonOutput", 32 | markdown AS "trueMarkdownOutput" 33 | FROM documents 34 | WHERE include_in_training = FALSE 35 | ORDER BY created_at 36 | LIMIT 1000; 37 | `); 38 | 39 | return result.rows as Input[]; 40 | } catch (error) { 41 | console.error('Error querying data from PostgreSQL:', error); 42 | throw error; 43 | } finally { 44 | await pool.end(); 45 | } 46 | }; 47 | -------------------------------------------------------------------------------- /src/utils/db.ts: -------------------------------------------------------------------------------- 1 | import { PrismaClient } from '@prisma/client'; 2 | import { Result } from '../types'; 3 | 4 | const prisma = new PrismaClient(); 5 | 6 | export async function createBenchmarkRun( 7 | timestamp: string, 8 | modelsConfig: any, 9 | totalDocuments: number, 10 | ) { 11 | return prisma.benchmarkRun.create({ 12 | data: { 13 | timestamp, 14 | status: 'running', 15 | modelsConfig: { models: modelsConfig }, 16 | totalDocuments, 17 | }, 18 | }); 19 | } 20 | 21 | export async function saveResult(runId: string, result: Result) { 22 | return prisma.benchmarkResult.create({ 23 | data: { 24 | benchmarkRunId: runId, 25 | fileUrl: result.fileUrl, 26 | metadata: result.metadata as any, 27 | ocrModel: result.ocrModel, 28 | extractionModel: result.extractionModel || '', 29 | jsonSchema: result.jsonSchema as any, 30 | directImageExtraction: result.directImageExtraction || false, 31 | trueMarkdown: result.trueMarkdown, 32 | trueJson: result.trueJson, 33 | predictedMarkdown: result.predictedMarkdown, 34 | predictedJson: result.predictedJson, 35 | levenshteinDistance: result.levenshteinDistance, 36 | jsonAccuracy: result.jsonAccuracy, 37 | jsonDiff: result.jsonDiff, 38 | fullJsonDiff: result.fullJsonDiff, 39 | jsonDiffStats: result.jsonDiffStats, 40 | jsonAccuracyResult: result.jsonAccuracyResult as any, 41 | usage: result.usage as any, 42 | error: JSON.stringify(result.error), 43 | }, 44 | }); 45 | } 46 | 47 | export async function completeBenchmarkRun(runId: string, error?: string) { 48 | return prisma.benchmarkRun.update({ 49 | where: { id: runId }, 50 | data: { 51 | status: error ? 'failed' : 'completed', 52 | completedAt: new Date(), 53 | error, 54 | }, 55 | }); 56 | } 57 | 58 | // Clean up function 59 | export async function disconnect() { 60 | await prisma.$disconnect(); 61 | } 62 | -------------------------------------------------------------------------------- /src/utils/file.ts: -------------------------------------------------------------------------------- 1 | export const getMimeType = (url: string): string => { 2 | const extension = url.split('.').pop()?.toLowerCase(); 3 | switch (extension) { 4 | case 'pdf': 5 | return 'application/pdf'; 6 | case 'png': 7 | return 'image/png'; 8 | case 'jpg': 9 | case 'jpeg': 10 | return 'image/jpeg'; 11 | case 'tiff': 12 | case 'tif': 13 | return 'image/tiff'; 14 | case 'gif': 15 | return 'image/gif'; 16 | case 'bmp': 17 | return 'image/bmp'; 18 | default: 19 | return 'image/png'; // default to PNG 20 | } 21 | }; 22 | -------------------------------------------------------------------------------- /src/utils/htmlToMarkdown.ts: -------------------------------------------------------------------------------- 1 | import TurndownService from 'turndown'; 2 | 3 | export function htmlToMarkdown(html: string): string { 4 | const turndownService = new TurndownService({}); 5 | 6 | turndownService.addRule('strong', { 7 | filter: ['strong', 'b'], 8 | replacement: (content) => `**${content}**`, 9 | }); 10 | 11 | // Convert HTML to Markdown 12 | return turndownService.turndown(html); 13 | } 14 | -------------------------------------------------------------------------------- /src/utils/index.ts: -------------------------------------------------------------------------------- 1 | export * from './dataLoader'; 2 | export * from './db'; 3 | export * from './file'; 4 | export * from './htmlToMarkdown'; 5 | export * from './logs'; 6 | export * from './zod'; 7 | -------------------------------------------------------------------------------- /src/utils/logs.ts: -------------------------------------------------------------------------------- 1 | import fs from 'fs'; 2 | import path from 'path'; 3 | 4 | import { ExtractionResult } from '../types'; 5 | 6 | export const createResultFolder = (folderName: string) => { 7 | // check if results folder exists 8 | const resultsFolder = path.join(__dirname, '..', '..', 'results'); 9 | if (!fs.existsSync(resultsFolder)) { 10 | fs.mkdirSync(resultsFolder, { recursive: true }); 11 | } 12 | 13 | const folderPath = path.join(resultsFolder, folderName); 14 | fs.mkdirSync(folderPath, { recursive: true }); 15 | return folderPath; 16 | }; 17 | 18 | export const writeToFile = (filePath: string, content: any) => { 19 | fs.writeFileSync(filePath, JSON.stringify(content, null, 2)); 20 | }; 21 | 22 | export const writeResultToFile = ( 23 | outputDir: string, 24 | fileName: string, 25 | result: ExtractionResult, 26 | ) => { 27 | fs.writeFileSync(path.join(outputDir, fileName), JSON.stringify(result, null, 2)); 28 | }; 29 | -------------------------------------------------------------------------------- /src/utils/zod.ts: -------------------------------------------------------------------------------- 1 | import { z } from 'zod'; 2 | 3 | const zodTypeMapping = { 4 | array: (itemSchema: any) => z.array(itemSchema), 5 | boolean: z.boolean(), 6 | integer: z.number().int(), 7 | number: z.number(), 8 | object: (properties: any) => z.object(properties).strict(), 9 | string: z.string(), 10 | }; 11 | 12 | export const generateZodSchema = (schemaDef: any): z.ZodObject => { 13 | const properties: Record = {}; 14 | 15 | for (const [key, value] of Object.entries(schemaDef.properties) as any) { 16 | let zodType; 17 | 18 | if (value.enum && Array.isArray(value.enum) && value.enum.length > 0) { 19 | zodType = z.enum(value.enum as [string, ...string[]]); 20 | } else { 21 | zodType = zodTypeMapping[value.type]; 22 | } 23 | 24 | if (value.type === 'array' && value.items.type === 'object') { 25 | properties[key] = zodType(generateZodSchema(value.items)); 26 | } else if (value.type === 'array' && value.items.type !== 'object') { 27 | properties[key] = zodType(zodTypeMapping[value.items.type]); 28 | } else if (value.type === 'object') { 29 | properties[key] = generateZodSchema(value); 30 | } else { 31 | properties[key] = zodType; 32 | } 33 | 34 | // Make properties nullable by default 35 | properties[key] = properties[key].nullable(); 36 | 37 | if (value.description) { 38 | properties[key] = properties?.[key]?.describe(value?.description); 39 | } 40 | } 41 | 42 | return z.object(properties).strict(); 43 | }; 44 | -------------------------------------------------------------------------------- /tests/evaluation/json.test.ts: -------------------------------------------------------------------------------- 1 | import { 2 | calculateJsonAccuracy, 3 | countTotalFields, 4 | countChanges, 5 | } from '../../src/evaluation/json'; 6 | 7 | describe('countTotalFields', () => { 8 | it('counts fields in nested objects including array elements', () => { 9 | const obj = { a: 1, b: { c: 2, d: [3, { e: 4 }] } }; 10 | expect(countTotalFields(obj)).toBe(4); 11 | }); 12 | 13 | it('counts array elements as individual fields', () => { 14 | const obj = { a: [1, 2, 3], b: 'test', c: true }; 15 | expect(countTotalFields(obj)).toBe(5); 16 | }); 17 | 18 | it('counts nested objects within arrays', () => { 19 | const obj = { a: [{ b: 1 }, { c: 2 }], d: 'test', e: true }; 20 | expect(countTotalFields(obj)).toBe(4); 21 | }); 22 | 23 | it('includes null values in field count', () => { 24 | const obj = { a: null, b: { c: null }, d: 'test' }; 25 | expect(countTotalFields(obj)).toBe(3); 26 | }); 27 | 28 | it('excludes fields with __diff metadata suffixes', () => { 29 | const obj = { 30 | a: 1, 31 | b__deleted: true, 32 | c__added: 'test', 33 | d: { e: 2 }, 34 | }; 35 | expect(countTotalFields(obj)).toBe(2); 36 | }); 37 | }); 38 | 39 | describe('calculateJsonAccuracy', () => { 40 | it('returns 0.5 when half of the fields match', () => { 41 | const actual = { a: 1, b: 2 }; 42 | const predicted = { a: 1, b: 3 }; 43 | const result = calculateJsonAccuracy(actual, predicted); 44 | expect(result.score).toBe(0.5); 45 | }); 46 | 47 | it('handles nested objects in accuracy calculation', () => { 48 | const actual = { a: 1, b: { c: 2, d: 4, e: 4 } }; 49 | const predicted = { a: 1, b: { c: 2, d: 4, e: 5 } }; 50 | const result = calculateJsonAccuracy(actual, predicted); 51 | expect(result.score).toBe(0.75); 52 | }); 53 | 54 | it('calculates accuracy for nested arrays and objects', () => { 55 | const actual = { a: 1, b: [{ c: 2, d: 4, e: 4, f: [2, 9] }] }; 56 | const predicted = { a: 1, b: [{ c: 2, d: 4, e: 5, f: [2, 3] }] }; 57 | const result = calculateJsonAccuracy(actual, predicted); 58 | expect(result.score).toBe(0.5); 59 | }); 60 | 61 | it('considers array elements matching regardless of order', () => { 62 | const actual = { 63 | a: 1, 64 | b: [ 65 | { c: 1, d: 2 }, 66 | { c: 3, d: 4 }, 67 | ], 68 | }; 69 | const predicted = { 70 | a: 1, 71 | b: [ 72 | { c: 3, d: 4 }, 73 | { c: 1, d: 2 }, 74 | ], 75 | }; 76 | const result = calculateJsonAccuracy(actual, predicted); 77 | expect(result.score).toBe(1); 78 | }); 79 | 80 | it('counts all array elements as unmatched when predicted array is null', () => { 81 | const actual = { a: 1, b: [1, 2, 3] }; 82 | const predicted = { a: 1, b: null }; 83 | const result = calculateJsonAccuracy(actual, predicted); 84 | expect(result.score).toBe(1 / 4); 85 | }); 86 | 87 | it('counts all nested array objects as unmatched when predicted is null', () => { 88 | const actual = { a: 1, b: [{ c: 1, d: 1 }, { c: 2 }, { c: 3, e: 4 }] }; 89 | const predicted = { a: 1, b: null }; 90 | const result = calculateJsonAccuracy(actual, predicted); 91 | expect(result.score).toBe(Number((1 / 6).toFixed(4))); 92 | }); 93 | 94 | it('considers null fields in predicted object as partial matches', () => { 95 | const actual = { a: 1, b: { c: 1, d: { e: 1, f: 2 } } }; 96 | const predicted = { a: 1, b: { c: 1, d: null } }; 97 | const result = calculateJsonAccuracy(actual, predicted); 98 | expect(result.score).toBe(0.5); 99 | }); 100 | 101 | describe('null value comparisons', () => { 102 | it('handles actual null to predicted value comparison', () => { 103 | const actual = { a: [{ b: 1, c: null }] }; 104 | const predicted = { a: [{ b: 1, c: 2 }] }; 105 | const result = calculateJsonAccuracy(actual, predicted); 106 | expect(result.score).toBe(0.5); 107 | }); 108 | 109 | it('handles actual null to predicted object comparison', () => { 110 | const actual = { a: [{ b: 1, c: null, f: 4 }] }; 111 | const predicted = { a: [{ b: 1, c: { d: 2 }, f: 4 }] }; 112 | const result = calculateJsonAccuracy(actual, predicted); 113 | expect(result.score).toBe(0.6667); 114 | }); 115 | 116 | it('handles actual null to predicted complex object comparison', () => { 117 | const actual = { a: [{ b: 1, c: null, f: 4 }] }; 118 | const predicted = { a: [{ b: 1, c: { d: 2, e: 3 }, f: 4 }] }; 119 | const result = calculateJsonAccuracy(actual, predicted); 120 | expect(result.score).toBe(0.3333); 121 | }); 122 | 123 | it('handles actual null to predicted array comparison', () => { 124 | const actual = { a: [{ b: 1, c: null, f: 4 }] }; 125 | const predicted = { a: [{ b: 1, c: [3], f: 4 }] }; 126 | const result = calculateJsonAccuracy(actual, predicted); 127 | expect(result.score).toBe(0.6667); 128 | }); 129 | 130 | it('handles actual value to predicted null comparison', () => { 131 | const actual = { a: [{ b: 1, c: 2 }] }; 132 | const predicted = { a: [{ b: 1, c: null }] }; 133 | const result = calculateJsonAccuracy(actual, predicted); 134 | expect(result.score).toBe(0.5); 135 | }); 136 | 137 | it('handles actual object to predicted null comparison', () => { 138 | const actual = { a: [{ b: 1, c: { d: 2 } }] }; 139 | const predicted = { a: [{ b: 1, c: null }] }; 140 | const result = calculateJsonAccuracy(actual, predicted); 141 | expect(result.score).toBe(0.5); 142 | }); 143 | 144 | it('handles actual complex object to predicted null comparison', () => { 145 | const actual = { a: [{ b: 1, c: { d: 2, e: 3 } }] }; 146 | const predicted = { a: [{ b: 1, c: null }] }; 147 | const result = calculateJsonAccuracy(actual, predicted); 148 | expect(result.score).toBe(0.3333); 149 | }); 150 | 151 | it('handles actual array to predicted null comparison', () => { 152 | const actual = { a: [{ b: 1, c: [3, 2] }] }; 153 | const predicted = { a: [{ b: 1, c: null }] }; 154 | const result = calculateJsonAccuracy(actual, predicted); 155 | expect(result.score).toBe(0.3333); 156 | }); 157 | }); 158 | }); 159 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es6", 4 | "module": "commonjs", 5 | "lib": ["es6", "dom", "esnext"], 6 | "outDir": "./dist", 7 | "rootDir": "./src", 8 | "strict": false, 9 | "esModuleInterop": true, 10 | "skipLibCheck": true, 11 | "declaration": true, 12 | "sourceMap": true, 13 | "resolveJsonModule": true, 14 | "moduleResolution": "node" 15 | }, 16 | "include": ["src/**/*.ts"], 17 | "exclude": ["node_modules", "dist"] 18 | } 19 | --------------------------------------------------------------------------------