├── requirements.txt ├── utils ├── model.py ├── preprocessing.py ├── ocr.py ├── visualization.py └── detection.py ├── app.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | transformers 3 | torch 4 | Pillow 5 | huggingface_hub 6 | matplotlib 7 | easyocr 8 | tqdm 9 | pandas 10 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection 3 | 4 | def load_detection_model(): 5 | model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") 6 | device = "cuda" if torch.cuda.is_available() else "cpu" 7 | model.to(device) 8 | return model, device 9 | 10 | def load_structure_model(device): 11 | model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all") 12 | model.to(device) 13 | return model 14 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | class MaxResize(object): 5 | def __init__(self, max_size=800): 6 | self.max_size = max_size 7 | 8 | def __call__(self, image): 9 | width, height = image.size 10 | current_max_size = max(width, height) 11 | scale = self.max_size / current_max_size 12 | resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) 13 | return resized_image 14 | 15 | detection_transform = transforms.Compose([ 16 | MaxResize(800), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | structure_transform = transforms.Compose([ 22 | MaxResize(1000), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | def prepare_image(image, device): 28 | pixel_values = detection_transform(image).unsqueeze(0) 29 | pixel_values = pixel_values.to(device) 30 | return pixel_values 31 | 32 | def prepare_cropped_image(cropped_image, device): 33 | pixel_values = structure_transform(cropped_image).unsqueeze(0) 34 | pixel_values = pixel_values.to(device) 35 | return pixel_values 36 | -------------------------------------------------------------------------------- /utils/ocr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import easyocr 3 | from tqdm.auto import tqdm 4 | import csv 5 | 6 | reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory 7 | 8 | def apply_ocr(cell_coordinates, cropped_table): 9 | data = dict() 10 | max_num_columns = 0 11 | for idx, row in enumerate(tqdm(cell_coordinates)): 12 | row_text = [] 13 | for cell in row["cells"]: 14 | cell_image = np.array(cropped_table.crop(cell["cell"])) 15 | result = reader.readtext(np.array(cell_image)) 16 | if len(result) > 0: 17 | text = " ".join([x[1] for x in result]) 18 | row_text.append(text) 19 | if len(row_text) > max_num_columns: 20 | max_num_columns = len(row_text) 21 | data[idx] = row_text 22 | 23 | print("Max number of columns:", max_num_columns) 24 | for row, row_data in data.copy().items(): 25 | if len(row_data) != max_num_columns: 26 | row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))] 27 | data[row] = row_data 28 | return data 29 | 30 | def save_csv(data): 31 | with open('output.csv', 'w') as result_file: 32 | wr = csv.writer(result_file, dialect='excel') 33 | for row, row_text in data.items(): 34 | wr.writerow(row_text) 35 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from PIL import Image 3 | from utils.model import load_detection_model, load_structure_model 4 | from utils.preprocessing import prepare_image, prepare_cropped_image 5 | from utils.detection import detect_tables, detect_cells, outputs_to_objects, objects_to_crops, get_cell_coordinates_by_row 6 | from utils.visualization import visualize_detected_tables, plot_results 7 | from utils.ocr import apply_ocr, save_csv 8 | 9 | # Load models 10 | detection_model, device = load_detection_model() 11 | structure_model = load_structure_model(device) 12 | 13 | st.title("Table Detection and OCR with Transformers") 14 | 15 | # Upload image 16 | uploaded_file = st.file_uploader("Upload an image containing a table", type=["png", "jpg", "jpeg"]) 17 | if uploaded_file is not None: 18 | image = Image.open(uploaded_file).convert("RGB") 19 | st.image(image, caption='Uploaded Image', use_column_width=True) 20 | 21 | # Add "no object" to id2label if not present 22 | if len(detection_model.config.id2label) not in detection_model.config.id2label: 23 | detection_model.config.id2label[len(detection_model.config.id2label)] = "no object" 24 | 25 | # Detect tables 26 | pixel_values = prepare_image(image, device) 27 | outputs = detect_tables(detection_model, pixel_values) 28 | objects = outputs_to_objects(outputs, image.size, detection_model.config.id2label) 29 | 30 | st.write("Detected tables:") 31 | st.write(objects) 32 | 33 | # Visualize tables 34 | fig = visualize_detected_tables(image, objects) 35 | st.pyplot(fig) 36 | 37 | # Crop the table 38 | tokens = [] 39 | detection_class_thresholds = {"table": 0.5, "table rotated": 0.5, "no object": 10} 40 | tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=0) 41 | cropped_table = tables_crops[0]['image'].convert("RGB") 42 | 43 | st.image(cropped_table, caption='Cropped Table', use_column_width=True) 44 | 45 | # Detect cells in the cropped table 46 | pixel_values = prepare_cropped_image(cropped_table, device) 47 | outputs = detect_cells(structure_model, pixel_values) 48 | cells = outputs_to_objects(outputs, cropped_table.size, structure_model.config.id2label) 49 | 50 | st.write("Detected cells:") 51 | st.write(cells) 52 | 53 | # Visualize cells 54 | fig = plot_results(cropped_table, cells, "table row") 55 | st.pyplot(fig) 56 | 57 | # Apply OCR 58 | cell_coordinates = get_cell_coordinates_by_row(cells) 59 | data = apply_ocr(cell_coordinates, cropped_table) 60 | 61 | # Display OCR results 62 | st.write("Extracted Table Data:") 63 | st.write(data) 64 | 65 | # Save results as CSV 66 | save_csv(data) 67 | st.write("CSV file saved as output.csv") 68 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.patches as patches 3 | from matplotlib.patches import Patch 4 | 5 | def visualize_detected_tables(img, det_tables): 6 | plt.imshow(img, interpolation="lanczos") 7 | fig = plt.gcf() 8 | fig.set_size_inches(20, 20) 9 | ax = plt.gca() 10 | 11 | for det_table in det_tables: 12 | bbox = det_table['bbox'] 13 | if det_table['label'] == 'table': 14 | facecolor = (1, 0, 0.45) 15 | edgecolor = (1, 0, 0.45) 16 | alpha = 0.3 17 | linewidth = 2 18 | hatch = '//////' 19 | elif det_table['label'] == 'table rotated': 20 | facecolor = (0.95, 0.6, 0.1) 21 | edgecolor = (0.95, 0.6, 0.1) 22 | alpha = 0.3 23 | linewidth = 2 24 | hatch = '//////' 25 | else: 26 | continue 27 | 28 | rect = patches.Rectangle(bbox[:2], bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=linewidth, 29 | edgecolor='none', facecolor=facecolor, alpha=0.1) 30 | ax.add_patch(rect) 31 | rect = patches.Rectangle(bbox[:2], bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=linewidth, 32 | edgecolor=edgecolor, facecolor='none', linestyle='-', alpha=alpha) 33 | ax.add_patch(rect) 34 | rect = patches.Rectangle(bbox[:2], bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=0, 35 | edgecolor=edgecolor, facecolor='none', linestyle='-', hatch=hatch, alpha=0.2) 36 | ax.add_patch(rect) 37 | 38 | plt.xticks([], []) 39 | plt.yticks([], []) 40 | legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), label='Table', hatch='//////', alpha=0.3), 41 | Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), label='Table (rotated)', hatch='//////', alpha=0.3)] 42 | plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, 43 | fontsize=10, ncol=2) 44 | plt.gcf().set_size_inches(10, 10) 45 | plt.axis('off') 46 | return fig 47 | 48 | def plot_results(cropped_table, cells, class_to_visualize): 49 | plt.figure(figsize=(16, 10)) 50 | plt.imshow(cropped_table) 51 | ax = plt.gca() 52 | 53 | for cell in cells: 54 | score = cell["score"] 55 | bbox = cell["bbox"] 56 | label = cell["label"] 57 | if label == class_to_visualize: 58 | xmin, ymin, xmax, ymax = tuple(bbox) 59 | ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color="red", linewidth=3)) 60 | text = f'{cell["label"]}: {score:0.2f}' 61 | ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) 62 | plt.axis('off') 63 | return plt.gcf() 64 | -------------------------------------------------------------------------------- /utils/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def detect_tables(model, pixel_values): 4 | with torch.no_grad(): 5 | outputs = model(pixel_values) 6 | return outputs 7 | 8 | def detect_cells(model, pixel_values): 9 | with torch.no_grad(): 10 | outputs = model(pixel_values) 11 | return outputs 12 | 13 | def outputs_to_objects(outputs, img_size, id2label): 14 | def box_cxcywh_to_xyxy(x): 15 | x_c, y_c, w, h = x.unbind(-1) 16 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 17 | return torch.stack(b, dim=1) 18 | 19 | def rescale_bboxes(out_bbox, size): 20 | img_w, img_h = size 21 | b = box_cxcywh_to_xyxy(out_bbox) 22 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 23 | return b 24 | 25 | # Add "no object" to id2label if not present 26 | if len(id2label) not in id2label: 27 | id2label[len(id2label)] = "no object" 28 | 29 | m = outputs.logits.softmax(-1).max(-1) 30 | pred_labels = list(m.indices.detach().cpu().numpy())[0] 31 | pred_scores = list(m.values.detach().cpu().numpy())[0] 32 | pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] 33 | pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] 34 | 35 | print(f"Predicted labels: {pred_labels}") 36 | print(f"id2label: {id2label}") 37 | 38 | objects = [] 39 | for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): 40 | try: 41 | class_label = id2label[int(label)] 42 | except KeyError: 43 | print(f"Label {label} not found in id2label. Skipping.") 44 | continue 45 | if not class_label == 'no object': 46 | objects.append({'label': class_label, 'score': float(score), 47 | 'bbox': [float(elem) for elem in bbox]}) 48 | 49 | return objects 50 | 51 | def objects_to_crops(img, tokens, objects, class_thresholds, padding=10): 52 | table_crops = [] 53 | for obj in objects: 54 | if obj['score'] < class_thresholds[obj['label']]: 55 | continue 56 | 57 | cropped_table = {} 58 | bbox = obj['bbox'] 59 | bbox = [bbox[0] - padding, bbox[1] - padding, bbox[2] + padding, bbox[3] + padding] 60 | cropped_img = img.crop(bbox) 61 | 62 | table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] 63 | for token in table_tokens: 64 | token['bbox'] = [token['bbox'][0] - bbox[0], token['bbox'][1] - bbox[1], token['bbox'][2] - bbox[0], token['bbox'][3] - bbox[1]] 65 | 66 | if obj['label'] == 'table rotated': 67 | cropped_img = cropped_img.rotate(270, expand=True) 68 | for token in table_tokens: 69 | bbox = token['bbox'] 70 | bbox = [cropped_img.size[0] - bbox[3] - 1, bbox[0], cropped_img.size[0] - bbox[1] - 1, bbox[2]] 71 | token['bbox'] = bbox 72 | 73 | cropped_table['image'] = cropped_img 74 | cropped_table['tokens'] = table_tokens 75 | table_crops.append(cropped_table) 76 | 77 | return table_crops 78 | 79 | def get_cell_coordinates_by_row(table_data): 80 | rows = [entry for entry in table_data if entry['label'] == 'table row'] 81 | columns = [entry for entry in table_data if entry['label'] == 'table column'] 82 | rows.sort(key=lambda x: x['bbox'][1]) 83 | columns.sort(key=lambda x: x['bbox'][0]) 84 | 85 | def find_cell_coordinates(row, column): 86 | cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]] 87 | return cell_bbox 88 | 89 | cell_coordinates = [] 90 | for row in rows: 91 | row_cells = [] 92 | for column in columns: 93 | cell_bbox = find_cell_coordinates(row, column) 94 | row_cells.append({'column': column['bbox'], 'cell': cell_bbox}) 95 | row_cells.sort(key=lambda x: x['column'][0]) 96 | cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)}) 97 | cell_coordinates.sort(key=lambda x: x['row'][1]) 98 | return cell_coordinates 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Table Detection and OCR with Transformers 2 | 3 | This project is a Streamlit app for detecting tables in images, cropping them, detecting cells within the cropped tables, and applying OCR (Optical Character Recognition) to extract the table data into a CSV file. 4 | 5 | ## Directory Structure 6 | 7 | The project is structured as follows: 8 | 9 | 10 | streamlit_table_app/ 11 | 12 | ├── app.py 13 | 14 | ├── requirements.txt 15 | 16 | ├── utils/ 17 | 18 | │ ├── model.py 19 | 20 | │ ├── preprocessing.py 21 | 22 | │ ├── detection.py 23 | 24 | │ ├── visualization.py 25 | 26 | │ ├── ocr.py 27 | 28 | 29 | ### app.py 30 | 31 | This file is the main entry point for the Streamlit app. It handles the user interface, image upload, and the sequence of steps from table detection to OCR and saving results. 32 | 33 | ### Main Features: 34 | - Upload an image containing a table. 35 | - Detect and visualize tables in the image. 36 | - Crop detected tables and visualize them. 37 | - Detect and visualize cells within the cropped tables. 38 | - Perform OCR on the cells to extract table data. 39 | - Save the extracted data as a CSV file. 40 | 41 | ### requirements.txt 42 | 43 | This file lists all the dependencies required for the project. 44 | 45 | #### Dependencies: 46 | - streamlit 47 | - transformers 48 | - torch 49 | - Pillow 50 | - huggingface_hub 51 | - matplotlib 52 | - easyocr 53 | - tqdm 54 | - pandas 55 | 56 | ### utils/model.py 57 | 58 | This file contains functions for loading the table detection and structure recognition models. 59 | 60 | #### Functions: 61 | - `load_detection_model()`: Loads the table detection model. 62 | - `load_structure_model(device)`: Loads the structure recognition model. 63 | 64 | ### utils/preprocessing.py 65 | 66 | This file contains functions for preparing images to be compatible with the models. 67 | 68 | #### Functions: 69 | - `prepare_image(image, device)`: Prepares and normalizes the image for the table detection model. 70 | - `prepare_cropped_image(cropped_image, device)`: Prepares and normalizes the cropped table image for the structure recognition model. 71 | 72 | ### utils/detection.py 73 | 74 | This file contains functions for detecting tables and cells in the images. 75 | 76 | #### Functions: 77 | - `detect_tables(model, pixel_values)`: Uses the table detection model to detect tables. 78 | - `detect_cells(model, pixel_values)`: Uses the structure recognition model to detect cells within cropped tables. 79 | 80 | ### utils/visualization.py 81 | 82 | This file contains functions for visualizing detected tables and cells. 83 | 84 | #### Functions: 85 | - `visualize_detected_tables(img, det_tables)`: Visualizes tables detected in the image. 86 | - `plot_results(cells, class_to_visualize)`: Visualizes detected cells within the cropped table. 87 | 88 | ### utils/ocr.py 89 | 90 | This file contains functions for applying OCR and saving the results as a CSV. 91 | 92 | #### Functions: 93 | - `apply_ocr(cell_coordinates, cropped_table)`: Performs OCR on detected cells to extract text. 94 | - `save_csv(data)`: Saves the extracted table data into a CSV file. 95 | 96 | ## Installation 97 | 98 | To set up the project, execute the following commands: 99 | ```sh 100 | git clone https://github.com/h9-tect/table_parse_using_table_transformers.git # Clone the repository 101 | cd streamlit_table_app # Navigate to the project directory 102 | pip install -r requirements.txt # Install dependencies 103 | ``` 104 | ## Usage 105 | 106 | To run the Streamlit app, execute: 107 | 108 | ```sh 109 | streamlit run app.py 110 | ``` 111 | This will launch your Streamlit app. You can then upload an image containing a table, and the app will process the image, detect tables and cells, apply OCR, and save the extracted table data as a CSV file named output.csv. 112 | 113 | ## Notes 114 | 115 | Ensure you have a CUDA-capable GPU for faster model inference, though the code will run on CPU if a GPU is not available. 116 | The provided pretrained models are from the Hugging Face model hub, specifically designed for table detection and structure recognition tasks. 117 | 118 | ## Contributing 119 | 120 | Feel free to fork this repository and submit pull requests. For significant changes, please open an issue first to discuss what you would like to change. 121 | 122 | --------------------------------------------------------------------------------