├── .gitignore ├── ExtractTabularData ├── ExtractTextFromImage.ipynb └── image_to_df.py ├── GridlinesImprovement ├── apply_gridlines_fix.py ├── cropping.py ├── draw_gridlines_functions.py └── remove_gridlines.py ├── Model Implementation ├── DummyDatabase │ ├── predictions │ │ ├── five_pred.jpg │ │ ├── four_pred.jpg │ │ ├── one_pred.jpg │ │ ├── three_pred.png │ │ └── two_pred.png │ └── test_images │ │ ├── five.jpg │ │ ├── four.jpg │ │ ├── one.jpg │ │ ├── three.png │ │ └── two.png ├── Pre-Processing │ ├── eda.ipynb │ ├── pre_processing.py │ └── preprocessing_utilities.py ├── Training │ ├── __pycache__ │ │ ├── encoder.cpython-39.pyc │ │ ├── path_constants.cpython-39.pyc │ │ └── tablenet_model.cpython-39.pyc │ ├── configurations.py │ ├── dataset.py │ ├── encoder.py │ ├── general_utilities.py │ ├── model_loss.py │ ├── model_training.py │ ├── path_constants.py │ └── tablenet_model.py ├── model_testing.py ├── model_testing_note.ipynb └── tables │ ├── with tables.png │ ├── with │ ├── table-image-110.png │ ├── table-image-111.png │ ├── table-image-120.png │ ├── table-image-159.png │ ├── table-image-169.png │ ├── table-image-170.png │ ├── table-image-172.png │ ├── table-image-181.png │ ├── table-image-26.png │ ├── table-image-3.png │ ├── table-image-39.png │ ├── table-image-61.png │ ├── table-image-74.png │ ├── table-image-82.png │ ├── table-image-84.png │ └── table-image-99.png │ ├── without tables.png │ └── without │ ├── table-image-102.png │ ├── table-image-129.png │ ├── table-image-50.png │ ├── table-image-78.png │ └── table-image-79.png ├── README.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Model extensions 10 | *.pth.tar 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ -------------------------------------------------------------------------------- /ExtractTabularData/image_to_df.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import PIL 4 | import pytesseract 5 | import os 6 | os.environ['TESSDATA_PREFIX'] = 'Extract Tabular Data/tessdata_dir/' 7 | 8 | 9 | def optimizeDf(old_df: pd.DataFrame) -> pd.DataFrame: 10 | df = old_df[["left", "top", "width", "text"]] 11 | df['left+width'] = df['left'] + df['width'] 12 | df = df.sort_values(by=['top'], ascending=True) 13 | df = df.groupby(['top', 'left+width'], sort=False)['text'].sum().unstack('left+width') 14 | df = df.reindex(sorted(df.columns), axis=1).dropna(how='all').dropna(axis='columns', how='all') 15 | df = df.fillna('') 16 | return df 17 | 18 | def mergeDfColumns(old_df: pd.DataFrame, threshold: int = 10, rotations: int = 5) -> pd.DataFrame: # threshold was 10 19 | df = old_df.copy() 20 | for j in range(0, rotations): 21 | new_columns = {} 22 | old_columns = df.columns 23 | i = 0 24 | while i < len(old_columns): 25 | if i < len(old_columns) - 1: 26 | if any(old_columns[i+1] == old_columns[i] + x for x in range(1, threshold)): 27 | new_col = df[old_columns[i]].astype(str) + df[old_columns[i+1]].astype(str) 28 | new_columns[old_columns[i+1]] = new_col 29 | i = i + 1 30 | else: 31 | new_columns[old_columns[i]] = df[old_columns[i]] 32 | else: 33 | new_columns[old_columns[i]] = df[old_columns[i]] 34 | i += 1 35 | df = pd.DataFrame.from_dict(new_columns).replace('', np.nan).dropna(axis='columns', how='all').replace(np.nan, '') 36 | return df 37 | 38 | def mergeDfRows(old_df: pd.DataFrame, threshold: int = 10) -> pd.DataFrame: 39 | new_df = old_df.iloc[:1] 40 | for i in range(1, len(old_df)): 41 | if abs(old_df.index[i] - old_df.index[i - 1]) < threshold: 42 | new_df.iloc[-1] = new_df.iloc[-1].astype(str) + old_df.iloc[i].astype(str) 43 | else: 44 | new_df = new_df.append(old_df.iloc[i]) 45 | return new_df.reset_index(drop=True) 46 | 47 | def cleanDf(df): 48 | # Remove columns with all cells holding the same value and its length is 0 or 1 49 | df = df.loc[:, (df != df.iloc[0]).any()] 50 | # Remove rows with empty cells or cells with only the '|' symbol 51 | df = df[(df != '|') & (df != '=') & (df != '') & (pd.notnull(df))] 52 | # Remove columns with only empty cells 53 | df = df.dropna(axis=1, how='all') 54 | return df.fillna('') 55 | 56 | print(pytesseract.get_tesseract_version()) 57 | print(pytesseract.get_languages()) 58 | 59 | """ 60 | Best Rsults: --psm 12 --oem 1 61 | History: 62 | 8) --psm 12 --oem 1 --dpi 3000 -> eng 80% 63 | 7) --psm 12 --oem 2 -> eng 90%, heb 78%-10%+5% -> ',' and '.' it cant decide between the two 64 | 6) --psm 12 --oem 1 -> eng 95%, heb 78%-10%+5% -> ',' and '.' it cant decide between the two 65 | 5) --psm 12 --oem 0 -> eng 85% 66 | 4) --psm 12 -> eng 90%, heb 75%-15+3%% 67 | 3) --psm 6 -> 40% 68 | 2) --psm 5 -> 10% 69 | 1) --psm 11 -> eng 85%, heb 70%+-15% 70 | URL: https://muthu.co/all-tesseract-ocr-options/ 71 | """ 72 | special_config = '--psm 12 --oem 1' 73 | languages_ = "eng" 74 | 75 | image_path = "Model Implementation/DummyDatabase/predictions/image_crop.png" 76 | 77 | img_pl=PIL.Image.open(image_path) 78 | 79 | data = pytesseract.image_to_data(img_pl, lang=languages_, output_type='data.frame', config=special_config) 80 | 81 | data_imp_sort = optimizeDf(data.copy()) 82 | 83 | df_new_col = mergeDfColumns(data_imp_sort.copy()) 84 | 85 | merged_row_df = mergeDfRows(df_new_col.copy(), threshold = 5) 86 | 87 | cleaned_df = cleanDf(merged_row_df.copy()) 88 | -------------------------------------------------------------------------------- /GridlinesImprovement/apply_gridlines_fix.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | from GridlinesImprovement.draw_gridlines_functions import drawGridlines 4 | from GridlinesImprovement.cropping import cropImage 5 | from GridlinesImprovement.remove_gridlines import removeLines 6 | 7 | 8 | image_path = "Model Implementation/DummyDatabase/test_images/image_gridless.png" 9 | new_image_path = "Model Implementation/DummyDatabase/test_images/image_grided.png" 10 | original = cv2.imread(image_path) 11 | # Remove all gridlines 12 | gridless = removeLines(removeLines(original, 'horizontal'), 'vertical') 13 | # Draw grid lines 14 | images_by_stage = drawGridlines(gridless.copy()) 15 | """ 16 | images_by_stage: (dict) 17 | 'threshold': threshold image 18 | 'vertical': vertical grid lines image 19 | 'horizontal': horizontal grid lines image 20 | 'full': full grid lines image 21 | """ 22 | # Obtain full grid image 23 | full_image = images_by_stage['full'].copy() 24 | # Crop image 25 | cropped_image = cropImage(full_image.copy()) 26 | # Save new image 27 | cv2.imwrite(new_image_path, cropped_image) 28 | # Show image 29 | plt.imshow(cropped_image) -------------------------------------------------------------------------------- /GridlinesImprovement/cropping.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def cropImage(old_image: np.ndarray) -> np.ndarray: 6 | # Get dimensions 7 | hh, ww = old_image.shape[:2] 8 | # Convert to gray 9 | gray = cv2.cvtColor(old_image, cv2.COLOR_BGR2GRAY) 10 | # Threshold 11 | thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)[1] 12 | # Crop 1 pixel and add 1 pixel white border to ensure outer white regions not considered the small contours 13 | thresh = thresh[1: hh - 1, 1 : ww - 1] 14 | thresh = cv2.copyMakeBorder(thresh, 1, 1, 1, 1, borderType = cv2.BORDER_CONSTANT, value = (255, 255, 255)) 15 | # Get contours 16 | contours = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 17 | contours = contours[0] if len(contours) == 2 else contours[1] 18 | # Get min and max x and y from all bounding boxes larger than half of the image size 19 | thresh_area = hh * ww / 2 20 | xmin = ww 21 | ymin = hh 22 | xmax = 0 23 | ymax = 0 24 | for contour in contours: 25 | area = cv2.contourArea(contour) 26 | if area < thresh_area: 27 | x, y, w, h = cv2.boundingRect(contour) 28 | xmin = x if (x < xmin) else xmin 29 | ymin = y if (y < ymin) else ymin 30 | xmax = x + w - 1 if (x + w - 1 > xmax ) else xmax 31 | ymax = y + h - 1 if (y + h - 1 > ymax) else ymax 32 | # Draw bounding box 33 | bounding_box = old_image.copy() 34 | cv2.rectangle(bounding_box, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2) 35 | # Crop old_image at the bounding box, but add 2 all around to keep the black lines 36 | result = old_image[ymin : ymax, xmin : xmax] 37 | return result 38 | -------------------------------------------------------------------------------- /GridlinesImprovement/draw_gridlines_functions.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | """ 6 | Return: 7 | (dict): 8 | 'threshold': threshold image 9 | 'vertical': vertical grid lines image 10 | 'horizontal': horizontal grid lines image 11 | 'full': full grid lines image 12 | """ 13 | def drawGridlines(old_image: np.ndarray) -> dict: 14 | # Get dimensions 15 | hh_, ww_ = old_image.shape[:2] 16 | # Convert image to grayscale 17 | gray = cv2.cvtColor(old_image, cv2.COLOR_BGR2GRAY) 18 | # Threshold on white - binary 19 | thresh = cv2.threshold(gray, 220, 255, cv2.THRESH_BINARY)[1] 20 | # Resize thresh image to a single row 21 | row = cv2.resize(thresh, (ww_, 1), interpolation = cv2.INTER_AREA) 22 | # Threshold on white 23 | thresh_row = cv2.threshold(row, 254, 255, cv2.THRESH_BINARY)[1] 24 | # Apply small amount of morphology to merge with column of text 25 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT , (5, 1)) 26 | thresh_row = cv2.morphologyEx(thresh_row, cv2.MORPH_OPEN, kernel) 27 | # Get vertical contours 28 | contours_v = cv2.findContours(thresh_row, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 29 | contours_v = contours_v[0] if len(contours_v) == 2 else contours_v[1] 30 | full_grid_image = old_image.copy() 31 | vertical_img = old_image.copy() 32 | for contour in contours_v: 33 | x, y, w, h = cv2.boundingRect(contour) 34 | xcenter = x + w // 2 35 | cv2.line(vertical_img, (xcenter, 0), (xcenter, hh_ - 1), (0, 0, 0), 1) 36 | cv2.line(full_grid_image, (xcenter, 0), (xcenter, hh_ - 1), (0, 0, 0), 1) 37 | # Resize thresh image to a single column 38 | column = cv2.resize(thresh, (1, hh_), interpolation = cv2.INTER_AREA) 39 | # Threshold on white - binary 40 | thresh_column = cv2.threshold(column, 254, 255, cv2.THRESH_BINARY)[1] 41 | # Get horizontal contours 42 | contours_h = cv2.findContours(thresh_column, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 43 | contours_h = contours_h[0] if len(contours_h) == 2 else contours_h[1] 44 | horizontal_img = old_image.copy() 45 | for contour in contours_h: 46 | x, y, w, h = cv2.boundingRect(contour) 47 | ycenter = y + h // 2 48 | cv2.line(horizontal_img, (0, ycenter), (ww_ - 1, ycenter), (0, 0, 0), 1) 49 | cv2.line(full_grid_image, (0, ycenter), (ww_ - 1, ycenter), (0, 0, 0), 1) 50 | # Return results as a dictionary 51 | return { 52 | 'threshold': thresh, 53 | 'vertical': vertical_img, 54 | 'horizontal': horizontal_img, 55 | 'full': full_grid_image 56 | } 57 | -------------------------------------------------------------------------------- /GridlinesImprovement/remove_gridlines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def removeLines(old_image: np.ndarray, axis) -> np.ndarray: 6 | gray = cv2.cvtColor(old_image, cv2.COLOR_BGR2GRAY) 7 | thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] 8 | if axis == "horizontal": 9 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25)) 10 | elif axis == "vertical": 11 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1)) 12 | else: 13 | raise ValueError("Axis must be either 'horizontal' or 'vertical' in order to work") 14 | detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations = 2) 15 | contours = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 16 | contours = contours[0] if len(contours) == 2 else contours[1] 17 | result = old_image.copy() 18 | for contour in contours: 19 | cv2.drawContours(result, [contour], -1, (255, 255, 255), 2) 20 | return result 21 | -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/predictions/five_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/predictions/five_pred.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/predictions/four_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/predictions/four_pred.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/predictions/one_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/predictions/one_pred.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/predictions/three_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/predictions/three_pred.png -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/predictions/two_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/predictions/two_pred.png -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/test_images/five.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/test_images/five.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/test_images/four.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/test_images/four.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/test_images/one.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/test_images/one.jpg -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/test_images/three.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/test_images/three.png -------------------------------------------------------------------------------- /Model Implementation/DummyDatabase/test_images/two.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/DummyDatabase/test_images/two.png -------------------------------------------------------------------------------- /Model Implementation/Pre-Processing/eda.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import cv2\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import glob\n", 12 | "import numpy as np\n", 13 | "from Training.path_constants import ORIG_DATA_PATH, Marmot_data\n", 14 | "from preprocessing_utilities import create_element_mask, get_table_bounding_box, get_column_bounding_box" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "positive_data = glob.glob(f'{ORIG_DATA_PATH}/Positive/Raw' + '/*.bmp')\n", 24 | "negative_data = glob.glob(f'{ORIG_DATA_PATH}/Negative/Raw' + '/*.bmp')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "fig = plt.figure(figsize = (10, 5))\n", 34 | "x = ['Neg Samples', 'Pos Samples']\n", 35 | "y = [len(negative_data), len(positive_data)]\n", 36 | "plt.bar(x, y,width = 0.4)\n", 37 | "plt.title('Distribution: Positive and Negative Samples')\n", 38 | "plt.show()" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "new_h, new_w = 1024, 1024" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Negative Example\n", 57 | "image_path = f'{ORIG_DATA_PATH}/Negative/Raw/10.1.1.1.2000_4.bmp'\n", 58 | "image = Image.open(image_path)\n", 59 | "image = image.resize((new_h, new_w))\n", 60 | "table_mask = create_element_mask(new_h, new_w)\n", 61 | "column_mask = create_element_mask(new_h, new_w)\n", 62 | "# Ploting\n", 63 | "f, ax = plt.subplots(1,3, figsize = (20,15))\n", 64 | "ax[0].imshow(np.array(image))\n", 65 | "ax[0].set_title('Original Image')\n", 66 | "ax[1].imshow(table_mask)\n", 67 | "ax[1].set_title('Table Mask')\n", 68 | "ax[2].imshow(column_mask)\n", 69 | "ax[2].set_title('Column Mask')\n", 70 | "plt.show()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Positive Example\n", 80 | "image_path = f'{ORIG_DATA_PATH}/Positive/Raw/10.1.1.1.2006_3.bmp'\n", 81 | "table_xml_path = f'{ORIG_DATA_PATH}/Positive/Labeled/10.1.1.1.2006_3.xml'\n", 82 | "column_xml_path = f'{Marmot_data}/10.1.1.1.2006_3.xml'\n", 83 | "# Load image\n", 84 | "image = Image.open(image_path)\n", 85 | "# Resize imageto std 1024, 1024\n", 86 | "w, h = image.size\n", 87 | "image = image.resize((new_h, new_w))\n", 88 | "# Convert to 3 channel image if 1 channel\n", 89 | "if image.mode != 'RGB':\n", 90 | " image = image.convert(\"RGB\")\n", 91 | "# Scaled versions of bbox coordinates of table\n", 92 | "table_bounding_boxes = get_table_bounding_box(table_xml_path, (new_h, new_w))\n", 93 | "# Scaled versions of bbox coordinates of columns\n", 94 | "column_bounding_boxes, table_bounding_boxes = get_column_bounding_box(column_xml_path, (h,w), (new_h, new_w), table_bounding_boxes)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "column_bounding_boxes, table_bounding_boxes" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "plt.figure(figsize = (20,10))\n", 113 | "image_copy = np.array(image).copy()\n", 114 | "for bounding_box in table_bounding_boxes:\n", 115 | " cv2.rectangle(image_copy, (bounding_box[0], bounding_box[1]), (bounding_box[2], bounding_box[3]), (0, 255, 0), 2)\n", 116 | "for bounding_box in column_bounding_boxes:\n", 117 | " cv2.rectangle(image_copy, (bounding_box[0], bounding_box[1]), (bounding_box[2], bounding_box[3]), (255, 255, 0), 2)\n", 118 | "plt.imshow(image_copy)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "table_mask = create_element_mask(new_h, new_w, table_bounding_boxes)\n", 128 | "column_mask = create_element_mask(new_h, new_w, column_bounding_boxes)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "f, ax = plt.subplots(1,3, figsize = (20,15))\n", 138 | "ax[0].imshow(np.array(image_copy))\n", 139 | "ax[0].set_title('Original Image')\n", 140 | "ax[1].imshow(table_mask)\n", 141 | "ax[1].set_title('Table Mask')\n", 142 | "ax[2].imshow(column_mask)\n", 143 | "ax[2].set_title('Column Mask')\n", 144 | "plt.show()" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "name": "python", 156 | "version": "3.9.7" 157 | }, 158 | "orig_nbformat": 4, 159 | "vscode": { 160 | "interpreter": { 161 | "hash": "92cc54a05d1ad6fd73bb4b9111dd84f41a66497e622f98d2a5bcc9478314e882" 162 | } 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 2 167 | } 168 | -------------------------------------------------------------------------------- /Model Implementation/Pre-Processing/pre_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import pandas as pd 6 | from Training.path_constants import ORIG_DATA_PATH, PROCESSED_DATA, IMAGE_PATH, TABLE_MASK_PATH, COL_MASK_PATH, POSITIVE_DATA_LBL, DATA_PATH 7 | from preprocessing_utilities import get_table_bounding_box, get_column_bounding_box, create_element_mask 8 | 9 | 10 | # Make directories to save data 11 | os.makedirs(PROCESSED_DATA, exist_ok = True) 12 | os.makedirs(IMAGE_PATH, exist_ok = True) 13 | os.makedirs(TABLE_MASK_PATH, exist_ok = True) 14 | os.makedirs(COL_MASK_PATH, exist_ok = True) 15 | 16 | positive_data = glob.glob(f'{ORIG_DATA_PATH}/Positive/Raw' + '/*.bmp') 17 | negative_data = glob.glob(f'{ORIG_DATA_PATH}/Negative/Raw' + '/*.bmp') 18 | 19 | new_h, new_w = 1024, 1024 20 | 21 | processed_data = [] 22 | for i, data in enumerate([negative_data, positive_data]): 23 | for j, image_path in tqdm(enumerate(data)): 24 | image_name = os.path.basename(image_path) 25 | image = Image.open(image_path) 26 | w, h = image.size 27 | # Convert image to RGB image 28 | image = image.resize((new_h, new_w)) 29 | if image.mode != 'RGB': 30 | image = image.convert("RGB") 31 | table_bounding_boxes, column_bounding_boxes = [], [] 32 | if i == 1: 33 | # Get xml filename 34 | xml_file = image_name.replace('bmp', 'xml') 35 | table_xml_path = os.path.join(POSITIVE_DATA_LBL, xml_file) 36 | column_xml_path = os.path.join(DATA_PATH, xml_file) 37 | # Get bounding boxes 38 | table_bounding_boxes = get_table_bounding_box(table_xml_path, (new_h, new_w)) 39 | if os.path.exists(column_xml_path): 40 | column_bounding_boxes, table_bounding_boxes = get_column_bounding_box(column_xml_path, (h,w), (new_h, new_w), table_bounding_boxes) 41 | else: 42 | column_bounding_boxes = [] 43 | # Create masks 44 | table_mask = create_element_mask(new_h, new_w, table_bounding_boxes) 45 | column_mask = create_element_mask(new_h, new_w, column_bounding_boxes) 46 | # Save images and masks 47 | save_image_path = os.path.join(IMAGE_PATH, image_name.replace('bmp', 'jpg')) 48 | save_table_mask_path = os.path.join(TABLE_MASK_PATH, image_name[:-4] + '_table_mask.png') 49 | save_column_mask_path = os.path.join(COL_MASK_PATH, image_name[:-4] + '_col_mask.png') 50 | image.save(save_image_path) 51 | table_mask.save(save_table_mask_path) 52 | column_mask.save(save_column_mask_path) 53 | # Add data to the dataframe 54 | len_table = len(table_bounding_boxes) 55 | len_columns = len(column_bounding_boxes) 56 | value = (save_image_path, save_table_mask_path, save_column_mask_path, h, w, int(len_table != 0), \ 57 | len_table, len_columns, table_bounding_boxes, column_bounding_boxes) 58 | processed_data.append(value) 59 | 60 | columns_name = ['img_path', 'table_mask', 'col_mask', 'original_height', 'original_width', 'hasTable', 'table_count', 'col_count', 'table_bboxes', 'col_bboxes'] 61 | processed_data = pd.DataFrame(processed_data, columns=columns_name) 62 | # Save dataframe and inspect it's data 63 | processed_data.to_csv(f"{PROCESSED_DATA}/processed_data.csv", index = False) 64 | print(processed_data.tail()) 65 | -------------------------------------------------------------------------------- /Model Implementation/Pre-Processing/preprocessing_utilities.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from PIL import Image 3 | import numpy as np 4 | import xml.etree.ElementTree as ET 5 | 6 | 7 | def get_table_bounding_box(table_xml_path: str, new_image_shape: tuple): 8 | """ 9 | Goal: Extract table coordinates from xml file and scale them to the new image shape 10 | Input: 11 | :param table_xml_path: xml file path 12 | :param new_image_shape: tuple (new_h, new_w) 13 | Return: table_bounding_boxes: List of all the bounding boxes of the tables 14 | """ 15 | tree = ET.parse(table_xml_path) 16 | root = tree.getroot() 17 | left, top, right, bottom = list(map(lambda x: struct.unpack('!d', bytes.fromhex(x))[0], root.get("CropBox").split())) 18 | width = abs(right - left) 19 | height = abs(top - bottom) 20 | table_bounding_boxes = [] 21 | for table in root.findall(".//Composite[@Label='TableBody']"): 22 | x0in, y0in, x1in, y1in = list(map(lambda x: struct.unpack('!d', bytes.fromhex(x))[0], table.get("BBox").split())) 23 | x0 = round(new_image_shape[1] * (x0in - left) / width) 24 | x1 = round(new_image_shape[1] * (x1in - left) / width) 25 | y0 = round(new_image_shape[0] * (top - y0in) / height) 26 | y1 = round(new_image_shape[0] * (top - y1in) / height) 27 | table_bounding_boxes.append([x0, y0, x1, y1]) 28 | return table_bounding_boxes 29 | 30 | def get_column_bounding_box(column_xml_path: str, old_image_shape: tuple, new_image_shape: tuple, 31 | table_bounding_box: list, threshhold: int = 3): 32 | """ 33 | Goal: 34 | - Extract column coordinates from the xml file and scale them to the new image shape and the old image shape 35 | - If there are no table_bounding_box present, approximate them using column bounding box 36 | Input: 37 | :param table_xml_path: xml file path 38 | :param old_image_shape: (new_h, new_w) 39 | :param new_image_shape: (new_h, new_w) 40 | :param table_bounding_box: List of table bbox coordinates 41 | :param threshold: the threshold t apply, defualts to 3 42 | Return: tuple (column_bounding_box, table_bounding_box) 43 | """ 44 | tree = ET.parse(column_xml_path) 45 | root = tree.getroot() 46 | x_mins = [round(int(coord.text) * new_image_shape[1] / old_image_shape[1]) for coord in root.findall("./object/bndbox/xmin")] 47 | y_mins = [round(int(coord.text) * new_image_shape[0] / old_image_shape[0]) for coord in root.findall("./object/bndbox/ymin")] 48 | x_maxs = [round(int(coord.text) * new_image_shape[1] / old_image_shape[1]) for coord in root.findall("./object/bndbox/xmax")] 49 | y_maxs = [round(int(coord.text) * new_image_shape[0] / old_image_shape[0]) for coord in root.findall("./object/bndbox/ymax")] 50 | column_bounding_box = [] 51 | for x_min, y_min, x_max, y_max in zip(x_mins, y_mins, x_maxs, y_maxs): 52 | bounding_box = [x_min, y_min, x_max, y_max] 53 | column_bounding_box.append(bounding_box) 54 | if len(table_bounding_box) == 0: 55 | x_min = min([x[0] for x in column_bounding_box]) - threshhold 56 | y_min = min([x[1] for x in column_bounding_box]) - threshhold 57 | x_max = max([x[2] for x in column_bounding_box]) + threshhold 58 | y_max = max([x[3] for x in column_bounding_box]) + threshhold 59 | table_bounding_box = [[x_min, y_min, x_max, y_max]] 60 | return column_bounding_box, table_bounding_box 61 | 62 | def create_element_mask(new_h: int, new_w: int, bounding_boxes: list = None): 63 | """ 64 | Goal: Create a mask based on new_h, new_w and bounding boxes 65 | Input: 66 | :param new_h: height of the mask 67 | :param new_w: width of the mask 68 | :param bounding_boxes: bounding box coordinates 69 | Return: mask: Image 70 | """ 71 | mask = np.zeros((new_h, new_w), dtype = np.int32) 72 | if bounding_boxes is None or len(bounding_boxes) == 0: 73 | return Image.fromarray(mask) 74 | for box in bounding_boxes: 75 | mask[box[1]:box[3], box[0]:box[2]] = 255 76 | return Image.fromarray(mask) 77 | -------------------------------------------------------------------------------- /Model Implementation/Training/__pycache__/encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/Training/__pycache__/encoder.cpython-39.pyc -------------------------------------------------------------------------------- /Model Implementation/Training/__pycache__/path_constants.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/Training/__pycache__/path_constants.cpython-39.pyc -------------------------------------------------------------------------------- /Model Implementation/Training/__pycache__/tablenet_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/Training/__pycache__/tablenet_model.cpython-39.pyc -------------------------------------------------------------------------------- /Model Implementation/Training/configurations.py: -------------------------------------------------------------------------------- 1 | from Training.path_constants import PROCESSED_DATA 2 | import torch # pip install torch 3 | 4 | SEED = 0 5 | LEARNING_RATE = 0.0001 6 | EPOCHS = 100 7 | BATCH_SIZE = 2 8 | WEIGHT_DECAY = 3e-4 9 | DATAPATH = f'{PROCESSED_DATA}/processed_data.csv' 10 | MODEL_NAME = "densenet_configuration_4_model_checkpoint.pth.tar" 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | -------------------------------------------------------------------------------- /Model Implementation/Training/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pandas as pd 4 | import numpy as np 5 | import albumentations as A 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torchvision.utils import save_image 9 | from albumentations.pytorch import ToTensorV2 10 | from tqdm import tqdm 11 | import tqdm 12 | from Training.path_constants import PROCESSED_DATA 13 | 14 | 15 | class ImageFolder(nn.Module): 16 | def __init__(self, df, transform = None): 17 | super(ImageFolder, self).__init__() 18 | self.df = df 19 | if transform is None: 20 | self.transform = A.Compose([ 21 | A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], max_pixel_value = 255,), 22 | ToTensorV2() 23 | ]) 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def __getitem__(self, index): 28 | image_path, table_mask_path, column_mask_path = self.df.iloc[index, 0], self.df.iloc[index, 1], self.df.iloc[index, 2] 29 | image = np.array(Image.open(image_path)) 30 | table_image = torch.FloatTensor(np.array(Image.open(table_mask_path)) / 255.0).reshape(1, 1024, 1024) 31 | column_image = torch.FloatTensor(np.array(Image.open(column_mask_path)) / 255.0).reshape(1, 1024, 1024) 32 | image = self.transform(image = image)['image'] 33 | return {"image": image, "table_image": table_image, "column_image": column_image} 34 | 35 | def get_mean_std(train_data, transform): 36 | dataset = ImageFolder(train_data , transform) 37 | train_loader = DataLoader(dataset, batch_size = 128) 38 | mean = 0. 39 | std = 0. 40 | for img_dict in tqdm.tqdm(train_loader): 41 | batch_samples = img_dict["image"].size(0) 42 | images = img_dict["image"].view(batch_samples, img_dict["image"].size(1), -1) 43 | mean += images.mean(2).sum(0) 44 | std += images.std(2).sum(0) 45 | mean /= len(train_loader.dataset) 46 | std /= len(train_loader.dataset) 47 | print(mean) 48 | print(std) 49 | 50 | # Read referencing csv file 51 | df = pd.read_csv(f'{PROCESSED_DATA}/processed_data.csv') 52 | dataset = ImageFolder(df[df['hasTable'] == 1]) 53 | img_num = 0 54 | for img_dict in dataset: 55 | save_image(img_dict["image"], f'image_{img_num}.png') 56 | save_image(img_dict["table_image"], f'table_image_{img_num}.png') 57 | save_image(img_dict["column_image"], f'column_image_{img_num}.png') 58 | img_num += 1 59 | if img_num == 6: 60 | break 61 | -------------------------------------------------------------------------------- /Model Implementation/Training/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from efficientnet_pytorch import EfficientNet 5 | 6 | class VGG19(nn.Module): 7 | def __init__(self, pretrained = True, requires_grad = True): 8 | super(VGG19, self).__init__() 9 | _vgg = torchvision.models.vgg19(pretrained = pretrained).features 10 | self.vgg_pool3 = torch.nn.Sequential() 11 | self.vgg_pool4 = torch.nn.Sequential() 12 | self.vgg_pool5 = torch.nn.Sequential() 13 | for x in range(19): 14 | self.vgg_pool3.add_module(str(x), _vgg[x]) 15 | for x in range(19, 28): 16 | self.vgg_pool4.add_module(str(x), _vgg[x]) 17 | for x in range(28, 37): 18 | self.vgg_pool5.add_module(str(x), _vgg[x]) 19 | if not requires_grad: 20 | for param in self.parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, x): 24 | pool_3_out = self.vgg_pool3(x) 25 | pool_4_out = self.vgg_pool4(pool_3_out) 26 | pool_5_out = self.vgg_pool5(pool_4_out) 27 | return (pool_3_out, pool_4_out, pool_5_out) 28 | 29 | class ResNet(nn.Module): 30 | def __init__(self, pretrained = True, requires_grad = True): 31 | super(ResNet, self).__init__() 32 | resnet18 = torchvision.models.resnet34(pretrained = True) 33 | self.layer_1 = nn.Sequential(resnet18.conv1, resnet18.bn1, resnet18.relu, resnet18.maxpool, resnet18.layer1) 34 | self.layer_2 = resnet18.layer2 35 | self.layer_3 = resnet18.layer3 36 | self.layer_4 = resnet18.layer4 37 | if not requires_grad: 38 | for param in self.parameters(): 39 | param.requires_grad = False 40 | 41 | def forward(self, x): 42 | out_1 = self.layer_2(self.layer_1(x)) 43 | out_2 = self.layer_3(out_1) 44 | out_3 = self.layer_4(out_2) 45 | return out_1, out_2, out_3 46 | 47 | class DenseNet(nn.Module): 48 | def __init__(self, pretrained = True, requires_grad = True): 49 | super(DenseNet, self).__init__() 50 | denseNet = torchvision.models.densenet121(pretrained = True).features 51 | self.densenet_out_1 = torch.nn.Sequential() 52 | self.densenet_out_2 = torch.nn.Sequential() 53 | self.densenet_out_3 = torch.nn.Sequential() 54 | for x in range(8): 55 | self.densenet_out_1.add_module(str(x), denseNet[x]) 56 | for x in range(8,10): 57 | self.densenet_out_2.add_module(str(x), denseNet[x]) 58 | self.densenet_out_3.add_module(str(10), denseNet[10]) 59 | if not requires_grad: 60 | for param in self.parameters(): 61 | param.requires_grad = False 62 | 63 | def forward(self, x): 64 | out_1 = self.densenet_out_1(x) 65 | out_2 = self.densenet_out_2(out_1) 66 | out_3 = self.densenet_out_3(out_2) 67 | return out_1, out_2, out_3 68 | 69 | class efficientNet_B0(nn.Module): 70 | def __init__(self, pretrained = True, requires_grad = True): 71 | super(efficientNet_B0, self).__init__() 72 | eNet = EfficientNet.from_pretrained('efficientnet-b0') 73 | self.eNet_out_1 = torch.nn.Sequential() 74 | self.eNet_out_2 = torch.nn.Sequential() 75 | self.eNet_out_3 = torch.nn.Sequential() 76 | blocks = eNet._blocks 77 | self.eNet_out_1.add_module('_conv_stem', eNet._conv_stem) 78 | self.eNet_out_1.add_module('_bn0', eNet._bn0) 79 | for x in range(14): 80 | self.eNet_out_1.add_module(str(x), blocks[x]) 81 | self.eNet_out_2.add_module(str(14), blocks[14]) 82 | self.eNet_out_3.add_module(str(15), blocks[15]) 83 | 84 | def forward(self, x): 85 | out_1 = self.eNet_out_1(x) 86 | out_2 = self.eNet_out_2(out_1) 87 | out_3 = self.eNet_out_3(out_2) 88 | return out_1, out_2, out_3 89 | 90 | class efficientNet(nn.Module): 91 | def __init__(self, model_type = 'efficientnet-b0', pretrained = True, requires_grad = True): 92 | super(efficientNet, self).__init__() 93 | eNet = EfficientNet.from_pretrained(model_type) 94 | self.eNet_out_1 = torch.nn.Sequential() 95 | self.eNet_out_2 = torch.nn.Sequential() 96 | self.eNet_out_3 = torch.nn.Sequential() 97 | blocks = eNet._blocks 98 | self.eNet_out_1.add_module('_conv_stem', eNet._conv_stem) 99 | self.eNet_out_1.add_module('_bn0', eNet._bn0) 100 | for x in range(len(blocks)-3): 101 | self.eNet_out_1.add_module(str(x), blocks[x]) 102 | self.eNet_out_2.add_module(str(len(blocks)-2), blocks[len(blocks)-2]) 103 | self.eNet_out_3.add_module(str(len(blocks)-1), blocks[len(blocks)-1]) 104 | 105 | def forward(self, x): 106 | out_1 = self.eNet_out_1(x) 107 | out_2 = self.eNet_out_2(out_1) 108 | out_3 = self.eNet_out_3(out_2) 109 | return out_1, out_2, out_3 110 | 111 | # model = DenseNet() 112 | # x = torch.randn(1, 3, 1024, 1024) 113 | # model(x) 114 | -------------------------------------------------------------------------------- /Model Implementation/Training/general_utilities.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import cv2 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import albumentations as A 9 | from torch.utils.data import DataLoader 10 | from sklearn.model_selection import train_test_split 11 | from albumentations.pytorch import ToTensorV2 12 | 13 | from Training.configurations import BATCH_SIZE, DATAPATH, DEVICE, SEED 14 | from Training.path_constants import PROCESSED_DATA 15 | from Training.dataset import ImageFolder 16 | 17 | 18 | TRANSFORM = A.Compose([ 19 | A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], max_pixel_value = 255,), 20 | ToTensorV2() 21 | ]) 22 | # Apply the SEED 23 | def seed_all(SEED_VALUE = SEED): 24 | random.seed(SEED_VALUE) 25 | os.environ['PYTHONHASHSEED'] = str(SEED_VALUE) 26 | np.random.seed(SEED_VALUE) 27 | torch.manual_seed(SEED_VALUE) 28 | torch.cuda.manual_seed(SEED_VALUE) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = True 31 | 32 | def get_data_loaders(data_path = DATAPATH): 33 | df = pd.read_csv(data_path) 34 | train_data, test_data = train_test_split(df, test_size = 0.2, random_state = SEED, stratify = df.hasTable) 35 | train_dataset = ImageFolder(train_data, isTrain = True, transform = None) 36 | test_dataset = ImageFolder(test_data, isTrain = False, transform = None) 37 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4, pin_memory = True) 38 | test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = False, num_workers = 4, pin_memory = True) 39 | return train_loader, test_loader 40 | 41 | # Save Checkpoint 42 | def save_checkpoint(state, filename = f"{PROCESSED_DATA}/model_checkpoint.pth.tar"): 43 | torch.save(state, filename) 44 | print("Checkpoint Saved at: ", filename) 45 | 46 | # Load the checkpoint we saved 47 | def load_checkpoint(checkpoint, model, optimizer = None): 48 | print("Loading checkpoint...") 49 | model.load_state_dict(checkpoint['state_dict']) 50 | if optimizer is not None: 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | last_epoch = checkpoint['epoch'] 53 | tr_metrics = checkpoint['train_metrics'] 54 | te_metrics = checkpoint['test_metrics'] 55 | return last_epoch, tr_metrics, te_metrics 56 | 57 | def write_summary(writer, tr_metrics, te_metrics, epoch): 58 | writer.add_scalar("Table Loss/Train", tr_metrics['table_loss'], global_step = epoch) 59 | writer.add_scalar("Table Loss/Test", te_metrics['table_loss'], global_step = epoch) 60 | writer.add_scalar("Table Acc/Train", tr_metrics['table_acc'], global_step = epoch) 61 | writer.add_scalar("Table Acc/Test", te_metrics['table_acc'], global_step = epoch) 62 | writer.add_scalar("Table F1/Train", tr_metrics['table_f1'], global_step = epoch) 63 | writer.add_scalar("Table F1/Test", te_metrics['table_f1'], global_step = epoch) 64 | writer.add_scalar("Table Precision/Train", tr_metrics['table_precision'], global_step = epoch) 65 | writer.add_scalar("Table Precision/Test", te_metrics['table_precision'], global_step = epoch) 66 | writer.add_scalar("Table Recall/Train", tr_metrics['table_recall'], global_step = epoch) 67 | writer.add_scalar("Table Recall/Test", te_metrics['table_recall'], global_step = epoch) 68 | writer.add_scalar("Column Loss/Train", tr_metrics['column_loss'], global_step = epoch) 69 | writer.add_scalar("Column Loss/Test", te_metrics['column_loss'], global_step = epoch) 70 | writer.add_scalar("Column Acc/Train", tr_metrics['col_acc'], global_step = epoch) 71 | writer.add_scalar("Column Acc/Test", te_metrics['col_acc'], global_step = epoch) 72 | writer.add_scalar("Column F1/Train", tr_metrics['col_f1'], global_step = epoch) 73 | writer.add_scalar("Column F1/Test", te_metrics['col_f1'], global_step = epoch) 74 | writer.add_scalar("Column Precision/Train", tr_metrics['col_precision'], global_step = epoch) 75 | writer.add_scalar("Column Precision/Test", te_metrics['col_precision'], global_step = epoch) 76 | writer.add_scalar("Column Recall/Train", tr_metrics['col_recall'], global_step = epoch) 77 | writer.add_scalar("Column Recall/Test", te_metrics['col_recall'], global_step = epoch) 78 | 79 | def display_metrics(epoch, tr_metrics, te_metrics): 80 | print(f"Epoch: {epoch} \n\ 81 | Table Loss -- Train: {tr_metrics['table_loss']:.3f} Test: {te_metrics['table_loss']:.3f}\n\ 82 | Table Acc -- Train: {tr_metrics['table_acc']:.3f} Test: {te_metrics['table_acc']:.3f}\n\ 83 | Table F1 -- Train: {tr_metrics['table_f1']:.3f} Test: {te_metrics['table_f1']:.3f}\n\ 84 | Table Precision -- Train: {tr_metrics['table_precision']:.3f} Test: {te_metrics['table_precision']:.3f}\n\ 85 | Table Recall -- Train: {tr_metrics['table_recall']:.3f} Test: {te_metrics['table_recall']:.3f}\n\ 86 | \n\ 87 | Col Loss -- Train: {tr_metrics['column_loss']:.3f} Test: {te_metrics['column_loss']:.3f}\n\ 88 | Col Acc -- Train: {tr_metrics['col_acc']:.3f} Test: {te_metrics['col_acc']:.3f}\n\ 89 | Col F1 -- Train: {tr_metrics['col_f1']:.3f} Test: {te_metrics['col_f1']:.3f}\n\ 90 | Col Precision -- Train: {tr_metrics['col_precision']:.3f} Test: {te_metrics['col_precision']:.3f}\n\ 91 | Col Recall -- Train: {tr_metrics['col_recall']:.3f} Test: {te_metrics['col_recall']:.3f}\n" 92 | ) 93 | 94 | def compute_metrics(ground_truth, prediction, threshold = 0.5): 95 | # Ref: https://stackoverflow.com/a/56649983 96 | ground_truth = ground_truth.int() 97 | prediction = (torch.sigmoid(prediction) > threshold).int() 98 | TP = torch.sum(prediction[ground_truth == 1] == 1) 99 | TN = torch.sum(prediction[ground_truth == 0] == 0) 100 | FP = torch.sum(prediction[ground_truth == 1] == 0) 101 | FN = torch.sum(prediction[ground_truth == 0] == 1) 102 | acc = (TP + TN) / (TP + TN + FP+ FN) 103 | precision = TP / (FP + TP + 1e-4) 104 | recall = TP / (FN + TP + 1e-4) 105 | f1 = 2 * precision * recall / (precision + recall + 1e-4) 106 | metrics = { 107 | 'acc': acc.item(), 108 | 'f1': f1.item(), 109 | 'precision':precision.item(), 110 | 'recall': recall.item() 111 | } 112 | return metrics 113 | 114 | def display(image, table, column, title = 'Original'): 115 | f, ax = plt.subplots(1, 3, figsize = (15, 8)) 116 | ax[0].imshow(image) 117 | ax[0].set_title(f'{title} Image') 118 | ax[1].imshow(table) 119 | ax[1].set_title(f'{title} Table Mask') 120 | ax[2].imshow(column) 121 | ax[2].set_title(f'{title} Column Mask') 122 | plt.show() 123 | 124 | def display_prediction(image, table = None, table_image = None, no_: bool = False): 125 | if no_: 126 | f1, ax = plt.subplots(1, 1, figsize = (7, 5)) 127 | ax.imshow(image) 128 | ax.set_title('Original Image') 129 | f1.suptitle('No Tables Detected') 130 | else: 131 | f2, ax = plt.subplots(1, 3, figsize = (15, 8)) 132 | ax[0].imshow(image) 133 | ax[0].set_title('Original Image') 134 | ax[1].imshow(table) 135 | ax[1].set_title('Image with Predicted Table') 136 | ax[2].imshow(table_image) 137 | ax[2].set_title('Predicted Table Example') 138 | plt.show() 139 | 140 | def get_TableMasks(test_image, model, transform = TRANSFORM, device = DEVICE): 141 | image = transform(image = test_image)["image"] 142 | # Get predictions 143 | model.eval() 144 | with torch.no_grad(): 145 | image = image.to(device).unsqueeze(0) 146 | # With torch.cuda.amp.autocast(): 147 | table_out, column_out = model(image) 148 | table_out = torch.sigmoid(table_out) 149 | column_out = torch.sigmoid(column_out) 150 | # Remove gradients 151 | table_out = (table_out.cpu().detach().numpy().squeeze(0).transpose(1, 2, 0) > 0.5).astype(int) 152 | column_out = (column_out.cpu().detach().numpy().squeeze(0).transpose(1, 2, 0) > 0.5).astype(int) 153 | # Return masks 154 | return table_out, column_out 155 | 156 | def fixMasks(image, table_mask, column_mask): 157 | """ Fix Table Bounding Box to get better OCR predictions """ 158 | table_mask = table_mask.reshape(1024, 1024).astype(np.uint8) 159 | column_mask = column_mask.reshape(1024, 1024).astype(np.uint8) 160 | # Get contours of the mask to get number of tables 161 | contours, table_heirarchy = cv2.findContours(table_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 162 | table_contours = [] 163 | # Ref: https://www.pyimagesearch.com/2015/02/09/removing-contours-image-using-python-opencv/ 164 | # Remove bad contours 165 | for c in contours: 166 | # if the contour is bad, draw it on the mask 167 | if cv2.contourArea(c) > 2000: 168 | table_contours.append(c) 169 | if len(table_contours) == 0: 170 | return None 171 | # Ref : https://docs.opencv.org/4.5.2/da/d0c/tutorial_bounding_rects_circles.html 172 | # Get bounding box for the contour 173 | table_bound_rect = [None] * len(table_contours) 174 | for i, c in enumerate(table_contours): 175 | polygon = cv2.approxPolyDP(c, 3, True) 176 | table_bound_rect[i] = cv2.boundingRect(polygon) 177 | # Table bounding Box 178 | table_bound_rect.sort() 179 | column_bound_rects = [] 180 | for x, y, w, h in table_bound_rect: 181 | column_mask_crop = column_mask[y : y + h, x : x + w] 182 | # Get contours of the mask to get number of tables 183 | contours, column_heirarchy = cv2.findContours(column_mask_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 184 | # Get bounding box for the contour 185 | bound_rect = [None] * len(contours) 186 | for i, c in enumerate(contours): 187 | polygon = cv2.approxPolyDP(c, 3, True) 188 | bound_rect[i] = cv2.boundingRect(polygon) 189 | # Adjusting columns as per table coordinates 190 | bound_rect[i] = (bound_rect[i][0] + x, bound_rect[i][1] + y, bound_rect[i][2], bound_rect[i][3]) 191 | column_bound_rects.append(bound_rect) 192 | image = image[...,0].reshape(1024, 1024).astype(np.uint8) 193 | # Draw bounding boxes 194 | color = (0, 255, 0) 195 | thickness = 4 196 | for x, y, w, h in table_bound_rect: 197 | image = cv2.rectangle(image, (x, y),(x + w, y + h), color, thickness) 198 | return image, table_bound_rect, column_bound_rects 199 | -------------------------------------------------------------------------------- /Model Implementation/Training/model_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class TableNetLoss(nn.Module): 4 | def __init__(self): 5 | super(TableNetLoss, self).__init__() 6 | self.bce = nn.BCEWithLogitsLoss() 7 | 8 | def forward(self, table_prediction, table_target, column_prediction = None, column_target = None,): 9 | table_loss = self.bce(table_prediction, table_target) 10 | column_loss = self.bce(column_prediction, column_target) 11 | return table_loss, column_loss 12 | -------------------------------------------------------------------------------- /Model Implementation/Training/model_training.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | from torch.utils.tensorboard import SummaryWriter 8 | from pytorch_model_summary import summary 9 | 10 | from Training.configurations import DATAPATH, DEVICE, EPOCHS, LEARNING_RATE, MODEL_NAME, SEED, WEIGHT_DECAY, BATCH_SIZE 11 | from Training.path_constants import PROCESSED_DATA 12 | from Training.tablenet_model import TableNet 13 | from Training.model_loss import TableNetLoss 14 | from Training.general_utilities import compute_metrics, seed_all, get_data_loaders, load_checkpoint, display_metrics, write_summary, save_checkpoint 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def train_on_epoch(data_loader, model, optimizer, loss, scaler, threshold = 0.5): 21 | combined_loss = [] 22 | table_loss, table_acc, table_precision, table_recall, table_f1 = [], [], [], [], [] 23 | column_loss, column_acc, column_precision, column_recall, column_f1 = [], [], [], [], [] 24 | loop = tqdm(data_loader, leave = True) 25 | for batch_i, image_dict in enumerate(loop): 26 | image = image_dict["image"].to(DEVICE) 27 | table_image = image_dict["table_image"].to(DEVICE) 28 | column_image = image_dict["column_image"].to(DEVICE) 29 | with torch.cuda.amp.autocast(): 30 | table_out, column_out = model(image) 31 | i_table_loss, i_column_loss = loss(table_out, table_image, column_out, column_image) 32 | table_loss.append(i_table_loss.item()) 33 | column_loss.append(i_column_loss.item()) 34 | combined_loss.append((i_table_loss + i_column_loss).item()) 35 | # Backward 36 | optimizer.zero_grad() 37 | scaler.scale(i_table_loss + i_column_loss).backward() 38 | scaler.step(optimizer) 39 | scaler.update() 40 | mean_loss = sum(combined_loss) / len(combined_loss) 41 | loop.set_postfix(loss = mean_loss) 42 | cal_metrics_table = compute_metrics(table_image, table_out, threshold) 43 | cal_metrics_col = compute_metrics(column_image, column_out, threshold) 44 | table_f1.append(cal_metrics_table['f1']) 45 | table_precision.append(cal_metrics_table['precision']) 46 | table_acc.append(cal_metrics_table['acc']) 47 | table_recall.append(cal_metrics_table['recall']) 48 | column_f1.append(cal_metrics_col['f1']) 49 | column_acc.append(cal_metrics_col['acc']) 50 | column_precision.append(cal_metrics_col['precision']) 51 | column_recall.append(cal_metrics_col['recall']) 52 | metrics = { 53 | 'combined_loss': np.mean(combined_loss), 54 | 'table_loss': np.mean(table_loss), 55 | 'column_loss': np.mean(column_loss), 56 | 'table_acc': np.mean(table_acc), 57 | 'col_acc': np.mean(column_acc), 58 | 'table_f1': np.mean(table_f1), 59 | 'col_f1': np.mean(column_f1), 60 | 'table_precision': np.mean(table_precision), 61 | 'col_precision': np.mean(column_precision), 62 | 'table_recall': np.mean(table_recall), 63 | 'col_recall': np.mean(column_recall) 64 | } 65 | return metrics 66 | 67 | def test_on_epoch(data_loader, model, loss, threshold = 0.5, device = DEVICE): 68 | combined_loss = [] 69 | table_loss, table_acc, table_precision, table_recall, table_f1 = [], [], [], [], [] 70 | column_loss, column_acc, column_precision, column_recall, column_f1 = [], [], [], [], [] 71 | model.eval() 72 | with torch.no_grad(): 73 | loop = tqdm(data_loader, leave = True) 74 | for batch_i, image_dict in enumerate(loop): 75 | image = image_dict["image"].to(device) 76 | table_image = image_dict["table_image"].to(device) 77 | column_image = image_dict["column_image"].to(device) 78 | with torch.cuda.amp.autocast(): 79 | table_out, column_out = model(image) 80 | i_table_loss, i_column_loss = loss(table_out, table_image, column_out, column_image) 81 | table_loss.append(i_table_loss.item()) 82 | column_loss.append(i_column_loss.item()) 83 | combined_loss.append((i_table_loss + i_column_loss).item()) 84 | mean_loss = sum(combined_loss) / len(combined_loss) 85 | loop.set_postfix(loss=mean_loss) 86 | cal_metrics_table = compute_metrics(table_image, table_out, threshold) 87 | cal_metrics_col = compute_metrics(column_image, column_out, threshold) 88 | table_f1.append(cal_metrics_table['f1']) 89 | table_precision.append(cal_metrics_table['precision']) 90 | table_acc.append(cal_metrics_table['acc']) 91 | table_recall.append(cal_metrics_table['recall']) 92 | column_f1.append(cal_metrics_col['f1']) 93 | column_acc.append(cal_metrics_col['acc']) 94 | column_precision.append(cal_metrics_col['precision']) 95 | column_recall.append(cal_metrics_col['recall']) 96 | metrics = { 97 | 'combined_loss': np.mean(combined_loss), 98 | 'table_loss': np.mean(table_loss), 99 | 'column_loss': np.mean(column_loss), 100 | 'table_acc': np.mean(table_acc), 101 | 'col_acc': np.mean(column_acc), 102 | 'table_f1': np.mean(table_f1), 103 | 'col_f1': np.mean(column_f1), 104 | 'table_precision': np.mean(table_precision), 105 | 'col_precision': np.mean(column_precision), 106 | 'table_recall': np.mean(table_recall), 107 | 'col_recall': np.mean(column_recall) 108 | } 109 | model.train() 110 | return metrics 111 | 112 | seed_all(SEED_VALUE = SEED) 113 | checkpoint_name = f'{PROCESSED_DATA}/{MODEL_NAME}' 114 | model = TableNet(encoder = 'densenet', use_pretrained_model = True, basemodel_requires_grad = True) 115 | 116 | print("Model Architecture and Trainable Paramerters") 117 | print("="*50) 118 | print(summary(model, torch.zeros((1, 3, 1024, 1024)), show_input = False, show_hierarchical = True)) 119 | 120 | model = model.to(DEVICE) 121 | optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY) 122 | loss = TableNetLoss() 123 | scaler = torch.cuda.amp.GradScaler() 124 | train_loader, test_loader = get_data_loaders(data_path = DATAPATH) 125 | 126 | # Load checkpoint 127 | if os.path.exists(checkpoint_name): 128 | last_epoch, train_metrics, test_metrics = load_checkpoint(torch.load(checkpoint_name), model) 129 | last_table_f1 = test_metrics['table_f1'] 130 | last_column_f1 = test_metrics['col_f1'] 131 | print("Loading Checkpoint...") 132 | display_metrics(last_epoch, train_metrics, test_metrics) 133 | print() 134 | else: 135 | last_epoch = 0 136 | last_table_f1 = 0. 137 | last_column_f1 = 0. 138 | 139 | # Train Network 140 | print("Training Model\n") 141 | writer = SummaryWriter(f"{PROCESSED_DATA}/runs/TableNet/densenet/configuration_4_batch_{BATCH_SIZE}_learningrate_{LEARNING_RATE}_encoder_train") 142 | # For early stopping 143 | i = 0 144 | 145 | for epoch in range(last_epoch + 1, EPOCHS): 146 | print("="*30) 147 | start = time.time() 148 | train_metrics = train_on_epoch(train_loader, model, optimizer, loss, scaler, threshold = 0.5) 149 | test_metrics = test_on_epoch(test_loader, model, loss, threshold = 0.5) 150 | write_summary(writer, train_metrics, test_metrics, epoch) 151 | end = time.time() 152 | display_metrics(epoch, train_metrics, test_metrics) 153 | if last_table_f1 < test_metrics['table_f1'] or last_column_f1 < test_metrics['col_f1']: 154 | last_table_f1 = test_metrics['table_f1'] 155 | last_column_f1 = test_metrics['col_f1'] 156 | checkpoint = { 157 | 'epoch': epoch, 158 | 'state_dict': model.state_dict(), 159 | 'optimizer': optimizer.state_dict(), 160 | 'train_metrics': train_metrics, 161 | 'test_metrics': test_metrics 162 | } 163 | save_checkpoint(checkpoint, checkpoint_name) 164 | -------------------------------------------------------------------------------- /Model Implementation/Training/path_constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dir_path = 'DummyDatabase' 4 | marmot_v1 = dir_path + '/marmot_v1' 5 | marmot_extended = dir_path + '/marmot_extended' 6 | ORIG_DATA_PATH = f'{marmot_v1}/marmot_dataset_v1.0/data/English' 7 | DATA_PATH = 'Marmot_data' 8 | PROCESSED_DATA = f'{dir_path}/marmot_processed' 9 | PREDICTIONS = f"{dir_path}/predictions" 10 | TEST_IMAGES = f"{dir_path}/test_images" 11 | MODELS = f"{dir_path}/models" 12 | IMAGE_PATH = os.path.join(PROCESSED_DATA, 'image') 13 | TABLE_MASK_PATH = os.path.join(PROCESSED_DATA, 'table_mask') 14 | COL_MASK_PATH = os.path.join(PROCESSED_DATA, 'col_mask') 15 | Marmot_data = f'{dir_path}/{DATA_PATH}' 16 | POSITIVE_DATA_LBL = os.path.join(ORIG_DATA_PATH, 'Positive','Labeled') 17 | -------------------------------------------------------------------------------- /Model Implementation/Training/tablenet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Training.encoder import VGG19, DenseNet, ResNet, efficientNet 4 | 5 | 6 | class TableDecoder(nn.Module): 7 | def __init__(self, channels, kernels, strides): 8 | super(TableDecoder, self).__init__() 9 | self.conv_7_table = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = kernels[0], stride = strides[0]) 10 | self.upsample_1_table = nn.ConvTranspose2d(in_channels = 256, out_channels=128, kernel_size = kernels[1], stride = strides[1]) 11 | self.upsample_2_table = nn.ConvTranspose2d(in_channels = 128 + channels[0], out_channels = 256, kernel_size = kernels[2], stride = strides[2]) 12 | self.upsample_3_table = nn.ConvTranspose2d(in_channels = 256 + channels[1], out_channels = 1, kernel_size = kernels[3], stride = strides[3]) 13 | 14 | def forward(self, x, pool3_out, pool4_out): 15 | x = self.conv_7_table(x) 16 | out = self.upsample_1_table(x) 17 | out = torch.cat((out, pool4_out), dim=1) 18 | out = self.upsample_2_table(out) 19 | out = torch.cat((out, pool3_out), dim=1) 20 | out = self.upsample_3_table(out) 21 | return out 22 | 23 | class ColumnDecoder(nn.Module): 24 | def __init__(self, channels, kernels, strides): 25 | super(ColumnDecoder, self).__init__() 26 | self.conv_8_column = nn.Sequential( 27 | nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = kernels[0], stride = strides[0]), 28 | nn.ReLU(inplace=True), 29 | nn.Dropout(0.8), 30 | nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = kernels[0], stride = strides[0]) 31 | ) 32 | self.upsample_1_column = nn.ConvTranspose2d(in_channels = 256, out_channels=128, kernel_size = kernels[1], stride = strides[1]) 33 | self.upsample_2_column = nn.ConvTranspose2d(in_channels = 128 + channels[0], out_channels = 256, kernel_size = kernels[2], stride = strides[2]) 34 | self.upsample_3_column = nn.ConvTranspose2d( in_channels = 256 + channels[1], out_channels = 1, kernel_size = kernels[3], stride = strides[3]) 35 | 36 | def forward(self, x, pool3_out, pool4_out): 37 | x = self.conv_8_column(x) 38 | out = self.upsample_1_column(x) 39 | out = torch.cat((out, pool4_out), dim=1) 40 | out = self.upsample_2_column(out) 41 | out = torch.cat((out, pool3_out), dim=1) 42 | out = self.upsample_3_column(out) 43 | return out 44 | 45 | class TableNet(nn.Module): 46 | def __init__(self,encoder = 'vgg', use_pretrained_model = True, basemodel_requires_grad = True): 47 | super(TableNet, self).__init__() 48 | self.kernels = [(1,1), (2,2), (2,2),(8,8)] 49 | self.strides = [(1,1), (2,2), (2,2),(8,8)] 50 | self.in_channels = 512 51 | if encoder == 'vgg': 52 | self.base_model = VGG19(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad) 53 | self.pool_channels = [512, 256] 54 | elif encoder == 'resnet': 55 | self.base_model = ResNet(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad) 56 | self.pool_channels = [256, 128] 57 | elif encoder == 'densenet': 58 | self.base_model = DenseNet(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad) 59 | self.pool_channels = [512, 256] 60 | self.in_channels = 1024 61 | self.kernels = [(1,1), (1,1), (2,2),(16,16)] 62 | self.strides = [(1,1), (1,1), (2,2),(16,16)] 63 | elif 'efficientnet' in encoder: 64 | self.base_model = efficientNet(model_type = encoder, pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad) 65 | if 'b0' in encoder: 66 | self.pool_channels = [192, 192] 67 | self.in_channels = 320 68 | elif 'b1' in encoder: 69 | self.pool_channels = [320, 192] 70 | self.in_channels = 320 71 | elif 'b2' in encoder: 72 | self.pool_channels = [352, 208] 73 | self.in_channels = 352 74 | self.kernels = [(1,1), (1,1), (1,1),(32,32)] 75 | self.strides = [(1,1), (1,1), (1,1),(32,32)] 76 | self.conv6 = nn.Sequential( 77 | nn.Conv2d(in_channels = self.in_channels, out_channels = 256, kernel_size=(1,1)), 78 | nn.ReLU(inplace=True), 79 | nn.Dropout(0.8), 80 | nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size=(1,1)), 81 | nn.ReLU(inplace=True), 82 | nn.Dropout(0.8) 83 | ) 84 | self.table_decoder = TableDecoder(self.pool_channels, self.kernels, self.strides) 85 | self.column_decoder = ColumnDecoder(self.pool_channels, self.kernels, self.strides) 86 | 87 | def forward(self, x): 88 | pool3_out, pool4_out, pool5_out = self.base_model(x) 89 | conv_out = self.conv6(pool5_out) 90 | table_out = self.table_decoder(conv_out, pool3_out, pool4_out) 91 | column_out = self.column_decoder(conv_out, pool3_out, pool4_out) 92 | return table_out, column_out 93 | -------------------------------------------------------------------------------- /Model Implementation/model_testing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import pytesseract 4 | import numpy as np 5 | import albumentations as A 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | from datetime import datetime 9 | from albumentations.pytorch import ToTensorV2 10 | from Training.configurations import MODEL_NAME 11 | from Training.tablenet_model import TableNet 12 | from Training.path_constants import PREDICTIONS, MODELS, TEST_IMAGES 13 | 14 | 15 | TRANSFORM = A.Compose([ 16 | A.Normalize( 17 | mean=[0.485, 0.456, 0.406], 18 | std=[0.229, 0.224, 0.225], 19 | max_pixel_value = 255, 20 | ), 21 | ToTensorV2() 22 | ]) 23 | 24 | def display_prediction(img, table = None, table_image = None, no_: bool = False): 25 | if no_: 26 | f1, ax = plt.subplots(1, 1, figsize = (7, 5)) 27 | ax.imshow(img) 28 | ax.set_title('Original Image') 29 | f1.suptitle('No Tables Detected') 30 | else: 31 | f2, ax = plt.subplots(1, 3, figsize = (15, 8)) 32 | ax[0].imshow(img) 33 | ax[0].set_title('Original Image') 34 | ax[1].imshow(table) 35 | ax[1].set_title('Image with Predicted Table') 36 | ax[2].imshow(table_image) 37 | ax[2].set_title('Predicted Table Example') 38 | plt.show() 39 | 40 | model_path = f"{MODELS}/{MODEL_NAME}" 41 | print(model_path) 42 | model = TableNet(encoder = 'densenet', use_pretrained_model = True, basemodel_requires_grad = True) 43 | model.eval() 44 | # Load checkpoint 45 | if torch.cuda.is_available(): 46 | model.load_state_dict(torch.load(model_path)['state_dict']) 47 | else: 48 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict']) 49 | 50 | # Final prediction function, using the mask fixing method too 51 | def predict(img_path): 52 | orig_image = Image.open(img_path).resize((1024, 1024)) 53 | test_img = np.array(orig_image.convert('LA').convert("RGB")) 54 | now = datetime.now() 55 | image = TRANSFORM(image = test_img)["image"] 56 | with torch.no_grad(): 57 | image = image.unsqueeze(0) 58 | table_out, _ = model(image) 59 | table_out = torch.sigmoid(table_out) 60 | # Remove gradients 61 | table_out = (table_out.detach().numpy().squeeze(0).transpose(1, 2, 0) > 0.5).astype(np.uint8) 62 | # Get contours of the mask to get number of tables 63 | contours, table_heirarchy = cv2.findContours(table_out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 64 | table_contours = [] 65 | # Ref: https://www.pyimagesearch.com/2015/02/09/removing-contours-image-using-python-opencv/ 66 | # Remove bad contours 67 | for c in contours: 68 | if cv2.contourArea(c) > 3000: 69 | table_contours.append(c) 70 | if len(table_contours) == 0: 71 | print("No Table detected") 72 | table_boundRect = [None] * len(table_contours) 73 | for i, c in enumerate(table_contours): 74 | polygon = cv2.approxPolyDP(c, 3, True) 75 | table_boundRect[i] = cv2.boundingRect(polygon) 76 | # Table bounding Box 77 | table_boundRect.sort() 78 | orig_image = np.array(orig_image) 79 | # Draw bounding boxes 80 | color = (0, 0, 255) 81 | thickness = 4 82 | for x,y,w,h in table_boundRect: 83 | cv2.rectangle(orig_image, (x , y), (x + w, y + h), color, thickness) 84 | # Show Original image with the table bordered extra 85 | plt.figure(figsize = (10, 5)) 86 | plt.imshow(orig_image) 87 | end_time = datetime.now() 88 | difference = end_time - now 89 | time = "{}".format(difference) 90 | print(f"Time Taken on cpu: {time} secs") 91 | print("Predicted Tables") 92 | image = test_img[...,0].reshape(1024, 1024).astype(np.uint8) 93 | for i, (x, y, w, h) in enumerate(table_boundRect): 94 | image_crop = image[y : y + h, x : x + w] 95 | # Show only the table 96 | plt.figure(figsize = (7.5, 5)) 97 | plt.imshow(image_crop) 98 | cv2.imwrite(f"/{PREDICTIONS}/image_crop.png", image_crop) 99 | data = pytesseract.image_to_string(image_crop) 100 | 101 | _image_path = f'{TEST_IMAGES}/10.1.1.160.563_6.jpg' 102 | print(_image_path) 103 | df = predict(img_path = _image_path) 104 | df 105 | -------------------------------------------------------------------------------- /Model Implementation/tables/with tables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with tables.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-110.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-111.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-120.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-159.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-159.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-169.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-169.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-170.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-170.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-172.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-172.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-181.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-181.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-26.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-3.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-39.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-61.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-61.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-74.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-74.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-82.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-82.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-84.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-84.png -------------------------------------------------------------------------------- /Model Implementation/tables/with/table-image-99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/with/table-image-99.png -------------------------------------------------------------------------------- /Model Implementation/tables/without tables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without tables.png -------------------------------------------------------------------------------- /Model Implementation/tables/without/table-image-102.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without/table-image-102.png -------------------------------------------------------------------------------- /Model Implementation/tables/without/table-image-129.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without/table-image-129.png -------------------------------------------------------------------------------- /Model Implementation/tables/without/table-image-50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without/table-image-50.png -------------------------------------------------------------------------------- /Model Implementation/tables/without/table-image-78.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without/table-image-78.png -------------------------------------------------------------------------------- /Model Implementation/tables/without/table-image-79.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LidorPrototype/TableNetTable2df/6ae4d2c686be2d760814bcb57a398fa2d7de434d/Model Implementation/tables/without/table-image-79.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TableNet using PyTorch 2 | 3 | In this repo, you can have an implementation of the TableNet with Pytorch 4 | 5 | ## Goal 6 | 7 | My goal is here to get a dataframe from an image, the image of a scanned document holding tabular data I will want to detect the image tables, crop the tables, and then extract the tabular data into a dataframe 8 | 9 | ## Data: 10 | 11 | To populate the DummyDatabase folder you can refer to the following links: 12 | - [Marmot Datase](https://www.icst.pku.edu.cn/cpdp/docs/20190424190300041510.zip) 13 | - [Marmot Extended dataset](https://drive.google.com/drive/folders/1QZiv5RKe3xlOBdTzuTVuYRxixemVIODp) 14 | 15 | ## Model: 16 | 17 | I will use a TableNet model with DenseNet121 as the main encoder. 18 | 19 | I tried different encoders like VGG-19, ResNet, DenseNet121, efficientNet_B0, efficientNet and I got the best results with DenseNet121 20 | 21 | > Note model itself is not uploaded because it's too big for GitHub uploads. 22 | 23 | ## Model Predictions: 24 | 25 | Predictions of the images in the folder DummyDatabase/test_images can be found in DummyDatabase/predictions 26 | 27 | ## Improvement idea: 28 | 29 | The tables the model will detect and be any of the following: 30 | 1) Tables with full gridlines 31 | 2) Tables with only horizontal/vertical gridlines 32 | 3) Tables with only parts of horizontal/vertical gridlines 33 | 4) Tables without any gridlines drawn 34 | 35 | So I had an idea which is, no matter what the table is of the above, remove all of the horizontal and vertical gridlines (if you find any), and then apply an OpenCV algorithm to detect the proper locations of all the gridlines and draw them artificially (The idea was implemented with help from StackOverflow). 36 | 37 | You can find this idea implemented in the folder called GridlinesImprovement. 38 | 39 | ## Extract Tabular Data using `pytesseract` 40 | 41 | Using the library `pytesseract` extract and process the tabular data and convert it into a dataframe. 42 | 43 | _____________________________________________________________________________________________________________________________________ 44 | 45 | Author: Lidor ES 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytesseract 2 | opencv-python 3 | numpy 4 | pandas 5 | Pillow 6 | tqdm 7 | torch 8 | albumentations 9 | torchvision 10 | sklearn 11 | pytorch_model_summary 12 | matplotlib --------------------------------------------------------------------------------