├── .gitignore
├── README.md
├── analyze_failures.ipynb
├── convert.py
├── grade_cli.py
├── grading.py
├── parsing.py
├── providers
├── README.md
├── __init__.py
├── azure_docintelligence.py
├── chunkr.py
├── config.py
├── gcloud.py
├── llm.py
├── reducto.py
├── textract.py
└── unstructured.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | *.pyc
3 | *.csv
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RD Table Bench Invocation + Grading Code
2 |
3 | This repo contains the code for invoking each provider and grading the results. This is a fork of Reducto's original repo https://github.com/reductoai/rd-tablebench with improved support for different LLM providers (OpenAI, Anthropic, Gemini).
4 |
5 | The proprietary models that Reduco implemeted have not been touched and will not working with the grading cli.
6 |
7 | ## Installing Dependencies
8 |
9 | ```
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ## Downloading Data
14 |
15 | https://huggingface.co/datasets/reducto/rd-tablebench/blob/main/README.md
16 |
17 | ## Env Vars
18 |
19 | Create an `.env` file with the following:
20 |
21 | ```
22 | # directory where the huggingface dataset is downloaded
23 | INPUT_DIR=
24 |
25 | # directory where the output will be saved
26 | OUTPUT_DIR=
27 |
28 | # note: only need keys for providers you want to use
29 | OPENAI_API_KEY=
30 | GEMINI_API_KEY=
31 | ANTHROPIC_API_KEY=
32 | ...
33 | ```
34 |
35 | ## Parsing
36 |
37 | `python -m providers.llm --model gemini-2.0-flash-exp --num-workers 10`
38 |
39 | ## Grading
40 |
41 | `python -m grade_cli --model gemini-2.0-flash-exp`
42 |
--------------------------------------------------------------------------------
/analyze_failures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 34,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "polars.config.Config"
12 | ]
13 | },
14 | "execution_count": 34,
15 | "metadata": {},
16 | "output_type": "execute_result"
17 | }
18 | ],
19 | "source": [
20 | "import polars as pl\n",
21 | "from pathlib import Path\n",
22 | "from providers.config import settings\n",
23 | "import webbrowser\n",
24 | "import random\n",
25 | "\n",
26 | "pl.Config.set_fmt_str_lengths(200)"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 16,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "data": {
36 | "text/html": [
37 | "
\n",
44 | "
shape: (662, 7)filename | gemini-2.0-flash-exp | gemini-1.5-pro | gemini-1.5-flash | gpt-4o-mini | gpt-4o | claude-3-5-sonnet-latest |
---|
str | f64 | f64 | f64 | f64 | f64 | f64 |
"31365_png.rf.c15a4717175cda592… | null | 1.0 | 0.72 | 0.79 | 0.83 | 1.0 |
"3066_png.rf.a271ccaade01ec4a7e… | 1.0 | 0.9 | 1.0 | 1.0 | 1.0 | 1.0 |
"42405_png.rf.29f7f349b16be7e94… | null | 0.88 | 0.87 | 0.75 | null | 0.71 |
"6464_png.rf.87df45257826205d38… | null | 1.0 | 0.66 | 0.82 | null | null |
"14587_png.rf.8ac51dfae9c3323be… | null | null | 0.52 | 0.22 | null | null |
… | … | … | … | … | … | … |
"4369_png.rf.22210329dce81c468f… | null | 0.94 | null | 0.33 | 0.32 | 0.62 |
"20008_png.rf.11bd3ea6ad0610c46… | null | 0.85 | null | 0.26 | 0.02 | 0.01 |
"3794_png.rf.f3b9a9fce3f6f5b4e4… | null | 0.58 | null | 0.08 | null | 0.5 |
"29352_png.rf.262ea50e23c787b4f… | null | 0.78 | null | 0.7 | null | 0.49 |
"999_png.rf.db81a5df0db1f0c4854… | null | null | null | 0.21 | null | null |
"
45 | ],
46 | "text/plain": [
47 | "shape: (662, 7)\n",
48 | "┌───────────────┬──────────────┬──────────────┬──────────────┬─────────────┬────────┬──────────────┐\n",
49 | "│ filename ┆ gemini-2.0-f ┆ gemini-1.5-p ┆ gemini-1.5-f ┆ gpt-4o-mini ┆ gpt-4o ┆ claude-3-5-s │\n",
50 | "│ --- ┆ lash-exp ┆ ro ┆ lash ┆ --- ┆ --- ┆ onnet-latest │\n",
51 | "│ str ┆ --- ┆ --- ┆ --- ┆ f64 ┆ f64 ┆ --- │\n",
52 | "│ ┆ f64 ┆ f64 ┆ f64 ┆ ┆ ┆ f64 │\n",
53 | "╞═══════════════╪══════════════╪══════════════╪══════════════╪═════════════╪════════╪══════════════╡\n",
54 | "│ 31365_png.rf. ┆ null ┆ 1.0 ┆ 0.72 ┆ 0.79 ┆ 0.83 ┆ 1.0 │\n",
55 | "│ c15a4717175cd ┆ ┆ ┆ ┆ ┆ ┆ │\n",
56 | "│ a592… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
57 | "│ 3066_png.rf.a ┆ 1.0 ┆ 0.9 ┆ 1.0 ┆ 1.0 ┆ 1.0 ┆ 1.0 │\n",
58 | "│ 271ccaade01ec ┆ ┆ ┆ ┆ ┆ ┆ │\n",
59 | "│ 4a7e… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
60 | "│ 42405_png.rf. ┆ null ┆ 0.88 ┆ 0.87 ┆ 0.75 ┆ null ┆ 0.71 │\n",
61 | "│ 29f7f349b16be ┆ ┆ ┆ ┆ ┆ ┆ │\n",
62 | "│ 7e94… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
63 | "│ 6464_png.rf.8 ┆ null ┆ 1.0 ┆ 0.66 ┆ 0.82 ┆ null ┆ null │\n",
64 | "│ 7df4525782620 ┆ ┆ ┆ ┆ ┆ ┆ │\n",
65 | "│ 5d38… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
66 | "│ 14587_png.rf. ┆ null ┆ null ┆ 0.52 ┆ 0.22 ┆ null ┆ null │\n",
67 | "│ 8ac51dfae9c33 ┆ ┆ ┆ ┆ ┆ ┆ │\n",
68 | "│ 23be… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
69 | "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
70 | "│ 4369_png.rf.2 ┆ null ┆ 0.94 ┆ null ┆ 0.33 ┆ 0.32 ┆ 0.62 │\n",
71 | "│ 2210329dce81c ┆ ┆ ┆ ┆ ┆ ┆ │\n",
72 | "│ 468f… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
73 | "│ 20008_png.rf. ┆ null ┆ 0.85 ┆ null ┆ 0.26 ┆ 0.02 ┆ 0.01 │\n",
74 | "│ 11bd3ea6ad061 ┆ ┆ ┆ ┆ ┆ ┆ │\n",
75 | "│ 0c46… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
76 | "│ 3794_png.rf.f ┆ null ┆ 0.58 ┆ null ┆ 0.08 ┆ null ┆ 0.5 │\n",
77 | "│ 3b9a9fce3f6f5 ┆ ┆ ┆ ┆ ┆ ┆ │\n",
78 | "│ b4e4… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
79 | "│ 29352_png.rf. ┆ null ┆ 0.78 ┆ null ┆ 0.7 ┆ null ┆ 0.49 │\n",
80 | "│ 262ea50e23c78 ┆ ┆ ┆ ┆ ┆ ┆ │\n",
81 | "│ 7b4f… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
82 | "│ 999_png.rf.db ┆ null ┆ null ┆ null ┆ 0.21 ┆ null ┆ null │\n",
83 | "│ 81a5df0db1f0c ┆ ┆ ┆ ┆ ┆ ┆ │\n",
84 | "│ 4854… ┆ ┆ ┆ ┆ ┆ ┆ │\n",
85 | "└───────────────┴──────────────┴──────────────┴──────────────┴─────────────┴────────┴──────────────┘"
86 | ]
87 | },
88 | "execution_count": 16,
89 | "metadata": {},
90 | "output_type": "execute_result"
91 | }
92 | ],
93 | "source": [
94 | "models = [\n",
95 | " \"gemini-2.0-flash-exp\",\n",
96 | " \"gemini-1.5-pro\",\n",
97 | " \"gemini-1.5-flash\",\n",
98 | " \"gpt-4o-mini\",\n",
99 | " \"gpt-4o\",\n",
100 | " \"claude-3-5-sonnet-latest\",\n",
101 | "]\n",
102 | "\n",
103 | "dfs = []\n",
104 | "for model in models:\n",
105 | " df = pl.read_csv(f\"./scores/{model}_scores.csv\")\n",
106 | " dfs.append(df.rename({\"score\": model}))\n",
107 | "\n",
108 | "merged_df = dfs[0]\n",
109 | "for df in dfs[1:]:\n",
110 | " merged_df = merged_df.join(df, on=\"filename\", how=\"full\", coalesce=True)\n",
111 | "\n",
112 | "merged_df = merged_df.with_columns(pl.col(pl.Float64).round(2))\n",
113 | "merged_df"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 74,
119 | "metadata": {},
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n",
126 | "None\n",
127 | "/Users/sergey/Downloads/rd-tablebench/outputs/gpt-4o-mini-raw/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n",
128 | "/Users/sergey/Downloads/rd-tablebench/groundtruth/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n",
129 | "/Users/sergey/Downloads/rd-tablebench/pdfs/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.pdf\n"
130 | ]
131 | }
132 | ],
133 | "source": [
134 | "def open_output_file(filename: str, model: str):\n",
135 | " url = settings.output_dir / f\"{model}-raw\" / filename\n",
136 | " assert url.exists(), f\"File {url} does not exist\"\n",
137 | " print(url)\n",
138 | "\n",
139 | " webbrowser.open(str(url))\n",
140 | "\n",
141 | "\n",
142 | "def open_ground_truth_file(filename: str):\n",
143 | " url = settings.input_dir / \"groundtruth\" / filename\n",
144 | " assert url.exists(), f\"File {url} does not exist\"\n",
145 | " print(url)\n",
146 | " webbrowser.open(str(url))\n",
147 | "\n",
148 | "\n",
149 | "def open_source_file(filename: str):\n",
150 | " url = settings.input_dir / \"pdfs\" / filename.replace(\".html\", \".pdf\")\n",
151 | " assert url.exists(), f\"File {url} does not exist\"\n",
152 | " print(url)\n",
153 | " webbrowser.open(str(url))\n",
154 | "\n",
155 | "\n",
156 | "random_filename = random.choice(merged_df.to_dicts())\n",
157 | "print(random_filename[\"filename\"])\n",
158 | "print(random_filename[\"claude-3-5-sonnet-latest\"])\n",
159 | "open_output_file(random_filename[\"filename\"], \"gpt-4o-mini\")\n",
160 | "open_ground_truth_file(random_filename[\"filename\"])\n",
161 | "open_source_file(random_filename[\"filename\"])\n"
162 | ]
163 | }
164 | ],
165 | "metadata": {
166 | "kernelspec": {
167 | "display_name": ".venv",
168 | "language": "python",
169 | "name": "python3"
170 | },
171 | "language_info": {
172 | "codemirror_mode": {
173 | "name": "ipython",
174 | "version": 3
175 | },
176 | "file_extension": ".py",
177 | "mimetype": "text/x-python",
178 | "name": "python",
179 | "nbconvert_exporter": "python",
180 | "pygments_lexer": "ipython3",
181 | "version": "3.12.5"
182 | }
183 | },
184 | "nbformat": 4,
185 | "nbformat_minor": 2
186 | }
187 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.typing as npt
3 | from lxml import etree
4 |
5 |
6 | def html_to_numpy(html_string: str) -> npt.NDArray[np.str_]:
7 | dom_tree = etree.HTML(html_string, parser=etree.HTMLParser())
8 | table_rows: list[list[str]] = []
9 | span_info: dict[int, tuple[str, int]] = {}
10 |
11 | for table_row in dom_tree.xpath("//tr"):
12 | current_row: list[str] = []
13 | column_index = 0
14 |
15 | while span_info.get(column_index, (None, 0))[1] > 0:
16 | current_row.append(span_info[column_index][0])
17 | span_info[column_index] = (
18 | span_info[column_index][0],
19 | span_info[column_index][1] - 1,
20 | )
21 | if span_info[column_index][1] == 0:
22 | del span_info[column_index]
23 | column_index += 1
24 |
25 | for table_cell in table_row.xpath("td|th"):
26 | while span_info.get(column_index, (None, 0))[1] > 0:
27 | current_row.append(span_info[column_index][0])
28 | span_info[column_index] = (
29 | span_info[column_index][0],
30 | span_info[column_index][1] - 1,
31 | )
32 | if span_info[column_index][1] == 0:
33 | del span_info[column_index]
34 | column_index += 1
35 |
36 | row_span = int(table_cell.get("rowspan", "1"))
37 | col_span = int(table_cell.get("colspan", "1"))
38 | cell_text = "".join(table_cell.itertext()).strip()
39 |
40 | if row_span > 1:
41 | for i in range(col_span):
42 | span_info[column_index + i] = (cell_text, row_span - 1)
43 |
44 | for _ in range(col_span):
45 | current_row.append(cell_text)
46 | column_index += col_span
47 |
48 | while span_info.get(column_index, (None, 0))[1] > 0:
49 | current_row.append(span_info[column_index][0])
50 | span_info[column_index] = (
51 | span_info[column_index][0],
52 | span_info[column_index][1] - 1,
53 | )
54 | if span_info[column_index][1] == 0:
55 | del span_info[column_index]
56 | column_index += 1
57 |
58 | table_rows.append(current_row)
59 |
60 | max_columns = max(map(len, table_rows)) if table_rows else 0
61 | for row in table_rows:
62 | row.extend([""] * (max_columns - len(row)))
63 |
64 | return np.array(table_rows)
65 |
--------------------------------------------------------------------------------
/grade_cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 |
5 | from convert import html_to_numpy
6 | from grading import table_similarity
7 | import polars as pl
8 |
9 | from providers.config import settings
10 |
11 |
12 | def main(model: str, folder: str, save_to_csv: bool):
13 | groundtruth = settings.input_dir / "groundtruth"
14 | scores = []
15 |
16 | html_files = glob.glob(os.path.join(folder, "*.html"))
17 | for pred_html_path in html_files:
18 | # The filename might be something like: 10035_png.rf.07e8e5bf2e9ad4e77a84fd38d1f53f38.html
19 | base_name = os.path.basename(pred_html_path)
20 |
21 | # Build the path to the corresponding ground-truth file
22 | gt_html_path = os.path.join(groundtruth, base_name)
23 | if not os.path.exists(gt_html_path):
24 | continue
25 |
26 | with open(pred_html_path, "r") as f:
27 | pred_html = f.read()
28 |
29 | with open(gt_html_path, "r") as f:
30 | gt_html = f.read()
31 |
32 | # Convert HTML -> NumPy arrays
33 | try:
34 | pred_array = html_to_numpy(pred_html)
35 | gt_array = html_to_numpy(gt_html)
36 |
37 | # Compute similarity (0.0 to 1.0)
38 | score = table_similarity(gt_array, pred_array)
39 | except Exception as e:
40 | print(f"Error converting {base_name}: {e}")
41 | continue
42 |
43 | scores.append((base_name, score))
44 | print(f"{base_name}: {score:.4f}")
45 |
46 | score_dicts = [{"filename": fname, "score": scr} for fname, scr in scores]
47 | df = pl.DataFrame(score_dicts)
48 | print(
49 | f"Average score for {model}: {df['score'].mean():.2f} with std {df['score'].std():.2f}"
50 | )
51 | if save_to_csv:
52 | df.write_csv(f"./scores/{model}_scores.csv")
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--model", type=str, required=True)
58 | parser.add_argument("--save-to-csv", type=bool, default=True)
59 | args = parser.parse_args()
60 |
61 | model_dir = settings.output_dir / f"{args.model}-raw"
62 | assert model_dir.exists(), f"Model directory {model_dir} does not exist"
63 | main(args.model, model_dir, args.save_to_csv)
64 |
--------------------------------------------------------------------------------
/grading.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.typing as npt
3 | from Levenshtein import distance as levenshtein_distance
4 |
5 |
6 | BATCH_SIZE = 150
7 |
8 | # Scoring parameters (you can adjust these as needed)
9 | S_ROW_MATCH = 5 # Match score for row alignment
10 | G_ROW = -3 # Gap penalty for row alignment (insertion/deletion of a row)
11 | S_CELL_MATCH = 1 # Match score for cell matching
12 | P_CELL_MISMATCH = -1 # Penalty for cell mismatch
13 | G_COL = -1 # Gap penalty for column alignment
14 |
15 |
16 | def cell_match_score(cell1: str | None, cell2: str | None) -> float:
17 | """Compute the match score between two cells considering partial matches."""
18 | if cell1 is None or cell2 is None:
19 | return P_CELL_MISMATCH # Penalty for gaps or mismatches
20 | if cell1 == cell2:
21 | return S_CELL_MATCH # Cells are identical
22 |
23 | # Compute the Levenshtein distance using the optimized library
24 | distance = levenshtein_distance(cell1, cell2)
25 | max_len = max(len(cell1), len(cell2))
26 | if max_len == 0:
27 | normalized_distance = 0.0 # Both cells are empty strings
28 | else:
29 | normalized_distance = distance / max_len
30 | similarity = 1.0 - normalized_distance # Similarity between 0 and 1
31 | match_score = P_CELL_MISMATCH + similarity * (S_CELL_MATCH - P_CELL_MISMATCH)
32 | return match_score
33 |
34 |
35 | def needleman_wunsch(
36 | seq1: list[str], seq2: list[str], gap_penalty: int
37 | ) -> tuple[list[str | None], list[str | None], float]:
38 | """
39 | Perform Needleman-Wunsch alignment between two sequences with free end gaps.
40 |
41 | Parameters:
42 | seq1, seq2: sequences to align (lists of strings)
43 | gap_penalty: penalty for gaps (insertions/deletions)
44 |
45 | Returns:
46 | alignment_a, alignment_b: aligned sequences with gaps represented by None
47 | score: total alignment score
48 | """
49 | m = len(seq1)
50 | n = len(seq2)
51 |
52 | # Initialize the scoring matrix
53 | score_matrix = np.zeros((m + 1, n + 1), dtype=np.float32)
54 | traceback = np.full((m + 1, n + 1), None)
55 |
56 | # Initialize the first row and column (no gap penalties for leading gaps)
57 | for i in range(1, m + 1):
58 | traceback[i, 0] = "up"
59 | for j in range(1, n + 1):
60 | traceback[0, j] = "left"
61 |
62 | # Fill the rest of the matrix
63 | for i in range(1, m + 1):
64 | seq1_i = seq1[i - 1]
65 | for j in range(1, n + 1):
66 | seq2_j = seq2[j - 1]
67 | match = score_matrix[i - 1, j - 1] + cell_match_score(seq1_i, seq2_j)
68 | delete = score_matrix[i - 1, j] + gap_penalty
69 | insert = score_matrix[i, j - 1] + gap_penalty
70 | max_score = max(match, delete, insert)
71 | score_matrix[i, j] = max_score
72 | if max_score == match:
73 | traceback[i, j] = "diag"
74 | elif max_score == delete:
75 | traceback[i, j] = "up"
76 | else:
77 | traceback[i, j] = "left"
78 |
79 | # Traceback from the position with the highest score in the last row or column
80 | i, j = m, n
81 | max_score = score_matrix[i, j]
82 | max_i, max_j = i, j
83 | # Find the maximum score in the last row and column for free end gaps
84 | last_row = score_matrix[:, n]
85 | last_col = score_matrix[m, :]
86 | if last_row.max() > max_score:
87 | max_i = last_row.argmax()
88 | max_j = n
89 | max_score = last_row[max_i]
90 | if last_col.max() > max_score:
91 | max_i = m
92 | max_j = last_col.argmax()
93 | max_score = last_col[max_j]
94 |
95 | # Traceback to get the aligned sequences
96 | alignment_a: list[str | None] = []
97 | alignment_b: list[str | None] = []
98 | i, j = max_i, max_j
99 | while i > 0 or j > 0:
100 | tb_direction = traceback[i, j]
101 | if i > 0 and j > 0 and tb_direction == "diag":
102 | alignment_a.insert(0, seq1[i - 1])
103 | alignment_b.insert(0, seq2[j - 1])
104 | i -= 1
105 | j -= 1
106 | elif i > 0 and (j == 0 or tb_direction == "up"):
107 | alignment_a.insert(0, seq1[i - 1])
108 | alignment_b.insert(0, None) # Gap in seq2
109 | i -= 1
110 | elif j > 0 and (i == 0 or tb_direction == "left"):
111 | alignment_a.insert(0, None) # Gap in seq1
112 | alignment_b.insert(0, seq2[j - 1])
113 | j -= 1
114 | else:
115 | break # Should not reach here
116 |
117 | return alignment_a, alignment_b, max_score
118 |
119 |
120 | def table_similarity(
121 | ground_truth: npt.NDArray[np.str_], prediction: npt.NDArray[np.str_]
122 | ) -> float:
123 | """
124 | Compute the similarity between two tables represented as ndarrays of strings,
125 | allowing for a subset of rows at the top or bottom without penalization (to avoid penalizing subtable cropping).
126 |
127 | Parameters:
128 | ground_truth, prediction: ndarrays of strings representing the tables
129 |
130 | Returns:
131 | similarity: similarity score between 0 and 1
132 | """
133 |
134 | # Remove newlines and normalize whitespace in cells
135 | def normalize_cell(cell: str) -> str:
136 | return "".join(cell.replace("\n", " ").replace("-", "").split()).replace(
137 | " ", ""
138 | )
139 |
140 | # Apply normalization to both ground truth and prediction arrays
141 | vectorized_normalize = np.vectorize(normalize_cell)
142 | ground_truth = vectorized_normalize(ground_truth)
143 | prediction = vectorized_normalize(prediction)
144 |
145 | # Convert to lists of lists for easier manipulation
146 | gt_rows = [list(row) for row in ground_truth]
147 | pred_rows = [list(row) for row in prediction]
148 |
149 | # Precompute the column alignment scores between all pairs of rows
150 | m = len(gt_rows)
151 | n = len(pred_rows)
152 | row_match_scores = np.zeros((m, n), dtype=np.float32)
153 |
154 | for i in range(m):
155 | gt_row = gt_rows[i]
156 | for j in range(n):
157 | pred_row = pred_rows[j]
158 | # Align columns of the two rows
159 | _, _, col_score = needleman_wunsch(gt_row, pred_row, G_COL)
160 | # Adjusted row match score
161 | row_match_scores[i, j] = col_score + S_ROW_MATCH
162 |
163 | # Initialize the scoring matrix for row alignment with free end gaps
164 | score_matrix = np.zeros((m + 1, n + 1), dtype=np.float32)
165 | traceback = np.full((m + 1, n + 1), None)
166 |
167 | # No gap penalties for leading gaps
168 | for i in range(1, m + 1):
169 | traceback[i, 0] = "up"
170 | for j in range(1, n + 1):
171 | traceback[0, j] = "left"
172 |
173 | # Fill the rest of the scoring matrix
174 | for i in range(1, m + 1):
175 | for j in range(1, n + 1):
176 | match = score_matrix[i - 1, j - 1] + row_match_scores[i - 1, j - 1]
177 | delete = score_matrix[i - 1, j] + G_ROW
178 | insert = score_matrix[i, j - 1] + G_ROW
179 | max_score = max(match, delete, insert)
180 | score_matrix[i, j] = max_score
181 | if max_score == match:
182 | traceback[i, j] = "diag"
183 | elif max_score == delete:
184 | traceback[i, j] = "up"
185 | else:
186 | traceback[i, j] = "left"
187 |
188 | # Traceback from the position with the highest score in the last row or column
189 | i, j = m, n
190 | max_score = score_matrix[i, j]
191 | max_i, max_j = i, j
192 | # Find the maximum score in the last row and column for free end gaps
193 | last_row = score_matrix[:, n]
194 | last_col = score_matrix[m, :]
195 | if last_row.max() > max_score:
196 | max_i = last_row.argmax()
197 | max_j = n
198 | max_score = last_row[max_i]
199 | if last_col.max() > max_score:
200 | max_i = m
201 | max_j = last_col.argmax()
202 | max_score = last_col[max_j]
203 |
204 | # Traceback to get the aligned rows
205 | alignment_gt_rows: list[list[str | None]] = []
206 | alignment_pred_rows: list[list[str | None]] = []
207 | i, j = max_i, max_j
208 | while i > 0 or j > 0:
209 | tb_direction = traceback[i, j]
210 | if i > 0 and j > 0 and tb_direction == "diag":
211 | alignment_gt_rows.insert(0, gt_rows[i - 1])
212 | alignment_pred_rows.insert(0, pred_rows[j - 1])
213 | i -= 1
214 | j -= 1
215 | elif i > 0 and (j == 0 or tb_direction == "up"):
216 | alignment_gt_rows.insert(0, gt_rows[i - 1])
217 | alignment_pred_rows.insert(
218 | 0, [None] * len(gt_rows[i - 1])
219 | ) # Gap in prediction
220 | i -= 1
221 | elif j > 0 and (i == 0 or tb_direction == "left"):
222 | alignment_gt_rows.insert(
223 | 0, [None] * len(pred_rows[j - 1])
224 | ) # Gap in ground truth
225 | alignment_pred_rows.insert(0, pred_rows[j - 1])
226 | j -= 1
227 | else:
228 | break # Should not reach here
229 |
230 | # Compute the actual total score
231 | actual_total_score = max_score
232 |
233 | # Compute the total possible score
234 | num_aligned_rows = len(alignment_gt_rows)
235 | if num_aligned_rows == 0:
236 | return 0.0 # Avoid division by zero
237 | max_row_score = num_aligned_rows * (S_ROW_MATCH + len(gt_rows[0]) * S_CELL_MATCH)
238 | total_possible_score = max_row_score
239 |
240 | # Normalize the similarity score
241 | similarity = actual_total_score / total_possible_score
242 | return max(0.0, min(similarity, 1.0))
243 |
--------------------------------------------------------------------------------
/parsing.py:
--------------------------------------------------------------------------------
1 | """
2 | For each format, this code extracts the largest HTML table from the response.
3 | """
4 |
5 | import json
6 | from typing import Any
7 | import os
8 |
9 | import argparse
10 | import glob
11 |
12 |
13 | def parse_textract_response(path: str) -> tuple[str | None, Any]:
14 | if not os.path.exists(path):
15 | return None, None
16 |
17 | with open(path, "r") as f:
18 | data = json.load(f)
19 |
20 | return data["html_table"], data
21 |
22 |
23 | def parse_gcloud_response(path: str) -> tuple[str | None, Any]:
24 | if not os.path.exists(path):
25 | return None, None
26 |
27 | try:
28 | with open(path, "r") as f:
29 | data = json.load(f)
30 | except Exception:
31 | return None, None
32 |
33 | return data["html_table"], data
34 |
35 |
36 | def parse_reducto_response(path: str) -> tuple[str | None, Any]:
37 | if not os.path.exists(path):
38 | return None, None
39 |
40 | with open(path, "r") as f:
41 | data = json.load(f)
42 |
43 | if "error" in data:
44 | return None, data
45 |
46 | longest_html = None
47 | max_length = 0
48 |
49 | for chunk in data["result"]["chunks"]:
50 | blocks = chunk["blocks"]
51 | for block in blocks:
52 | if block["type"] == "Table":
53 | if len(block["content"]) > max_length:
54 | max_length = len(block["content"])
55 | longest_html = block["content"]
56 |
57 | return longest_html, data
58 |
59 |
60 | def parse_chunkr_response(path: str) -> tuple[str | None, Any]:
61 | if not os.path.exists(path):
62 | return None, None
63 |
64 | with open(path, "r") as f:
65 | data = json.load(f)
66 |
67 | if data.get("status") != "Succeeded":
68 | return None, data
69 |
70 | largest_html = None
71 | max_length = 0
72 |
73 | try:
74 | for output in (
75 | data.get("output", [])
76 | if "chunks" not in data.get("output")
77 | else data["output"]["chunks"]
78 | ):
79 | for segment in output.get("segments", []):
80 | if segment.get("segment_type") == "Table" and segment.get("html"):
81 | if len(segment["html"]) > max_length:
82 | max_length = len(segment["html"])
83 | largest_html = segment["html"]
84 | except Exception:
85 | import traceback
86 |
87 | traceback.print_exc()
88 | print(data)
89 |
90 | return largest_html, data
91 |
92 |
93 | def parse_unstructured_response(path: str) -> tuple[str | None, Any]:
94 | if not os.path.exists(path):
95 | return None, None
96 |
97 | with open(path, "r") as f:
98 | data = json.load(f)
99 |
100 | largest_html = None
101 | max_length = 0
102 |
103 | for element in data.get("elements", []):
104 | if element.get("type") == "Table" and element.get("metadata", {}).get(
105 | "text_as_html"
106 | ):
107 | html = element["metadata"]["text_as_html"]
108 | if len(html) > max_length:
109 | max_length = len(html)
110 | largest_html = html
111 |
112 | return largest_html, data
113 |
114 |
115 | def parse_gpt4o_response(path: str) -> tuple[str | None, Any]:
116 | if not os.path.exists(path):
117 | return None, None
118 |
119 | with open(path, "r") as f:
120 | data = json.load(f)
121 |
122 | html = data["html_table"]
123 | # Extract just the table portion between
124 | start = html.find("")
125 | end = html.find("
") + 8
126 | if start != -1 and end != -1:
127 | return html[start:end], data
128 | return None, data
129 |
130 |
131 | def parse_gemini_response(path: str) -> tuple[str | None, Any]:
132 | if not os.path.exists(path):
133 | return None, None
134 |
135 | with open(path, "r") as f:
136 | data = json.load(f)
137 |
138 | html = data["html_table"]
139 | # Extract just the table portion between
140 | start = html.find("")
141 | end = html.find("
") + 8
142 | if start != -1 and end != -1:
143 | return html[start:end], data
144 | return None, data
145 |
146 |
147 | def parse_azure_response(path: str) -> tuple[str | None, Any]:
148 | data = None
149 | try:
150 | with open(path, "r") as f:
151 | data = json.load(f)
152 |
153 | def azure_to_html(table: Any) -> str:
154 | html = ""
155 | for row_index in range(table["rowCount"]):
156 | html += ""
157 | for col_index in range(table["columnCount"]):
158 | cell = next(
159 | (
160 | c
161 | for c in table["cells"]
162 | if c["rowIndex"] == row_index
163 | and c["columnIndex"] == col_index
164 | ),
165 | None,
166 | )
167 | if cell:
168 | content = (
169 | cell["content"]
170 | .replace(":selected:", "")
171 | .replace(":unselected:", "")
172 | )
173 | tag = "th" if cell.get("kind") == "columnHeader" else "td"
174 | rowspan = (
175 | f" rowspan='{cell['rowSpan']}'" if "rowSpan" in cell else ""
176 | )
177 | colspan = (
178 | f" colspan='{cell['columnSpan']}'"
179 | if "columnSpan" in cell
180 | else ""
181 | )
182 | html += f"<{tag}{rowspan}{colspan}>{content}{tag}>"
183 | else:
184 | pass
185 | html += "
"
186 | html += "
"
187 | return html
188 |
189 | # Find table with largest area (row count * column count)
190 | largest_table = max(
191 | data["tables"], key=lambda t: t["rowCount"] * t["columnCount"]
192 | )
193 | return azure_to_html(largest_table), data
194 | except Exception:
195 | return None, data
196 |
197 |
198 | PARSERS = {
199 | "textract": parse_textract_response,
200 | "gcloud": parse_gcloud_response,
201 | "reducto": parse_reducto_response,
202 | "chunkr": parse_chunkr_response,
203 | "unstructured": parse_unstructured_response,
204 | "gpt4o": parse_gpt4o_response,
205 | "azure": parse_azure_response,
206 | "gemini": parse_gemini_response,
207 | }
208 |
209 |
210 | def main():
211 | parser = argparse.ArgumentParser()
212 | parser.add_argument(
213 | "--provider",
214 | required=True,
215 | choices=PARSERS.keys(),
216 | help="Which parser to use (e.g. 'gpt4o', 'azure', etc.).",
217 | )
218 | parser.add_argument(
219 | "--input-folder",
220 | required=True,
221 | help="Folder containing .json files from the provider.",
222 | )
223 | parser.add_argument(
224 | "--output-folder",
225 | required=True,
226 | help="Folder to write the extracted .html files.",
227 | )
228 | args = parser.parse_args()
229 |
230 | # Get the parser function based on the provider
231 | parse_func = PARSERS[args.provider]
232 |
233 | # Ensure output folder exists
234 | os.makedirs(args.output_folder, exist_ok=True)
235 |
236 | # Find all JSON files under input folder (recursively)
237 | json_paths = glob.glob(
238 | os.path.join(args.input_folder, "**", "*.json"), recursive=True
239 | )
240 |
241 | for json_file in json_paths:
242 | # Parse the JSON to get HTML
243 | html, raw_data = parse_func(json_file)
244 |
245 | if not html:
246 | # No table found or parse error
247 | print(f"Skipping (no HTML found): {json_file}")
248 | continue
249 |
250 | # Build output path: replace .json with .html and replicate subfolders if desired
251 | relative_path = os.path.relpath(json_file, start=args.input_folder)
252 | out_name = os.path.splitext(relative_path)[0] + ".html"
253 | out_path = os.path.join(args.output_folder, out_name)
254 |
255 | # Make sure subdirectories exist
256 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
257 |
258 | # Write the HTML to file
259 | with open(out_path, "w") as f:
260 | f.write(html)
261 |
262 | print(f"Saved HTML to: {out_path}")
263 |
264 |
265 | if __name__ == "__main__":
266 | main()
267 |
--------------------------------------------------------------------------------
/providers/README.md:
--------------------------------------------------------------------------------
1 | # Providers
2 |
3 | This directory contains the code for the different providers that are used to extract tables from PDFs.
--------------------------------------------------------------------------------
/providers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Filimoa/rd-tablebench/4f9b12479cbef94f052b09191f3bffabdf719823/providers/__init__.py
--------------------------------------------------------------------------------
/providers/azure_docintelligence.py:
--------------------------------------------------------------------------------
1 | # import libraries
2 | import os
3 | from azure.core.credentials import AzureKeyCredential
4 | from azure.ai.documentintelligence import DocumentIntelligenceClient
5 | from azure.ai.documentintelligence.models import AnalyzeDocumentRequest
6 | import glob
7 | import json
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 | from tqdm import tqdm
10 | from azure.core.exceptions import HttpResponseError
11 | import backoff
12 |
13 | endpoint = os.environ["AZURE_ENDPOINT"]
14 | key = os.environ["AZURE_KEY"]
15 |
16 | base_path = os.path.expanduser("~/data/human_table_benchmark")
17 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
18 |
19 | document_intelligence_client = DocumentIntelligenceClient(
20 | endpoint=endpoint, credential=AzureKeyCredential(key)
21 | )
22 |
23 |
24 | def is_rate_limit_error(exception):
25 | return isinstance(exception, HttpResponseError) and exception.status_code == 429
26 |
27 |
28 | @backoff.on_exception(
29 | backoff.expo, HttpResponseError, giveup=lambda e: not is_rate_limit_error(e)
30 | )
31 | def analyze_document(file_content):
32 | return document_intelligence_client.begin_analyze_document(
33 | "prebuilt-layout", AnalyzeDocumentRequest(bytes_source=file_content)
34 | ).result()
35 |
36 |
37 | def process_pdf(pdf_path: str):
38 | output_path = pdf_path.replace("pdfs", "azure").replace(".pdf", ".json")
39 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
40 |
41 | try:
42 | with open(pdf_path, "rb") as file:
43 | result = analyze_document(file.read())
44 |
45 | with open(output_path, "w") as f:
46 | json.dump(result.as_dict(), f, indent=2)
47 |
48 | return pdf_path, None
49 | except HttpResponseError as e:
50 | if e.status_code == 429: # Rate limit exceeded
51 | return pdf_path, "Rate limit"
52 | else:
53 | return pdf_path, str(e)
54 | except Exception as e:
55 | return pdf_path, str(e)
56 |
57 |
58 | def process_all_pdfs(pdfs: list[str]):
59 | max_workers = 200
60 |
61 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
62 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs}
63 |
64 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
65 |
66 | for future in as_completed(futures):
67 | pdf_path, error = future.result()
68 | if error:
69 | print(f"Error processing {pdf_path}: {error}")
70 | progress_bar.update(1)
71 |
72 | progress_bar.close()
73 |
74 | print(f"Processed {len(pdfs)} PDFs")
75 |
76 |
77 | if __name__ == "__main__":
78 | process_all_pdfs(pdfs)
79 |
--------------------------------------------------------------------------------
/providers/chunkr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import asyncio
5 | import aiohttp
6 | import time
7 | from typing import Tuple
8 | from tqdm import tqdm
9 |
10 | base_path = os.path.expanduser("~/data/human_table_benchmark")
11 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
12 |
13 | API_KEY = os.environ.get("CHUNKR_API_KEY")
14 | if not API_KEY:
15 | raise ValueError("CHUNKR_API_KEY environment variable is not set")
16 |
17 | CHUNKR_URL = "https://api.chunkr.ai/api/v1/task"
18 | HEADERS = {"Authorization": API_KEY}
19 |
20 |
21 | async def process_pdf(
22 | pdf_path: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore
23 | ) -> Tuple[str, dict]:
24 | async with semaphore:
25 | try:
26 | with open(pdf_path, "rb") as file:
27 | data = aiohttp.FormData()
28 | data.add_field(
29 | "file",
30 | file,
31 | filename=os.path.basename(pdf_path),
32 | content_type="application/pdf",
33 | )
34 | data.add_field("model", "HighQuality")
35 | data.add_field("target_chunk_length", "512")
36 | data.add_field("ocr_strategy", "Auto")
37 |
38 | async with session.post(
39 | CHUNKR_URL, headers=HEADERS, data=data
40 | ) as response:
41 | if response.status == 200:
42 | task_info = await response.json()
43 | task_id = task_info["task_id"]
44 | result = await poll_task(pdf_path, task_id, session)
45 | return pdf_path, result
46 | else:
47 | return pdf_path, {
48 | "error": f"Error: {response.status}, {await response.text()}"
49 | }
50 | except Exception as e:
51 | return pdf_path, {"error": str(e)}
52 |
53 |
54 | async def poll_task(
55 | pdf_path: str, task_id: str, session: aiohttp.ClientSession
56 | ) -> dict:
57 | task_url = f"{CHUNKR_URL}/{task_id}"
58 | start_time = time.time()
59 | while time.time() - start_time < 60 * 10:
60 | try:
61 | async with session.get(task_url, headers=HEADERS) as response:
62 | if response.status == 200:
63 | task_info = await response.json()
64 | if task_info["status"] == "Succeeded":
65 | return task_info
66 | elif task_info["status"] in ["Failed", "Canceled"]:
67 | return {"error": f"Task failed or canceled: {task_info}"}
68 | else:
69 | return {
70 | "error": f"Error polling task: {response.status}, {await response.text()}"
71 | }
72 | except Exception as e:
73 | return {"error": str(e)}
74 | await asyncio.sleep(5) # Wait 5 seconds before polling again
75 | return {"error": "Timeout: Task did not complete within 10 minutes"}
76 |
77 |
78 | async def process_all_pdfs(pdfs: list[str]):
79 | semaphore = asyncio.Semaphore(100) # Limit to 100
80 | async with aiohttp.ClientSession() as session:
81 | tasks = [process_pdf(pdf, session, semaphore) for pdf in pdfs]
82 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
83 | results = []
84 | for task in asyncio.as_completed(tasks):
85 | pdf_path, result = await task
86 | results.append((pdf_path, result))
87 | progress_bar.update(1)
88 | progress_bar.close()
89 |
90 | for pdf_path, result in results:
91 | output_path = pdf_path.replace("pdfs", "chunkr_nov1").replace(".pdf", ".json")
92 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
93 | with open(output_path, "w") as f:
94 | json.dump(result, f, indent=2)
95 |
96 | print(f"Processed {len(results)} PDFs")
97 |
98 |
99 | asyncio.run(process_all_pdfs(pdfs))
100 |
--------------------------------------------------------------------------------
/providers/config.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from dotenv import load_dotenv
3 | import os
4 | from pathlib import Path
5 |
6 | load_dotenv()
7 |
8 |
9 | class Settings(BaseModel):
10 | input_dir: Path
11 | output_dir: Path
12 | openai_api_key: str | None = None
13 | gemini_api_key: str | None = None
14 | anthropic_api_key: str | None = None
15 |
16 |
17 | settings = Settings(
18 | input_dir=os.getenv("INPUT_DIR"),
19 | output_dir=os.getenv("OUTPUT_DIR"),
20 | openai_api_key=os.getenv("OPENAI_API_KEY"),
21 | gemini_api_key=os.getenv("GEMINI_API_KEY"),
22 | anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
23 | )
24 |
--------------------------------------------------------------------------------
/providers/gcloud.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | from google.api_core.client_options import ClientOptions
3 | from google.cloud import documentai
4 | from google.cloud.documentai_toolbox import document
5 | import os
6 | import glob
7 | import json
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 | from tqdm import tqdm
10 | import time
11 | import backoff
12 |
13 | project_id = os.environ["GCP_PROJECT_ID"]
14 | location = "us"
15 | processor_id = os.environ["GCP_PROCESSOR_ID"]
16 | mime_type = "application/pdf"
17 |
18 | base_path = os.path.expanduser("~/data/human_table_benchmark")
19 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
20 |
21 | opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com")
22 | client = documentai.DocumentProcessorServiceClient(client_options=opts)
23 | name = client.processor_path(project_id, location, processor_id)
24 |
25 |
26 | @backoff.on_exception(
27 | backoff.expo, Exception, giveup=lambda e: "quota" not in str(e).lower(), max_tries=5
28 | )
29 | def process_document(file_path: str) -> Tuple[Optional[str], float]:
30 | start_time = time.time()
31 |
32 | # Read the file into memory
33 | with open(file_path, "rb") as pdf_file:
34 | pdf_content = pdf_file.read()
35 |
36 | # Load binary data
37 | raw_document = documentai.RawDocument(content=pdf_content, mime_type=mime_type)
38 |
39 | # Configure the process request
40 | request = documentai.ProcessRequest(
41 | name=name,
42 | raw_document=raw_document,
43 | )
44 |
45 | result = client.process_document(request=request)
46 |
47 | toolbox_document = document.Document.from_documentai_document(result.document)
48 |
49 | for page in toolbox_document.pages:
50 | for table in page.tables:
51 | html_table = table.to_dataframe().to_html()
52 | processing_time = time.time() - start_time
53 | return html_table, processing_time
54 |
55 | processing_time = time.time() - start_time
56 | return None, processing_time
57 |
58 |
59 | def process_pdf(pdf_path: str):
60 | output_path = pdf_path.replace("pdfs", "gcloud").replace(".pdf", ".json")
61 |
62 | if os.path.exists(output_path):
63 | return pdf_path, None
64 |
65 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
66 |
67 | try:
68 | html_table, processing_time = process_document(pdf_path)
69 |
70 | result = {"html_table": html_table, "processing_time": processing_time}
71 |
72 | with open(output_path, "w") as f:
73 | json.dump(result, f, indent=2)
74 |
75 | return pdf_path, None
76 | except Exception as e:
77 | return pdf_path, str(e)
78 |
79 |
80 | def process_all_pdfs(pdfs: list[str]):
81 | max_workers = 5
82 |
83 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
84 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs}
85 |
86 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
87 |
88 | for future in as_completed(futures):
89 | pdf_path, error = future.result()
90 | if error:
91 | print(f"Error processing {pdf_path}: {error}")
92 | progress_bar.update(1)
93 |
94 | progress_bar.close()
95 |
96 | print(f"Processed {len(pdfs)} PDFs")
97 |
98 |
99 | if __name__ == "__main__":
100 | process_all_pdfs(pdfs)
101 |
--------------------------------------------------------------------------------
/providers/llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import glob
4 | import os
5 | from concurrent.futures import ThreadPoolExecutor, as_completed
6 | from io import BytesIO
7 | from typing import Any, Literal
8 |
9 | import backoff
10 | import openai
11 | from openai import OpenAI
12 | from pdf2image import convert_from_path
13 | from tqdm import tqdm
14 |
15 | from providers.config import settings
16 |
17 |
18 | base_path = os.path.expanduser(settings.input_dir)
19 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
20 |
21 |
22 | def convert_pdf_to_base64_image(pdf_path):
23 | images = convert_from_path(pdf_path, first_page=1, last_page=1)
24 | img_buffer = BytesIO()
25 | images[0].save(img_buffer, format="PNG")
26 | return base64.b64encode(img_buffer.getvalue()).decode("utf-8")
27 |
28 |
29 | @backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=5)
30 | def analyze_document_openai_sdk(base64_image, model: str):
31 | response = client.chat.completions.create(
32 | model=model,
33 | messages=[
34 | {
35 | "role": "user",
36 | "content": [
37 | {
38 | "type": "text",
39 | "text": "Convert the image to an HTML table. The output should begin with . Specify rowspan and colspan attributes when they are greater than 1. Do not specify any other attributes. Only use table related HTML tags, no additional formatting is required.",
40 | },
41 | {
42 | "type": "image_url",
43 | "image_url": {"url": f"data:image/png;base64,{base64_image}"},
44 | },
45 | ],
46 | }
47 | ],
48 | max_tokens=4096,
49 | )
50 | return response.choices[0].message.content
51 |
52 |
53 | @backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=5)
54 | def analyze_document_anthropic(base64_image, model: str):
55 | response = client.messages.create(
56 | model=model,
57 | max_tokens=4096,
58 | messages=[
59 | {
60 | "role": "user",
61 | "content": [
62 | {
63 | "type": "text",
64 | "text": "Convert the image to an HTML table. The output should begin with . Specify rowspan and colspan attributes when they are greater than 1. Do not specify any other attributes. Only use table related HTML tags, no additional formatting is required.",
65 | },
66 | {
67 | "type": "image",
68 | "source": {
69 | "type": "base64",
70 | "media_type": "image/png",
71 | "data": base64_image,
72 | },
73 | },
74 | ],
75 | }
76 | ],
77 | )
78 | return response.content[0].text
79 |
80 |
81 | def parse_gemini_response(content: str) -> tuple[str | None, Any]:
82 | # Extract just the table portion between
83 | start = content.find("")
84 | end = content.find("
") + 8
85 | if start != -1 and end != -1:
86 | return content[start:end], None
87 | return None, None
88 |
89 |
90 | def process_pdf(pdf_path: str, model: str):
91 | output_path = pdf_path.replace("pdfs", f"outputs/{model}-raw").replace(
92 | ".pdf", ".html"
93 | )
94 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
95 |
96 | try:
97 | base64_image = convert_pdf_to_base64_image(pdf_path)
98 | if "gemini" in model:
99 | html_table = analyze_document_openai_sdk(base64_image, model)
100 | elif "gpt" in model:
101 | html_table = analyze_document_openai_sdk(base64_image, model)
102 | elif "claude" in model:
103 | html_table = analyze_document_anthropic(base64_image, model)
104 | else:
105 | raise ValueError(f"Unknown model: {model}")
106 |
107 | html, _ = parse_gemini_response(html_table)
108 |
109 | if not html:
110 | print(f"Skipping (no HTML found): {pdf_path}")
111 | return pdf_path, None
112 |
113 | with open(output_path, "w") as f:
114 | f.write(html_table)
115 |
116 | return pdf_path, None
117 | except Exception as e:
118 | return pdf_path, str(e)
119 |
120 |
121 | def process_all_pdfs(
122 | pdfs: list[str],
123 | model: Literal[
124 | "gemini-2.0-flash-exp",
125 | "gemini-1.5-pro",
126 | "gemini-1.5-flash",
127 | "gpt-4o-mini",
128 | "gpt-4o",
129 | "claude-3-5-sonnet-latest",
130 | ],
131 | max_workers: int,
132 | ):
133 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
134 | futures = {executor.submit(process_pdf, pdf, model): pdf for pdf in pdfs}
135 |
136 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
137 |
138 | for future in as_completed(futures):
139 | pdf_path, error = future.result()
140 | if error:
141 | print(f"Error processing {pdf_path}: {error}")
142 | progress_bar.update(1)
143 |
144 | progress_bar.close()
145 |
146 | print(f"Processed {len(pdfs)} PDFs")
147 |
148 |
149 | if __name__ == "__main__":
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument("--model", type=str, required=True)
152 | parser.add_argument("--num-workers", type=int, default=1)
153 | args = parser.parse_args()
154 |
155 | if "gemini" in args.model:
156 | assert settings.gemini_api_key
157 |
158 | client = OpenAI(
159 | api_key=settings.gemini_api_key,
160 | base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
161 | )
162 | elif "gpt" in args.model:
163 | assert settings.openai_api_key
164 |
165 | client = OpenAI(api_key=settings.openai_api_key)
166 | elif "claude" in args.model:
167 | from anthropic import Anthropic
168 |
169 | assert settings.anthropic_api_key
170 | client = Anthropic(api_key=settings.anthropic_api_key)
171 | else:
172 | raise ValueError(f"Unknown model: {args.model}")
173 |
174 | process_all_pdfs(pdfs, args.model, args.num_workers)
175 |
--------------------------------------------------------------------------------
/providers/reducto.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import asyncio
5 | import aiohttp
6 | import time
7 | from typing import Tuple
8 | from tqdm import tqdm
9 |
10 | base_path = os.path.expanduser("~/data/human_table_benchmark")
11 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
12 |
13 | API_KEY = os.environ.get("REDUCTO_API_KEY")
14 | if not API_KEY:
15 | raise ValueError("REDUCTO_API_KEY environment variable is not set")
16 |
17 | REDUCTO_URL = "https://platform.reducto.ai"
18 | HEADERS = {"Authorization": f"Bearer {API_KEY}"}
19 |
20 |
21 | async def upload_pdf(
22 | pdf_path: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore
23 | ) -> Tuple[str, str]:
24 | async with semaphore:
25 | try:
26 | upload_url = f"{REDUCTO_URL}/upload"
27 | with open(pdf_path, "rb") as file:
28 | data = aiohttp.FormData()
29 | data.add_field(
30 | "file",
31 | file,
32 | filename=os.path.basename(pdf_path),
33 | content_type="application/pdf",
34 | )
35 |
36 | async with session.post(
37 | upload_url, headers=HEADERS, data=data
38 | ) as response:
39 | if response.status == 200:
40 | result = await response.json()
41 | return pdf_path, result["file_id"]
42 | else:
43 | return (
44 | pdf_path,
45 | f"Error: {response.status}, {await response.text()}",
46 | )
47 | except Exception as e:
48 | return pdf_path, f"Error: {str(e)}"
49 |
50 |
51 | async def parse_async(
52 | file_id: str,
53 | pdf_path: str,
54 | session: aiohttp.ClientSession,
55 | semaphore: asyncio.Semaphore,
56 | ) -> Tuple[str, str]:
57 | async with semaphore:
58 | try:
59 | parse_url = f"{REDUCTO_URL}/parse_async"
60 | payload = {
61 | "document_url": f"{file_id}",
62 | "advanced_options": {
63 | "ocr_system": "combined",
64 | },
65 | }
66 |
67 | async with session.post(
68 | parse_url, headers=HEADERS, json=payload
69 | ) as response:
70 | if response.status == 200:
71 | result = await response.json()
72 | return pdf_path, result["job_id"]
73 | else:
74 | return (
75 | pdf_path,
76 | f"Error: {response.status}, {await response.text()}",
77 | )
78 | except Exception as e:
79 | return pdf_path, f"Error: {str(e)}"
80 |
81 |
82 | async def poll_job(
83 | file_id: str, job_id: str, session: aiohttp.ClientSession
84 | ) -> Tuple[str, dict]:
85 | job_url = f"{REDUCTO_URL}/job/{job_id}"
86 | start_time = time.time()
87 | while time.time() - start_time < 60 * 10: # 10 minutes timeout
88 | try:
89 | async with session.get(job_url, headers=HEADERS) as response:
90 | if response.status == 200:
91 | job_info = await response.json()
92 | if job_info["status"] == "Completed":
93 | return file_id, job_info["result"]
94 | elif job_info["status"] == "Failed":
95 | return file_id, {"error": f"Job failed: {job_info}"}
96 | else:
97 | return file_id, {
98 | "error": f"Error polling job: {response.status}, {await response.text()}"
99 | }
100 | except Exception as e:
101 | return file_id, {"error": str(e)}
102 | await asyncio.sleep(5) # Wait 5 seconds before polling again
103 | return file_id, {"error": "Timeout: Job did not complete within 10 minutes"}
104 |
105 |
106 | async def process_pdf(
107 | pdf_path: str,
108 | session: aiohttp.ClientSession,
109 | upload_semaphore: asyncio.Semaphore,
110 | parse_semaphore: asyncio.Semaphore,
111 | ) -> Tuple[str, dict]:
112 | pdf_path, file_id = await upload_pdf(pdf_path, session, upload_semaphore)
113 | if file_id.startswith("Error"):
114 | return pdf_path, {"error": file_id}
115 |
116 | file_id, job_id = await parse_async(file_id, pdf_path, session, parse_semaphore)
117 | if job_id.startswith("Error"):
118 | return pdf_path, {"error": job_id}
119 |
120 | return await poll_job(pdf_path, job_id, session)
121 |
122 |
123 | async def process_all_pdfs(pdfs: list[str]):
124 | upload_semaphore = asyncio.Semaphore(10)
125 | parse_semaphore = asyncio.Semaphore(10)
126 |
127 | async with aiohttp.ClientSession() as session:
128 | tasks = [
129 | process_pdf(pdf, session, upload_semaphore, parse_semaphore) for pdf in pdfs
130 | ]
131 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
132 | results = []
133 | for task in asyncio.as_completed(tasks):
134 | pdf_path, result = await task
135 | results.append((pdf_path, result))
136 | progress_bar.update(1)
137 | progress_bar.close()
138 |
139 | for pdf_path, result in results:
140 | output_path = pdf_path.replace("pdfs", "reducto_nov1").replace(".pdf", ".json")
141 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
142 | with open(output_path, "w") as f:
143 | json.dump(result, f, indent=2)
144 |
145 | print(f"Processed {len(results)} PDFs")
146 |
147 |
148 | async def main():
149 | await process_all_pdfs(pdfs)
150 |
151 |
152 | if __name__ == "__main__":
153 | asyncio.run(main())
154 |
--------------------------------------------------------------------------------
/providers/textract.py:
--------------------------------------------------------------------------------
1 | from pdf2image import convert_from_path
2 | from textractor import Textractor
3 | from textractor.data.constants import TextractFeatures
4 | import re
5 | import os
6 | import glob
7 | import json
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 | from tqdm import tqdm
10 | import time
11 | import backoff
12 |
13 | extractor = Textractor(region_name="us-west-2")
14 |
15 | base_path = os.path.expanduser("~/data/human_table_benchmark")
16 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
17 |
18 |
19 | @backoff.on_exception(
20 | backoff.expo, Exception, giveup=lambda e: "limit" not in str(e).lower(), max_tries=5
21 | )
22 | def extract_tables(pdf_path: str):
23 | start_time = time.time()
24 |
25 | image = convert_from_path(pdf_path)[0]
26 |
27 | document = extractor.analyze_document(
28 | file_source=image,
29 | features=[TextractFeatures.TABLES],
30 | )
31 |
32 | if len(document.tables) == 0:
33 | return None, time.time() - start_time
34 |
35 | html_table = document.tables[0].to_html()
36 |
37 | html_table = re.sub(r".*?", "", html_table, flags=re.DOTALL)
38 |
39 | end_time = time.time()
40 | processing_time = end_time - start_time
41 |
42 | return html_table, processing_time
43 |
44 |
45 | def process_pdf(pdf_path: str):
46 | output_path = pdf_path.replace("pdfs", "textract").replace(".pdf", ".json")
47 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
48 |
49 | try:
50 | html_table, processing_time = extract_tables(pdf_path)
51 |
52 | if html_table is None:
53 | result = {"html_table": None, "processing_time": processing_time}
54 | else:
55 | result = {"html_table": html_table, "processing_time": processing_time}
56 |
57 | with open(output_path, "w") as f:
58 | json.dump(result, f, indent=2)
59 |
60 | return pdf_path, None
61 | except Exception as e:
62 | return pdf_path, str(e)
63 |
64 |
65 | def process_all_pdfs(pdfs: list[str]):
66 | max_workers = 20
67 |
68 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
69 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs}
70 |
71 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
72 |
73 | for future in as_completed(futures):
74 | image_path, error = future.result()
75 | if error:
76 | print(f"Error processing {image_path}: {error}")
77 | progress_bar.update(1)
78 |
79 | progress_bar.close()
80 |
81 | print(f"Processed {len(pdfs)} PDFs")
82 |
83 |
84 | if __name__ == "__main__":
85 | process_all_pdfs(pdfs)
86 |
--------------------------------------------------------------------------------
/providers/unstructured.py:
--------------------------------------------------------------------------------
1 | # Before calling the API, replace filename and ensure sdk is installed: "pip install unstructured-client"
2 | # See https://docs.unstructured.io/api-reference/api-services/sdk for more details
3 |
4 | import os
5 | import glob
6 | import json
7 | from concurrent.futures import ThreadPoolExecutor, as_completed
8 | from tqdm import tqdm
9 | import unstructured_client
10 | from unstructured_client.models import operations, shared
11 | import backoff
12 | import time
13 |
14 | # Replace with your actual API key
15 | API_KEY = os.environ["UNSTRUCTURED_API_KEY"]
16 | SERVER_URL = "https://api.unstructuredapp.io"
17 |
18 | client = unstructured_client.UnstructuredClient(
19 | api_key_auth=API_KEY,
20 | server_url=SERVER_URL,
21 | )
22 |
23 | base_path = os.path.expanduser("~/data/human_table_benchmark")
24 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True)
25 |
26 |
27 | @backoff.on_exception(backoff.expo, Exception, max_tries=3)
28 | def process_single_file(file_path: str):
29 | with open(file_path, "rb") as f:
30 | data = f.read()
31 |
32 | start_time = time.time()
33 | req = operations.PartitionRequest(
34 | partition_parameters=shared.PartitionParameters(
35 | files=shared.Files(
36 | content=data,
37 | file_name=os.path.basename(file_path),
38 | ),
39 | strategy=shared.Strategy.HI_RES,
40 | ),
41 | )
42 |
43 | res = client.general.partition(request=req)
44 | processing_time = time.time() - start_time
45 | return {"elements": res.elements, "processing_time": processing_time}
46 |
47 |
48 | def process_pdf(pdf_path: str):
49 | output_path = pdf_path.replace("pdfs", "unstructured").replace(".pdf", ".json")
50 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
51 |
52 | try:
53 | result = process_single_file(pdf_path)
54 | with open(output_path, "w") as f:
55 | json.dump(result, f, indent=2)
56 | except Exception as e:
57 | print(f"Error processing {pdf_path}: {str(e)}")
58 |
59 |
60 | def process_all_pdfs(pdfs: list[str]):
61 | max_workers = 10 # Adjust this based on API rate limits
62 |
63 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
64 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs}
65 |
66 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs")
67 |
68 | for _ in as_completed(futures):
69 | progress_bar.update(1)
70 |
71 | progress_bar.close()
72 |
73 | print(f"Processed {len(pdfs)} PDFs")
74 |
75 |
76 | def test_single_file(file_path: str):
77 | try:
78 | result = process_single_file(file_path)
79 | print(result)
80 | except Exception as e:
81 | print(f"Error processing {file_path}: {str(e)}")
82 |
83 |
84 | if __name__ == "__main__":
85 | process_all_pdfs(pdfs)
86 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs==2.4.4
2 | aiohttp==3.11.11
3 | aiolimiter==1.2.1
4 | aiosignal==1.3.2
5 | annotated-types==0.7.0
6 | anthropic==0.43.0
7 | anyio==4.8.0
8 | attrs==24.3.0
9 | backoff==2.2.1
10 | cachetools==5.5.0
11 | certifi==2024.12.14
12 | charset-normalizer==3.4.1
13 | click==8.1.8
14 | distro==1.9.0
15 | docstring-parser==0.16
16 | frozenlist==1.5.0
17 | google-ai-generativelanguage==0.6.10
18 | google-api-core==2.24.0
19 | google-api-python-client==2.158.0
20 | google-auth==2.37.0
21 | google-auth-httplib2==0.2.0
22 | google-cloud-aiplatform==1.76.0
23 | google-cloud-bigquery==3.27.0
24 | google-cloud-core==2.4.1
25 | google-cloud-resource-manager==1.14.0
26 | google-cloud-storage==2.19.0
27 | google-crc32c==1.6.0
28 | google-generativeai==0.8.3
29 | google-resumable-media==2.7.2
30 | googleapis-common-protos==1.66.0
31 | grpc-google-iam-v1==0.14.0
32 | grpcio==1.69.0
33 | grpcio-status==1.69.0
34 | h11==0.14.0
35 | httpcore==1.0.7
36 | httplib2==0.22.0
37 | httpx==0.28.1
38 | idna==3.10
39 | instructor==1.7.2
40 | jinja2==3.1.5
41 | jiter==0.8.2
42 | jsonref==1.1.0
43 | levenshtein==0.26.1
44 | lxml==5.3.0
45 | markdown-it-py==3.0.0
46 | markupsafe==3.0.2
47 | mdurl==0.1.2
48 | multidict==6.1.0
49 | numpy==2.2.1
50 | openai==1.59.7
51 | packaging==24.2
52 | pdf2image==1.17.0
53 | pillow==11.1.0
54 | polars==1.19.0
55 | propcache==0.2.1
56 | proto-plus==1.25.0
57 | protobuf==5.29.3
58 | pyasn1==0.6.1
59 | pyasn1-modules==0.4.1
60 | pydantic==2.10.5
61 | pydantic-core==2.27.2
62 | pygments==2.19.1
63 | pyparsing==3.2.1
64 | python-dateutil==2.9.0.post0
65 | python-dotenv==1.0.1
66 | rapidfuzz==3.11.0
67 | requests==2.32.3
68 | rich==13.9.4
69 | rsa==4.9
70 | shapely==2.0.6
71 | shellingham==1.5.4
72 | six==1.17.0
73 | sniffio==1.3.1
74 | tenacity==9.0.0
75 | tqdm==4.67.1
76 | typer==0.15.1
77 | typing-extensions==4.12.2
78 | uritemplate==4.1.1
79 | urllib3==2.3.0
80 | yarl==1.18.3
81 |
--------------------------------------------------------------------------------