├── .gitignore ├── Aptfile ├── Fine_tuning_LayoutLMForSequenceClassification_on_RVL_CDIP.ipynb ├── LICENSE ├── Predictor.ipynb ├── Procfile ├── README.md ├── requirements.txt ├── saved_model └── config.json ├── setup.sh └── streamlit-app.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /Aptfile: -------------------------------------------------------------------------------- 1 | tesseract-ocr 2 | tesseract-ocr-eng -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Lucky Verma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Predictor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "capital-berkeley", 6 | "metadata": {}, 7 | "source": [ 8 | "# Legacy Import" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "functioning-maker", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import pandas as pd\n", 20 | "from PIL import Image, ImageDraw, ImageFont\n", 21 | "from transformers import LayoutLMForSequenceClassification, LayoutLMTokenizer\n", 22 | "import torch\n", 23 | "from torch.utils.data import Dataset, DataLoader\n", 24 | "import pytesseract\n", 25 | "from datasets import Features, Sequence, ClassLabel, Value, Array2D\n", 26 | "import numpy as np\n", 27 | "\n", 28 | "classes = [\"bill\", \"invoice\", \"others\", \"Purchase_Order\", \"remittance\"]" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "restricted-cedar", 34 | "metadata": {}, 35 | "source": [ 36 | "# Legacy Methods" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "german-modem", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from datasets import Dataset\n", 47 | "\n", 48 | "def normalize_box(box, width, height):\n", 49 | " return [\n", 50 | " int(1000 * (box[0] / width)),\n", 51 | " int(1000 * (box[1] / height)),\n", 52 | " int(1000 * (box[2] / width)),\n", 53 | " int(1000 * (box[3] / height)),\n", 54 | " ]\n", 55 | "\n", 56 | "def apply_ocr(example):\n", 57 | " # get the image\n", 58 | " image = Image.open(example['image_path'])\n", 59 | "\n", 60 | " width, height = image.size\n", 61 | " \n", 62 | " # apply ocr to the image \n", 63 | " ocr_df = pytesseract.image_to_data(image, output_type='data.frame')\n", 64 | " float_cols = ocr_df.select_dtypes('float').columns\n", 65 | " ocr_df = ocr_df.dropna().reset_index(drop=True)\n", 66 | " ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)\n", 67 | " ocr_df = ocr_df.replace(r'^\\s*$', np.nan, regex=True)\n", 68 | " ocr_df = ocr_df.dropna().reset_index(drop=True)\n", 69 | "\n", 70 | " # get the words and actual (unnormalized) bounding boxes\n", 71 | " #words = [word for word in ocr_df.text if str(word) != 'nan'])\n", 72 | " words = list(ocr_df.text)\n", 73 | " words = [str(w) for w in words]\n", 74 | " coordinates = ocr_df[['left', 'top', 'width', 'height']]\n", 75 | " actual_boxes = []\n", 76 | " for idx, row in coordinates.iterrows():\n", 77 | " x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format\n", 78 | " actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box \n", 79 | " actual_boxes.append(actual_box)\n", 80 | " \n", 81 | " # normalize the bounding boxes\n", 82 | " boxes = []\n", 83 | " for box in actual_boxes:\n", 84 | " boxes.append(normalize_box(box, width, height))\n", 85 | " \n", 86 | " # add as extra columns \n", 87 | " assert len(words) == len(boxes)\n", 88 | " example['words'] = words\n", 89 | " example['bbox'] = boxes\n", 90 | " return example\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "id": "mathematical-archives", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "tokenizer = LayoutLMTokenizer.from_pretrained(\"microsoft/layoutlm-base-uncased\")\n", 101 | "\n", 102 | "def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]):\n", 103 | " words = example['words']\n", 104 | " normalized_word_boxes = example['bbox']\n", 105 | "\n", 106 | " assert len(words) == len(normalized_word_boxes)\n", 107 | "\n", 108 | " token_boxes = []\n", 109 | " for word, box in zip(words, normalized_word_boxes):\n", 110 | " word_tokens = tokenizer.tokenize(word)\n", 111 | " token_boxes.extend([box] * len(word_tokens))\n", 112 | " \n", 113 | " # Truncation of token_boxes\n", 114 | " special_tokens_count = 2 \n", 115 | " if len(token_boxes) > max_seq_length - special_tokens_count:\n", 116 | " token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]\n", 117 | " \n", 118 | " # add bounding boxes of cls + sep tokens\n", 119 | " token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]\n", 120 | " \n", 121 | " encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)\n", 122 | " # Padding of token_boxes up the bounding boxes to the sequence length.\n", 123 | " input_ids = tokenizer(' '.join(words), truncation=True)[\"input_ids\"]\n", 124 | " padding_length = max_seq_length - len(input_ids)\n", 125 | " token_boxes += [pad_token_box] * padding_length\n", 126 | " encoding['bbox'] = token_boxes\n", 127 | "\n", 128 | " assert len(encoding['input_ids']) == max_seq_length\n", 129 | " assert len(encoding['attention_mask']) == max_seq_length\n", 130 | " assert len(encoding['token_type_ids']) == max_seq_length\n", 131 | " assert len(encoding['bbox']) == max_seq_length\n", 132 | "\n", 133 | " return encoding" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 4, 139 | "id": "afraid-township", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# we need to define the features ourselves as the bbox of LayoutLM are an extra feature\n", 144 | "features = Features({\n", 145 | " 'input_ids': Sequence(feature=Value(dtype='int64')),\n", 146 | " 'bbox': Array2D(dtype=\"int64\", shape=(512, 4)),\n", 147 | " 'attention_mask': Sequence(Value(dtype='int64')),\n", 148 | " 'token_type_ids': Sequence(Value(dtype='int64')),\n", 149 | " 'image_path': Value(dtype='string'),\n", 150 | " 'words': Sequence(feature=Value(dtype='string')),\n", 151 | "})\n" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "analyzed-legend", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 5, 165 | "id": "intense-recall", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "LayoutLMForSequenceClassification(\n", 172 | " (layoutlm): LayoutLMModel(\n", 173 | " (embeddings): LayoutLMEmbeddings(\n", 174 | " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", 175 | " (position_embeddings): Embedding(512, 768)\n", 176 | " (x_position_embeddings): Embedding(1024, 768)\n", 177 | " (y_position_embeddings): Embedding(1024, 768)\n", 178 | " (h_position_embeddings): Embedding(1024, 768)\n", 179 | " (w_position_embeddings): Embedding(1024, 768)\n", 180 | " (token_type_embeddings): Embedding(2, 768)\n", 181 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 182 | " (dropout): Dropout(p=0.1, inplace=False)\n", 183 | " )\n", 184 | " (encoder): LayoutLMEncoder(\n", 185 | " (layer): ModuleList(\n", 186 | " (0): LayoutLMLayer(\n", 187 | " (attention): LayoutLMAttention(\n", 188 | " (self): LayoutLMSelfAttention(\n", 189 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 190 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 191 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 192 | " (dropout): Dropout(p=0.1, inplace=False)\n", 193 | " )\n", 194 | " (output): LayoutLMSelfOutput(\n", 195 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 196 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 197 | " (dropout): Dropout(p=0.1, inplace=False)\n", 198 | " )\n", 199 | " )\n", 200 | " (intermediate): LayoutLMIntermediate(\n", 201 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 202 | " )\n", 203 | " (output): LayoutLMOutput(\n", 204 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 205 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 206 | " (dropout): Dropout(p=0.1, inplace=False)\n", 207 | " )\n", 208 | " )\n", 209 | " (1): LayoutLMLayer(\n", 210 | " (attention): LayoutLMAttention(\n", 211 | " (self): LayoutLMSelfAttention(\n", 212 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 213 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 214 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 215 | " (dropout): Dropout(p=0.1, inplace=False)\n", 216 | " )\n", 217 | " (output): LayoutLMSelfOutput(\n", 218 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 219 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 220 | " (dropout): Dropout(p=0.1, inplace=False)\n", 221 | " )\n", 222 | " )\n", 223 | " (intermediate): LayoutLMIntermediate(\n", 224 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 225 | " )\n", 226 | " (output): LayoutLMOutput(\n", 227 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 228 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 229 | " (dropout): Dropout(p=0.1, inplace=False)\n", 230 | " )\n", 231 | " )\n", 232 | " (2): LayoutLMLayer(\n", 233 | " (attention): LayoutLMAttention(\n", 234 | " (self): LayoutLMSelfAttention(\n", 235 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 236 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 237 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 238 | " (dropout): Dropout(p=0.1, inplace=False)\n", 239 | " )\n", 240 | " (output): LayoutLMSelfOutput(\n", 241 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 242 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 243 | " (dropout): Dropout(p=0.1, inplace=False)\n", 244 | " )\n", 245 | " )\n", 246 | " (intermediate): LayoutLMIntermediate(\n", 247 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 248 | " )\n", 249 | " (output): LayoutLMOutput(\n", 250 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 251 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 252 | " (dropout): Dropout(p=0.1, inplace=False)\n", 253 | " )\n", 254 | " )\n", 255 | " (3): LayoutLMLayer(\n", 256 | " (attention): LayoutLMAttention(\n", 257 | " (self): LayoutLMSelfAttention(\n", 258 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 259 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 260 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 261 | " (dropout): Dropout(p=0.1, inplace=False)\n", 262 | " )\n", 263 | " (output): LayoutLMSelfOutput(\n", 264 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 265 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 266 | " (dropout): Dropout(p=0.1, inplace=False)\n", 267 | " )\n", 268 | " )\n", 269 | " (intermediate): LayoutLMIntermediate(\n", 270 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 271 | " )\n", 272 | " (output): LayoutLMOutput(\n", 273 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 274 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 275 | " (dropout): Dropout(p=0.1, inplace=False)\n", 276 | " )\n", 277 | " )\n", 278 | " (4): LayoutLMLayer(\n", 279 | " (attention): LayoutLMAttention(\n", 280 | " (self): LayoutLMSelfAttention(\n", 281 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 282 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 283 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 284 | " (dropout): Dropout(p=0.1, inplace=False)\n", 285 | " )\n", 286 | " (output): LayoutLMSelfOutput(\n", 287 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 288 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 289 | " (dropout): Dropout(p=0.1, inplace=False)\n", 290 | " )\n", 291 | " )\n", 292 | " (intermediate): LayoutLMIntermediate(\n", 293 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 294 | " )\n", 295 | " (output): LayoutLMOutput(\n", 296 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 297 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 298 | " (dropout): Dropout(p=0.1, inplace=False)\n", 299 | " )\n", 300 | " )\n", 301 | " (5): LayoutLMLayer(\n", 302 | " (attention): LayoutLMAttention(\n", 303 | " (self): LayoutLMSelfAttention(\n", 304 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 305 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 306 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 307 | " (dropout): Dropout(p=0.1, inplace=False)\n", 308 | " )\n", 309 | " (output): LayoutLMSelfOutput(\n", 310 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 311 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 312 | " (dropout): Dropout(p=0.1, inplace=False)\n", 313 | " )\n", 314 | " )\n", 315 | " (intermediate): LayoutLMIntermediate(\n", 316 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 317 | " )\n", 318 | " (output): LayoutLMOutput(\n", 319 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 320 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 321 | " (dropout): Dropout(p=0.1, inplace=False)\n", 322 | " )\n", 323 | " )\n", 324 | " (6): LayoutLMLayer(\n", 325 | " (attention): LayoutLMAttention(\n", 326 | " (self): LayoutLMSelfAttention(\n", 327 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 328 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 329 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 330 | " (dropout): Dropout(p=0.1, inplace=False)\n", 331 | " )\n", 332 | " (output): LayoutLMSelfOutput(\n", 333 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 334 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 335 | " (dropout): Dropout(p=0.1, inplace=False)\n", 336 | " )\n", 337 | " )\n", 338 | " (intermediate): LayoutLMIntermediate(\n", 339 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 340 | " )\n", 341 | " (output): LayoutLMOutput(\n", 342 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 343 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 344 | " (dropout): Dropout(p=0.1, inplace=False)\n", 345 | " )\n", 346 | " )\n", 347 | " (7): LayoutLMLayer(\n", 348 | " (attention): LayoutLMAttention(\n", 349 | " (self): LayoutLMSelfAttention(\n", 350 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 351 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 352 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 353 | " (dropout): Dropout(p=0.1, inplace=False)\n", 354 | " )\n", 355 | " (output): LayoutLMSelfOutput(\n", 356 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 357 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 358 | " (dropout): Dropout(p=0.1, inplace=False)\n", 359 | " )\n", 360 | " )\n", 361 | " (intermediate): LayoutLMIntermediate(\n", 362 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 363 | " )\n", 364 | " (output): LayoutLMOutput(\n", 365 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 366 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 367 | " (dropout): Dropout(p=0.1, inplace=False)\n", 368 | " )\n", 369 | " )\n", 370 | " (8): LayoutLMLayer(\n", 371 | " (attention): LayoutLMAttention(\n", 372 | " (self): LayoutLMSelfAttention(\n", 373 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 374 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 375 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 376 | " (dropout): Dropout(p=0.1, inplace=False)\n", 377 | " )\n", 378 | " (output): LayoutLMSelfOutput(\n", 379 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 380 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 381 | " (dropout): Dropout(p=0.1, inplace=False)\n", 382 | " )\n", 383 | " )\n", 384 | " (intermediate): LayoutLMIntermediate(\n", 385 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 386 | " )\n", 387 | " (output): LayoutLMOutput(\n", 388 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 389 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 390 | " (dropout): Dropout(p=0.1, inplace=False)\n", 391 | " )\n", 392 | " )\n", 393 | " (9): LayoutLMLayer(\n", 394 | " (attention): LayoutLMAttention(\n", 395 | " (self): LayoutLMSelfAttention(\n", 396 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 397 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 398 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 399 | " (dropout): Dropout(p=0.1, inplace=False)\n", 400 | " )\n", 401 | " (output): LayoutLMSelfOutput(\n", 402 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 403 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 404 | " (dropout): Dropout(p=0.1, inplace=False)\n", 405 | " )\n", 406 | " )\n", 407 | " (intermediate): LayoutLMIntermediate(\n", 408 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 409 | " )\n", 410 | " (output): LayoutLMOutput(\n", 411 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 412 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 413 | " (dropout): Dropout(p=0.1, inplace=False)\n", 414 | " )\n", 415 | " )\n", 416 | " (10): LayoutLMLayer(\n", 417 | " (attention): LayoutLMAttention(\n", 418 | " (self): LayoutLMSelfAttention(\n", 419 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 420 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 421 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 422 | " (dropout): Dropout(p=0.1, inplace=False)\n", 423 | " )\n", 424 | " (output): LayoutLMSelfOutput(\n", 425 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 426 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 427 | " (dropout): Dropout(p=0.1, inplace=False)\n", 428 | " )\n", 429 | " )\n", 430 | " (intermediate): LayoutLMIntermediate(\n", 431 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 432 | " )\n", 433 | " (output): LayoutLMOutput(\n", 434 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 435 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 436 | " (dropout): Dropout(p=0.1, inplace=False)\n", 437 | " )\n", 438 | " )\n", 439 | " (11): LayoutLMLayer(\n", 440 | " (attention): LayoutLMAttention(\n", 441 | " (self): LayoutLMSelfAttention(\n", 442 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 443 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 444 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 445 | " (dropout): Dropout(p=0.1, inplace=False)\n", 446 | " )\n", 447 | " (output): LayoutLMSelfOutput(\n", 448 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 449 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 450 | " (dropout): Dropout(p=0.1, inplace=False)\n", 451 | " )\n", 452 | " )\n", 453 | " (intermediate): LayoutLMIntermediate(\n", 454 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 455 | " )\n", 456 | " (output): LayoutLMOutput(\n", 457 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 458 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 459 | " (dropout): Dropout(p=0.1, inplace=False)\n", 460 | " )\n", 461 | " )\n", 462 | " )\n", 463 | " )\n", 464 | " (pooler): LayoutLMPooler(\n", 465 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 466 | " (activation): Tanh()\n", 467 | " )\n", 468 | " )\n", 469 | " (dropout): Dropout(p=0.1, inplace=False)\n", 470 | " (classifier): Linear(in_features=768, out_features=5, bias=True)\n", 471 | ")" 472 | ] 473 | }, 474 | "execution_count": 5, 475 | "metadata": {}, 476 | "output_type": "execute_result" 477 | } 478 | ], 479 | "source": [ 480 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 481 | "model = LayoutLMForSequenceClassification.from_pretrained(\"saved_model/run2\")\n", 482 | "model.to(device)" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "id": "answering-credits", 488 | "metadata": {}, 489 | "source": [ 490 | "# Data Processing Flow" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 11, 496 | "id": "involved-cycle", 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "name": "stdout", 501 | "output_type": "stream", 502 | "text": [ 503 | "test_data [] ['audacious.jpg', 'Developer-564x804.png']\n", 504 | "['audacious.jpg', 'Developer-564x804.png']\n" 505 | ] 506 | }, 507 | { 508 | "data": { 509 | "text/html": [ 510 | "
\n", 511 | "\n", 524 | "\n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | "
image_path
0test_data/audacious.jpg
\n", 538 | "
" 539 | ], 540 | "text/plain": [ 541 | " image_path\n", 542 | "0 test_data/audacious.jpg" 543 | ] 544 | }, 545 | "execution_count": 11, 546 | "metadata": {}, 547 | "output_type": "execute_result" 548 | } 549 | ], 550 | "source": [ 551 | "images = []\n", 552 | "labels = []\n", 553 | "dataset_path = 'test_data'\n", 554 | "\n", 555 | "for label_folder, _, file_names in os.walk(dataset_path):\n", 556 | " print(label_folder, _, file_names)\n", 557 | " print(file_names)\n", 558 | " relative_image_names = []\n", 559 | " relative_image_names.append(dataset_path + \"/\" + file_names[0])\n", 560 | " images.extend(relative_image_names)\n", 561 | "test_data = pd.DataFrame.from_dict({'image_path': images})\n", 562 | "test_data.head()" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": 12, 568 | "id": "auburn-letter", 569 | "metadata": {}, 570 | "outputs": [ 571 | { 572 | "data": { 573 | "application/vnd.jupyter.widget-view+json": { 574 | "model_id": "66aba325525c4a6b86718dc6a7dd6fbb", 575 | "version_major": 2, 576 | "version_minor": 0 577 | }, 578 | "text/plain": [ 579 | " 0%| | 0/1 [00:00), hidden_states=None, attentions=None)\n" 663 | ] 664 | } 665 | ], 666 | "source": [ 667 | "input_ids = test_batch[\"input_ids\"].to(device)\n", 668 | "bbox = test_batch[\"bbox\"].to(device)\n", 669 | "attention_mask = test_batch[\"attention_mask\"].to(device)\n", 670 | "token_type_ids = test_batch[\"token_type_ids\"].to(device)\n", 671 | "\n", 672 | "# forward pass\n", 673 | "outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, \n", 674 | " token_type_ids=token_type_ids)\n", 675 | "\n", 676 | "# prediction = int(torch.max(outputs.data, 1)[1].numpy())\n", 677 | "print(outputs)" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 17, 683 | "id": "documentary-hartford", 684 | "metadata": {}, 685 | "outputs": [ 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "bill: 0%\n", 691 | "invoice: 0%\n", 692 | "others: 100%\n", 693 | "Purchase_Order: 0%\n", 694 | "remittance: 0%\n" 695 | ] 696 | } 697 | ], 698 | "source": [ 699 | "# import torch.nn.functional as F\n", 700 | "# pt_predictions = F.softmax(outputs[0], dim=-1)\n", 701 | "# pt_predictions\n", 702 | "\n", 703 | "classification_logits = outputs.logits\n", 704 | "classification_results = torch.softmax(classification_logits, dim=1).tolist()[0]\n", 705 | "for i in range(len(classes)):\n", 706 | " print(f\"{classes[i]}: {int(round(classification_results[i] * 100))}%\")" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 24, 712 | "id": "unsigned-luther", 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "{'bill': '0%', 'invoice': '0%', 'others': '100%', 'Purchase_Order': '0%', 'remittance': '0%'}\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "thisdict ={}\n", 725 | "for i in range(len(classes)):\n", 726 | " thisdict[classes[i]] = str(int(round(classification_results[i] * 100))) + \"%\"\n", 727 | "print(thisdict)\n" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 18, 733 | "id": "color-bones", 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "data": { 738 | "text/plain": [ 739 | "tensor([[5.5799e-06, 4.3818e-06, 9.9998e-01, 5.3273e-06, 3.8329e-06]],\n", 740 | " device='cuda:0', grad_fn=)" 741 | ] 742 | }, 743 | "execution_count": 18, 744 | "metadata": {}, 745 | "output_type": "execute_result" 746 | } 747 | ], 748 | "source": [ 749 | "import torch.nn.functional as F\n", 750 | "pt_predictions = F.softmax(outputs[0], dim=-1)\n", 751 | "pt_predictions" 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": 19, 757 | "id": "helpful-seventh", 758 | "metadata": {}, 759 | "outputs": [ 760 | { 761 | "data": { 762 | "text/plain": [ 763 | "2" 764 | ] 765 | }, 766 | "execution_count": 19, 767 | "metadata": {}, 768 | "output_type": "execute_result" 769 | } 770 | ], 771 | "source": [ 772 | "predictions = outputs.logits.argmax(-1).squeeze().tolist()\n", 773 | "predictions" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "id": "extra-workstation", 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "# NATIVE T5\n", 784 | "\n", 785 | "# generated_answer = model.generate(input_ids, attention_mask=attention_mask, \n", 786 | "# max_length=decoder_max_len, top_p=0.98, top_k=50)\n", 787 | "# decoded_answer = tokenizer.decode(generated_answer.numpy()[0])\n", 788 | "# print(\"Answer: \", decoded_answer)" 789 | ] 790 | }, 791 | { 792 | "cell_type": "code", 793 | "execution_count": null, 794 | "id": "brilliant-uncertainty", 795 | "metadata": {}, 796 | "outputs": [], 797 | "source": [] 798 | } 799 | ], 800 | "metadata": { 801 | "kernelspec": { 802 | "display_name": "Python 3", 803 | "language": "python", 804 | "name": "python3" 805 | }, 806 | "language_info": { 807 | "codemirror_mode": { 808 | "name": "ipython", 809 | "version": 3 810 | }, 811 | "file_extension": ".py", 812 | "mimetype": "text/x-python", 813 | "name": "python", 814 | "nbconvert_exporter": "python", 815 | "pygments_lexer": "ipython3", 816 | "version": "3.7.10" 817 | } 818 | }, 819 | "nbformat": 4, 820 | "nbformat_minor": 5 821 | } 822 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: sh setup.sh && streamlit run app.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document-Classification-using-LayoutLM 2 | This PyTorch implementation of LayoutLM paper by Microsoft demonstrate the SequenceClassfication task using HuggingFaceTransformers to classify types of Documents. 3 | 4 | 5 | ## Star History 6 | 7 | [![Star History Chart](https://api.star-history.com/svg?repos=lucky-verma/Document-Classification-using-LayoutLM&type=Date)](https://star-history.com/#lucky-verma/Document-Classification-using-LayoutLM&Date) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==2.2.0 2 | altair==4.1.0 3 | anyio==2.2.0 4 | argon2-cffi==20.1.0 5 | astor==0.8.1 6 | async-generator==1.10 7 | atomicwrites==1.4.0 8 | attrs==20.3.0 9 | Babel==2.9.0 10 | backcall==0.2.0 11 | base58==2.1.0 12 | bleach==3.3.0 13 | blinker==1.4 14 | blis==0.4.1 15 | boto3==1.17.14 16 | botocore==1.20.14 17 | cached-property==1.5.2 18 | cachetools==4.2.1 19 | catalogue==1.0.0 20 | certifi==2020.12.5 21 | cffi==1.14.5 22 | chardet==4.0.0 23 | click==7.1.2 24 | colorama==0.4.4 25 | configparser==5.0.2 26 | cycler==0.10.0 27 | cymem==2.0.5 28 | datasets==1.5.0 29 | decorator==4.4.2 30 | defusedxml==0.7.1 31 | dill==0.3.3 32 | docker-pycreds==0.4.0 33 | entrypoints==0.3 34 | filelock==3.0.12 35 | fsspec==0.8.7 36 | gitdb==4.0.7 37 | GitPython==3.1.14 38 | h5py==3.1.0 39 | huggingface-hub==0.0.7 40 | idna==2.10 41 | importlib-metadata==3.10.0 42 | iniconfig==1.1.1 43 | ipykernel==5.5.3 44 | ipython==7.22.0 45 | ipython-genutils==0.2.0 46 | ipywidgets==7.6.3 47 | jedi==0.18.0 48 | Jinja2==2.11.3 49 | jmespath==0.10.0 50 | joblib==1.0.1 51 | json5==0.9.5 52 | jsonpickle==2.0.0 53 | jsonschema==3.2.0 54 | jupyter-client==6.1.12 55 | jupyter-core==4.7.1 56 | jupyter-packaging==0.7.12 57 | jupyter-server==1.5.1 58 | jupyterlab==3.0.12 59 | jupyterlab-pygments==0.1.2 60 | jupyterlab-server==2.4.0 61 | jupyterlab-widgets==1.0.0 62 | kiwisolver==1.3.1 63 | lmdb==1.1.1 64 | MarkupSafe==1.1.1 65 | matplotlib==3.4.1 66 | mistune==0.8.4 67 | more-itertools==8.7.0 68 | multiprocess==0.70.11.1 69 | murmurhash==1.0.5 70 | nbclassic==0.2.6 71 | nbclient==0.5.3 72 | nbconvert==6.0.7 73 | nbformat==5.1.2 74 | nest-asyncio==1.5.1 75 | nltk==3.5 76 | notebook==6.3.0 77 | numpy==1.20.2 78 | opencv-python==4.3.0.36 79 | overrides==3.1.0 80 | packaging==20.9 81 | pandas==1.0.5 82 | pandocfilters==1.4.3 83 | parso==0.8.1 84 | pathtools==0.1.2 85 | pickleshare==0.7.5 86 | Pillow==8.1.2 87 | plac==1.1.3 88 | plotly==4.14.3 89 | pluggy==0.13.1 90 | preshed==3.0.5 91 | prometheus-client==0.9.0 92 | promise==2.3 93 | prompt-toolkit==3.0.18 94 | protobuf==3.15.1 95 | psutil==5.8.0 96 | py==1.10.0 97 | pyarrow==3.0.0 98 | pycparser==2.20 99 | pydeck==0.6.1 100 | Pygments==2.8.1 101 | pyparsing==2.4.7 102 | pyrsistent==0.17.3 103 | pytesseract==0.3.7 104 | pytest==6.2.2 105 | python-dateutil==2.8.1 106 | pytz==2021.1 107 | pywin32==300 108 | pywinpty==0.5.7 109 | PyYAML==5.4.1 110 | pyzmq==22.0.3 111 | regex==2020.11.13 112 | requests==2.25.1 113 | retrying==1.3.3 114 | s3transfer==0.3.4 115 | sacremoses==0.0.43 116 | scikit-learn==0.24.1 117 | scipy==1.5.4 118 | Send2Trash==1.5.0 119 | sentencepiece==0.1.95 120 | sentry-sdk==1.0.0 121 | shortuuid==1.0.1 122 | six==1.15.0 123 | smmap==4.0.0 124 | sniffio==1.2.0 125 | spacy==2.2.4 126 | srsly==1.0.5 127 | streamlit==0.79.0 128 | subprocess32==3.5.4 129 | tabulate==0.8.7 130 | tensorboardX==2.1 131 | terminado==0.9.4 132 | testpath==0.4.4 133 | tez==0.1.2 134 | thinc==7.4.0 135 | threadpoolctl==2.1.0 136 | tokenizers==0.10.1 137 | toml==0.10.2 138 | toolz==0.11.1 139 | torch==1.8.1+cu102 140 | torchaudio==0.8.1 141 | torchtext==0.6.0 142 | torchvision==0.9.1+cu102 143 | tornado==6.1 144 | tqdm==4.59.0 145 | traitlets==5.0.5 146 | transformers==4.4.2 147 | typing-extensions==3.7.4.3 148 | tzlocal==2.1 149 | urllib3==1.26.4 150 | validators==0.18.2 151 | wandb==0.10.23 152 | wasabi==0.8.2 153 | watchdog==2.0.2 154 | wcwidth==0.2.5 155 | webencodings==0.5.1 156 | widgetsnbextension==3.5.1 157 | wincertstore==0.2 158 | xxhash==2.0.0 159 | zipp==3.4.1 160 | -------------------------------------------------------------------------------- /saved_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "microsoft/layoutlm-base-uncased", 3 | "architectures": [ 4 | "LayoutLMForSequenceClassification" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "gradient_checkpointing": false, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "id2label": { 12 | "0": "LABEL_0", 13 | "1": "LABEL_1", 14 | "2": "LABEL_2", 15 | "3": "LABEL_3", 16 | "4": "LABEL_4" 17 | }, 18 | "initializer_range": 0.02, 19 | "intermediate_size": 3072, 20 | "label2id": { 21 | "LABEL_0": 0, 22 | "LABEL_1": 1, 23 | "LABEL_2": 2, 24 | "LABEL_3": 3, 25 | "LABEL_4": 4 26 | }, 27 | "layer_norm_eps": 1e-12, 28 | "max_2d_position_embeddings": 1024, 29 | "max_position_embeddings": 512, 30 | "model_type": "layoutlm", 31 | "num_attention_heads": 12, 32 | "num_hidden_layers": 12, 33 | "output_past": true, 34 | "pad_token_id": 0, 35 | "position_embedding_type": "absolute", 36 | "transformers_version": "4.4.2", 37 | "type_vocab_size": 2, 38 | "use_cache": true, 39 | "vocab_size": 30522 40 | } 41 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/.streamlit/ 2 | 3 | echo "\ 4 | [server]\n\ 5 | headless = true\n\ 6 | port = $PORT\n\ 7 | enableCORS = false\n\ 8 | \n\ 9 | " > ~/.streamlit/config.toml -------------------------------------------------------------------------------- /streamlit-app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from PIL import Image, ImageDraw, ImageFont 4 | from transformers import LayoutLMForSequenceClassification, LayoutLMTokenizer 5 | import torch 6 | import requests 7 | from torch.utils.data import Dataset, DataLoader 8 | import pytesseract 9 | from datasets import Features, Sequence, ClassLabel, Value, Array2D 10 | import numpy as np 11 | import streamlit as st 12 | from datasets import Dataset 13 | import plotly.figure_factory as ff 14 | import plotly.express as px 15 | from plotly.subplots import make_subplots 16 | import plotly.graph_objects as go 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | # Legacy method imports 21 | 22 | def normalize_box(box, width, height): 23 | return [ 24 | int(1000 * (box[0] / width)), 25 | int(1000 * (box[1] / height)), 26 | int(1000 * (box[2] / width)), 27 | int(1000 * (box[3] / height)), 28 | ] 29 | 30 | def apply_ocr(example): 31 | # get the image 32 | image = Image.open(example['image_path']) 33 | 34 | width, height = image.size 35 | 36 | # apply ocr to the image 37 | ocr_df = pytesseract.image_to_data(image, output_type='data.frame') 38 | float_cols = ocr_df.select_dtypes('float').columns 39 | ocr_df = ocr_df.dropna().reset_index(drop=True) 40 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) 41 | ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) 42 | ocr_df = ocr_df.dropna().reset_index(drop=True) 43 | 44 | # get the words and actual (unnormalized) bounding boxes 45 | #words = [word for word in ocr_df.text if str(word) != 'nan']) 46 | words = list(ocr_df.text) 47 | words = [str(w) for w in words] 48 | coordinates = ocr_df[['left', 'top', 'width', 'height']] 49 | actual_boxes = [] 50 | for idx, row in coordinates.iterrows(): 51 | x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format 52 | actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box 53 | actual_boxes.append(actual_box) 54 | 55 | # normalize the bounding boxes 56 | boxes = [] 57 | for box in actual_boxes: 58 | boxes.append(normalize_box(box, width, height)) 59 | 60 | # add as extra columns 61 | assert len(words) == len(boxes) 62 | example['words'] = words 63 | example['bbox'] = boxes 64 | return example 65 | 66 | tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") 67 | 68 | def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]): 69 | words = example['words'] 70 | normalized_word_boxes = example['bbox'] 71 | 72 | assert len(words) == len(normalized_word_boxes) 73 | 74 | token_boxes = [] 75 | for word, box in zip(words, normalized_word_boxes): 76 | word_tokens = tokenizer.tokenize(word) 77 | token_boxes.extend([box] * len(word_tokens)) 78 | 79 | # Truncation of token_boxes 80 | special_tokens_count = 2 81 | if len(token_boxes) > max_seq_length - special_tokens_count: 82 | token_boxes = token_boxes[: (max_seq_length - special_tokens_count)] 83 | 84 | # add bounding boxes of cls + sep tokens 85 | token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] 86 | 87 | encoding = tokenizer(' '.join(words), padding='max_length', truncation=True) 88 | # Padding of token_boxes up the bounding boxes to the sequence length. 89 | input_ids = tokenizer(' '.join(words), truncation=True)["input_ids"] 90 | padding_length = max_seq_length - len(input_ids) 91 | token_boxes += [pad_token_box] * padding_length 92 | encoding['bbox'] = token_boxes 93 | 94 | assert len(encoding['input_ids']) == max_seq_length 95 | assert len(encoding['attention_mask']) == max_seq_length 96 | assert len(encoding['token_type_ids']) == max_seq_length 97 | assert len(encoding['bbox']) == max_seq_length 98 | 99 | return encoding 100 | 101 | # we need to define the features ourselves as the bbox of LayoutLM are an extra feature 102 | features = Features({ 103 | 'input_ids': Sequence(feature=Value(dtype='int64')), 104 | 'bbox': Array2D(dtype="int64", shape=(512, 4)), 105 | 'attention_mask': Sequence(Value(dtype='int64')), 106 | 'token_type_ids': Sequence(Value(dtype='int64')), 107 | 'image_path': Value(dtype='string'), 108 | 'words': Sequence(feature=Value(dtype='string')), 109 | }) 110 | 111 | classes = ["bill", "invoice", "others", "Purchase_Order", "remittance"] 112 | 113 | 114 | # Model Loading 115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | @st.cache(allow_output_mutation=True) 117 | def load_model(): 118 | url = "https://vast-ml-models.s3-ap-southeast-2.amazonaws.com/Document-Classification-5-labels-final.bin" 119 | r = requests.get(url, allow_redirects=True) 120 | open('saved_model/pytorch_model.bin', 'wb').write(r.content) 121 | model = LayoutLMForSequenceClassification.from_pretrained("saved_model") 122 | return model 123 | 124 | load_model().to(device) 125 | 126 | # Data processing 127 | 128 | st.title('VAST: Document Classifier') 129 | st.header('Upload any document image') 130 | hide_streamlit_style = """ 131 | 135 | """ 136 | st.markdown(hide_streamlit_style, unsafe_allow_html=True) 137 | 138 | 139 | image = st.file_uploader('Upload here', type=['jpg', 'png', 'jpeg', 'webp']) 140 | 141 | if image is None: 142 | st.write("### Please upload your Invoice IMAGE") 143 | else: 144 | im = Image.open(image) 145 | rgb_im = im.convert('RGB') 146 | rgb_im.save('test_data/audacious.jpg') 147 | os.getcwd() 148 | test_data = pd.DataFrame.from_dict({'image_path': ['test_data/audacious.jpg']}) 149 | st.image(image, caption='your_doc', use_column_width=True) 150 | if st.button("Process"): 151 | st.spinner() 152 | with st.spinner(text='In progress'): 153 | test_dataset = Dataset.from_pandas(test_data) 154 | updated_test_dataset = test_dataset.map(apply_ocr) 155 | st.success('OCR Done') 156 | encoded_test_dataset = updated_test_dataset.map(lambda example: encode_example(example), 157 | features=features) 158 | encoded_test_dataset.set_format(type='torch', columns=['input_ids', 'bbox', 'attention_mask', 'token_type_ids']) 159 | test_dataloader = torch.utils.data.DataLoader(encoded_test_dataset, batch_size=1, shuffle=True) 160 | test_batch = next(iter(test_dataloader)) 161 | st.success('Encoding Data Done') 162 | input_ids = test_batch["input_ids"].to(device) 163 | bbox = test_batch["bbox"].to(device) 164 | attention_mask = test_batch["attention_mask"].to(device) 165 | token_type_ids = test_batch["token_type_ids"].to(device) 166 | 167 | # forward pass 168 | outputs = load_model()(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, 169 | token_type_ids=token_type_ids) 170 | 171 | classification_logits = outputs.logits 172 | classification_results = torch.softmax(classification_logits, dim=1).tolist()[0] 173 | 174 | # Show JSON output 175 | thisdict ={} 176 | for i in range(len(classes)): 177 | thisdict[classes[i]] = str(int(round(classification_results[i] * 100))) + "%" 178 | st.json(thisdict) 179 | 180 | # Show a Plotly Graph 181 | res_list = [] 182 | res_dict ={"Type of Document":["bill", "invoice", "others", "Purchase_Order", "remittance"], 183 | "Prediction Percent": res_list} 184 | for i in range(len(classes)): 185 | res_list.append(classification_results[i] * 100) 186 | res_dict[classes[i]] = int(round(classification_results[i] * 100)) 187 | total_dataframe = pd.DataFrame(res_dict) 188 | state_total_graph = px.bar( 189 | total_dataframe, 190 | x='Type of Document', 191 | y='Prediction Percent', 192 | labels={'YOYO': 'Prediction Percent' }, color='Type of Document') 193 | st.plotly_chart(state_total_graph) 194 | 195 | 196 | st.success('Done') 197 | st.balloons() 198 | 199 | --------------------------------------------------------------------------------