',
216 | unsafe_allow_html=True,
217 | )
218 | # Add navigation buttons
219 | if st.button("←"):
220 | if st.session_state.selected_test_idx > 0:
221 | st.session_state.selected_test_idx -= 1
222 | st.rerun()
223 |
224 | with nav_col_2:
225 | st.write(
226 | '
',
227 | unsafe_allow_html=True,
228 | )
229 | if st.button("→"):
230 | if st.session_state.selected_test_idx < len(results_with_diffs) - 1:
231 | st.session_state.selected_test_idx += 1
232 | st.rerun()
233 |
234 | # 4. Load only the selected test case
235 | selected_result_id = results_with_diffs[st.session_state.selected_test_idx]["id"]
236 | detailed_data = load_one_result(selected_timestamp, selected_result_id)
237 | test_case = detailed_data["result"]
238 |
239 | # Display run metadata if available
240 | if detailed_data.get("description") or detailed_data.get("run_by"):
241 | with st.expander("Run Details", expanded=False):
242 | cols = st.columns(3)
243 | with cols[0]:
244 | st.markdown(f"**Status:** {detailed_data['status'].title()}")
245 | if detailed_data.get("run_by"):
246 | st.markdown(f"**Run By:** {detailed_data['run_by']}")
247 | with cols[1]:
248 | st.markdown(f"**Created:** {detailed_data['created_at']}")
249 | if detailed_data.get("completed_at"):
250 | st.markdown(f"**Completed:** {detailed_data['completed_at']}")
251 | with cols[2]:
252 | if detailed_data.get("description"):
253 | st.markdown(f"**Description:** {detailed_data['description']}")
254 |
255 | # Display file URL
256 | st.markdown(f"**File URL:** [{test_case['fileUrl']}]({test_case['fileUrl']})")
257 |
258 | # Create two columns for file preview and JSON diff
259 | left_col, right_col = st.columns(2)
260 |
261 | # Display file preview on the left
262 | with left_col:
263 | display_file_preview(test_case, left_col)
264 |
265 | # Display JSON diff on the right
266 | with right_col:
267 | display_json_diff(test_case, right_col)
268 |
269 | # Display markdown diff at the bottom
270 | st.markdown("---") # Add a separator
271 | display_markdown_diff(test_case)
272 |
273 |
274 | if __name__ == "__main__":
275 | main()
276 |
--------------------------------------------------------------------------------
/dashboard/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pathlib import Path
4 | from datetime import datetime
5 | from dotenv import load_dotenv
6 | from sqlalchemy.sql import text
7 | from sqlalchemy.orm import sessionmaker
8 | from sqlalchemy import create_engine
9 | from typing import Dict, Any, List, TypedDict, Optional
10 |
11 | load_dotenv()
12 |
13 |
14 | class BenchmarkRunMetadata(TypedDict):
15 | timestamp: str
16 | status: str
17 | run_by: Optional[str]
18 | description: Optional[str]
19 | total_documents: Optional[int]
20 | created_at: Optional[str]
21 | completed_at: Optional[str]
22 |
23 |
24 | def load_run_list_from_folder(
25 | results_dir: str = "results",
26 | ) -> List[BenchmarkRunMetadata]:
27 | """Load list of benchmark runs from the results directory"""
28 | results_path = Path(results_dir)
29 | result_dirs = [d for d in results_path.iterdir() if d.is_dir()]
30 | runs = []
31 |
32 | for dir_path in result_dirs:
33 | timestamp = dir_path.name
34 | json_path = dir_path / "results.json"
35 | if json_path.exists():
36 | runs.append(
37 | {
38 | "timestamp": timestamp,
39 | "status": "completed", # Assuming completed if file exists
40 | "run_by": None,
41 | "description": None,
42 | "total_documents": None,
43 | "created_at": format_timestamp(timestamp),
44 | "completed_at": format_timestamp(timestamp),
45 | }
46 | )
47 |
48 | return sorted(runs, key=lambda x: x["timestamp"], reverse=True)
49 |
50 |
51 | def load_run_list_from_db() -> List[BenchmarkRunMetadata]:
52 | """Load list of benchmark runs from database"""
53 | database_url = os.getenv("DATABASE_URL")
54 | engine = create_engine(database_url)
55 | Session = sessionmaker(bind=engine)
56 | session = Session()
57 |
58 | query = text(
59 | """
60 | SELECT
61 | timestamp,
62 | status,
63 | run_by,
64 | description,
65 | total_documents,
66 | created_at,
67 | completed_at
68 | FROM benchmark_runs
69 | ORDER BY created_at DESC
70 | """
71 | )
72 |
73 | rows = session.execute(query)
74 | runs = []
75 |
76 | for row in rows:
77 | runs.append(
78 | {
79 | "timestamp": row.timestamp,
80 | "status": row.status,
81 | "run_by": row.run_by,
82 | "description": row.description,
83 | "total_documents": row.total_documents,
84 | "created_at": (
85 | row.created_at.strftime("%Y-%m-%d %H:%M:%S")
86 | if row.created_at
87 | else None
88 | ),
89 | "completed_at": (
90 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S")
91 | if row.completed_at
92 | else None
93 | ),
94 | }
95 | )
96 |
97 | session.close()
98 | return runs
99 |
100 |
101 | def load_results_for_run_from_folder(
102 | timestamp: str, results_dir: str = "results"
103 | ) -> Dict[str, Any]:
104 | """Load results for a specific run from folder"""
105 | results_path = Path(results_dir) / timestamp / "results.json"
106 | if results_path.exists():
107 | with open(results_path) as f:
108 | results = json.load(f)
109 | # Assign id to each result if not already present
110 | for idx, result in enumerate(results):
111 | if "id" not in result:
112 | result["id"] = idx
113 | total_documents = len(results)
114 | return {
115 | "results": results,
116 | "status": "completed",
117 | "run_by": None,
118 | "description": None,
119 | "total_documents": total_documents,
120 | "created_at": format_timestamp(timestamp),
121 | "completed_at": format_timestamp(timestamp),
122 | }
123 | return {}
124 |
125 |
126 | def load_results_for_run_from_db(
127 | timestamp: str, include_metrics_only: bool = True
128 | ) -> Dict[str, Any]:
129 | """Load results for a specific run from database"""
130 | database_url = os.getenv("DATABASE_URL")
131 | engine = create_engine(database_url)
132 | Session = sessionmaker(bind=engine)
133 | session = Session()
134 |
135 | if not include_metrics_only:
136 | output_string = """
137 | 'trueMarkdown', bres.true_markdown,
138 | 'predictedMarkdown', bres.predicted_markdown,
139 | 'trueJson', bres.true_json,
140 | 'predictedJson', bres.predicted_json,
141 | 'jsonDiff', bres.json_diff,
142 | 'fullJsonDiff', bres.full_json_diff,
143 | """
144 | else:
145 | output_string = ""
146 |
147 | query = text(
148 | f"""
149 | WITH filtered_run AS (
150 | SELECT id, timestamp, status, run_by, description, total_documents, created_at, completed_at
151 | FROM benchmark_runs
152 | WHERE timestamp = :timestamp
153 | )
154 | SELECT
155 | fr.timestamp,
156 | fr.status,
157 | fr.run_by,
158 | fr.description,
159 | fr.total_documents,
160 | fr.created_at,
161 | fr.completed_at,
162 | json_agg(
163 | json_build_object(
164 | 'id', bres.id,
165 | 'fileUrl', bres.file_url,
166 | 'ocrModel', bres.ocr_model,
167 | 'extractionModel', bres.extraction_model,
168 | 'directImageExtraction', bres.direct_image_extraction,
169 | {output_string}
170 | 'levenshteinDistance', bres.levenshtein_distance,
171 | 'jsonAccuracy', bres.json_accuracy,
172 | 'jsonAccuracyResult', bres.json_accuracy_result,
173 | 'jsonDiffStats', bres.json_diff_stats,
174 | 'metadata', bres.metadata,
175 | 'usage', bres.usage,
176 | 'error', bres.error
177 | )
178 | ) as results
179 | FROM filtered_run fr
180 | LEFT JOIN benchmark_results bres ON fr.id = bres.benchmark_run_id
181 | GROUP BY fr.id, fr.timestamp, fr.status, fr.run_by, fr.description, fr.total_documents, fr.created_at, fr.completed_at
182 | """
183 | )
184 |
185 | row = session.execute(query, {"timestamp": timestamp}).first()
186 | session.close()
187 |
188 | if row:
189 | return {
190 | "results": row.results,
191 | "status": row.status,
192 | "total_documents": row.total_documents,
193 | "run_by": row.run_by,
194 | "description": row.description,
195 | "created_at": (
196 | row.created_at.strftime("%Y-%m-%d %H:%M:%S") if row.created_at else None
197 | ),
198 | "completed_at": (
199 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S")
200 | if row.completed_at
201 | else None
202 | ),
203 | }
204 | return {}
205 |
206 |
207 | def load_one_result_from_db(timestamp: str, id: str) -> Dict[str, Any]:
208 | """Load one test case result from database for a specific run and file"""
209 | database_url = os.getenv("DATABASE_URL")
210 | engine = create_engine(database_url)
211 | Session = sessionmaker(bind=engine)
212 | session = Session()
213 |
214 | query = text(
215 | """
216 | WITH filtered_results AS (
217 | SELECT *
218 | FROM benchmark_results
219 | WHERE id = :id
220 | )
221 | SELECT
222 | br.timestamp,
223 | br.status,
224 | br.run_by,
225 | br.description,
226 | br.total_documents,
227 | br.created_at,
228 | br.completed_at,
229 | json_build_object(
230 | 'id', fr.id,
231 | 'fileUrl', fr.file_url,
232 | 'ocrModel', fr.ocr_model,
233 | 'extractionModel', fr.extraction_model,
234 | 'directImageExtraction', fr.direct_image_extraction,
235 | 'trueMarkdown', fr.true_markdown,
236 | 'predictedMarkdown', fr.predicted_markdown,
237 | 'trueJson', fr.true_json,
238 | 'predictedJson', fr.predicted_json,
239 | 'jsonDiff', fr.json_diff,
240 | 'fullJsonDiff', fr.full_json_diff,
241 | 'jsonDiffStats', fr.json_diff_stats,
242 | 'levenshteinDistance', fr.levenshtein_distance,
243 | 'jsonAccuracy', fr.json_accuracy,
244 | 'jsonAccuracyResult', fr.json_accuracy_result,
245 | 'jsonSchema', fr.json_schema,
246 | 'metadata', fr.metadata,
247 | 'usage', fr.usage,
248 | 'error', fr.error
249 | ) as result
250 | FROM benchmark_runs br
251 | INNER JOIN filtered_results fr ON br.id = fr.benchmark_run_id
252 | WHERE br.timestamp = :timestamp
253 | LIMIT 1
254 | """
255 | )
256 |
257 | row = session.execute(query, {"timestamp": timestamp, "id": id}).first()
258 | session.close()
259 |
260 | if row:
261 | return {
262 | "result": row.result,
263 | "status": row.status,
264 | "run_by": row.run_by,
265 | "description": row.description,
266 | "created_at": (
267 | row.created_at.strftime("%Y-%m-%d %H:%M:%S") if row.created_at else None
268 | ),
269 | "completed_at": (
270 | row.completed_at.strftime("%Y-%m-%d %H:%M:%S")
271 | if row.completed_at
272 | else None
273 | ),
274 | }
275 | return {}
276 |
277 |
278 | def load_one_result_from_folder(
279 | timestamp: str, id: str, results_dir: str = "results"
280 | ) -> Dict[str, Any]:
281 | """Load one test case result from folder for a specific run and file"""
282 | results_path = Path(results_dir) / timestamp / "results.json"
283 | if results_path.exists():
284 | with open(results_path) as f:
285 | results = json.load(f)
286 | for idx, result in enumerate(results):
287 | if idx == id:
288 | return {
289 | "result": result,
290 | "status": "completed",
291 | "run_by": None,
292 | "description": None,
293 | "created_at": format_timestamp(timestamp),
294 | "completed_at": format_timestamp(timestamp),
295 | }
296 | return {}
297 |
298 |
299 | def load_run_list() -> List[BenchmarkRunMetadata]:
300 | """Load list of benchmark runs from either database or local files"""
301 | if os.getenv("DATABASE_URL"):
302 | return load_run_list_from_db()
303 | return load_run_list_from_folder()
304 |
305 |
306 | def load_results_for_run(
307 | timestamp: str, include_metrics_only: bool = True
308 | ) -> Dict[str, Any]:
309 | """Load results for a specific run from either database or local files"""
310 | if os.getenv("DATABASE_URL"):
311 | return load_results_for_run_from_db(timestamp, include_metrics_only)
312 | return load_results_for_run_from_folder(timestamp)
313 |
314 |
315 | def load_one_result(timestamp: str, id: str) -> Dict[str, Any]:
316 | """Load one test case result from either database or local files"""
317 | if os.getenv("DATABASE_URL"):
318 | return load_one_result_from_db(timestamp, id)
319 | return load_one_result_from_folder(timestamp, id)
320 |
321 |
322 | def format_timestamp(timestamp: str) -> str:
323 | """Convert timestamp string to readable format"""
324 | return datetime.strptime(timestamp, "%Y-%m-%d-%H-%M-%S").strftime(
325 | "%Y-%m-%d %H:%M:%S"
326 | )
327 |
--------------------------------------------------------------------------------
/dashboard/utils/style.py:
--------------------------------------------------------------------------------
1 | SIDEBAR_STYLE = """
2 |
7 | """
8 |
--------------------------------------------------------------------------------
/data/receipt.json:
--------------------------------------------------------------------------------
1 | {
2 | "imageUrl": "https://omni-demo-data.s3.us-east-1.amazonaws.com/templates/receipt.png",
3 | "metadata": {
4 | "orientation": 0,
5 | "documentQuality": "clean",
6 | "resolution": [612, 792],
7 | "language": "EN"
8 | },
9 | "jsonSchema": {
10 | "type": "object",
11 | "required": ["merchant", "receipt_details", "totals"],
12 | "properties": {
13 | "totals": {
14 | "type": "object",
15 | "required": ["total"],
16 | "properties": {
17 | "tax": {
18 | "type": "number",
19 | "description": "Tax amount"
20 | },
21 | "total": {
22 | "type": "number",
23 | "description": "Final total amount"
24 | },
25 | "subtotal": {
26 | "type": "number",
27 | "description": "Subtotal before tax and fees"
28 | }
29 | },
30 | "description": "Payment totals"
31 | },
32 | "merchant": {
33 | "type": "object",
34 | "required": ["name"],
35 | "properties": {
36 | "name": {
37 | "type": "string",
38 | "description": "Business name"
39 | },
40 | "phone": {
41 | "type": "string",
42 | "description": "Contact phone number"
43 | },
44 | "address": {
45 | "type": "string",
46 | "description": "Store location address"
47 | }
48 | },
49 | "description": "Basic merchant information"
50 | },
51 | "line_items": {
52 | "type": "array",
53 | "items": {
54 | "type": "object",
55 | "required": ["description", "amount"],
56 | "properties": {
57 | "amount": {
58 | "type": "number",
59 | "description": "Price of the item"
60 | },
61 | "description": {
62 | "type": "string",
63 | "description": "Item name or description"
64 | }
65 | }
66 | },
67 | "description": "List of purchased items"
68 | },
69 | "receipt_details": {
70 | "type": "object",
71 | "required": ["date"],
72 | "properties": {
73 | "date": {
74 | "type": "string",
75 | "description": "Transaction date"
76 | },
77 | "time": {
78 | "type": "string",
79 | "description": "Transaction time"
80 | },
81 | "receipt_number": {
82 | "type": "string",
83 | "description": "Receipt or ticket number"
84 | }
85 | },
86 | "description": "Transaction details"
87 | },
88 | "payment": {
89 | "type": "object",
90 | "properties": {
91 | "payment_method": {
92 | "type": "string",
93 | "description": ""
94 | },
95 | "card_last_four_digits": {
96 | "type": "string",
97 | "description": ""
98 | }
99 | }
100 | }
101 | }
102 | },
103 | "trueJsonOutput": {
104 | "totals": {
105 | "tax": 6.18,
106 | "total": 48.43,
107 | "subtotal": 42.25
108 | },
109 | "merchant": {
110 | "name": "Nick the Greek Souvlaki & Gyro House",
111 | "phone": "(415) 757-0426",
112 | "address": "121 Spear Street, Suite B08, San Francisco, CA 94105"
113 | },
114 | "line_items": [
115 | {
116 | "amount": 12.5,
117 | "description": "Beef/Lamb Gyro Pita"
118 | },
119 | {
120 | "amount": 13.25,
121 | "description": "Gyro Bowl"
122 | },
123 | {
124 | "amount": 16.5,
125 | "description": "Pork Gyro Pita"
126 | }
127 | ],
128 | "receipt_details": {
129 | "date": "November 8, 2024",
130 | "time": "2:16 PM",
131 | "receipt_number": "NKZ1"
132 | },
133 | "payment": {
134 | "payment_method": "Mastercard",
135 | "card_last_four_digits": "0920"
136 | }
137 | },
138 | "trueMarkdownOutput": "**NICK THE GREEK**\n\nSOUVLAKI & GYRO HOUSE\n\n**San Francisco**\n\n121 spear streeet \nSuite B08 \nsan francisco, CA \n94105 \n(415) 757-0426 \nwww.nickthegreeksj.com\n\nNovember 8, 2024 \n2:16 PM \nSamantha\n\nTicket: 17 \nReceipt: NKZ1 \nAuthorization: CF2D4F\n\nMastercard \nAID A0 00 00 00 04 10 10\n\n**TO GO**\n\nBeef/Lamb Gyro Pita $12.50 \nGyro Bowl $13.25 \nBeef/Lamb Gyro \nPork Gyro Pita $16.50 \nFries & Drink ($4.00)\n\nSubtotal $42.25 \nSF Mandate (6%) $2.54 \n8.625% (8.625%) $3.64\n\n**Total** $48.43 \nMastercard 0920 (Contactless) $48.43"
139 | }
140 |
--------------------------------------------------------------------------------
/jest.config.ts:
--------------------------------------------------------------------------------
1 | // jest.config.js
2 | module.exports = {
3 | preset: 'ts-jest',
4 | testEnvironment: 'node',
5 | moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json'],
6 | testMatch: ['**/tests/**/*.test.ts'],
7 | transform: {
8 | '^.+\\.(ts|tsx)$': 'ts-jest',
9 | },
10 | };
11 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "benchmark",
3 | "version": "0.0.1",
4 | "description": "OCR Benchmark",
5 | "main": "index.js",
6 | "scripts": {
7 | "build": "tsc",
8 | "test": "jest",
9 | "benchmark": "ts-node src/index.ts"
10 | },
11 | "dependencies": {
12 | "@ai-sdk/anthropic": "^1.2.12",
13 | "@ai-sdk/azure": "^1.1.9",
14 | "@ai-sdk/deepseek": "^0.1.6",
15 | "@ai-sdk/google": "^1.1.10",
16 | "@ai-sdk/openai": "^1.3.14",
17 | "@aws-sdk/client-textract": "^3.716.0",
18 | "@azure-rest/ai-document-intelligence": "^1.0.0",
19 | "@azure/core-auth": "^1.9.0",
20 | "@google-cloud/documentai": "^8.12.0",
21 | "@google/generative-ai": "^0.21.0",
22 | "@huggingface/hub": "^1.0.1",
23 | "@mistralai/mistralai": "^1.5.1",
24 | "@prisma/client": "^6.3.1",
25 | "ai": "^4.3.16",
26 | "axios": "^1.7.9",
27 | "canvas": "^3.1.0",
28 | "cli-progress": "^3.12.0",
29 | "dotenv": "^16.4.7",
30 | "fastest-levenshtein": "^1.0.16",
31 | "form-data": "^4.0.2",
32 | "jimp": "^1.6.0",
33 | "json-diff": "^1.0.6",
34 | "lodash": "^4.17.21",
35 | "moment": "^2.30.1",
36 | "openai": "^4.94.0",
37 | "p-limit": "^3.1.0",
38 | "pdfkit": "^0.17.1",
39 | "pg": "^8.13.1",
40 | "sharp": "^0.33.5",
41 | "together-ai": "^0.13.0",
42 | "turndown": "^7.2.0",
43 | "zerox": "^1.0.43"
44 | },
45 | "devDependencies": {
46 | "@eslint/js": "^9.17.0",
47 | "@types/jest": "^29.5.14",
48 | "eslint": "^9.17.0",
49 | "jest": "^29.7.0",
50 | "prettier": "^3.4.2",
51 | "prisma": "^6.3.1",
52 | "ts-jest": "^29.2.5",
53 | "ts-node": "^10.9.2",
54 | "typescript": "^5.7.2",
55 | "typescript-eslint": "^8.18.1"
56 | },
57 | "keywords": [
58 | "OCR",
59 | "Benchmark",
60 | "LLM"
61 | ],
62 | "author": "[@annapo23, @tylermaran, @kailingding, @zeeshan]",
63 | "license": "ISC"
64 | }
65 |
--------------------------------------------------------------------------------
/prisma/schema.prisma:
--------------------------------------------------------------------------------
1 | generator client {
2 | provider = "prisma-client-js"
3 | }
4 |
5 | datasource db {
6 | provider = "postgresql"
7 | url = env("DATABASE_URL")
8 | }
9 |
10 | model BenchmarkRun {
11 | id String @id @default(uuid()) @db.Uuid
12 | completedAt DateTime? @map("completed_at")
13 | createdAt DateTime @default(now()) @map("created_at")
14 | description String? @map("description")
15 | error String?
16 | modelsConfig Json @map("models_config") // The models.yaml configuration
17 | results BenchmarkResult[]
18 | runBy String? @map("run_by")
19 | status String // 'running', 'completed', 'failed'
20 | timestamp String // timestamp format: YYYY-MM-DD-HH-mm-ss
21 | totalDocuments Int @map("total_documents")
22 |
23 | @@map("benchmark_runs")
24 | }
25 |
26 | model BenchmarkResult {
27 | id String @id @default(uuid()) @db.Uuid
28 | benchmarkRun BenchmarkRun @relation(fields: [benchmarkRunId], references: [id])
29 | benchmarkRunId String @map("benchmark_run_id")
30 | createdAt DateTime @default(now()) @map("created_at")
31 | directImageExtraction Boolean @default(false) @map("direct_image_extraction")
32 | error String?
33 | extractionModel String? @map("extraction_model")
34 | fileUrl String @map("file_url")
35 | fullJsonDiff Json? @map("full_json_diff")
36 | jsonAccuracy Float? @map("json_accuracy")
37 | jsonAccuracyResult Json? @map("json_accuracy_result")
38 | jsonDiff Json? @map("json_diff")
39 | jsonDiffStats Json? @map("json_diff_stats")
40 | jsonSchema Json @map("json_schema")
41 | levenshteinDistance Float? @map("levenshtein_distance")
42 | metadata Json @map("metadata")
43 | ocrModel String @map("ocr_model")
44 | predictedJson Json? @map("predicted_json")
45 | predictedMarkdown String? @map("predicted_markdown")
46 | trueJson Json @map("true_json")
47 | trueMarkdown String @map("true_markdown")
48 | usage Json?
49 |
50 | @@map("benchmark_results")
51 | }
52 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.41.1
2 | pandas==2.2.3
3 | datetime==5.5
4 | plotly==5.24.1
5 | sqlalchemy==2.0.38
6 | psycopg2-binary==2.9.10
7 | python-dotenv==1.0.1
--------------------------------------------------------------------------------
/src/evaluation/index.ts:
--------------------------------------------------------------------------------
1 | export * from './text';
2 | export * from './json';
3 |
--------------------------------------------------------------------------------
/src/evaluation/json.ts:
--------------------------------------------------------------------------------
1 | import { diff } from 'json-diff';
2 |
3 | interface DiffStats {
4 | additions: number;
5 | deletions: number;
6 | modifications: number;
7 | total: number;
8 | }
9 |
10 | export interface AccuracyResult {
11 | score: number;
12 | fullJsonDiff: Record
;
13 | jsonDiff: Record;
14 | jsonDiffStats?: DiffStats;
15 | totalFields: number;
16 | }
17 |
18 | /**
19 | * Calculates accuracy for JSON structure and primitive values only
20 | *
21 | * The accuracy is calculated as:
22 | * 1 - (number of differences / total fields in actual)
23 | *
24 | * Differences include:
25 | * - Additions: Fields present in predicted but not in actual
26 | * - Deletions: Fields present in actual but not in predicted
27 | * - Modifications: Fields present in both but with different values
28 | *
29 | * A score of 1.0 means the JSONs are identical
30 | * A score of 0.0 means completely different
31 | */
32 | export const calculateJsonAccuracy = (
33 | actual: Record,
34 | predicted: Record,
35 | ignoreCases: boolean = false,
36 | ): AccuracyResult => {
37 | // Convert strings to uppercase if ignoreCases is true
38 | const processedActual = ignoreCases ? convertStringsToUppercase(actual) : actual;
39 | const processedPredicted = ignoreCases
40 | ? convertStringsToUppercase(predicted)
41 | : predicted;
42 |
43 | // Get the diff result
44 | const fullDiffResult = diff(processedActual, processedPredicted, {
45 | full: true,
46 | sort: true,
47 | });
48 | const diffResult = diff(processedActual, processedPredicted, { sort: true });
49 | const totalFields = countTotalFields(processedActual);
50 |
51 | if (!diffResult) {
52 | // If there's no diff, the JSONs are identical
53 | return {
54 | score: 1,
55 | jsonDiff: {},
56 | fullJsonDiff: {},
57 | jsonDiffStats: {
58 | additions: 0,
59 | deletions: 0,
60 | modifications: 0,
61 | total: 0,
62 | },
63 | totalFields,
64 | };
65 | }
66 |
67 | const changes = countChanges(diffResult);
68 | const score = Math.max(
69 | 0,
70 | 1 - (changes.additions + changes.deletions + changes.modifications) / totalFields,
71 | );
72 |
73 | return {
74 | score: Number(score.toFixed(4)),
75 | jsonDiff: diffResult,
76 | fullJsonDiff: fullDiffResult,
77 | jsonDiffStats: changes,
78 | totalFields,
79 | };
80 | };
81 |
82 | /**
83 | * Recursively converts all string values in an object to uppercase
84 | */
85 | const convertStringsToUppercase = (obj: any): any => {
86 | if (obj === null || typeof obj !== 'object') {
87 | return obj;
88 | }
89 |
90 | if (Array.isArray(obj)) {
91 | return obj.map((item) => convertStringsToUppercase(item));
92 | }
93 |
94 | const result: Record = {};
95 | for (const key in obj) {
96 | const value = obj[key];
97 | if (typeof value === 'string') {
98 | result[key] = value.toUpperCase();
99 | } else if (typeof value === 'object' && value !== null) {
100 | result[key] = convertStringsToUppercase(value);
101 | } else {
102 | result[key] = value;
103 | }
104 | }
105 | return result;
106 | };
107 |
108 | export const countChanges = (diffResult: any): DiffStats => {
109 | const changes: DiffStats = {
110 | additions: 0,
111 | deletions: 0,
112 | modifications: 0,
113 | total: 0,
114 | };
115 |
116 | const traverse = (obj: any) => {
117 | if (!obj || typeof obj !== 'object') {
118 | return;
119 | }
120 |
121 | for (const key in obj) {
122 | const value = obj[key];
123 |
124 | if (Array.isArray(value)) {
125 | // Handle array diffs
126 | value.forEach((item) => {
127 | // Check if item is in the expected [operation, element] format
128 | if (!Array.isArray(item) || item.length !== 2) {
129 | return;
130 | }
131 |
132 | const [operation, element] = item;
133 | if (element === null || typeof element !== 'object') {
134 | // Handle primitive value changes in arrays
135 | switch (operation) {
136 | case '+':
137 | changes.additions++;
138 | break;
139 | case '-':
140 | changes.deletions++;
141 | break;
142 | }
143 | } else {
144 | switch (operation) {
145 | // Handle array element additions and deletions
146 | case '+':
147 | changes.additions += countTotalFields(element);
148 | break;
149 | case '-':
150 | changes.deletions += countTotalFields(element);
151 | break;
152 | case '~':
153 | // Handle array element modifications
154 | traverse(element);
155 | break;
156 | }
157 | }
158 | });
159 | } else {
160 | if (key.endsWith('__deleted')) {
161 | if (value === null || typeof value !== 'object') {
162 | changes.deletions++;
163 | } else {
164 | changes.deletions += countTotalFields(value);
165 | }
166 | } else if (key.endsWith('__added')) {
167 | if (value === null || typeof value !== 'object') {
168 | changes.additions++;
169 | } else {
170 | changes.additions += countTotalFields(value);
171 | }
172 | } else if (typeof value === 'object' && value !== null) {
173 | if (value.__old !== undefined && value.__new !== undefined) {
174 | if (value.__old === null && value.__new !== null) {
175 | changes.modifications += countTotalFields(value.__new) || 1;
176 | } else {
177 | changes.modifications += countTotalFields(value.__old) || 1;
178 | }
179 | } else {
180 | traverse(value);
181 | }
182 | }
183 | }
184 | }
185 | };
186 |
187 | traverse(diffResult);
188 |
189 | changes.total = changes.additions + changes.deletions + changes.modifications;
190 | return changes;
191 | };
192 |
193 | export function countTotalFields(obj: any): number {
194 | let count = 0;
195 |
196 | const traverse = (current: any) => {
197 | if (!current || typeof current !== 'object') {
198 | return;
199 | }
200 |
201 | if (Array.isArray(current)) {
202 | // Traverse into array elements if they're objects
203 | current.forEach((item) => {
204 | if (typeof item === 'object' && item !== null) {
205 | traverse(item);
206 | } else {
207 | count++;
208 | }
209 | });
210 | } else {
211 | for (const key in current) {
212 | // Skip diff metadata keys
213 | if (key.includes('__')) {
214 | continue;
215 | }
216 |
217 | // Only count primitive value fields
218 | if (
219 | current[key] === null ||
220 | typeof current[key] === 'string' ||
221 | typeof current[key] === 'number' ||
222 | typeof current[key] === 'boolean'
223 | ) {
224 | count++;
225 | }
226 | // Recurse into nested objects and arrays
227 | else if (typeof current[key] === 'object') {
228 | traverse(current[key]);
229 | }
230 | }
231 | }
232 | };
233 |
234 | traverse(obj);
235 | return count;
236 | }
237 |
--------------------------------------------------------------------------------
/src/evaluation/text.ts:
--------------------------------------------------------------------------------
1 | import { distance } from 'fastest-levenshtein';
2 |
3 | /**
4 | * Calculates text similarity between original and OCR text using Levenshtein distance
5 | * Returns a score between 0 and 1, where:
6 | * 1.0 = texts are identical
7 | * 0.0 = texts are completely different
8 | */
9 | export const calculateTextSimilarity = (original: string, predicted: string): number => {
10 | if (original === predicted) return 1;
11 | if (!original.length || !predicted.length) return 0;
12 |
13 | // Normalize strings
14 | const normalizedOriginal = original.trim().toLowerCase();
15 | const normalizedPredicted = predicted.trim().toLowerCase();
16 |
17 | // Calculate Levenshtein distance
18 | const levenshteinDistance = distance(normalizedOriginal, normalizedPredicted);
19 |
20 | // Normalize score between 0 and 1
21 | const maxLength = Math.max(normalizedOriginal.length, normalizedPredicted.length);
22 | const similarity = 1 - levenshteinDistance / maxLength;
23 |
24 | return Number(similarity.toFixed(4));
25 | };
26 |
--------------------------------------------------------------------------------
/src/index.ts:
--------------------------------------------------------------------------------
1 | import dotenv from 'dotenv';
2 | import path from 'path';
3 | import moment from 'moment';
4 | import cliProgress from 'cli-progress';
5 | import { isEmpty } from 'lodash';
6 | import pLimit from 'p-limit';
7 | import yaml from 'js-yaml';
8 | import fs from 'fs';
9 |
10 | import { BenchmarkRun } from '@prisma/client';
11 | import { calculateJsonAccuracy, calculateTextSimilarity } from './evaluation';
12 | import { getModelProvider } from './models';
13 | import { Result } from './types';
14 | import {
15 | createResultFolder,
16 | loadLocalData,
17 | writeToFile,
18 | loadFromDb,
19 | createBenchmarkRun,
20 | saveResult,
21 | completeBenchmarkRun,
22 | } from './utils';
23 |
24 | dotenv.config();
25 |
26 | /* -------------------------------------------------------------------------- */
27 | /* Benchmark Config */
28 | /* -------------------------------------------------------------------------- */
29 |
30 | const MODEL_CONCURRENCY = {
31 | 'aws-textract': 50,
32 | 'azure-document-intelligence': 50,
33 | 'claude-3-5-sonnet-20241022': 10,
34 | 'gemini-2.0-flash-001': 30,
35 | 'mistral-ocr': 5,
36 | 'gpt-4o': 50,
37 | 'qwen2.5-vl-32b-instruct': 10,
38 | 'qwen2.5-vl-72b-instruct': 10,
39 | 'google/gemma-3-27b-it': 10,
40 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo': 10,
41 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo': 10,
42 | omniai: 30,
43 | zerox: 50,
44 | };
45 |
46 | interface ModelConfig {
47 | ocr: string;
48 | extraction?: string;
49 | directImageExtraction?: boolean;
50 | }
51 |
52 | // Load models config
53 | const loadModelsConfig = () => {
54 | try {
55 | const configPath = path.join(__dirname, 'models.yaml');
56 | const fileContents = fs.readFileSync(configPath, 'utf8');
57 | const config = yaml.load(fileContents) as { models: ModelConfig[] };
58 | return config.models;
59 | } catch (error) {
60 | console.error('Error loading models config:', error);
61 | return [] as ModelConfig[];
62 | }
63 | };
64 |
65 | const MODELS = loadModelsConfig();
66 |
67 | const DATA_FOLDER = path.join(__dirname, '../data');
68 |
69 | const DATABASE_URL = process.env.DATABASE_URL;
70 |
71 | const TIMEOUT_MS = 10 * 60 * 1000; // 10 minutes in milliseconds
72 |
73 | const withTimeout = async (promise: Promise, operation: string) => {
74 | let timeoutId: NodeJS.Timeout;
75 |
76 | const timeoutPromise = new Promise((_, reject) => {
77 | timeoutId = setTimeout(() => {
78 | reject(new Error(`${operation} operation timed out after ${TIMEOUT_MS}ms`));
79 | }, TIMEOUT_MS);
80 | });
81 |
82 | try {
83 | const result = await Promise.race([promise, timeoutPromise]);
84 | clearTimeout(timeoutId);
85 | return result;
86 | } catch (error) {
87 | clearTimeout(timeoutId);
88 | console.error(`Timeout error in ${operation}:`, error);
89 | throw error;
90 | }
91 | };
92 |
93 | /* -------------------------------------------------------------------------- */
94 | /* Run Benchmark */
95 | /* -------------------------------------------------------------------------- */
96 |
97 | const timestamp = moment(new Date()).format('YYYY-MM-DD-HH-mm-ss');
98 | const resultFolder = createResultFolder(timestamp);
99 |
100 | const runBenchmark = async () => {
101 | const data = DATABASE_URL ? await loadFromDb() : loadLocalData(DATA_FOLDER);
102 | const results: Result[] = [];
103 |
104 | // Create benchmark run
105 | let benchmarkRun: BenchmarkRun;
106 | if (DATABASE_URL) {
107 | benchmarkRun = await createBenchmarkRun(timestamp, MODELS, data.length);
108 | }
109 |
110 | // Create multiple progress bars
111 | const multibar = new cliProgress.MultiBar({
112 | format: '{model} |{bar}| {percentage}% | {value}/{total}',
113 | barCompleteChar: '\u2588',
114 | barIncompleteChar: '\u2591',
115 | clearOnComplete: false,
116 | hideCursor: true,
117 | });
118 |
119 | // Create progress bars for each model
120 | const progressBars = MODELS.reduce(
121 | (acc, model) => ({
122 | ...acc,
123 | [`${model.directImageExtraction ? `${model.extraction} (IMG2JSON)` : `${model.ocr}-${model.extraction}`}`]:
124 | multibar.create(data.length, 0, {
125 | model: `${model.directImageExtraction ? `${model.extraction} (IMG2JSON)` : `${model.ocr} -> ${model.extraction}`}`,
126 | }),
127 | }),
128 | {},
129 | );
130 |
131 | const modelPromises = MODELS.map(
132 | async ({ ocr: ocrModel, extraction: extractionModel, directImageExtraction }) => {
133 | // Calculate concurrent requests based on rate limit
134 | const concurrency = Math.min(
135 | MODEL_CONCURRENCY[ocrModel as keyof typeof MODEL_CONCURRENCY] ?? 20,
136 | MODEL_CONCURRENCY[extractionModel as keyof typeof MODEL_CONCURRENCY] ?? 20,
137 | );
138 | const limit = pLimit(concurrency);
139 |
140 | const promises = data.map((item) =>
141 | limit(async () => {
142 | const ocrModelProvider = getModelProvider(ocrModel);
143 | const extractionModelProvider = extractionModel
144 | ? getModelProvider(extractionModel)
145 | : undefined;
146 |
147 | const result: Result = {
148 | fileUrl: item.imageUrl,
149 | metadata: item.metadata,
150 | jsonSchema: item.jsonSchema,
151 | ocrModel,
152 | extractionModel,
153 | directImageExtraction,
154 | trueMarkdown: item.trueMarkdownOutput,
155 | trueJson: item.trueJsonOutput,
156 | predictedMarkdown: undefined,
157 | predictedJson: undefined,
158 | levenshteinDistance: undefined,
159 | jsonAccuracy: undefined,
160 | jsonDiff: undefined,
161 | fullJsonDiff: undefined,
162 | jsonDiffStats: undefined,
163 | jsonAccuracyResult: undefined,
164 | usage: undefined,
165 | };
166 |
167 | try {
168 | if (directImageExtraction) {
169 | const extractionResult = await withTimeout(
170 | extractionModelProvider.extractFromImage(item.imageUrl, item.jsonSchema),
171 | `JSON extraction: ${extractionModel}`,
172 | );
173 | result.predictedJson = extractionResult.json;
174 | result.usage = {
175 | ...extractionResult.usage,
176 | ocr: undefined,
177 | extraction: extractionResult.usage,
178 | };
179 | } else {
180 | let ocrResult;
181 | if (ocrModel === 'ground-truth') {
182 | result.predictedMarkdown = item.trueMarkdownOutput;
183 | } else {
184 | if (ocrModelProvider) {
185 | ocrResult = await withTimeout(
186 | ocrModelProvider.ocr(item.imageUrl),
187 | `OCR: ${ocrModel}`,
188 | );
189 | result.predictedMarkdown = ocrResult.text;
190 | result.usage = {
191 | ...ocrResult.usage,
192 | ocr: ocrResult.usage,
193 | extraction: undefined,
194 | };
195 | }
196 | }
197 |
198 | let extractionResult;
199 | if (extractionModelProvider) {
200 | extractionResult = await withTimeout(
201 | extractionModelProvider.extractFromText(
202 | result.predictedMarkdown,
203 | item.jsonSchema,
204 | ocrResult?.imageBase64s,
205 | ),
206 | `JSON extraction: ${extractionModel}`,
207 | );
208 | result.predictedJson = extractionResult.json;
209 |
210 | const mergeUsage = (base: any, additional: any) => ({
211 | duration: (base?.duration ?? 0) + (additional?.duration ?? 0),
212 | inputTokens: (base?.inputTokens ?? 0) + (additional?.inputTokens ?? 0),
213 | outputTokens:
214 | (base?.outputTokens ?? 0) + (additional?.outputTokens ?? 0),
215 | totalTokens: (base?.totalTokens ?? 0) + (additional?.totalTokens ?? 0),
216 | inputCost: (base?.inputCost ?? 0) + (additional?.inputCost ?? 0),
217 | outputCost: (base?.outputCost ?? 0) + (additional?.outputCost ?? 0),
218 | totalCost: (base?.totalCost ?? 0) + (additional?.totalCost ?? 0),
219 | });
220 |
221 | result.usage = {
222 | ocr: result.usage?.ocr ?? {},
223 | extraction: extractionResult.usage,
224 | ...mergeUsage(result.usage, extractionResult.usage),
225 | };
226 | }
227 | }
228 |
229 | if (result.predictedMarkdown) {
230 | result.levenshteinDistance = calculateTextSimilarity(
231 | item.trueMarkdownOutput,
232 | result.predictedMarkdown,
233 | );
234 | }
235 |
236 | if (!isEmpty(result.predictedJson)) {
237 | const jsonAccuracyResult = calculateJsonAccuracy(
238 | item.trueJsonOutput,
239 | result.predictedJson,
240 | );
241 | result.jsonAccuracy = jsonAccuracyResult.score;
242 | result.jsonDiff = jsonAccuracyResult.jsonDiff;
243 | result.fullJsonDiff = jsonAccuracyResult.fullJsonDiff;
244 | result.jsonDiffStats = jsonAccuracyResult.jsonDiffStats;
245 | result.jsonAccuracyResult = jsonAccuracyResult;
246 | }
247 | } catch (error) {
248 | result.error = error;
249 | console.error(
250 | `Error processing ${item.imageUrl} with ${ocrModel} and ${extractionModel}:\n`,
251 | error,
252 | );
253 | }
254 |
255 | if (benchmarkRun) {
256 | await saveResult(benchmarkRun.id, result);
257 | }
258 |
259 | // Update progress bar for this model
260 | progressBars[
261 | `${directImageExtraction ? `${extractionModel} (IMG2JSON)` : `${ocrModel}-${extractionModel}`}`
262 | ].increment();
263 | return result;
264 | }),
265 | );
266 |
267 | // Process items concurrently for this model
268 | const modelResults = await Promise.all(promises);
269 |
270 | results.push(...modelResults);
271 | },
272 | );
273 |
274 | // Process each model with its own concurrency limit
275 | await Promise.all(modelPromises);
276 |
277 | // Stop all progress bars
278 | multibar.stop();
279 |
280 | // Complete benchmark run successfully
281 | if (benchmarkRun) {
282 | await completeBenchmarkRun(benchmarkRun.id);
283 | }
284 |
285 | writeToFile(path.join(resultFolder, 'results.json'), results);
286 | };
287 |
288 | runBenchmark();
289 |
--------------------------------------------------------------------------------
/src/models.example.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | - ocr: ground-truth
3 | extraction: gpt-4o
4 |
5 | - ocr: gpt-4o
6 | extraction: gpt-4o
7 |
8 | - ocr: gpt-4o
9 | extraction: gpt-4o
10 | directImageExtraction: true
11 |
12 | # - ocr: gemini-2.0-flash-001
13 | # extraction: gpt-4o
14 |
15 | # - ocr: azure-gpt-4o
16 | # extraction: azure-gpt-4o
17 |
18 | # - ocr: claude-3-5-sonnet-20241022
19 | # extraction: claude-3-5-sonnet-20241022
20 |
21 | # - ocr: claude-3-5-sonnet-20241022
22 | # extraction: claude-3-5-sonnet-20241022
23 | # directImageExtraction: true
24 |
25 | # - ocr: zerox
26 | # extraction: gpt-4o
27 |
28 | # - ocr: omniai
29 | # extraction: gpt-4o
30 |
31 | # - ocr: aws-textract
32 | # extraction: gpt-4o
33 |
34 | # - ocr: google-document-ai
35 | # extraction: gpt-4o
36 |
37 | # - ocr: azure-document-intelligence
38 | # extraction: gpt-4o
39 |
40 | # - ocr: unstructured
41 | # extraction: gpt-4o
42 |
43 | # - ocr: gpt-4o
44 | # extraction: deepseek-chat
45 |
--------------------------------------------------------------------------------
/src/models/awsTextract.ts:
--------------------------------------------------------------------------------
1 | import { TextractClient, AnalyzeDocumentCommand } from '@aws-sdk/client-textract';
2 | import { ModelProvider } from './base';
3 |
4 | // https://aws.amazon.com/textract/pricing/
5 | // $4 per 1000 pages for the first 1M pages, Layout model
6 | const COST_PER_PAGE = 4 / 1000;
7 |
8 | export class AWSTextractProvider extends ModelProvider {
9 | private client: TextractClient;
10 |
11 | constructor() {
12 | super('aws-textract');
13 | this.client = new TextractClient({
14 | region: process.env.AWS_REGION,
15 | credentials: {
16 | accessKeyId: process.env.AWS_ACCESS_KEY_ID!,
17 | secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY!,
18 | },
19 | });
20 | }
21 |
22 | async ocr(imagePath: string) {
23 | try {
24 | // Convert image URL to base64
25 | const response = await fetch(imagePath);
26 | const arrayBuffer = await response.arrayBuffer();
27 | const buffer = Buffer.from(arrayBuffer);
28 |
29 | const start = performance.now();
30 | const command = new AnalyzeDocumentCommand({
31 | Document: {
32 | Bytes: buffer,
33 | },
34 | FeatureTypes: ['LAYOUT'],
35 | });
36 |
37 | const result = await this.client.send(command);
38 | const end = performance.now();
39 |
40 | // Extract text from blocks
41 | const text =
42 | result.Blocks?.filter((block) => block.Text)
43 | .map((block) => block.Text)
44 | .join('\n') || '';
45 |
46 | return {
47 | text,
48 | usage: {
49 | duration: end - start,
50 | totalCost: COST_PER_PAGE, // the input is always 1 page.
51 | },
52 | };
53 | } catch (error) {
54 | console.error('AWS Textract Error:', error);
55 | throw error;
56 | }
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/models/azure.ts:
--------------------------------------------------------------------------------
1 | import { AzureKeyCredential } from '@azure/core-auth';
2 | import DocumentIntelligence, {
3 | DocumentIntelligenceClient,
4 | getLongRunningPoller,
5 | isUnexpected,
6 | AnalyzeOperationOutput,
7 | } from '@azure-rest/ai-document-intelligence';
8 |
9 | import { ModelProvider } from './base';
10 |
11 | // https://azure.microsoft.com/en-us/pricing/details/ai-document-intelligence/
12 | // $10 per 1000 pages for the first 1M pages, Prebuilt-Layout model
13 | const COST_PER_PAGE = 10 / 1000;
14 |
15 | export class AzureDocumentIntelligenceProvider extends ModelProvider {
16 | private client: DocumentIntelligenceClient;
17 |
18 | constructor() {
19 | super('azure-document-intelligence');
20 |
21 | const endpoint = process.env.AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT;
22 | const apiKey = process.env.AZURE_DOCUMENT_INTELLIGENCE_KEY;
23 |
24 | if (!endpoint || !apiKey) {
25 | throw new Error('Missing required Azure Document Intelligence configuration');
26 | }
27 |
28 | this.client = DocumentIntelligence(endpoint, new AzureKeyCredential(apiKey));
29 | }
30 |
31 | async ocr(imagePath: string) {
32 | try {
33 | const start = performance.now();
34 |
35 | const initialResponse = await this.client
36 | .path('/documentModels/{modelId}:analyze', 'prebuilt-layout')
37 | .post({
38 | contentType: 'application/json',
39 | body: {
40 | urlSource: imagePath,
41 | },
42 | queryParameters: { outputContentFormat: 'markdown' },
43 | });
44 |
45 | if (isUnexpected(initialResponse)) {
46 | throw initialResponse.body.error;
47 | }
48 |
49 | const poller = getLongRunningPoller(this.client, initialResponse);
50 | const result = (await poller.pollUntilDone()).body as AnalyzeOperationOutput;
51 | const analyzeResult = result.analyzeResult;
52 | const text = analyzeResult?.content;
53 |
54 | const end = performance.now();
55 |
56 | return {
57 | text,
58 | usage: {
59 | duration: end - start,
60 | totalCost: COST_PER_PAGE, // the input is always 1 page.
61 | },
62 | };
63 | } catch (error) {
64 | console.error('Azure Document Intelligence Error:', error);
65 | throw error;
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/models/base.ts:
--------------------------------------------------------------------------------
1 | import { JsonSchema, Usage } from '../types';
2 |
3 | export class ModelProvider {
4 | model: string;
5 | outputDir?: string;
6 |
7 | constructor(model: string, outputDir?: string) {
8 | this.model = model;
9 | this.outputDir = outputDir;
10 | }
11 |
12 | async ocr(imagePath: string): Promise<{
13 | text: string;
14 | imageBase64s?: string[];
15 | usage: Usage;
16 | }> {
17 | throw new Error('Not implemented');
18 | }
19 |
20 | async extractFromText?(
21 | text: string,
22 | schema: JsonSchema,
23 | imageBase64s?: string[],
24 | ): Promise<{
25 | json: Record;
26 | usage: Usage;
27 | }> {
28 | throw new Error('Not implemented');
29 | }
30 |
31 | async extractFromImage?(
32 | imagePath: string,
33 | schema: JsonSchema,
34 | ): Promise<{
35 | json: Record;
36 | usage: Usage;
37 | }> {
38 | throw new Error('Not implemented');
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/src/models/dashscope.ts:
--------------------------------------------------------------------------------
1 | import OpenAI from 'openai';
2 | import sharp from 'sharp';
3 |
4 | import { ModelProvider } from './base';
5 | import { Usage } from '../types';
6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared';
7 |
8 | export class DashscopeProvider extends ModelProvider {
9 | private client: OpenAI;
10 |
11 | constructor(model: string) {
12 | super(model);
13 |
14 | const apiKey = process.env.DASHSCOPE_API_KEY;
15 | if (!apiKey) {
16 | throw new Error('Missing required HuggingFace API key');
17 | }
18 |
19 | this.client = new OpenAI({
20 | baseURL: 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
21 | apiKey,
22 | });
23 | }
24 |
25 | async ocr(imagePath: string): Promise<{
26 | text: string;
27 | imageBase64s?: string[];
28 | usage: Usage;
29 | }> {
30 | const start = performance.now();
31 |
32 | // Fetch the image
33 | const imageResponse = await fetch(imagePath);
34 | const imageBuffer = await imageResponse.arrayBuffer();
35 |
36 | // compress the image
37 | const resizedImageBuffer = await sharp(Buffer.from(imageBuffer))
38 | .jpeg({ quality: 90 })
39 | .toBuffer();
40 |
41 | // Convert to base64
42 | const resizedImageBase64 = `data:image/jpeg;base64,${resizedImageBuffer.toString('base64')}`;
43 |
44 | const response = await this.client.chat.completions.create({
45 | model: 'qwen2.5-vl-32b-instruct',
46 | messages: [
47 | {
48 | role: 'user',
49 | content: [
50 | { type: 'text', text: OCR_SYSTEM_PROMPT },
51 | {
52 | type: 'image_url',
53 | image_url: {
54 | url: resizedImageBase64,
55 | },
56 | },
57 | ],
58 | },
59 | ],
60 | });
61 |
62 | const end = performance.now();
63 |
64 | const inputTokens = response.usage?.prompt_tokens || 0;
65 | const outputTokens = response.usage?.completion_tokens || 0;
66 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
67 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
68 |
69 | return {
70 | text: response.choices[0].message.content || '',
71 | usage: {
72 | duration: end - start,
73 | inputTokens,
74 | outputTokens,
75 | totalTokens: inputTokens + outputTokens,
76 | inputCost,
77 | outputCost,
78 | totalCost: inputCost + outputCost,
79 | },
80 | };
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/src/models/gemini.ts:
--------------------------------------------------------------------------------
1 | import { GoogleGenerativeAI } from '@google/generative-ai';
2 |
3 | import { ModelProvider } from './base';
4 | import {
5 | IMAGE_EXTRACTION_SYSTEM_PROMPT,
6 | JSON_EXTRACTION_SYSTEM_PROMPT,
7 | OCR_SYSTEM_PROMPT,
8 | } from './shared/prompt';
9 | import { calculateTokenCost } from './shared/tokenCost';
10 | import { getMimeType } from '../utils';
11 | import { JsonSchema } from '../types';
12 |
13 | export class GeminiProvider extends ModelProvider {
14 | private client: GoogleGenerativeAI;
15 |
16 | constructor(model: string) {
17 | super(model);
18 |
19 | const apiKey = process.env.GOOGLE_GENERATIVE_AI_API_KEY;
20 |
21 | if (!apiKey) {
22 | throw new Error('Missing required Google Generative AI configuration');
23 | }
24 |
25 | this.client = new GoogleGenerativeAI(apiKey);
26 | }
27 |
28 | async ocr(imagePath: string) {
29 | try {
30 | const start = performance.now();
31 |
32 | const model = this.client.getGenerativeModel({
33 | model: this.model,
34 | generationConfig: { temperature: 0 },
35 | });
36 |
37 | // read image and convert to base64
38 | const response = await fetch(imagePath);
39 | const imageBuffer = await response.arrayBuffer();
40 | const base64Image = Buffer.from(imageBuffer).toString('base64');
41 |
42 | const imagePart = {
43 | inlineData: {
44 | data: base64Image,
45 | mimeType: getMimeType(imagePath),
46 | },
47 | };
48 |
49 | const ocrResult = await model.generateContent([OCR_SYSTEM_PROMPT, imagePart]);
50 | const text = ocrResult.response.text();
51 |
52 | const end = performance.now();
53 |
54 | const ocrInputTokens = ocrResult.response.usageMetadata.promptTokenCount;
55 | const ocrOutputTokens = ocrResult.response.usageMetadata.candidatesTokenCount;
56 | const inputCost = calculateTokenCost(this.model, 'input', ocrInputTokens);
57 | const outputCost = calculateTokenCost(this.model, 'output', ocrOutputTokens);
58 |
59 | return {
60 | text,
61 | usage: {
62 | duration: end - start,
63 | inputTokens: ocrInputTokens,
64 | outputTokens: ocrOutputTokens,
65 | totalTokens: ocrInputTokens + ocrOutputTokens,
66 | inputCost,
67 | outputCost,
68 | totalCost: inputCost + outputCost,
69 | },
70 | };
71 | } catch (error) {
72 | console.error('Google Generative AI OCR Error:', error);
73 | throw error;
74 | }
75 | }
76 |
77 | // FIXME: JSON output might not be 100% correct yet, because Gemini uses a subset of OpenAPI 3.0 schema
78 | // https://sdk.vercel.ai/providers/ai-sdk-providers/google-generative-ai#schema-limitations
79 | async extractFromText(text: string, schema: JsonSchema) {
80 | const filteredSchema = this.convertSchemaForGemini(schema);
81 |
82 | const start = performance.now();
83 | const model = this.client.getGenerativeModel({
84 | model: this.model,
85 | generationConfig: {
86 | temperature: 0,
87 | responseMimeType: 'application/json',
88 | responseSchema: filteredSchema,
89 | },
90 | });
91 |
92 | const result = await model.generateContent([JSON_EXTRACTION_SYSTEM_PROMPT, text]);
93 |
94 | const json = JSON.parse(result.response.text());
95 |
96 | const end = performance.now();
97 |
98 | const inputTokens = result.response.usageMetadata.promptTokenCount;
99 | const outputTokens = result.response.usageMetadata.candidatesTokenCount;
100 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
101 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
102 |
103 | return {
104 | json,
105 | usage: {
106 | duration: end - start,
107 | inputTokens,
108 | outputTokens,
109 | totalTokens: inputTokens + outputTokens,
110 | inputCost,
111 | outputCost,
112 | totalCost: inputCost + outputCost,
113 | },
114 | };
115 | }
116 |
117 | // FIXME: JSON output might not be 100% correct yet, because Gemini uses a subset of OpenAPI 3.0 schema
118 | // https://sdk.vercel.ai/providers/ai-sdk-providers/google-generative-ai#schema-limitations
119 | async extractFromImage(imagePath: string, schema: JsonSchema) {
120 | const filteredSchema = this.convertSchemaForGemini(schema);
121 |
122 | // read image and convert to base64
123 | const response = await fetch(imagePath);
124 | const imageBuffer = await response.arrayBuffer();
125 | const base64Image = Buffer.from(imageBuffer).toString('base64');
126 |
127 | const start = performance.now();
128 |
129 | const model = this.client.getGenerativeModel({
130 | model: this.model,
131 | generationConfig: {
132 | temperature: 0,
133 | responseMimeType: 'application/json',
134 | responseSchema: filteredSchema,
135 | },
136 | });
137 |
138 | const imagePart = {
139 | inlineData: {
140 | data: base64Image,
141 | mimeType: getMimeType(imagePath),
142 | },
143 | };
144 |
145 | const result = await model.generateContent([
146 | IMAGE_EXTRACTION_SYSTEM_PROMPT,
147 | imagePart,
148 | ]);
149 |
150 | const json = JSON.parse(result.response.text());
151 |
152 | const end = performance.now();
153 |
154 | const inputTokens = result.response.usageMetadata.promptTokenCount;
155 | const outputTokens = result.response.usageMetadata.candidatesTokenCount;
156 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
157 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
158 |
159 | return {
160 | json,
161 | usage: {
162 | duration: end - start,
163 | inputTokens,
164 | outputTokens,
165 | totalTokens: inputTokens + outputTokens,
166 | inputCost,
167 | outputCost,
168 | totalCost: inputCost + outputCost,
169 | },
170 | };
171 | }
172 |
173 | convertSchemaForGemini(schema) {
174 | // Deep clone the schema to avoid modifying the original
175 | const newSchema = JSON.parse(JSON.stringify(schema));
176 |
177 | function processSchemaNode(node) {
178 | if (!node || typeof node !== 'object') return node;
179 |
180 | // Fix enum type definition
181 | if (node.type === 'enum' && node.enum) {
182 | node.type = 'string';
183 | }
184 | // Handle case where enum array exists but type isn't specified
185 | if (node.enum && !node.type) {
186 | node.type = 'string';
187 | }
188 |
189 | // Remove additionalProperties constraints
190 | if ('additionalProperties' in node) {
191 | delete node.additionalProperties;
192 | }
193 |
194 | // Handle 'not' validation keyword
195 | if (node.not) {
196 | if (node.not.type === 'null') {
197 | delete node.not;
198 | node.nullable = false;
199 | } else {
200 | processSchemaNode(node.not);
201 | }
202 | }
203 |
204 | // Handle arrays
205 | if (node.type === 'array' && node.items) {
206 | // Move required fields to items level
207 | if (node.required) {
208 | if (!node.items.required) {
209 | node.items.required = node.required;
210 | } else {
211 | node.items.required = [
212 | ...new Set([...node.items.required, ...node.required]),
213 | ];
214 | }
215 | delete node.required;
216 | }
217 |
218 | processSchemaNode(node.items);
219 | }
220 |
221 | // Handle objects with properties
222 | if (node.properties) {
223 | Object.entries(node.properties).forEach(([key, prop]) => {
224 | node.properties[key] = processSchemaNode(prop);
225 | });
226 | }
227 |
228 | return node;
229 | }
230 |
231 | return processSchemaNode(newSchema);
232 | }
233 | }
234 |
--------------------------------------------------------------------------------
/src/models/googleDocumentAI.ts:
--------------------------------------------------------------------------------
1 | import fs from 'fs';
2 | import { DocumentProcessorServiceClient } from '@google-cloud/documentai';
3 | import { ModelProvider } from './base';
4 |
5 | // https://cloud.google.com/document-ai/pricing
6 | // $1.5 per 1000 pages for the first 5M pages
7 | const COST_PER_PAGE = 1.5 / 1000;
8 |
9 | export class GoogleDocumentAIProvider extends ModelProvider {
10 | private client: DocumentProcessorServiceClient;
11 | private processorPath: string;
12 |
13 | constructor() {
14 | super('google-document-ai');
15 |
16 | const projectId = process.env.GOOGLE_PROJECT_ID;
17 | const location = process.env.GOOGLE_LOCATION || 'us'; // default to 'us'
18 | const processorId = process.env.GOOGLE_PROCESSOR_ID;
19 |
20 | if (!projectId || !processorId) {
21 | throw new Error('Missing required Google Document AI configuration');
22 | }
23 |
24 | const credentials = JSON.parse(
25 | fs.readFileSync(process.env.GOOGLE_APPLICATION_CREDENTIALS_PATH || '', 'utf8'),
26 | );
27 | this.client = new DocumentProcessorServiceClient({
28 | credentials,
29 | });
30 |
31 | this.processorPath = `projects/${projectId}/locations/${location}/processors/${processorId}`;
32 | }
33 |
34 | async ocr(imagePath: string) {
35 | try {
36 | // Download the image
37 | const response = await fetch(imagePath);
38 | const arrayBuffer = await response.arrayBuffer();
39 | const imageContent = Buffer.from(arrayBuffer).toString('base64');
40 |
41 | // Determine MIME type from URL
42 | const mimeType = this.getMimeType(imagePath);
43 |
44 | const request = {
45 | name: this.processorPath,
46 | rawDocument: {
47 | content: imageContent,
48 | mimeType: mimeType,
49 | },
50 | };
51 |
52 | const start = performance.now();
53 | const [result] = await this.client.processDocument(request);
54 | const { document } = result;
55 | const end = performance.now();
56 |
57 | // Extract text from the document
58 | const text = document?.text || '';
59 |
60 | return {
61 | text,
62 | usage: {
63 | duration: end - start,
64 | totalCost: COST_PER_PAGE, // the input is always 1 page.
65 | },
66 | };
67 | } catch (error) {
68 | console.error('Google Document AI Error:', error);
69 | throw error;
70 | }
71 | }
72 |
73 | private getMimeType(url: string): string {
74 | const extension = url.split('.').pop()?.toLowerCase();
75 | switch (extension) {
76 | case 'pdf':
77 | return 'application/pdf';
78 | case 'png':
79 | return 'image/png';
80 | case 'jpg':
81 | case 'jpeg':
82 | return 'image/jpeg';
83 | case 'tiff':
84 | case 'tif':
85 | return 'image/tiff';
86 | case 'gif':
87 | return 'image/gif';
88 | case 'bmp':
89 | return 'image/bmp';
90 | default:
91 | return 'image/png'; // default to PNG
92 | }
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/src/models/index.ts:
--------------------------------------------------------------------------------
1 | export * from './registry';
2 |
--------------------------------------------------------------------------------
/src/models/llm.ts:
--------------------------------------------------------------------------------
1 | import {
2 | generateText,
3 | generateObject,
4 | CoreMessage,
5 | CoreUserMessage,
6 | NoObjectGeneratedError,
7 | } from 'ai';
8 | import { createOpenAI } from '@ai-sdk/openai';
9 | import { createAnthropic } from '@ai-sdk/anthropic';
10 | import { createGoogleGenerativeAI } from '@ai-sdk/google';
11 | import { createDeepSeek } from '@ai-sdk/deepseek';
12 | import { createAzure } from '@ai-sdk/azure';
13 |
14 | import { ExtractionResult, JsonSchema } from '../types';
15 | import { generateZodSchema, writeResultToFile } from '../utils';
16 | import { calculateTokenCost } from './shared';
17 | import { ModelProvider } from './base';
18 | import {
19 | OCR_SYSTEM_PROMPT,
20 | JSON_EXTRACTION_SYSTEM_PROMPT,
21 | IMAGE_EXTRACTION_SYSTEM_PROMPT,
22 | } from './shared';
23 | import {
24 | OPENAI_MODELS,
25 | ANTHROPIC_MODELS,
26 | GOOGLE_GENERATIVE_AI_MODELS,
27 | FINETUNED_MODELS,
28 | DEEPSEEK_MODELS,
29 | AZURE_OPENAI_MODELS,
30 | } from './registry';
31 |
32 | export const createModelProvider = (model: string) => {
33 | if (OPENAI_MODELS.includes(model)) {
34 | return createOpenAI({
35 | apiKey: process.env.OPENAI_API_KEY,
36 | baseURL: process.env.OPENAI_ENDPOINT || 'https://api.openai.com/v1',
37 | });
38 | }
39 | if (AZURE_OPENAI_MODELS.includes(model)) {
40 | return createAzure({
41 | apiKey: process.env.AZURE_OPENAI_API_KEY,
42 | resourceName: process.env.AZURE_OPENAI_RESOURCE_NAME,
43 | });
44 | }
45 | if (FINETUNED_MODELS.includes(model)) {
46 | return createOpenAI({ apiKey: process.env.OPENAI_API_KEY });
47 | }
48 | if (ANTHROPIC_MODELS.includes(model)) {
49 | return createAnthropic({ apiKey: process.env.ANTHROPIC_API_KEY });
50 | }
51 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(model)) {
52 | return createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY });
53 | }
54 | if (DEEPSEEK_MODELS.includes(model)) {
55 | return createDeepSeek({ apiKey: process.env.DEEPSEEK_API_KEY });
56 | }
57 | throw new Error(`Model '${model}' does not support image inputs`);
58 | };
59 |
60 | export class LLMProvider extends ModelProvider {
61 | constructor(model: string) {
62 | if (AZURE_OPENAI_MODELS.includes(model)) {
63 | const openaiModel = model.replace('azure-', '');
64 | super(openaiModel);
65 | } else {
66 | super(model);
67 | }
68 | }
69 |
70 | async ocr(imagePath: string) {
71 | const modelProvider = createModelProvider(this.model);
72 |
73 | let imageMessage: CoreUserMessage = {
74 | role: 'user',
75 | content: [
76 | {
77 | type: 'image',
78 | image: imagePath,
79 | },
80 | ],
81 | };
82 |
83 | if (ANTHROPIC_MODELS.includes(this.model)) {
84 | // read image and convert to base64
85 | const response = await fetch(imagePath);
86 | const imageBuffer = await response.arrayBuffer();
87 | const base64Image = Buffer.from(imageBuffer).toString('base64');
88 | imageMessage.content = [
89 | {
90 | type: 'image',
91 | image: base64Image,
92 | },
93 | ];
94 | }
95 |
96 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(this.model)) {
97 | // gemini requires a text message in user messages
98 | imageMessage.content = [
99 | {
100 | type: 'text',
101 | text: ' ',
102 | },
103 | {
104 | type: 'file',
105 | data: imagePath,
106 | mimeType: 'image/png',
107 | },
108 | ];
109 | }
110 |
111 | const messages: CoreMessage[] = [
112 | { role: 'system', content: OCR_SYSTEM_PROMPT },
113 | imageMessage,
114 | ];
115 |
116 | const start = performance.now();
117 | const { text, usage: ocrUsage } = await generateText({
118 | model: modelProvider(this.model),
119 | messages,
120 | });
121 | const end = performance.now();
122 |
123 | const inputCost = calculateTokenCost(this.model, 'input', ocrUsage.promptTokens);
124 | const outputCost = calculateTokenCost(
125 | this.model,
126 | 'output',
127 | ocrUsage.completionTokens,
128 | );
129 |
130 | const usage = {
131 | duration: end - start,
132 | inputTokens: ocrUsage.promptTokens,
133 | outputTokens: ocrUsage.completionTokens,
134 | totalTokens: ocrUsage.totalTokens,
135 | inputCost,
136 | outputCost,
137 | totalCost: inputCost + outputCost,
138 | };
139 |
140 | return {
141 | text,
142 | usage,
143 | };
144 | }
145 |
146 | async extractFromText(text: string, schema: JsonSchema, imageBase64s?: string[]) {
147 | const modelProvider = createModelProvider(this.model);
148 |
149 | let imageMessages: CoreMessage[] = [];
150 | if (imageBase64s && imageBase64s.length > 0) {
151 | imageMessages = [
152 | {
153 | role: 'user',
154 | content: imageBase64s.map((base64) => ({
155 | type: 'image',
156 | image: base64,
157 | })),
158 | },
159 | ];
160 | }
161 | const messages: CoreMessage[] = [
162 | { role: 'system', content: JSON_EXTRACTION_SYSTEM_PROMPT },
163 | ...imageMessages,
164 | { role: 'user', content: text },
165 | ];
166 |
167 | const zodSchema = generateZodSchema(schema);
168 |
169 | const start = performance.now();
170 |
171 | let json, extractionUsage;
172 | try {
173 | const { object, usage } = await generateObject({
174 | model: modelProvider(this.model),
175 | messages,
176 | schema: zodSchema,
177 | temperature: 0,
178 | });
179 | json = object;
180 | extractionUsage = usage;
181 | } catch (error) {
182 | // if cause is AI_TypeValidationError, then still parse the json
183 | if (error instanceof NoObjectGeneratedError) {
184 | const errorText = error.text;
185 | json = JSON.parse(errorText);
186 | extractionUsage = error.usage;
187 | } else {
188 | throw error;
189 | }
190 | }
191 |
192 | const end = performance.now();
193 | const inputCost = calculateTokenCost(
194 | this.model,
195 | 'input',
196 | extractionUsage.promptTokens,
197 | );
198 | const outputCost = calculateTokenCost(
199 | this.model,
200 | 'output',
201 | extractionUsage.completionTokens,
202 | );
203 |
204 | const usage = {
205 | duration: end - start,
206 | inputTokens: extractionUsage.promptTokens,
207 | outputTokens: extractionUsage.completionTokens,
208 | totalTokens: extractionUsage.totalTokens,
209 | inputCost,
210 | outputCost,
211 | totalCost: inputCost + outputCost,
212 | };
213 |
214 | return {
215 | json,
216 | usage,
217 | };
218 | }
219 |
220 | async extractFromImage(imagePath: string, schema: JsonSchema) {
221 | const modelProvider = createModelProvider(this.model);
222 |
223 | let imageMessage: CoreUserMessage = {
224 | role: 'user',
225 | content: [
226 | {
227 | type: 'image',
228 | image: imagePath,
229 | },
230 | ],
231 | };
232 |
233 | if (ANTHROPIC_MODELS.includes(this.model)) {
234 | // read image and convert to base64
235 | const response = await fetch(imagePath);
236 | const imageBuffer = await response.arrayBuffer();
237 | const base64Image = Buffer.from(imageBuffer).toString('base64');
238 | imageMessage.content = [
239 | {
240 | type: 'image',
241 | image: base64Image,
242 | },
243 | ];
244 | }
245 |
246 | if (GOOGLE_GENERATIVE_AI_MODELS.includes(this.model)) {
247 | // gemini requires a text message in user messages
248 | imageMessage.content = [
249 | {
250 | type: 'text',
251 | text: ' ',
252 | },
253 | {
254 | type: 'file',
255 | data: imagePath,
256 | mimeType: 'image/png',
257 | },
258 | ];
259 | }
260 |
261 | const messages: CoreMessage[] = [
262 | { role: 'system', content: IMAGE_EXTRACTION_SYSTEM_PROMPT },
263 | imageMessage,
264 | ];
265 |
266 | const zodSchema = generateZodSchema(schema);
267 |
268 | const start = performance.now();
269 | const { object: json, usage: extractionUsage } = await generateObject({
270 | model: modelProvider(this.model),
271 | messages,
272 | schema: zodSchema,
273 | temperature: 0,
274 | });
275 | const end = performance.now();
276 |
277 | const inputCost = calculateTokenCost(
278 | this.model,
279 | 'input',
280 | extractionUsage.promptTokens,
281 | );
282 | const outputCost = calculateTokenCost(
283 | this.model,
284 | 'output',
285 | extractionUsage.completionTokens,
286 | );
287 |
288 | const usage = {
289 | duration: end - start,
290 | inputTokens: extractionUsage.promptTokens,
291 | outputTokens: extractionUsage.completionTokens,
292 | totalTokens: extractionUsage.totalTokens,
293 | inputCost,
294 | outputCost,
295 | totalCost: inputCost + outputCost,
296 | };
297 |
298 | return {
299 | json,
300 | usage,
301 | };
302 | }
303 | }
304 |
--------------------------------------------------------------------------------
/src/models/mistral.ts:
--------------------------------------------------------------------------------
1 | import { Mistral } from '@mistralai/mistralai';
2 |
3 | import { ModelProvider } from './base';
4 |
5 | // $1.00 per 1000 images
6 | const COST_PER_IMAGE = 0.001;
7 |
8 | export class MistralProvider extends ModelProvider {
9 | private client: Mistral;
10 |
11 | constructor() {
12 | super('mistral-ocr');
13 |
14 | const apiKey = process.env.MISTRAL_API_KEY;
15 |
16 | if (!apiKey) {
17 | throw new Error('Missing required Mistral API key');
18 | }
19 |
20 | this.client = new Mistral({
21 | apiKey,
22 | });
23 | }
24 |
25 | async ocr(imagePath: string) {
26 | try {
27 | const start = performance.now();
28 |
29 | const response = await this.client.ocr.process({
30 | model: 'mistral-ocr-latest',
31 | document: {
32 | imageUrl: imagePath,
33 | },
34 | includeImageBase64: true,
35 | });
36 |
37 | const text = response.pages.map((page) => page.markdown).join('\n');
38 | const end = performance.now();
39 |
40 | const imageBase64s = response.pages.flatMap((page) =>
41 | page.images.map((image) => image.imageBase64).filter((base64) => base64),
42 | );
43 |
44 | return {
45 | text,
46 | imageBase64s,
47 | usage: {
48 | duration: end - start,
49 | totalCost: COST_PER_IMAGE,
50 | },
51 | };
52 | } catch (error) {
53 | console.error('Mistral OCR Error:', error);
54 | throw error;
55 | }
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/models/omniAI.ts:
--------------------------------------------------------------------------------
1 | import axios from 'axios';
2 | import FormData from 'form-data';
3 |
4 | import { JsonSchema } from '../types';
5 | import { ModelProvider } from './base';
6 |
7 | // https://getomni.ai/pricing
8 | // 1 cent per page
9 | const COST_PER_PAGE = 0.01;
10 |
11 | interface ExtractResponse {
12 | ocr: {
13 | pages: Array<{
14 | page: number;
15 | content: string;
16 | }>;
17 | inputTokens: number;
18 | outputTokens: number;
19 | };
20 | extracted?: Record; // Only present when schema is provided
21 | }
22 |
23 | export const sendExtractRequest = async (
24 | imageUrl: string,
25 | schema?: JsonSchema,
26 | ): Promise => {
27 | const apiKey = process.env.OMNIAI_API_KEY;
28 | if (!apiKey) {
29 | throw new Error('Missing OMNIAI_API_KEY in .env');
30 | }
31 |
32 | const formData = new FormData();
33 | formData.append('url', imageUrl);
34 |
35 | // Add optional parameters if provided
36 | if (schema) {
37 | formData.append('schema', JSON.stringify(schema));
38 | }
39 |
40 | try {
41 | const response = await axios.post(
42 | `${process.env.OMNIAI_API_URL}/extract/sync`,
43 | formData,
44 | {
45 | headers: {
46 | 'x-api-key': apiKey,
47 | ...formData.getHeaders(),
48 | },
49 | },
50 | );
51 |
52 | return response.data.result;
53 | } catch (error) {
54 | if (axios.isAxiosError(error)) {
55 | throw new Error(
56 | `Failed to extract from image: ${JSON.stringify(error.response?.data) || JSON.stringify(error.message)}`,
57 | );
58 | }
59 | throw error;
60 | }
61 | };
62 |
63 | export class OmniAIProvider extends ModelProvider {
64 | constructor(model: string) {
65 | super(model);
66 | }
67 |
68 | async ocr(imagePath: string) {
69 | const start = performance.now();
70 | const response = await sendExtractRequest(imagePath);
71 | const end = performance.now();
72 |
73 | const text = response.ocr.pages.map((page) => page.content).join('\n');
74 | const inputTokens = response.ocr.inputTokens;
75 | const outputTokens = response.ocr.outputTokens;
76 |
77 | return {
78 | text,
79 | usage: {
80 | duration: end - start,
81 | inputTokens,
82 | outputTokens,
83 | totalTokens: inputTokens + outputTokens,
84 | totalCost: COST_PER_PAGE,
85 | },
86 | };
87 | }
88 |
89 | async extractFromImage(imagePath: string, schema?: JsonSchema) {
90 | const start = performance.now();
91 | const response = await sendExtractRequest(imagePath, schema);
92 | const end = performance.now();
93 |
94 | const inputToken = response.ocr.inputTokens;
95 | const outputToken = response.ocr.outputTokens;
96 |
97 | return {
98 | json: response.extracted || {},
99 | usage: {
100 | duration: end - start,
101 | inputTokens: inputToken,
102 | outputTokens: outputToken,
103 | totalTokens: inputToken + outputToken,
104 | totalCost: 0, // TODO: extraction cost is included in the OCR cost, 1 cent per page
105 | },
106 | };
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/src/models/openai.ts:
--------------------------------------------------------------------------------
1 | import OpenAI from 'openai';
2 | import sharp from 'sharp';
3 |
4 | import { ModelProvider } from './base';
5 | import { Usage } from '../types';
6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared';
7 |
8 | export class OpenAIProvider extends ModelProvider {
9 | private client: OpenAI;
10 |
11 | constructor(model: string) {
12 | super(model);
13 |
14 | const apiKey = process.env.COMPATIBLE_OPENAI_API_KEY;
15 | const baseURL = process.env.COMPATIBLE_OPENAI_BASE_URL;
16 | if (!apiKey) {
17 | throw new Error('Missing required API key');
18 | }
19 |
20 | this.client = new OpenAI({
21 | baseURL,
22 | apiKey,
23 | });
24 | }
25 |
26 | async ocr(imagePath: string): Promise<{
27 | text: string;
28 | imageBase64s?: string[];
29 | usage: Usage;
30 | }> {
31 | const start = performance.now();
32 |
33 | const response = await this.client.chat.completions.create({
34 | model: this.model,
35 | messages: [
36 | {
37 | role: 'user',
38 | content: [
39 | { type: 'text', text: OCR_SYSTEM_PROMPT },
40 | {
41 | type: 'image_url',
42 | image_url: {
43 | url: imagePath,
44 | },
45 | },
46 | ],
47 | },
48 | ],
49 | });
50 |
51 | const end = performance.now();
52 |
53 | const inputTokens = response.usage?.prompt_tokens || 0;
54 | const outputTokens = response.usage?.completion_tokens || 0;
55 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
56 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
57 |
58 | return {
59 | text: response.choices[0].message.content || '',
60 | usage: {
61 | duration: end - start,
62 | inputTokens,
63 | outputTokens,
64 | totalTokens: inputTokens + outputTokens,
65 | inputCost,
66 | outputCost,
67 | totalCost: inputCost + outputCost,
68 | },
69 | };
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/models/openrouter.ts:
--------------------------------------------------------------------------------
1 | import OpenAI from 'openai';
2 | import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
3 |
4 | import { ModelProvider } from './base';
5 | import { Usage } from '../types';
6 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared';
7 |
8 | export class OpenRouterProvider extends ModelProvider {
9 | private client: OpenAI;
10 |
11 | constructor(model: string) {
12 | super(model);
13 |
14 | const apiKey = process.env.OPENROUTER_API_KEY;
15 | if (!apiKey) {
16 | throw new Error('Missing required OpenRouter API key');
17 | }
18 |
19 | this.client = new OpenAI({
20 | baseURL: 'https://openrouter.ai/api/v1',
21 | apiKey,
22 | defaultHeaders: {
23 | 'HTTP-Referer': process.env.SITE_URL || 'https://github.com/omni-ai/benchmark',
24 | 'X-Title': 'OmniAI OCR Benchmark',
25 | },
26 | });
27 | }
28 |
29 | async ocr(imagePath: string): Promise<{
30 | text: string;
31 | imageBase64s?: string[];
32 | usage: Usage;
33 | }> {
34 | const start = performance.now();
35 |
36 | const messages: ChatCompletionMessageParam[] = [
37 | {
38 | role: 'user',
39 | content: [
40 | { type: 'text', text: OCR_SYSTEM_PROMPT },
41 | { type: 'image_url', image_url: { url: imagePath } },
42 | ],
43 | },
44 | ];
45 |
46 | const response = await this.client.chat.completions.create({
47 | model: this.model,
48 | messages,
49 | });
50 |
51 | const end = performance.now();
52 |
53 | const inputTokens = response.usage?.prompt_tokens || 0;
54 | const outputTokens = response.usage?.completion_tokens || 0;
55 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
56 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
57 |
58 | return {
59 | text: response.choices[0].message.content || '',
60 | usage: {
61 | duration: end - start,
62 | inputTokens,
63 | outputTokens,
64 | totalTokens: inputTokens + outputTokens,
65 | inputCost,
66 | outputCost,
67 | totalCost: inputCost + outputCost,
68 | },
69 | };
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/models/registry.ts:
--------------------------------------------------------------------------------
1 | import { AzureDocumentIntelligenceProvider } from './azure';
2 | import { AWSTextractProvider } from './awsTextract';
3 | import { DashscopeProvider } from './dashscope';
4 | import { GeminiProvider } from './gemini';
5 | import { GoogleDocumentAIProvider } from './googleDocumentAI';
6 | import { LLMProvider } from './llm';
7 | import { MistralProvider } from './mistral';
8 | import { OmniAIProvider } from './omniAI';
9 | import { OpenAIProvider } from './openai';
10 | import { OpenRouterProvider } from './openrouter';
11 | import { TogetherProvider } from './togetherai';
12 | import { UnstructuredProvider } from './unstructured';
13 | import { ZeroxProvider } from './zerox';
14 |
15 | export const OPENAI_MODELS = [
16 | 'chatgpt-4o-latest',
17 | 'gpt-4o-mini',
18 | 'gpt-4o',
19 | 'o1',
20 | 'o1-mini',
21 | 'o3-mini',
22 | 'o4-mini',
23 | 'gpt-4o-2024-11-20',
24 | 'gpt-4.1',
25 | 'gpt-4.1-mini',
26 | 'gpt-4.1-nano',
27 | ];
28 | export const AZURE_OPENAI_MODELS = [
29 | 'azure-gpt-4o-mini',
30 | 'azure-gpt-4o',
31 | 'azure-o1',
32 | 'azure-o1-mini',
33 | 'azure-o3-mini',
34 | 'azure-gpt-4.1',
35 | 'azure-gpt-4.1-mini',
36 | 'azure-gpt-4.1-nano',
37 | ];
38 | export const ANTHROPIC_MODELS = [
39 | 'claude-3-5-sonnet-20241022',
40 | 'claude-3-7-sonnet-20250219',
41 | 'claude-sonnet-4-20250514',
42 | 'claude-opus-4-20250514',
43 | ];
44 | export const DEEPSEEK_MODELS = ['deepseek-chat'];
45 | export const GOOGLE_GENERATIVE_AI_MODELS = [
46 | 'gemini-1.5-pro',
47 | 'gemini-1.5-flash',
48 | 'gemini-2.0-flash-001',
49 | 'gemini-2.5-pro-exp-03-25',
50 | 'gemini-2.5-pro-preview-03-25',
51 | 'gemini-2.5-flash-preview-05-20',
52 | ];
53 | export const OPENROUTER_MODELS = [
54 | 'qwen/qwen2.5-vl-32b-instruct:free',
55 | 'qwen/qwen-2.5-vl-72b-instruct',
56 | // 'google/gemma-3-27b-it',
57 | 'deepseek/deepseek-chat-v3-0324',
58 | 'meta-llama/llama-3.2-11b-vision-instruct',
59 | 'meta-llama/llama-3.2-90b-vision-instruct',
60 | ];
61 | export const TOGETHER_MODELS = [
62 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo',
63 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo',
64 | 'meta-llama/Llama-4-Scout-17B-16E-Instruct',
65 | 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8',
66 | ];
67 | export const FINETUNED_MODELS = [];
68 |
69 | export const MODEL_PROVIDERS = {
70 | anthropic: {
71 | models: ANTHROPIC_MODELS,
72 | provider: LLMProvider,
73 | },
74 | aws: {
75 | models: ['aws-textract'],
76 | provider: AWSTextractProvider,
77 | },
78 | azureOpenai: {
79 | models: AZURE_OPENAI_MODELS,
80 | provider: LLMProvider,
81 | },
82 | gemini: {
83 | models: GOOGLE_GENERATIVE_AI_MODELS,
84 | provider: GeminiProvider,
85 | },
86 | google: {
87 | models: ['google-document-ai'],
88 | provider: GoogleDocumentAIProvider,
89 | },
90 | deepseek: {
91 | models: DEEPSEEK_MODELS,
92 | provider: LLMProvider,
93 | },
94 | azure: {
95 | models: ['azure-document-intelligence'],
96 | provider: AzureDocumentIntelligenceProvider,
97 | },
98 | mistral: {
99 | models: ['mistral-ocr'],
100 | provider: MistralProvider,
101 | },
102 | omniai: {
103 | models: ['omniai'],
104 | provider: OmniAIProvider,
105 | },
106 | openai: {
107 | models: OPENAI_MODELS,
108 | provider: LLMProvider,
109 | },
110 | openaiBase: {
111 | models: ['google/gemma-3-27b-it'],
112 | provider: OpenAIProvider,
113 | },
114 | openrouter: {
115 | models: OPENROUTER_MODELS,
116 | provider: OpenRouterProvider,
117 | },
118 | together: {
119 | models: TOGETHER_MODELS,
120 | provider: TogetherProvider,
121 | },
122 | dashscope: {
123 | models: ['qwen2.5-vl-32b-instruct', 'qwen2.5-vl-72b-instruct'],
124 | provider: DashscopeProvider,
125 | },
126 | unstructured: {
127 | models: ['unstructured'],
128 | provider: UnstructuredProvider,
129 | },
130 | zerox: {
131 | models: ['zerox'],
132 | provider: ZeroxProvider,
133 | },
134 | groundTruth: {
135 | models: ['ground-truth'],
136 | provider: undefined,
137 | },
138 | };
139 |
140 | export const getModelProvider = (model: string) => {
141 | // Include Openai FT models
142 | MODEL_PROVIDERS['openaiFt'] = {
143 | models: FINETUNED_MODELS,
144 | provider: LLMProvider,
145 | };
146 | const foundProvider = Object.values(MODEL_PROVIDERS).find(
147 | (group) => group.models && group.models.includes(model),
148 | );
149 |
150 | if (foundProvider) {
151 | if (model === 'ground-truth') {
152 | return undefined;
153 | }
154 | const provider = new foundProvider.provider(model);
155 | return provider;
156 | }
157 |
158 | throw new Error(`Model '${model}' is not supported.`);
159 | };
160 |
--------------------------------------------------------------------------------
/src/models/shared/index.ts:
--------------------------------------------------------------------------------
1 | export * from './prompt';
2 | export * from './tokenCost';
3 |
--------------------------------------------------------------------------------
/src/models/shared/prompt.ts:
--------------------------------------------------------------------------------
1 | export const OCR_SYSTEM_PROMPT = `
2 | Convert the following document to markdown.
3 | Return only the markdown with no explanation text. Do not include delimiters like \`\`\`markdown or \`\`\`html.
4 |
5 | RULES:
6 | - You must include all information on the page. Do not exclude headers, footers, charts, infographics, or subtext.
7 | - Return tables in an HTML format.
8 | - Logos should be wrapped in brackets. Ex: Coca-Cola
9 | - Watermarks should be wrapped in brackets. Ex: OFFICIAL COPY
10 | - Page numbers should be wrapped in brackets. Ex: 14 or 9/22
11 | - Prefer using ☐ and ☑ for check boxes.
12 | `;
13 |
14 | export const JSON_EXTRACTION_SYSTEM_PROMPT = `
15 | Extract data from the following document based on the JSON schema.
16 | Return null if the document does not contain information relevant to schema.
17 | Return only the JSON with no explanation text.
18 | `;
19 |
20 | export const IMAGE_EXTRACTION_SYSTEM_PROMPT = `
21 | Extract the following JSON schema from the image.
22 | Return only the JSON with no explanation text.
23 | `;
24 |
--------------------------------------------------------------------------------
/src/models/shared/tokenCost.ts:
--------------------------------------------------------------------------------
1 | import { FINETUNED_MODELS } from '../registry';
2 |
3 | export const TOKEN_COST = {
4 | 'azure-gpt-4o': {
5 | input: 2.5,
6 | output: 10,
7 | },
8 | 'azure-gpt-4o-mini': {
9 | input: 0.15,
10 | output: 0.6,
11 | },
12 | 'azure-gpt-4.1': {
13 | input: 2,
14 | output: 8,
15 | },
16 | 'azure-gpt-4.1-mini': {
17 | input: 0.4,
18 | output: 1.6,
19 | },
20 | 'azure-gpt-4.1-nano': {
21 | input: 0.1,
22 | output: 0.4,
23 | },
24 | 'azure-o1': {
25 | input: 15,
26 | output: 60,
27 | },
28 | 'azure-o1-mini': {
29 | input: 1.1,
30 | output: 4.4,
31 | },
32 | 'azure-o3-mini': {
33 | input: 1.1,
34 | output: 4.4,
35 | },
36 | 'claude-3-5-sonnet-20241022': {
37 | input: 3,
38 | output: 15,
39 | },
40 | 'claude-3-7-sonnet-20250219': {
41 | input: 3,
42 | output: 15,
43 | },
44 | 'claude-sonnet-4-20250514': {
45 | input: 3,
46 | output: 15,
47 | },
48 | 'claude-opus-4-20250514': {
49 | input: 15,
50 | output: 75,
51 | },
52 | 'deepseek-chat': {
53 | input: 0.14,
54 | output: 0.28,
55 | },
56 | 'gemini-1.5-pro': {
57 | input: 1.25,
58 | output: 5,
59 | },
60 | 'gemini-1.5-flash': {
61 | input: 0.075,
62 | output: 0.3,
63 | },
64 | 'gemini-2.0-flash-001': {
65 | input: 0.1,
66 | output: 0.4,
67 | },
68 | 'gemini-2.5-pro-exp-03-25': {
69 | input: 1.25,
70 | output: 10,
71 | },
72 | 'gemini-2.5-pro-preview-03-25': {
73 | input: 1.25,
74 | output: 10,
75 | },
76 | 'gemini-2.5-flash-preview-05-20': {
77 | input: 0.15,
78 | output: 0.6,
79 | },
80 | 'gpt-4o': {
81 | input: 2.5,
82 | output: 10,
83 | },
84 | 'gpt-4o-2024-11-20': {
85 | input: 2.5,
86 | output: 10,
87 | },
88 | 'gpt-4o-mini': {
89 | input: 0.15,
90 | output: 0.6,
91 | },
92 | 'gpt-4.1': {
93 | input: 2,
94 | output: 8,
95 | },
96 | 'gpt-4.1-mini': {
97 | input: 0.4,
98 | output: 1.6,
99 | },
100 | 'gpt-4.1-nano': {
101 | input: 0.1,
102 | output: 0.4,
103 | },
104 |
105 | o1: {
106 | input: 15,
107 | output: 60,
108 | },
109 | 'o1-mini': {
110 | input: 1.1,
111 | output: 4.4,
112 | },
113 | 'o3-mini': {
114 | input: 1.1,
115 | output: 4.4,
116 | },
117 | 'o4-mini': {
118 | input: 1.1,
119 | output: 4.4,
120 | },
121 | 'chatgpt-4o-latest': {
122 | input: 2.5,
123 | output: 10,
124 | },
125 | zerox: {
126 | input: 2.5,
127 | output: 10,
128 | },
129 | 'qwen2.5-vl-32b-instruct': {
130 | input: 0, // TODO: Add cost
131 | output: 0, // TODO: Add cost
132 | },
133 | 'qwen2.5-vl-72b-instruct': {
134 | input: 0, // TODO: Add cost
135 | output: 0, // TODO: Add cost
136 | },
137 | 'google/gemma-3-27b-it': {
138 | input: 0.1,
139 | output: 0.2,
140 | },
141 | 'deepseek/deepseek-chat-v3-0324': {
142 | input: 0.27,
143 | output: 1.1,
144 | },
145 | 'meta-llama/llama-3.2-11b-vision-instruct': {
146 | input: 0.055,
147 | output: 0.055,
148 | },
149 | 'meta-llama/llama-3.2-90b-vision-instruct': {
150 | input: 0.8,
151 | output: 1.6,
152 | },
153 | 'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo': {
154 | input: 0.18,
155 | output: 0.18,
156 | },
157 | 'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo': {
158 | input: 1.2,
159 | output: 1.2,
160 | },
161 | 'meta-llama/Llama-4-Scout-17B-16E-Instruct': {
162 | input: 0.18,
163 | output: 0.59,
164 | },
165 | 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8': {
166 | input: 0.27,
167 | output: 0.85,
168 | },
169 | };
170 |
171 | export const calculateTokenCost = (
172 | model: string,
173 | type: 'input' | 'output',
174 | tokens: number,
175 | ): number => {
176 | const fineTuneCost = Object.fromEntries(
177 | FINETUNED_MODELS.map((el) => [el, { input: 3.75, output: 15.0 }]),
178 | );
179 | const combinedCost = { ...TOKEN_COST, ...fineTuneCost };
180 | const modelInfo = combinedCost[model];
181 | if (!modelInfo) throw new Error(`Model '${model}' is not supported.`);
182 | return (modelInfo[type] * tokens) / 1_000_000;
183 | };
184 |
--------------------------------------------------------------------------------
/src/models/togetherai.ts:
--------------------------------------------------------------------------------
1 | import Together from 'together-ai';
2 |
3 | import { ModelProvider } from './base';
4 | import { Usage } from '../types';
5 | import { calculateTokenCost, OCR_SYSTEM_PROMPT } from './shared';
6 |
7 | export class TogetherProvider extends ModelProvider {
8 | private client: Together;
9 |
10 | constructor(model: string) {
11 | super(model);
12 |
13 | const apiKey = process.env.TOGETHER_API_KEY;
14 | if (!apiKey) {
15 | throw new Error('Missing required Together API key');
16 | }
17 |
18 | this.client = new Together();
19 | }
20 |
21 | async ocr(imagePath: string): Promise<{
22 | text: string;
23 | imageBase64s?: string[];
24 | usage: Usage;
25 | }> {
26 | const start = performance.now();
27 |
28 | const response = await this.client.chat.completions.create({
29 | model: this.model,
30 | messages: [
31 | {
32 | role: 'user',
33 | content: [
34 | { type: 'text', text: OCR_SYSTEM_PROMPT },
35 | { type: 'image_url', image_url: { url: imagePath } },
36 | ],
37 | },
38 | ],
39 | });
40 |
41 | const end = performance.now();
42 |
43 | const inputTokens = response.usage?.prompt_tokens || 0;
44 | const outputTokens = response.usage?.completion_tokens || 0;
45 | const inputCost = calculateTokenCost(this.model, 'input', inputTokens);
46 | const outputCost = calculateTokenCost(this.model, 'output', outputTokens);
47 |
48 | return {
49 | text: response.choices[0].message.content || '',
50 | usage: {
51 | duration: end - start,
52 | inputTokens,
53 | outputTokens,
54 | totalTokens: inputTokens + outputTokens,
55 | inputCost,
56 | outputCost,
57 | totalCost: inputCost + outputCost,
58 | },
59 | };
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/src/models/unstructured.ts:
--------------------------------------------------------------------------------
1 | import axios from 'axios';
2 | import { ModelProvider } from './base';
3 | import { htmlToMarkdown } from '../utils';
4 |
5 | // Fast Pipeline: $1 per 1,000 pages
6 | const COST_PER_PAGE = 20 / 1000;
7 |
8 | enum UnstructuredTypes {
9 | Title = 'Title',
10 | Header = 'Header',
11 | NarrativeText = 'NarrativeText',
12 | }
13 |
14 | interface UnstructuredElement {
15 | text: string;
16 | type: UnstructuredTypes;
17 | metadata: {
18 | filename: string;
19 | filetype: string;
20 | languages: string[];
21 | page_number: number;
22 | parent_id?: string;
23 | text_as_html?: string;
24 | };
25 | element_id: string;
26 | }
27 |
28 | export class UnstructuredProvider extends ModelProvider {
29 | constructor() {
30 | super('unstructured');
31 | }
32 |
33 | async ocr(imagePath: string) {
34 | try {
35 | const start = performance.now();
36 |
37 | const fileName = imagePath.split('/').pop()[0];
38 | const formData = new FormData();
39 | const response = await axios.get(imagePath, { responseType: 'arraybuffer' });
40 | const fileData = Buffer.from(response.data);
41 |
42 | formData.append('files', new Blob([fileData]), fileName);
43 |
44 | const apiResponse = await axios.post(
45 | 'https://api.unstructuredapp.io/general/v0/general',
46 | formData,
47 | {
48 | headers: {
49 | accept: 'application/json',
50 | 'unstructured-api-key': process.env.UNSTRUCTURED_API_KEY,
51 | },
52 | },
53 | );
54 |
55 | const unstructuredResult = apiResponse.data as UnstructuredElement[];
56 |
57 | // Format the result
58 | let markdown = '';
59 | if (Array.isArray(unstructuredResult)) {
60 | markdown = unstructuredResult.reduce((acc, el) => {
61 | if (el.type === UnstructuredTypes.Title) {
62 | acc += `\n### ${el.text}\n`;
63 | } else if (el.type === UnstructuredTypes.NarrativeText) {
64 | acc += `\n${el.text}\n`;
65 | } else if (el.metadata?.text_as_html) {
66 | acc += htmlToMarkdown(el.metadata.text_as_html) + '\n';
67 | } else if (el.text) {
68 | acc += el.text + '\n';
69 | }
70 | return acc;
71 | }, '');
72 | } else {
73 | markdown = JSON.stringify(unstructuredResult);
74 | }
75 |
76 | const end = performance.now();
77 |
78 | return {
79 | text: markdown,
80 | usage: {
81 | duration: end - start,
82 | totalCost: COST_PER_PAGE, // the input is always 1 page.
83 | },
84 | };
85 | } catch (error) {
86 | console.error('Unstructured Error:', error);
87 | throw error;
88 | }
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/src/models/zerox.ts:
--------------------------------------------------------------------------------
1 | import { zerox } from 'zerox';
2 |
3 | import { ModelProvider } from './base';
4 | import { calculateTokenCost } from './shared';
5 |
6 | export class ZeroxProvider extends ModelProvider {
7 | constructor() {
8 | super('zerox');
9 | }
10 |
11 | async ocr(imagePath: string) {
12 | const startTime = performance.now();
13 |
14 | const result = await zerox({
15 | filePath: imagePath,
16 | openaiAPIKey: process.env.OPENAI_API_KEY,
17 | });
18 |
19 | const endTime = performance.now();
20 |
21 | const text = result.pages.map((page) => page.content).join('\n');
22 |
23 | const inputCost = calculateTokenCost(this.model, 'input', result.inputTokens);
24 | const outputCost = calculateTokenCost(this.model, 'output', result.outputTokens);
25 |
26 | const usage = {
27 | duration: endTime - startTime,
28 | inputTokens: result.inputTokens,
29 | outputTokens: result.outputTokens,
30 | totalTokens: result.inputTokens + result.outputTokens,
31 | inputCost,
32 | outputCost,
33 | totalCost: inputCost + outputCost,
34 | };
35 |
36 | return {
37 | text,
38 | usage,
39 | };
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/src/types/data.ts:
--------------------------------------------------------------------------------
1 | import { Usage } from './model';
2 | import { AccuracyResult } from '../evaluation';
3 |
4 | export interface Input {
5 | imageUrl: string;
6 | metadata: Metadata;
7 | jsonSchema: JsonSchema;
8 | trueJsonOutput: Record;
9 | trueMarkdownOutput: string;
10 | }
11 |
12 | export interface Metadata {
13 | orientation?: number;
14 | documentQuality?: string;
15 | resolution?: number[];
16 | language?: string;
17 | }
18 |
19 | export interface JsonSchema {
20 | type: string;
21 | description?: string;
22 | properties?: Record;
23 | items?: JsonSchema;
24 | required?: string[];
25 | }
26 |
27 | export interface Result {
28 | fileUrl: string;
29 | metadata: Metadata;
30 | ocrModel: string;
31 | extractionModel: string;
32 | jsonSchema: JsonSchema;
33 | directImageExtraction?: boolean;
34 | trueMarkdown: string;
35 | trueJson: Record;
36 | predictedMarkdown?: string;
37 | predictedJson?: Record;
38 | levenshteinDistance?: number;
39 | jsonAccuracy?: number;
40 | jsonDiff?: Record;
41 | fullJsonDiff?: Record;
42 | jsonDiffStats?: Record;
43 | jsonAccuracyResult?: AccuracyResult;
44 | usage?: Usage;
45 | error?: any;
46 | }
47 |
--------------------------------------------------------------------------------
/src/types/index.ts:
--------------------------------------------------------------------------------
1 | export * from './data';
2 | export * from './model';
3 |
--------------------------------------------------------------------------------
/src/types/model.ts:
--------------------------------------------------------------------------------
1 | export interface ExtractionResult {
2 | json?: Record;
3 | text?: string;
4 | usage: Usage;
5 | }
6 |
7 | export interface Usage {
8 | duration?: number;
9 | inputTokens?: number;
10 | outputTokens?: number;
11 | totalTokens?: number;
12 | inputCost?: number;
13 | outputCost?: number;
14 | totalCost?: number;
15 | ocr?: Usage;
16 | extraction?: Usage;
17 | }
18 |
--------------------------------------------------------------------------------
/src/utils/dataLoader.ts:
--------------------------------------------------------------------------------
1 | import { Input } from '../types';
2 | import { Pool } from 'pg';
3 | import fs from 'fs';
4 | import path from 'path';
5 |
6 | // Pull JSON files from local folder
7 | export const loadLocalData = (folder: string): Input[] => {
8 | const files = fs.readdirSync(folder).filter((file) => file.endsWith('.json'));
9 | const data = files.map((file) => {
10 | const filePath = path.join(folder, file);
11 | const fileContent = fs.readFileSync(filePath, 'utf8');
12 | return JSON.parse(fileContent);
13 | });
14 |
15 | return data;
16 | };
17 |
18 | // Query results from the documents table.
19 | export const loadFromDb = async (): Promise => {
20 | const pool = new Pool({
21 | connectionString: process.env.DATABASE_URL,
22 | ssl: { rejectUnauthorized: false },
23 | });
24 |
25 | try {
26 | const result = await pool.query(`
27 | SELECT
28 | url AS "imageUrl",
29 | config AS "metadata",
30 | schema AS "jsonSchema",
31 | extracted_json AS "trueJsonOutput",
32 | markdown AS "trueMarkdownOutput"
33 | FROM documents
34 | WHERE include_in_training = FALSE
35 | ORDER BY created_at
36 | LIMIT 1000;
37 | `);
38 |
39 | return result.rows as Input[];
40 | } catch (error) {
41 | console.error('Error querying data from PostgreSQL:', error);
42 | throw error;
43 | } finally {
44 | await pool.end();
45 | }
46 | };
47 |
--------------------------------------------------------------------------------
/src/utils/db.ts:
--------------------------------------------------------------------------------
1 | import { PrismaClient } from '@prisma/client';
2 | import { Result } from '../types';
3 |
4 | const prisma = new PrismaClient();
5 |
6 | export async function createBenchmarkRun(
7 | timestamp: string,
8 | modelsConfig: any,
9 | totalDocuments: number,
10 | ) {
11 | return prisma.benchmarkRun.create({
12 | data: {
13 | timestamp,
14 | status: 'running',
15 | modelsConfig: { models: modelsConfig },
16 | totalDocuments,
17 | },
18 | });
19 | }
20 |
21 | export async function saveResult(runId: string, result: Result) {
22 | return prisma.benchmarkResult.create({
23 | data: {
24 | benchmarkRunId: runId,
25 | fileUrl: result.fileUrl,
26 | metadata: result.metadata as any,
27 | ocrModel: result.ocrModel,
28 | extractionModel: result.extractionModel || '',
29 | jsonSchema: result.jsonSchema as any,
30 | directImageExtraction: result.directImageExtraction || false,
31 | trueMarkdown: result.trueMarkdown,
32 | trueJson: result.trueJson,
33 | predictedMarkdown: result.predictedMarkdown,
34 | predictedJson: result.predictedJson,
35 | levenshteinDistance: result.levenshteinDistance,
36 | jsonAccuracy: result.jsonAccuracy,
37 | jsonDiff: result.jsonDiff,
38 | fullJsonDiff: result.fullJsonDiff,
39 | jsonDiffStats: result.jsonDiffStats,
40 | jsonAccuracyResult: result.jsonAccuracyResult as any,
41 | usage: result.usage as any,
42 | error: JSON.stringify(result.error),
43 | },
44 | });
45 | }
46 |
47 | export async function completeBenchmarkRun(runId: string, error?: string) {
48 | return prisma.benchmarkRun.update({
49 | where: { id: runId },
50 | data: {
51 | status: error ? 'failed' : 'completed',
52 | completedAt: new Date(),
53 | error,
54 | },
55 | });
56 | }
57 |
58 | // Clean up function
59 | export async function disconnect() {
60 | await prisma.$disconnect();
61 | }
62 |
--------------------------------------------------------------------------------
/src/utils/file.ts:
--------------------------------------------------------------------------------
1 | export const getMimeType = (url: string): string => {
2 | const extension = url.split('.').pop()?.toLowerCase();
3 | switch (extension) {
4 | case 'pdf':
5 | return 'application/pdf';
6 | case 'png':
7 | return 'image/png';
8 | case 'jpg':
9 | case 'jpeg':
10 | return 'image/jpeg';
11 | case 'tiff':
12 | case 'tif':
13 | return 'image/tiff';
14 | case 'gif':
15 | return 'image/gif';
16 | case 'bmp':
17 | return 'image/bmp';
18 | default:
19 | return 'image/png'; // default to PNG
20 | }
21 | };
22 |
--------------------------------------------------------------------------------
/src/utils/htmlToMarkdown.ts:
--------------------------------------------------------------------------------
1 | import TurndownService from 'turndown';
2 |
3 | export function htmlToMarkdown(html: string): string {
4 | const turndownService = new TurndownService({});
5 |
6 | turndownService.addRule('strong', {
7 | filter: ['strong', 'b'],
8 | replacement: (content) => `**${content}**`,
9 | });
10 |
11 | // Convert HTML to Markdown
12 | return turndownService.turndown(html);
13 | }
14 |
--------------------------------------------------------------------------------
/src/utils/index.ts:
--------------------------------------------------------------------------------
1 | export * from './dataLoader';
2 | export * from './db';
3 | export * from './file';
4 | export * from './htmlToMarkdown';
5 | export * from './logs';
6 | export * from './zod';
7 |
--------------------------------------------------------------------------------
/src/utils/logs.ts:
--------------------------------------------------------------------------------
1 | import fs from 'fs';
2 | import path from 'path';
3 |
4 | import { ExtractionResult } from '../types';
5 |
6 | export const createResultFolder = (folderName: string) => {
7 | // check if results folder exists
8 | const resultsFolder = path.join(__dirname, '..', '..', 'results');
9 | if (!fs.existsSync(resultsFolder)) {
10 | fs.mkdirSync(resultsFolder, { recursive: true });
11 | }
12 |
13 | const folderPath = path.join(resultsFolder, folderName);
14 | fs.mkdirSync(folderPath, { recursive: true });
15 | return folderPath;
16 | };
17 |
18 | export const writeToFile = (filePath: string, content: any) => {
19 | fs.writeFileSync(filePath, JSON.stringify(content, null, 2));
20 | };
21 |
22 | export const writeResultToFile = (
23 | outputDir: string,
24 | fileName: string,
25 | result: ExtractionResult,
26 | ) => {
27 | fs.writeFileSync(path.join(outputDir, fileName), JSON.stringify(result, null, 2));
28 | };
29 |
--------------------------------------------------------------------------------
/src/utils/zod.ts:
--------------------------------------------------------------------------------
1 | import { z } from 'zod';
2 |
3 | const zodTypeMapping = {
4 | array: (itemSchema: any) => z.array(itemSchema),
5 | boolean: z.boolean(),
6 | integer: z.number().int(),
7 | number: z.number(),
8 | object: (properties: any) => z.object(properties).strict(),
9 | string: z.string(),
10 | };
11 |
12 | export const generateZodSchema = (schemaDef: any): z.ZodObject => {
13 | const properties: Record = {};
14 |
15 | for (const [key, value] of Object.entries(schemaDef.properties) as any) {
16 | let zodType;
17 |
18 | if (value.enum && Array.isArray(value.enum) && value.enum.length > 0) {
19 | zodType = z.enum(value.enum as [string, ...string[]]);
20 | } else {
21 | zodType = zodTypeMapping[value.type];
22 | }
23 |
24 | if (value.type === 'array' && value.items.type === 'object') {
25 | properties[key] = zodType(generateZodSchema(value.items));
26 | } else if (value.type === 'array' && value.items.type !== 'object') {
27 | properties[key] = zodType(zodTypeMapping[value.items.type]);
28 | } else if (value.type === 'object') {
29 | properties[key] = generateZodSchema(value);
30 | } else {
31 | properties[key] = zodType;
32 | }
33 |
34 | // Make properties nullable by default
35 | properties[key] = properties[key].nullable();
36 |
37 | if (value.description) {
38 | properties[key] = properties?.[key]?.describe(value?.description);
39 | }
40 | }
41 |
42 | return z.object(properties).strict();
43 | };
44 |
--------------------------------------------------------------------------------
/tests/evaluation/json.test.ts:
--------------------------------------------------------------------------------
1 | import {
2 | calculateJsonAccuracy,
3 | countTotalFields,
4 | countChanges,
5 | } from '../../src/evaluation/json';
6 |
7 | describe('countTotalFields', () => {
8 | it('counts fields in nested objects including array elements', () => {
9 | const obj = { a: 1, b: { c: 2, d: [3, { e: 4 }] } };
10 | expect(countTotalFields(obj)).toBe(4);
11 | });
12 |
13 | it('counts array elements as individual fields', () => {
14 | const obj = { a: [1, 2, 3], b: 'test', c: true };
15 | expect(countTotalFields(obj)).toBe(5);
16 | });
17 |
18 | it('counts nested objects within arrays', () => {
19 | const obj = { a: [{ b: 1 }, { c: 2 }], d: 'test', e: true };
20 | expect(countTotalFields(obj)).toBe(4);
21 | });
22 |
23 | it('includes null values in field count', () => {
24 | const obj = { a: null, b: { c: null }, d: 'test' };
25 | expect(countTotalFields(obj)).toBe(3);
26 | });
27 |
28 | it('excludes fields with __diff metadata suffixes', () => {
29 | const obj = {
30 | a: 1,
31 | b__deleted: true,
32 | c__added: 'test',
33 | d: { e: 2 },
34 | };
35 | expect(countTotalFields(obj)).toBe(2);
36 | });
37 | });
38 |
39 | describe('calculateJsonAccuracy', () => {
40 | it('returns 0.5 when half of the fields match', () => {
41 | const actual = { a: 1, b: 2 };
42 | const predicted = { a: 1, b: 3 };
43 | const result = calculateJsonAccuracy(actual, predicted);
44 | expect(result.score).toBe(0.5);
45 | });
46 |
47 | it('handles nested objects in accuracy calculation', () => {
48 | const actual = { a: 1, b: { c: 2, d: 4, e: 4 } };
49 | const predicted = { a: 1, b: { c: 2, d: 4, e: 5 } };
50 | const result = calculateJsonAccuracy(actual, predicted);
51 | expect(result.score).toBe(0.75);
52 | });
53 |
54 | it('calculates accuracy for nested arrays and objects', () => {
55 | const actual = { a: 1, b: [{ c: 2, d: 4, e: 4, f: [2, 9] }] };
56 | const predicted = { a: 1, b: [{ c: 2, d: 4, e: 5, f: [2, 3] }] };
57 | const result = calculateJsonAccuracy(actual, predicted);
58 | expect(result.score).toBe(0.5);
59 | });
60 |
61 | it('considers array elements matching regardless of order', () => {
62 | const actual = {
63 | a: 1,
64 | b: [
65 | { c: 1, d: 2 },
66 | { c: 3, d: 4 },
67 | ],
68 | };
69 | const predicted = {
70 | a: 1,
71 | b: [
72 | { c: 3, d: 4 },
73 | { c: 1, d: 2 },
74 | ],
75 | };
76 | const result = calculateJsonAccuracy(actual, predicted);
77 | expect(result.score).toBe(1);
78 | });
79 |
80 | it('counts all array elements as unmatched when predicted array is null', () => {
81 | const actual = { a: 1, b: [1, 2, 3] };
82 | const predicted = { a: 1, b: null };
83 | const result = calculateJsonAccuracy(actual, predicted);
84 | expect(result.score).toBe(1 / 4);
85 | });
86 |
87 | it('counts all nested array objects as unmatched when predicted is null', () => {
88 | const actual = { a: 1, b: [{ c: 1, d: 1 }, { c: 2 }, { c: 3, e: 4 }] };
89 | const predicted = { a: 1, b: null };
90 | const result = calculateJsonAccuracy(actual, predicted);
91 | expect(result.score).toBe(Number((1 / 6).toFixed(4)));
92 | });
93 |
94 | it('considers null fields in predicted object as partial matches', () => {
95 | const actual = { a: 1, b: { c: 1, d: { e: 1, f: 2 } } };
96 | const predicted = { a: 1, b: { c: 1, d: null } };
97 | const result = calculateJsonAccuracy(actual, predicted);
98 | expect(result.score).toBe(0.5);
99 | });
100 |
101 | describe('null value comparisons', () => {
102 | it('handles actual null to predicted value comparison', () => {
103 | const actual = { a: [{ b: 1, c: null }] };
104 | const predicted = { a: [{ b: 1, c: 2 }] };
105 | const result = calculateJsonAccuracy(actual, predicted);
106 | expect(result.score).toBe(0.5);
107 | });
108 |
109 | it('handles actual null to predicted object comparison', () => {
110 | const actual = { a: [{ b: 1, c: null, f: 4 }] };
111 | const predicted = { a: [{ b: 1, c: { d: 2 }, f: 4 }] };
112 | const result = calculateJsonAccuracy(actual, predicted);
113 | expect(result.score).toBe(0.6667);
114 | });
115 |
116 | it('handles actual null to predicted complex object comparison', () => {
117 | const actual = { a: [{ b: 1, c: null, f: 4 }] };
118 | const predicted = { a: [{ b: 1, c: { d: 2, e: 3 }, f: 4 }] };
119 | const result = calculateJsonAccuracy(actual, predicted);
120 | expect(result.score).toBe(0.3333);
121 | });
122 |
123 | it('handles actual null to predicted array comparison', () => {
124 | const actual = { a: [{ b: 1, c: null, f: 4 }] };
125 | const predicted = { a: [{ b: 1, c: [3], f: 4 }] };
126 | const result = calculateJsonAccuracy(actual, predicted);
127 | expect(result.score).toBe(0.6667);
128 | });
129 |
130 | it('handles actual value to predicted null comparison', () => {
131 | const actual = { a: [{ b: 1, c: 2 }] };
132 | const predicted = { a: [{ b: 1, c: null }] };
133 | const result = calculateJsonAccuracy(actual, predicted);
134 | expect(result.score).toBe(0.5);
135 | });
136 |
137 | it('handles actual object to predicted null comparison', () => {
138 | const actual = { a: [{ b: 1, c: { d: 2 } }] };
139 | const predicted = { a: [{ b: 1, c: null }] };
140 | const result = calculateJsonAccuracy(actual, predicted);
141 | expect(result.score).toBe(0.5);
142 | });
143 |
144 | it('handles actual complex object to predicted null comparison', () => {
145 | const actual = { a: [{ b: 1, c: { d: 2, e: 3 } }] };
146 | const predicted = { a: [{ b: 1, c: null }] };
147 | const result = calculateJsonAccuracy(actual, predicted);
148 | expect(result.score).toBe(0.3333);
149 | });
150 |
151 | it('handles actual array to predicted null comparison', () => {
152 | const actual = { a: [{ b: 1, c: [3, 2] }] };
153 | const predicted = { a: [{ b: 1, c: null }] };
154 | const result = calculateJsonAccuracy(actual, predicted);
155 | expect(result.score).toBe(0.3333);
156 | });
157 | });
158 | });
159 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es6",
4 | "module": "commonjs",
5 | "lib": ["es6", "dom", "esnext"],
6 | "outDir": "./dist",
7 | "rootDir": "./src",
8 | "strict": false,
9 | "esModuleInterop": true,
10 | "skipLibCheck": true,
11 | "declaration": true,
12 | "sourceMap": true,
13 | "resolveJsonModule": true,
14 | "moduleResolution": "node"
15 | },
16 | "include": ["src/**/*.ts"],
17 | "exclude": ["node_modules", "dist"]
18 | }
19 |
--------------------------------------------------------------------------------