├── .gitignore ├── README.md ├── analyze_failures.ipynb ├── convert.py ├── grade_cli.py ├── grading.py ├── parsing.py ├── providers ├── README.md ├── __init__.py ├── azure_docintelligence.py ├── chunkr.py ├── config.py ├── gcloud.py ├── llm.py ├── reducto.py ├── textract.py └── unstructured.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | *.pyc 3 | *.csv 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RD Table Bench Invocation + Grading Code 2 | 3 | This repo contains the code for invoking each provider and grading the results. This is a fork of Reducto's original repo https://github.com/reductoai/rd-tablebench with improved support for different LLM providers (OpenAI, Anthropic, Gemini). 4 | 5 | The proprietary models that Reduco implemeted have not been touched and will not working with the grading cli. 6 | 7 | ## Installing Dependencies 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Downloading Data 14 | 15 | https://huggingface.co/datasets/reducto/rd-tablebench/blob/main/README.md 16 | 17 | ## Env Vars 18 | 19 | Create an `.env` file with the following: 20 | 21 | ``` 22 | # directory where the huggingface dataset is downloaded 23 | INPUT_DIR= 24 | 25 | # directory where the output will be saved 26 | OUTPUT_DIR= 27 | 28 | # note: only need keys for providers you want to use 29 | OPENAI_API_KEY= 30 | GEMINI_API_KEY= 31 | ANTHROPIC_API_KEY= 32 | ... 33 | ``` 34 | 35 | ## Parsing 36 | 37 | `python -m providers.llm --model gemini-2.0-flash-exp --num-workers 10` 38 | 39 | ## Grading 40 | 41 | `python -m grade_cli --model gemini-2.0-flash-exp` 42 | -------------------------------------------------------------------------------- /analyze_failures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 34, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "polars.config.Config" 12 | ] 13 | }, 14 | "execution_count": 34, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import polars as pl\n", 21 | "from pathlib import Path\n", 22 | "from providers.config import settings\n", 23 | "import webbrowser\n", 24 | "import random\n", 25 | "\n", 26 | "pl.Config.set_fmt_str_lengths(200)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 16, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 44 | "shape: (662, 7)
filenamegemini-2.0-flash-expgemini-1.5-progemini-1.5-flashgpt-4o-minigpt-4oclaude-3-5-sonnet-latest
strf64f64f64f64f64f64
"31365_png.rf.c15a4717175cda592…null1.00.720.790.831.0
"3066_png.rf.a271ccaade01ec4a7e…1.00.91.01.01.01.0
"42405_png.rf.29f7f349b16be7e94…null0.880.870.75null0.71
"6464_png.rf.87df45257826205d38…null1.00.660.82nullnull
"14587_png.rf.8ac51dfae9c3323be…nullnull0.520.22nullnull
"4369_png.rf.22210329dce81c468f…null0.94null0.330.320.62
"20008_png.rf.11bd3ea6ad0610c46…null0.85null0.260.020.01
"3794_png.rf.f3b9a9fce3f6f5b4e4…null0.58null0.08null0.5
"29352_png.rf.262ea50e23c787b4f…null0.78null0.7null0.49
"999_png.rf.db81a5df0db1f0c4854…nullnullnull0.21nullnull
" 45 | ], 46 | "text/plain": [ 47 | "shape: (662, 7)\n", 48 | "┌───────────────┬──────────────┬──────────────┬──────────────┬─────────────┬────────┬──────────────┐\n", 49 | "│ filename ┆ gemini-2.0-f ┆ gemini-1.5-p ┆ gemini-1.5-f ┆ gpt-4o-mini ┆ gpt-4o ┆ claude-3-5-s │\n", 50 | "│ --- ┆ lash-exp ┆ ro ┆ lash ┆ --- ┆ --- ┆ onnet-latest │\n", 51 | "│ str ┆ --- ┆ --- ┆ --- ┆ f64 ┆ f64 ┆ --- │\n", 52 | "│ ┆ f64 ┆ f64 ┆ f64 ┆ ┆ ┆ f64 │\n", 53 | "╞═══════════════╪══════════════╪══════════════╪══════════════╪═════════════╪════════╪══════════════╡\n", 54 | "│ 31365_png.rf. ┆ null ┆ 1.0 ┆ 0.72 ┆ 0.79 ┆ 0.83 ┆ 1.0 │\n", 55 | "│ c15a4717175cd ┆ ┆ ┆ ┆ ┆ ┆ │\n", 56 | "│ a592… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 57 | "│ 3066_png.rf.a ┆ 1.0 ┆ 0.9 ┆ 1.0 ┆ 1.0 ┆ 1.0 ┆ 1.0 │\n", 58 | "│ 271ccaade01ec ┆ ┆ ┆ ┆ ┆ ┆ │\n", 59 | "│ 4a7e… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 60 | "│ 42405_png.rf. ┆ null ┆ 0.88 ┆ 0.87 ┆ 0.75 ┆ null ┆ 0.71 │\n", 61 | "│ 29f7f349b16be ┆ ┆ ┆ ┆ ┆ ┆ │\n", 62 | "│ 7e94… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 63 | "│ 6464_png.rf.8 ┆ null ┆ 1.0 ┆ 0.66 ┆ 0.82 ┆ null ┆ null │\n", 64 | "│ 7df4525782620 ┆ ┆ ┆ ┆ ┆ ┆ │\n", 65 | "│ 5d38… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 66 | "│ 14587_png.rf. ┆ null ┆ null ┆ 0.52 ┆ 0.22 ┆ null ┆ null │\n", 67 | "│ 8ac51dfae9c33 ┆ ┆ ┆ ┆ ┆ ┆ │\n", 68 | "│ 23be… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 69 | "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", 70 | "│ 4369_png.rf.2 ┆ null ┆ 0.94 ┆ null ┆ 0.33 ┆ 0.32 ┆ 0.62 │\n", 71 | "│ 2210329dce81c ┆ ┆ ┆ ┆ ┆ ┆ │\n", 72 | "│ 468f… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 73 | "│ 20008_png.rf. ┆ null ┆ 0.85 ┆ null ┆ 0.26 ┆ 0.02 ┆ 0.01 │\n", 74 | "│ 11bd3ea6ad061 ┆ ┆ ┆ ┆ ┆ ┆ │\n", 75 | "│ 0c46… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 76 | "│ 3794_png.rf.f ┆ null ┆ 0.58 ┆ null ┆ 0.08 ┆ null ┆ 0.5 │\n", 77 | "│ 3b9a9fce3f6f5 ┆ ┆ ┆ ┆ ┆ ┆ │\n", 78 | "│ b4e4… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 79 | "│ 29352_png.rf. ┆ null ┆ 0.78 ┆ null ┆ 0.7 ┆ null ┆ 0.49 │\n", 80 | "│ 262ea50e23c78 ┆ ┆ ┆ ┆ ┆ ┆ │\n", 81 | "│ 7b4f… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 82 | "│ 999_png.rf.db ┆ null ┆ null ┆ null ┆ 0.21 ┆ null ┆ null │\n", 83 | "│ 81a5df0db1f0c ┆ ┆ ┆ ┆ ┆ ┆ │\n", 84 | "│ 4854… ┆ ┆ ┆ ┆ ┆ ┆ │\n", 85 | "└───────────────┴──────────────┴──────────────┴──────────────┴─────────────┴────────┴──────────────┘" 86 | ] 87 | }, 88 | "execution_count": 16, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "models = [\n", 95 | " \"gemini-2.0-flash-exp\",\n", 96 | " \"gemini-1.5-pro\",\n", 97 | " \"gemini-1.5-flash\",\n", 98 | " \"gpt-4o-mini\",\n", 99 | " \"gpt-4o\",\n", 100 | " \"claude-3-5-sonnet-latest\",\n", 101 | "]\n", 102 | "\n", 103 | "dfs = []\n", 104 | "for model in models:\n", 105 | " df = pl.read_csv(f\"./scores/{model}_scores.csv\")\n", 106 | " dfs.append(df.rename({\"score\": model}))\n", 107 | "\n", 108 | "merged_df = dfs[0]\n", 109 | "for df in dfs[1:]:\n", 110 | " merged_df = merged_df.join(df, on=\"filename\", how=\"full\", coalesce=True)\n", 111 | "\n", 112 | "merged_df = merged_df.with_columns(pl.col(pl.Float64).round(2))\n", 113 | "merged_df" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 74, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n", 126 | "None\n", 127 | "/Users/sergey/Downloads/rd-tablebench/outputs/gpt-4o-mini-raw/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n", 128 | "/Users/sergey/Downloads/rd-tablebench/groundtruth/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.html\n", 129 | "/Users/sergey/Downloads/rd-tablebench/pdfs/36278_png.rf.745f6e6efb75dd8a2247cad732e53033.pdf\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "def open_output_file(filename: str, model: str):\n", 135 | " url = settings.output_dir / f\"{model}-raw\" / filename\n", 136 | " assert url.exists(), f\"File {url} does not exist\"\n", 137 | " print(url)\n", 138 | "\n", 139 | " webbrowser.open(str(url))\n", 140 | "\n", 141 | "\n", 142 | "def open_ground_truth_file(filename: str):\n", 143 | " url = settings.input_dir / \"groundtruth\" / filename\n", 144 | " assert url.exists(), f\"File {url} does not exist\"\n", 145 | " print(url)\n", 146 | " webbrowser.open(str(url))\n", 147 | "\n", 148 | "\n", 149 | "def open_source_file(filename: str):\n", 150 | " url = settings.input_dir / \"pdfs\" / filename.replace(\".html\", \".pdf\")\n", 151 | " assert url.exists(), f\"File {url} does not exist\"\n", 152 | " print(url)\n", 153 | " webbrowser.open(str(url))\n", 154 | "\n", 155 | "\n", 156 | "random_filename = random.choice(merged_df.to_dicts())\n", 157 | "print(random_filename[\"filename\"])\n", 158 | "print(random_filename[\"claude-3-5-sonnet-latest\"])\n", 159 | "open_output_file(random_filename[\"filename\"], \"gpt-4o-mini\")\n", 160 | "open_ground_truth_file(random_filename[\"filename\"])\n", 161 | "open_source_file(random_filename[\"filename\"])\n" 162 | ] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": ".venv", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.12.5" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 2 186 | } 187 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | from lxml import etree 4 | 5 | 6 | def html_to_numpy(html_string: str) -> npt.NDArray[np.str_]: 7 | dom_tree = etree.HTML(html_string, parser=etree.HTMLParser()) 8 | table_rows: list[list[str]] = [] 9 | span_info: dict[int, tuple[str, int]] = {} 10 | 11 | for table_row in dom_tree.xpath("//tr"): 12 | current_row: list[str] = [] 13 | column_index = 0 14 | 15 | while span_info.get(column_index, (None, 0))[1] > 0: 16 | current_row.append(span_info[column_index][0]) 17 | span_info[column_index] = ( 18 | span_info[column_index][0], 19 | span_info[column_index][1] - 1, 20 | ) 21 | if span_info[column_index][1] == 0: 22 | del span_info[column_index] 23 | column_index += 1 24 | 25 | for table_cell in table_row.xpath("td|th"): 26 | while span_info.get(column_index, (None, 0))[1] > 0: 27 | current_row.append(span_info[column_index][0]) 28 | span_info[column_index] = ( 29 | span_info[column_index][0], 30 | span_info[column_index][1] - 1, 31 | ) 32 | if span_info[column_index][1] == 0: 33 | del span_info[column_index] 34 | column_index += 1 35 | 36 | row_span = int(table_cell.get("rowspan", "1")) 37 | col_span = int(table_cell.get("colspan", "1")) 38 | cell_text = "".join(table_cell.itertext()).strip() 39 | 40 | if row_span > 1: 41 | for i in range(col_span): 42 | span_info[column_index + i] = (cell_text, row_span - 1) 43 | 44 | for _ in range(col_span): 45 | current_row.append(cell_text) 46 | column_index += col_span 47 | 48 | while span_info.get(column_index, (None, 0))[1] > 0: 49 | current_row.append(span_info[column_index][0]) 50 | span_info[column_index] = ( 51 | span_info[column_index][0], 52 | span_info[column_index][1] - 1, 53 | ) 54 | if span_info[column_index][1] == 0: 55 | del span_info[column_index] 56 | column_index += 1 57 | 58 | table_rows.append(current_row) 59 | 60 | max_columns = max(map(len, table_rows)) if table_rows else 0 61 | for row in table_rows: 62 | row.extend([""] * (max_columns - len(row))) 63 | 64 | return np.array(table_rows) 65 | -------------------------------------------------------------------------------- /grade_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | 5 | from convert import html_to_numpy 6 | from grading import table_similarity 7 | import polars as pl 8 | 9 | from providers.config import settings 10 | 11 | 12 | def main(model: str, folder: str, save_to_csv: bool): 13 | groundtruth = settings.input_dir / "groundtruth" 14 | scores = [] 15 | 16 | html_files = glob.glob(os.path.join(folder, "*.html")) 17 | for pred_html_path in html_files: 18 | # The filename might be something like: 10035_png.rf.07e8e5bf2e9ad4e77a84fd38d1f53f38.html 19 | base_name = os.path.basename(pred_html_path) 20 | 21 | # Build the path to the corresponding ground-truth file 22 | gt_html_path = os.path.join(groundtruth, base_name) 23 | if not os.path.exists(gt_html_path): 24 | continue 25 | 26 | with open(pred_html_path, "r") as f: 27 | pred_html = f.read() 28 | 29 | with open(gt_html_path, "r") as f: 30 | gt_html = f.read() 31 | 32 | # Convert HTML -> NumPy arrays 33 | try: 34 | pred_array = html_to_numpy(pred_html) 35 | gt_array = html_to_numpy(gt_html) 36 | 37 | # Compute similarity (0.0 to 1.0) 38 | score = table_similarity(gt_array, pred_array) 39 | except Exception as e: 40 | print(f"Error converting {base_name}: {e}") 41 | continue 42 | 43 | scores.append((base_name, score)) 44 | print(f"{base_name}: {score:.4f}") 45 | 46 | score_dicts = [{"filename": fname, "score": scr} for fname, scr in scores] 47 | df = pl.DataFrame(score_dicts) 48 | print( 49 | f"Average score for {model}: {df['score'].mean():.2f} with std {df['score'].std():.2f}" 50 | ) 51 | if save_to_csv: 52 | df.write_csv(f"./scores/{model}_scores.csv") 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--model", type=str, required=True) 58 | parser.add_argument("--save-to-csv", type=bool, default=True) 59 | args = parser.parse_args() 60 | 61 | model_dir = settings.output_dir / f"{args.model}-raw" 62 | assert model_dir.exists(), f"Model directory {model_dir} does not exist" 63 | main(args.model, model_dir, args.save_to_csv) 64 | -------------------------------------------------------------------------------- /grading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | from Levenshtein import distance as levenshtein_distance 4 | 5 | 6 | BATCH_SIZE = 150 7 | 8 | # Scoring parameters (you can adjust these as needed) 9 | S_ROW_MATCH = 5 # Match score for row alignment 10 | G_ROW = -3 # Gap penalty for row alignment (insertion/deletion of a row) 11 | S_CELL_MATCH = 1 # Match score for cell matching 12 | P_CELL_MISMATCH = -1 # Penalty for cell mismatch 13 | G_COL = -1 # Gap penalty for column alignment 14 | 15 | 16 | def cell_match_score(cell1: str | None, cell2: str | None) -> float: 17 | """Compute the match score between two cells considering partial matches.""" 18 | if cell1 is None or cell2 is None: 19 | return P_CELL_MISMATCH # Penalty for gaps or mismatches 20 | if cell1 == cell2: 21 | return S_CELL_MATCH # Cells are identical 22 | 23 | # Compute the Levenshtein distance using the optimized library 24 | distance = levenshtein_distance(cell1, cell2) 25 | max_len = max(len(cell1), len(cell2)) 26 | if max_len == 0: 27 | normalized_distance = 0.0 # Both cells are empty strings 28 | else: 29 | normalized_distance = distance / max_len 30 | similarity = 1.0 - normalized_distance # Similarity between 0 and 1 31 | match_score = P_CELL_MISMATCH + similarity * (S_CELL_MATCH - P_CELL_MISMATCH) 32 | return match_score 33 | 34 | 35 | def needleman_wunsch( 36 | seq1: list[str], seq2: list[str], gap_penalty: int 37 | ) -> tuple[list[str | None], list[str | None], float]: 38 | """ 39 | Perform Needleman-Wunsch alignment between two sequences with free end gaps. 40 | 41 | Parameters: 42 | seq1, seq2: sequences to align (lists of strings) 43 | gap_penalty: penalty for gaps (insertions/deletions) 44 | 45 | Returns: 46 | alignment_a, alignment_b: aligned sequences with gaps represented by None 47 | score: total alignment score 48 | """ 49 | m = len(seq1) 50 | n = len(seq2) 51 | 52 | # Initialize the scoring matrix 53 | score_matrix = np.zeros((m + 1, n + 1), dtype=np.float32) 54 | traceback = np.full((m + 1, n + 1), None) 55 | 56 | # Initialize the first row and column (no gap penalties for leading gaps) 57 | for i in range(1, m + 1): 58 | traceback[i, 0] = "up" 59 | for j in range(1, n + 1): 60 | traceback[0, j] = "left" 61 | 62 | # Fill the rest of the matrix 63 | for i in range(1, m + 1): 64 | seq1_i = seq1[i - 1] 65 | for j in range(1, n + 1): 66 | seq2_j = seq2[j - 1] 67 | match = score_matrix[i - 1, j - 1] + cell_match_score(seq1_i, seq2_j) 68 | delete = score_matrix[i - 1, j] + gap_penalty 69 | insert = score_matrix[i, j - 1] + gap_penalty 70 | max_score = max(match, delete, insert) 71 | score_matrix[i, j] = max_score 72 | if max_score == match: 73 | traceback[i, j] = "diag" 74 | elif max_score == delete: 75 | traceback[i, j] = "up" 76 | else: 77 | traceback[i, j] = "left" 78 | 79 | # Traceback from the position with the highest score in the last row or column 80 | i, j = m, n 81 | max_score = score_matrix[i, j] 82 | max_i, max_j = i, j 83 | # Find the maximum score in the last row and column for free end gaps 84 | last_row = score_matrix[:, n] 85 | last_col = score_matrix[m, :] 86 | if last_row.max() > max_score: 87 | max_i = last_row.argmax() 88 | max_j = n 89 | max_score = last_row[max_i] 90 | if last_col.max() > max_score: 91 | max_i = m 92 | max_j = last_col.argmax() 93 | max_score = last_col[max_j] 94 | 95 | # Traceback to get the aligned sequences 96 | alignment_a: list[str | None] = [] 97 | alignment_b: list[str | None] = [] 98 | i, j = max_i, max_j 99 | while i > 0 or j > 0: 100 | tb_direction = traceback[i, j] 101 | if i > 0 and j > 0 and tb_direction == "diag": 102 | alignment_a.insert(0, seq1[i - 1]) 103 | alignment_b.insert(0, seq2[j - 1]) 104 | i -= 1 105 | j -= 1 106 | elif i > 0 and (j == 0 or tb_direction == "up"): 107 | alignment_a.insert(0, seq1[i - 1]) 108 | alignment_b.insert(0, None) # Gap in seq2 109 | i -= 1 110 | elif j > 0 and (i == 0 or tb_direction == "left"): 111 | alignment_a.insert(0, None) # Gap in seq1 112 | alignment_b.insert(0, seq2[j - 1]) 113 | j -= 1 114 | else: 115 | break # Should not reach here 116 | 117 | return alignment_a, alignment_b, max_score 118 | 119 | 120 | def table_similarity( 121 | ground_truth: npt.NDArray[np.str_], prediction: npt.NDArray[np.str_] 122 | ) -> float: 123 | """ 124 | Compute the similarity between two tables represented as ndarrays of strings, 125 | allowing for a subset of rows at the top or bottom without penalization (to avoid penalizing subtable cropping). 126 | 127 | Parameters: 128 | ground_truth, prediction: ndarrays of strings representing the tables 129 | 130 | Returns: 131 | similarity: similarity score between 0 and 1 132 | """ 133 | 134 | # Remove newlines and normalize whitespace in cells 135 | def normalize_cell(cell: str) -> str: 136 | return "".join(cell.replace("\n", " ").replace("-", "").split()).replace( 137 | " ", "" 138 | ) 139 | 140 | # Apply normalization to both ground truth and prediction arrays 141 | vectorized_normalize = np.vectorize(normalize_cell) 142 | ground_truth = vectorized_normalize(ground_truth) 143 | prediction = vectorized_normalize(prediction) 144 | 145 | # Convert to lists of lists for easier manipulation 146 | gt_rows = [list(row) for row in ground_truth] 147 | pred_rows = [list(row) for row in prediction] 148 | 149 | # Precompute the column alignment scores between all pairs of rows 150 | m = len(gt_rows) 151 | n = len(pred_rows) 152 | row_match_scores = np.zeros((m, n), dtype=np.float32) 153 | 154 | for i in range(m): 155 | gt_row = gt_rows[i] 156 | for j in range(n): 157 | pred_row = pred_rows[j] 158 | # Align columns of the two rows 159 | _, _, col_score = needleman_wunsch(gt_row, pred_row, G_COL) 160 | # Adjusted row match score 161 | row_match_scores[i, j] = col_score + S_ROW_MATCH 162 | 163 | # Initialize the scoring matrix for row alignment with free end gaps 164 | score_matrix = np.zeros((m + 1, n + 1), dtype=np.float32) 165 | traceback = np.full((m + 1, n + 1), None) 166 | 167 | # No gap penalties for leading gaps 168 | for i in range(1, m + 1): 169 | traceback[i, 0] = "up" 170 | for j in range(1, n + 1): 171 | traceback[0, j] = "left" 172 | 173 | # Fill the rest of the scoring matrix 174 | for i in range(1, m + 1): 175 | for j in range(1, n + 1): 176 | match = score_matrix[i - 1, j - 1] + row_match_scores[i - 1, j - 1] 177 | delete = score_matrix[i - 1, j] + G_ROW 178 | insert = score_matrix[i, j - 1] + G_ROW 179 | max_score = max(match, delete, insert) 180 | score_matrix[i, j] = max_score 181 | if max_score == match: 182 | traceback[i, j] = "diag" 183 | elif max_score == delete: 184 | traceback[i, j] = "up" 185 | else: 186 | traceback[i, j] = "left" 187 | 188 | # Traceback from the position with the highest score in the last row or column 189 | i, j = m, n 190 | max_score = score_matrix[i, j] 191 | max_i, max_j = i, j 192 | # Find the maximum score in the last row and column for free end gaps 193 | last_row = score_matrix[:, n] 194 | last_col = score_matrix[m, :] 195 | if last_row.max() > max_score: 196 | max_i = last_row.argmax() 197 | max_j = n 198 | max_score = last_row[max_i] 199 | if last_col.max() > max_score: 200 | max_i = m 201 | max_j = last_col.argmax() 202 | max_score = last_col[max_j] 203 | 204 | # Traceback to get the aligned rows 205 | alignment_gt_rows: list[list[str | None]] = [] 206 | alignment_pred_rows: list[list[str | None]] = [] 207 | i, j = max_i, max_j 208 | while i > 0 or j > 0: 209 | tb_direction = traceback[i, j] 210 | if i > 0 and j > 0 and tb_direction == "diag": 211 | alignment_gt_rows.insert(0, gt_rows[i - 1]) 212 | alignment_pred_rows.insert(0, pred_rows[j - 1]) 213 | i -= 1 214 | j -= 1 215 | elif i > 0 and (j == 0 or tb_direction == "up"): 216 | alignment_gt_rows.insert(0, gt_rows[i - 1]) 217 | alignment_pred_rows.insert( 218 | 0, [None] * len(gt_rows[i - 1]) 219 | ) # Gap in prediction 220 | i -= 1 221 | elif j > 0 and (i == 0 or tb_direction == "left"): 222 | alignment_gt_rows.insert( 223 | 0, [None] * len(pred_rows[j - 1]) 224 | ) # Gap in ground truth 225 | alignment_pred_rows.insert(0, pred_rows[j - 1]) 226 | j -= 1 227 | else: 228 | break # Should not reach here 229 | 230 | # Compute the actual total score 231 | actual_total_score = max_score 232 | 233 | # Compute the total possible score 234 | num_aligned_rows = len(alignment_gt_rows) 235 | if num_aligned_rows == 0: 236 | return 0.0 # Avoid division by zero 237 | max_row_score = num_aligned_rows * (S_ROW_MATCH + len(gt_rows[0]) * S_CELL_MATCH) 238 | total_possible_score = max_row_score 239 | 240 | # Normalize the similarity score 241 | similarity = actual_total_score / total_possible_score 242 | return max(0.0, min(similarity, 1.0)) 243 | -------------------------------------------------------------------------------- /parsing.py: -------------------------------------------------------------------------------- 1 | """ 2 | For each format, this code extracts the largest HTML table from the response. 3 | """ 4 | 5 | import json 6 | from typing import Any 7 | import os 8 | 9 | import argparse 10 | import glob 11 | 12 | 13 | def parse_textract_response(path: str) -> tuple[str | None, Any]: 14 | if not os.path.exists(path): 15 | return None, None 16 | 17 | with open(path, "r") as f: 18 | data = json.load(f) 19 | 20 | return data["html_table"], data 21 | 22 | 23 | def parse_gcloud_response(path: str) -> tuple[str | None, Any]: 24 | if not os.path.exists(path): 25 | return None, None 26 | 27 | try: 28 | with open(path, "r") as f: 29 | data = json.load(f) 30 | except Exception: 31 | return None, None 32 | 33 | return data["html_table"], data 34 | 35 | 36 | def parse_reducto_response(path: str) -> tuple[str | None, Any]: 37 | if not os.path.exists(path): 38 | return None, None 39 | 40 | with open(path, "r") as f: 41 | data = json.load(f) 42 | 43 | if "error" in data: 44 | return None, data 45 | 46 | longest_html = None 47 | max_length = 0 48 | 49 | for chunk in data["result"]["chunks"]: 50 | blocks = chunk["blocks"] 51 | for block in blocks: 52 | if block["type"] == "Table": 53 | if len(block["content"]) > max_length: 54 | max_length = len(block["content"]) 55 | longest_html = block["content"] 56 | 57 | return longest_html, data 58 | 59 | 60 | def parse_chunkr_response(path: str) -> tuple[str | None, Any]: 61 | if not os.path.exists(path): 62 | return None, None 63 | 64 | with open(path, "r") as f: 65 | data = json.load(f) 66 | 67 | if data.get("status") != "Succeeded": 68 | return None, data 69 | 70 | largest_html = None 71 | max_length = 0 72 | 73 | try: 74 | for output in ( 75 | data.get("output", []) 76 | if "chunks" not in data.get("output") 77 | else data["output"]["chunks"] 78 | ): 79 | for segment in output.get("segments", []): 80 | if segment.get("segment_type") == "Table" and segment.get("html"): 81 | if len(segment["html"]) > max_length: 82 | max_length = len(segment["html"]) 83 | largest_html = segment["html"] 84 | except Exception: 85 | import traceback 86 | 87 | traceback.print_exc() 88 | print(data) 89 | 90 | return largest_html, data 91 | 92 | 93 | def parse_unstructured_response(path: str) -> tuple[str | None, Any]: 94 | if not os.path.exists(path): 95 | return None, None 96 | 97 | with open(path, "r") as f: 98 | data = json.load(f) 99 | 100 | largest_html = None 101 | max_length = 0 102 | 103 | for element in data.get("elements", []): 104 | if element.get("type") == "Table" and element.get("metadata", {}).get( 105 | "text_as_html" 106 | ): 107 | html = element["metadata"]["text_as_html"] 108 | if len(html) > max_length: 109 | max_length = len(html) 110 | largest_html = html 111 | 112 | return largest_html, data 113 | 114 | 115 | def parse_gpt4o_response(path: str) -> tuple[str | None, Any]: 116 | if not os.path.exists(path): 117 | return None, None 118 | 119 | with open(path, "r") as f: 120 | data = json.load(f) 121 | 122 | html = data["html_table"] 123 | # Extract just the table portion between and
124 | start = html.find("") 125 | end = html.find("
") + 8 126 | if start != -1 and end != -1: 127 | return html[start:end], data 128 | return None, data 129 | 130 | 131 | def parse_gemini_response(path: str) -> tuple[str | None, Any]: 132 | if not os.path.exists(path): 133 | return None, None 134 | 135 | with open(path, "r") as f: 136 | data = json.load(f) 137 | 138 | html = data["html_table"] 139 | # Extract just the table portion between and
140 | start = html.find("") 141 | end = html.find("
") + 8 142 | if start != -1 and end != -1: 143 | return html[start:end], data 144 | return None, data 145 | 146 | 147 | def parse_azure_response(path: str) -> tuple[str | None, Any]: 148 | data = None 149 | try: 150 | with open(path, "r") as f: 151 | data = json.load(f) 152 | 153 | def azure_to_html(table: Any) -> str: 154 | html = "" 155 | for row_index in range(table["rowCount"]): 156 | html += "" 157 | for col_index in range(table["columnCount"]): 158 | cell = next( 159 | ( 160 | c 161 | for c in table["cells"] 162 | if c["rowIndex"] == row_index 163 | and c["columnIndex"] == col_index 164 | ), 165 | None, 166 | ) 167 | if cell: 168 | content = ( 169 | cell["content"] 170 | .replace(":selected:", "") 171 | .replace(":unselected:", "") 172 | ) 173 | tag = "th" if cell.get("kind") == "columnHeader" else "td" 174 | rowspan = ( 175 | f" rowspan='{cell['rowSpan']}'" if "rowSpan" in cell else "" 176 | ) 177 | colspan = ( 178 | f" colspan='{cell['columnSpan']}'" 179 | if "columnSpan" in cell 180 | else "" 181 | ) 182 | html += f"<{tag}{rowspan}{colspan}>{content}" 183 | else: 184 | pass 185 | html += "" 186 | html += "
" 187 | return html 188 | 189 | # Find table with largest area (row count * column count) 190 | largest_table = max( 191 | data["tables"], key=lambda t: t["rowCount"] * t["columnCount"] 192 | ) 193 | return azure_to_html(largest_table), data 194 | except Exception: 195 | return None, data 196 | 197 | 198 | PARSERS = { 199 | "textract": parse_textract_response, 200 | "gcloud": parse_gcloud_response, 201 | "reducto": parse_reducto_response, 202 | "chunkr": parse_chunkr_response, 203 | "unstructured": parse_unstructured_response, 204 | "gpt4o": parse_gpt4o_response, 205 | "azure": parse_azure_response, 206 | "gemini": parse_gemini_response, 207 | } 208 | 209 | 210 | def main(): 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument( 213 | "--provider", 214 | required=True, 215 | choices=PARSERS.keys(), 216 | help="Which parser to use (e.g. 'gpt4o', 'azure', etc.).", 217 | ) 218 | parser.add_argument( 219 | "--input-folder", 220 | required=True, 221 | help="Folder containing .json files from the provider.", 222 | ) 223 | parser.add_argument( 224 | "--output-folder", 225 | required=True, 226 | help="Folder to write the extracted .html files.", 227 | ) 228 | args = parser.parse_args() 229 | 230 | # Get the parser function based on the provider 231 | parse_func = PARSERS[args.provider] 232 | 233 | # Ensure output folder exists 234 | os.makedirs(args.output_folder, exist_ok=True) 235 | 236 | # Find all JSON files under input folder (recursively) 237 | json_paths = glob.glob( 238 | os.path.join(args.input_folder, "**", "*.json"), recursive=True 239 | ) 240 | 241 | for json_file in json_paths: 242 | # Parse the JSON to get HTML 243 | html, raw_data = parse_func(json_file) 244 | 245 | if not html: 246 | # No table found or parse error 247 | print(f"Skipping (no HTML found): {json_file}") 248 | continue 249 | 250 | # Build output path: replace .json with .html and replicate subfolders if desired 251 | relative_path = os.path.relpath(json_file, start=args.input_folder) 252 | out_name = os.path.splitext(relative_path)[0] + ".html" 253 | out_path = os.path.join(args.output_folder, out_name) 254 | 255 | # Make sure subdirectories exist 256 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 257 | 258 | # Write the HTML to file 259 | with open(out_path, "w") as f: 260 | f.write(html) 261 | 262 | print(f"Saved HTML to: {out_path}") 263 | 264 | 265 | if __name__ == "__main__": 266 | main() 267 | -------------------------------------------------------------------------------- /providers/README.md: -------------------------------------------------------------------------------- 1 | # Providers 2 | 3 | This directory contains the code for the different providers that are used to extract tables from PDFs. -------------------------------------------------------------------------------- /providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Filimoa/rd-tablebench/4f9b12479cbef94f052b09191f3bffabdf719823/providers/__init__.py -------------------------------------------------------------------------------- /providers/azure_docintelligence.py: -------------------------------------------------------------------------------- 1 | # import libraries 2 | import os 3 | from azure.core.credentials import AzureKeyCredential 4 | from azure.ai.documentintelligence import DocumentIntelligenceClient 5 | from azure.ai.documentintelligence.models import AnalyzeDocumentRequest 6 | import glob 7 | import json 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | from tqdm import tqdm 10 | from azure.core.exceptions import HttpResponseError 11 | import backoff 12 | 13 | endpoint = os.environ["AZURE_ENDPOINT"] 14 | key = os.environ["AZURE_KEY"] 15 | 16 | base_path = os.path.expanduser("~/data/human_table_benchmark") 17 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 18 | 19 | document_intelligence_client = DocumentIntelligenceClient( 20 | endpoint=endpoint, credential=AzureKeyCredential(key) 21 | ) 22 | 23 | 24 | def is_rate_limit_error(exception): 25 | return isinstance(exception, HttpResponseError) and exception.status_code == 429 26 | 27 | 28 | @backoff.on_exception( 29 | backoff.expo, HttpResponseError, giveup=lambda e: not is_rate_limit_error(e) 30 | ) 31 | def analyze_document(file_content): 32 | return document_intelligence_client.begin_analyze_document( 33 | "prebuilt-layout", AnalyzeDocumentRequest(bytes_source=file_content) 34 | ).result() 35 | 36 | 37 | def process_pdf(pdf_path: str): 38 | output_path = pdf_path.replace("pdfs", "azure").replace(".pdf", ".json") 39 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 40 | 41 | try: 42 | with open(pdf_path, "rb") as file: 43 | result = analyze_document(file.read()) 44 | 45 | with open(output_path, "w") as f: 46 | json.dump(result.as_dict(), f, indent=2) 47 | 48 | return pdf_path, None 49 | except HttpResponseError as e: 50 | if e.status_code == 429: # Rate limit exceeded 51 | return pdf_path, "Rate limit" 52 | else: 53 | return pdf_path, str(e) 54 | except Exception as e: 55 | return pdf_path, str(e) 56 | 57 | 58 | def process_all_pdfs(pdfs: list[str]): 59 | max_workers = 200 60 | 61 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 62 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs} 63 | 64 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 65 | 66 | for future in as_completed(futures): 67 | pdf_path, error = future.result() 68 | if error: 69 | print(f"Error processing {pdf_path}: {error}") 70 | progress_bar.update(1) 71 | 72 | progress_bar.close() 73 | 74 | print(f"Processed {len(pdfs)} PDFs") 75 | 76 | 77 | if __name__ == "__main__": 78 | process_all_pdfs(pdfs) 79 | -------------------------------------------------------------------------------- /providers/chunkr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import asyncio 5 | import aiohttp 6 | import time 7 | from typing import Tuple 8 | from tqdm import tqdm 9 | 10 | base_path = os.path.expanduser("~/data/human_table_benchmark") 11 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 12 | 13 | API_KEY = os.environ.get("CHUNKR_API_KEY") 14 | if not API_KEY: 15 | raise ValueError("CHUNKR_API_KEY environment variable is not set") 16 | 17 | CHUNKR_URL = "https://api.chunkr.ai/api/v1/task" 18 | HEADERS = {"Authorization": API_KEY} 19 | 20 | 21 | async def process_pdf( 22 | pdf_path: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore 23 | ) -> Tuple[str, dict]: 24 | async with semaphore: 25 | try: 26 | with open(pdf_path, "rb") as file: 27 | data = aiohttp.FormData() 28 | data.add_field( 29 | "file", 30 | file, 31 | filename=os.path.basename(pdf_path), 32 | content_type="application/pdf", 33 | ) 34 | data.add_field("model", "HighQuality") 35 | data.add_field("target_chunk_length", "512") 36 | data.add_field("ocr_strategy", "Auto") 37 | 38 | async with session.post( 39 | CHUNKR_URL, headers=HEADERS, data=data 40 | ) as response: 41 | if response.status == 200: 42 | task_info = await response.json() 43 | task_id = task_info["task_id"] 44 | result = await poll_task(pdf_path, task_id, session) 45 | return pdf_path, result 46 | else: 47 | return pdf_path, { 48 | "error": f"Error: {response.status}, {await response.text()}" 49 | } 50 | except Exception as e: 51 | return pdf_path, {"error": str(e)} 52 | 53 | 54 | async def poll_task( 55 | pdf_path: str, task_id: str, session: aiohttp.ClientSession 56 | ) -> dict: 57 | task_url = f"{CHUNKR_URL}/{task_id}" 58 | start_time = time.time() 59 | while time.time() - start_time < 60 * 10: 60 | try: 61 | async with session.get(task_url, headers=HEADERS) as response: 62 | if response.status == 200: 63 | task_info = await response.json() 64 | if task_info["status"] == "Succeeded": 65 | return task_info 66 | elif task_info["status"] in ["Failed", "Canceled"]: 67 | return {"error": f"Task failed or canceled: {task_info}"} 68 | else: 69 | return { 70 | "error": f"Error polling task: {response.status}, {await response.text()}" 71 | } 72 | except Exception as e: 73 | return {"error": str(e)} 74 | await asyncio.sleep(5) # Wait 5 seconds before polling again 75 | return {"error": "Timeout: Task did not complete within 10 minutes"} 76 | 77 | 78 | async def process_all_pdfs(pdfs: list[str]): 79 | semaphore = asyncio.Semaphore(100) # Limit to 100 80 | async with aiohttp.ClientSession() as session: 81 | tasks = [process_pdf(pdf, session, semaphore) for pdf in pdfs] 82 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 83 | results = [] 84 | for task in asyncio.as_completed(tasks): 85 | pdf_path, result = await task 86 | results.append((pdf_path, result)) 87 | progress_bar.update(1) 88 | progress_bar.close() 89 | 90 | for pdf_path, result in results: 91 | output_path = pdf_path.replace("pdfs", "chunkr_nov1").replace(".pdf", ".json") 92 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 93 | with open(output_path, "w") as f: 94 | json.dump(result, f, indent=2) 95 | 96 | print(f"Processed {len(results)} PDFs") 97 | 98 | 99 | asyncio.run(process_all_pdfs(pdfs)) 100 | -------------------------------------------------------------------------------- /providers/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from dotenv import load_dotenv 3 | import os 4 | from pathlib import Path 5 | 6 | load_dotenv() 7 | 8 | 9 | class Settings(BaseModel): 10 | input_dir: Path 11 | output_dir: Path 12 | openai_api_key: str | None = None 13 | gemini_api_key: str | None = None 14 | anthropic_api_key: str | None = None 15 | 16 | 17 | settings = Settings( 18 | input_dir=os.getenv("INPUT_DIR"), 19 | output_dir=os.getenv("OUTPUT_DIR"), 20 | openai_api_key=os.getenv("OPENAI_API_KEY"), 21 | gemini_api_key=os.getenv("GEMINI_API_KEY"), 22 | anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), 23 | ) 24 | -------------------------------------------------------------------------------- /providers/gcloud.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | from google.api_core.client_options import ClientOptions 3 | from google.cloud import documentai 4 | from google.cloud.documentai_toolbox import document 5 | import os 6 | import glob 7 | import json 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | from tqdm import tqdm 10 | import time 11 | import backoff 12 | 13 | project_id = os.environ["GCP_PROJECT_ID"] 14 | location = "us" 15 | processor_id = os.environ["GCP_PROCESSOR_ID"] 16 | mime_type = "application/pdf" 17 | 18 | base_path = os.path.expanduser("~/data/human_table_benchmark") 19 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 20 | 21 | opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com") 22 | client = documentai.DocumentProcessorServiceClient(client_options=opts) 23 | name = client.processor_path(project_id, location, processor_id) 24 | 25 | 26 | @backoff.on_exception( 27 | backoff.expo, Exception, giveup=lambda e: "quota" not in str(e).lower(), max_tries=5 28 | ) 29 | def process_document(file_path: str) -> Tuple[Optional[str], float]: 30 | start_time = time.time() 31 | 32 | # Read the file into memory 33 | with open(file_path, "rb") as pdf_file: 34 | pdf_content = pdf_file.read() 35 | 36 | # Load binary data 37 | raw_document = documentai.RawDocument(content=pdf_content, mime_type=mime_type) 38 | 39 | # Configure the process request 40 | request = documentai.ProcessRequest( 41 | name=name, 42 | raw_document=raw_document, 43 | ) 44 | 45 | result = client.process_document(request=request) 46 | 47 | toolbox_document = document.Document.from_documentai_document(result.document) 48 | 49 | for page in toolbox_document.pages: 50 | for table in page.tables: 51 | html_table = table.to_dataframe().to_html() 52 | processing_time = time.time() - start_time 53 | return html_table, processing_time 54 | 55 | processing_time = time.time() - start_time 56 | return None, processing_time 57 | 58 | 59 | def process_pdf(pdf_path: str): 60 | output_path = pdf_path.replace("pdfs", "gcloud").replace(".pdf", ".json") 61 | 62 | if os.path.exists(output_path): 63 | return pdf_path, None 64 | 65 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 66 | 67 | try: 68 | html_table, processing_time = process_document(pdf_path) 69 | 70 | result = {"html_table": html_table, "processing_time": processing_time} 71 | 72 | with open(output_path, "w") as f: 73 | json.dump(result, f, indent=2) 74 | 75 | return pdf_path, None 76 | except Exception as e: 77 | return pdf_path, str(e) 78 | 79 | 80 | def process_all_pdfs(pdfs: list[str]): 81 | max_workers = 5 82 | 83 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 84 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs} 85 | 86 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 87 | 88 | for future in as_completed(futures): 89 | pdf_path, error = future.result() 90 | if error: 91 | print(f"Error processing {pdf_path}: {error}") 92 | progress_bar.update(1) 93 | 94 | progress_bar.close() 95 | 96 | print(f"Processed {len(pdfs)} PDFs") 97 | 98 | 99 | if __name__ == "__main__": 100 | process_all_pdfs(pdfs) 101 | -------------------------------------------------------------------------------- /providers/llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import glob 4 | import os 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from io import BytesIO 7 | from typing import Any, Literal 8 | 9 | import backoff 10 | import openai 11 | from openai import OpenAI 12 | from pdf2image import convert_from_path 13 | from tqdm import tqdm 14 | 15 | from providers.config import settings 16 | 17 | 18 | base_path = os.path.expanduser(settings.input_dir) 19 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 20 | 21 | 22 | def convert_pdf_to_base64_image(pdf_path): 23 | images = convert_from_path(pdf_path, first_page=1, last_page=1) 24 | img_buffer = BytesIO() 25 | images[0].save(img_buffer, format="PNG") 26 | return base64.b64encode(img_buffer.getvalue()).decode("utf-8") 27 | 28 | 29 | @backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=5) 30 | def analyze_document_openai_sdk(base64_image, model: str): 31 | response = client.chat.completions.create( 32 | model=model, 33 | messages=[ 34 | { 35 | "role": "user", 36 | "content": [ 37 | { 38 | "type": "text", 39 | "text": "Convert the image to an HTML table. The output should begin with and end with
. Specify rowspan and colspan attributes when they are greater than 1. Do not specify any other attributes. Only use table related HTML tags, no additional formatting is required.", 40 | }, 41 | { 42 | "type": "image_url", 43 | "image_url": {"url": f"data:image/png;base64,{base64_image}"}, 44 | }, 45 | ], 46 | } 47 | ], 48 | max_tokens=4096, 49 | ) 50 | return response.choices[0].message.content 51 | 52 | 53 | @backoff.on_exception(backoff.expo, (openai.RateLimitError), max_tries=5) 54 | def analyze_document_anthropic(base64_image, model: str): 55 | response = client.messages.create( 56 | model=model, 57 | max_tokens=4096, 58 | messages=[ 59 | { 60 | "role": "user", 61 | "content": [ 62 | { 63 | "type": "text", 64 | "text": "Convert the image to an HTML table. The output should begin with and end with
. Specify rowspan and colspan attributes when they are greater than 1. Do not specify any other attributes. Only use table related HTML tags, no additional formatting is required.", 65 | }, 66 | { 67 | "type": "image", 68 | "source": { 69 | "type": "base64", 70 | "media_type": "image/png", 71 | "data": base64_image, 72 | }, 73 | }, 74 | ], 75 | } 76 | ], 77 | ) 78 | return response.content[0].text 79 | 80 | 81 | def parse_gemini_response(content: str) -> tuple[str | None, Any]: 82 | # Extract just the table portion between and
83 | start = content.find("") 84 | end = content.find("
") + 8 85 | if start != -1 and end != -1: 86 | return content[start:end], None 87 | return None, None 88 | 89 | 90 | def process_pdf(pdf_path: str, model: str): 91 | output_path = pdf_path.replace("pdfs", f"outputs/{model}-raw").replace( 92 | ".pdf", ".html" 93 | ) 94 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 95 | 96 | try: 97 | base64_image = convert_pdf_to_base64_image(pdf_path) 98 | if "gemini" in model: 99 | html_table = analyze_document_openai_sdk(base64_image, model) 100 | elif "gpt" in model: 101 | html_table = analyze_document_openai_sdk(base64_image, model) 102 | elif "claude" in model: 103 | html_table = analyze_document_anthropic(base64_image, model) 104 | else: 105 | raise ValueError(f"Unknown model: {model}") 106 | 107 | html, _ = parse_gemini_response(html_table) 108 | 109 | if not html: 110 | print(f"Skipping (no HTML found): {pdf_path}") 111 | return pdf_path, None 112 | 113 | with open(output_path, "w") as f: 114 | f.write(html_table) 115 | 116 | return pdf_path, None 117 | except Exception as e: 118 | return pdf_path, str(e) 119 | 120 | 121 | def process_all_pdfs( 122 | pdfs: list[str], 123 | model: Literal[ 124 | "gemini-2.0-flash-exp", 125 | "gemini-1.5-pro", 126 | "gemini-1.5-flash", 127 | "gpt-4o-mini", 128 | "gpt-4o", 129 | "claude-3-5-sonnet-latest", 130 | ], 131 | max_workers: int, 132 | ): 133 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 134 | futures = {executor.submit(process_pdf, pdf, model): pdf for pdf in pdfs} 135 | 136 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 137 | 138 | for future in as_completed(futures): 139 | pdf_path, error = future.result() 140 | if error: 141 | print(f"Error processing {pdf_path}: {error}") 142 | progress_bar.update(1) 143 | 144 | progress_bar.close() 145 | 146 | print(f"Processed {len(pdfs)} PDFs") 147 | 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--model", type=str, required=True) 152 | parser.add_argument("--num-workers", type=int, default=1) 153 | args = parser.parse_args() 154 | 155 | if "gemini" in args.model: 156 | assert settings.gemini_api_key 157 | 158 | client = OpenAI( 159 | api_key=settings.gemini_api_key, 160 | base_url="https://generativelanguage.googleapis.com/v1beta/openai/", 161 | ) 162 | elif "gpt" in args.model: 163 | assert settings.openai_api_key 164 | 165 | client = OpenAI(api_key=settings.openai_api_key) 166 | elif "claude" in args.model: 167 | from anthropic import Anthropic 168 | 169 | assert settings.anthropic_api_key 170 | client = Anthropic(api_key=settings.anthropic_api_key) 171 | else: 172 | raise ValueError(f"Unknown model: {args.model}") 173 | 174 | process_all_pdfs(pdfs, args.model, args.num_workers) 175 | -------------------------------------------------------------------------------- /providers/reducto.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import asyncio 5 | import aiohttp 6 | import time 7 | from typing import Tuple 8 | from tqdm import tqdm 9 | 10 | base_path = os.path.expanduser("~/data/human_table_benchmark") 11 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 12 | 13 | API_KEY = os.environ.get("REDUCTO_API_KEY") 14 | if not API_KEY: 15 | raise ValueError("REDUCTO_API_KEY environment variable is not set") 16 | 17 | REDUCTO_URL = "https://platform.reducto.ai" 18 | HEADERS = {"Authorization": f"Bearer {API_KEY}"} 19 | 20 | 21 | async def upload_pdf( 22 | pdf_path: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore 23 | ) -> Tuple[str, str]: 24 | async with semaphore: 25 | try: 26 | upload_url = f"{REDUCTO_URL}/upload" 27 | with open(pdf_path, "rb") as file: 28 | data = aiohttp.FormData() 29 | data.add_field( 30 | "file", 31 | file, 32 | filename=os.path.basename(pdf_path), 33 | content_type="application/pdf", 34 | ) 35 | 36 | async with session.post( 37 | upload_url, headers=HEADERS, data=data 38 | ) as response: 39 | if response.status == 200: 40 | result = await response.json() 41 | return pdf_path, result["file_id"] 42 | else: 43 | return ( 44 | pdf_path, 45 | f"Error: {response.status}, {await response.text()}", 46 | ) 47 | except Exception as e: 48 | return pdf_path, f"Error: {str(e)}" 49 | 50 | 51 | async def parse_async( 52 | file_id: str, 53 | pdf_path: str, 54 | session: aiohttp.ClientSession, 55 | semaphore: asyncio.Semaphore, 56 | ) -> Tuple[str, str]: 57 | async with semaphore: 58 | try: 59 | parse_url = f"{REDUCTO_URL}/parse_async" 60 | payload = { 61 | "document_url": f"{file_id}", 62 | "advanced_options": { 63 | "ocr_system": "combined", 64 | }, 65 | } 66 | 67 | async with session.post( 68 | parse_url, headers=HEADERS, json=payload 69 | ) as response: 70 | if response.status == 200: 71 | result = await response.json() 72 | return pdf_path, result["job_id"] 73 | else: 74 | return ( 75 | pdf_path, 76 | f"Error: {response.status}, {await response.text()}", 77 | ) 78 | except Exception as e: 79 | return pdf_path, f"Error: {str(e)}" 80 | 81 | 82 | async def poll_job( 83 | file_id: str, job_id: str, session: aiohttp.ClientSession 84 | ) -> Tuple[str, dict]: 85 | job_url = f"{REDUCTO_URL}/job/{job_id}" 86 | start_time = time.time() 87 | while time.time() - start_time < 60 * 10: # 10 minutes timeout 88 | try: 89 | async with session.get(job_url, headers=HEADERS) as response: 90 | if response.status == 200: 91 | job_info = await response.json() 92 | if job_info["status"] == "Completed": 93 | return file_id, job_info["result"] 94 | elif job_info["status"] == "Failed": 95 | return file_id, {"error": f"Job failed: {job_info}"} 96 | else: 97 | return file_id, { 98 | "error": f"Error polling job: {response.status}, {await response.text()}" 99 | } 100 | except Exception as e: 101 | return file_id, {"error": str(e)} 102 | await asyncio.sleep(5) # Wait 5 seconds before polling again 103 | return file_id, {"error": "Timeout: Job did not complete within 10 minutes"} 104 | 105 | 106 | async def process_pdf( 107 | pdf_path: str, 108 | session: aiohttp.ClientSession, 109 | upload_semaphore: asyncio.Semaphore, 110 | parse_semaphore: asyncio.Semaphore, 111 | ) -> Tuple[str, dict]: 112 | pdf_path, file_id = await upload_pdf(pdf_path, session, upload_semaphore) 113 | if file_id.startswith("Error"): 114 | return pdf_path, {"error": file_id} 115 | 116 | file_id, job_id = await parse_async(file_id, pdf_path, session, parse_semaphore) 117 | if job_id.startswith("Error"): 118 | return pdf_path, {"error": job_id} 119 | 120 | return await poll_job(pdf_path, job_id, session) 121 | 122 | 123 | async def process_all_pdfs(pdfs: list[str]): 124 | upload_semaphore = asyncio.Semaphore(10) 125 | parse_semaphore = asyncio.Semaphore(10) 126 | 127 | async with aiohttp.ClientSession() as session: 128 | tasks = [ 129 | process_pdf(pdf, session, upload_semaphore, parse_semaphore) for pdf in pdfs 130 | ] 131 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 132 | results = [] 133 | for task in asyncio.as_completed(tasks): 134 | pdf_path, result = await task 135 | results.append((pdf_path, result)) 136 | progress_bar.update(1) 137 | progress_bar.close() 138 | 139 | for pdf_path, result in results: 140 | output_path = pdf_path.replace("pdfs", "reducto_nov1").replace(".pdf", ".json") 141 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 142 | with open(output_path, "w") as f: 143 | json.dump(result, f, indent=2) 144 | 145 | print(f"Processed {len(results)} PDFs") 146 | 147 | 148 | async def main(): 149 | await process_all_pdfs(pdfs) 150 | 151 | 152 | if __name__ == "__main__": 153 | asyncio.run(main()) 154 | -------------------------------------------------------------------------------- /providers/textract.py: -------------------------------------------------------------------------------- 1 | from pdf2image import convert_from_path 2 | from textractor import Textractor 3 | from textractor.data.constants import TextractFeatures 4 | import re 5 | import os 6 | import glob 7 | import json 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | from tqdm import tqdm 10 | import time 11 | import backoff 12 | 13 | extractor = Textractor(region_name="us-west-2") 14 | 15 | base_path = os.path.expanduser("~/data/human_table_benchmark") 16 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 17 | 18 | 19 | @backoff.on_exception( 20 | backoff.expo, Exception, giveup=lambda e: "limit" not in str(e).lower(), max_tries=5 21 | ) 22 | def extract_tables(pdf_path: str): 23 | start_time = time.time() 24 | 25 | image = convert_from_path(pdf_path)[0] 26 | 27 | document = extractor.analyze_document( 28 | file_source=image, 29 | features=[TextractFeatures.TABLES], 30 | ) 31 | 32 | if len(document.tables) == 0: 33 | return None, time.time() - start_time 34 | 35 | html_table = document.tables[0].to_html() 36 | 37 | html_table = re.sub(r".*?", "", html_table, flags=re.DOTALL) 38 | 39 | end_time = time.time() 40 | processing_time = end_time - start_time 41 | 42 | return html_table, processing_time 43 | 44 | 45 | def process_pdf(pdf_path: str): 46 | output_path = pdf_path.replace("pdfs", "textract").replace(".pdf", ".json") 47 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 48 | 49 | try: 50 | html_table, processing_time = extract_tables(pdf_path) 51 | 52 | if html_table is None: 53 | result = {"html_table": None, "processing_time": processing_time} 54 | else: 55 | result = {"html_table": html_table, "processing_time": processing_time} 56 | 57 | with open(output_path, "w") as f: 58 | json.dump(result, f, indent=2) 59 | 60 | return pdf_path, None 61 | except Exception as e: 62 | return pdf_path, str(e) 63 | 64 | 65 | def process_all_pdfs(pdfs: list[str]): 66 | max_workers = 20 67 | 68 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 69 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs} 70 | 71 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 72 | 73 | for future in as_completed(futures): 74 | image_path, error = future.result() 75 | if error: 76 | print(f"Error processing {image_path}: {error}") 77 | progress_bar.update(1) 78 | 79 | progress_bar.close() 80 | 81 | print(f"Processed {len(pdfs)} PDFs") 82 | 83 | 84 | if __name__ == "__main__": 85 | process_all_pdfs(pdfs) 86 | -------------------------------------------------------------------------------- /providers/unstructured.py: -------------------------------------------------------------------------------- 1 | # Before calling the API, replace filename and ensure sdk is installed: "pip install unstructured-client" 2 | # See https://docs.unstructured.io/api-reference/api-services/sdk for more details 3 | 4 | import os 5 | import glob 6 | import json 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | from tqdm import tqdm 9 | import unstructured_client 10 | from unstructured_client.models import operations, shared 11 | import backoff 12 | import time 13 | 14 | # Replace with your actual API key 15 | API_KEY = os.environ["UNSTRUCTURED_API_KEY"] 16 | SERVER_URL = "https://api.unstructuredapp.io" 17 | 18 | client = unstructured_client.UnstructuredClient( 19 | api_key_auth=API_KEY, 20 | server_url=SERVER_URL, 21 | ) 22 | 23 | base_path = os.path.expanduser("~/data/human_table_benchmark") 24 | pdfs = glob.glob(os.path.join(base_path, "**", "*.pdf"), recursive=True) 25 | 26 | 27 | @backoff.on_exception(backoff.expo, Exception, max_tries=3) 28 | def process_single_file(file_path: str): 29 | with open(file_path, "rb") as f: 30 | data = f.read() 31 | 32 | start_time = time.time() 33 | req = operations.PartitionRequest( 34 | partition_parameters=shared.PartitionParameters( 35 | files=shared.Files( 36 | content=data, 37 | file_name=os.path.basename(file_path), 38 | ), 39 | strategy=shared.Strategy.HI_RES, 40 | ), 41 | ) 42 | 43 | res = client.general.partition(request=req) 44 | processing_time = time.time() - start_time 45 | return {"elements": res.elements, "processing_time": processing_time} 46 | 47 | 48 | def process_pdf(pdf_path: str): 49 | output_path = pdf_path.replace("pdfs", "unstructured").replace(".pdf", ".json") 50 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 51 | 52 | try: 53 | result = process_single_file(pdf_path) 54 | with open(output_path, "w") as f: 55 | json.dump(result, f, indent=2) 56 | except Exception as e: 57 | print(f"Error processing {pdf_path}: {str(e)}") 58 | 59 | 60 | def process_all_pdfs(pdfs: list[str]): 61 | max_workers = 10 # Adjust this based on API rate limits 62 | 63 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 64 | futures = {executor.submit(process_pdf, pdf): pdf for pdf in pdfs} 65 | 66 | progress_bar = tqdm(total=len(pdfs), desc="Processing PDFs") 67 | 68 | for _ in as_completed(futures): 69 | progress_bar.update(1) 70 | 71 | progress_bar.close() 72 | 73 | print(f"Processed {len(pdfs)} PDFs") 74 | 75 | 76 | def test_single_file(file_path: str): 77 | try: 78 | result = process_single_file(file_path) 79 | print(result) 80 | except Exception as e: 81 | print(f"Error processing {file_path}: {str(e)}") 82 | 83 | 84 | if __name__ == "__main__": 85 | process_all_pdfs(pdfs) 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.4 2 | aiohttp==3.11.11 3 | aiolimiter==1.2.1 4 | aiosignal==1.3.2 5 | annotated-types==0.7.0 6 | anthropic==0.43.0 7 | anyio==4.8.0 8 | attrs==24.3.0 9 | backoff==2.2.1 10 | cachetools==5.5.0 11 | certifi==2024.12.14 12 | charset-normalizer==3.4.1 13 | click==8.1.8 14 | distro==1.9.0 15 | docstring-parser==0.16 16 | frozenlist==1.5.0 17 | google-ai-generativelanguage==0.6.10 18 | google-api-core==2.24.0 19 | google-api-python-client==2.158.0 20 | google-auth==2.37.0 21 | google-auth-httplib2==0.2.0 22 | google-cloud-aiplatform==1.76.0 23 | google-cloud-bigquery==3.27.0 24 | google-cloud-core==2.4.1 25 | google-cloud-resource-manager==1.14.0 26 | google-cloud-storage==2.19.0 27 | google-crc32c==1.6.0 28 | google-generativeai==0.8.3 29 | google-resumable-media==2.7.2 30 | googleapis-common-protos==1.66.0 31 | grpc-google-iam-v1==0.14.0 32 | grpcio==1.69.0 33 | grpcio-status==1.69.0 34 | h11==0.14.0 35 | httpcore==1.0.7 36 | httplib2==0.22.0 37 | httpx==0.28.1 38 | idna==3.10 39 | instructor==1.7.2 40 | jinja2==3.1.5 41 | jiter==0.8.2 42 | jsonref==1.1.0 43 | levenshtein==0.26.1 44 | lxml==5.3.0 45 | markdown-it-py==3.0.0 46 | markupsafe==3.0.2 47 | mdurl==0.1.2 48 | multidict==6.1.0 49 | numpy==2.2.1 50 | openai==1.59.7 51 | packaging==24.2 52 | pdf2image==1.17.0 53 | pillow==11.1.0 54 | polars==1.19.0 55 | propcache==0.2.1 56 | proto-plus==1.25.0 57 | protobuf==5.29.3 58 | pyasn1==0.6.1 59 | pyasn1-modules==0.4.1 60 | pydantic==2.10.5 61 | pydantic-core==2.27.2 62 | pygments==2.19.1 63 | pyparsing==3.2.1 64 | python-dateutil==2.9.0.post0 65 | python-dotenv==1.0.1 66 | rapidfuzz==3.11.0 67 | requests==2.32.3 68 | rich==13.9.4 69 | rsa==4.9 70 | shapely==2.0.6 71 | shellingham==1.5.4 72 | six==1.17.0 73 | sniffio==1.3.1 74 | tenacity==9.0.0 75 | tqdm==4.67.1 76 | typer==0.15.1 77 | typing-extensions==4.12.2 78 | uritemplate==4.1.1 79 | urllib3==2.3.0 80 | yarl==1.18.3 81 | --------------------------------------------------------------------------------