├── .gitignore
├── Aptfile
├── Fine_tuning_LayoutLMForSequenceClassification_on_RVL_CDIP.ipynb
├── LICENSE
├── Predictor.ipynb
├── Procfile
├── README.md
├── requirements.txt
├── saved_model
└── config.json
├── setup.sh
└── streamlit-app.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/Aptfile:
--------------------------------------------------------------------------------
1 | tesseract-ocr
2 | tesseract-ocr-eng
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Lucky Verma
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Predictor.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "capital-berkeley",
6 | "metadata": {},
7 | "source": [
8 | "# Legacy Import"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "functioning-maker",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import pandas as pd\n",
20 | "from PIL import Image, ImageDraw, ImageFont\n",
21 | "from transformers import LayoutLMForSequenceClassification, LayoutLMTokenizer\n",
22 | "import torch\n",
23 | "from torch.utils.data import Dataset, DataLoader\n",
24 | "import pytesseract\n",
25 | "from datasets import Features, Sequence, ClassLabel, Value, Array2D\n",
26 | "import numpy as np\n",
27 | "\n",
28 | "classes = [\"bill\", \"invoice\", \"others\", \"Purchase_Order\", \"remittance\"]"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "id": "restricted-cedar",
34 | "metadata": {},
35 | "source": [
36 | "# Legacy Methods"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "id": "german-modem",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "from datasets import Dataset\n",
47 | "\n",
48 | "def normalize_box(box, width, height):\n",
49 | " return [\n",
50 | " int(1000 * (box[0] / width)),\n",
51 | " int(1000 * (box[1] / height)),\n",
52 | " int(1000 * (box[2] / width)),\n",
53 | " int(1000 * (box[3] / height)),\n",
54 | " ]\n",
55 | "\n",
56 | "def apply_ocr(example):\n",
57 | " # get the image\n",
58 | " image = Image.open(example['image_path'])\n",
59 | "\n",
60 | " width, height = image.size\n",
61 | " \n",
62 | " # apply ocr to the image \n",
63 | " ocr_df = pytesseract.image_to_data(image, output_type='data.frame')\n",
64 | " float_cols = ocr_df.select_dtypes('float').columns\n",
65 | " ocr_df = ocr_df.dropna().reset_index(drop=True)\n",
66 | " ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)\n",
67 | " ocr_df = ocr_df.replace(r'^\\s*$', np.nan, regex=True)\n",
68 | " ocr_df = ocr_df.dropna().reset_index(drop=True)\n",
69 | "\n",
70 | " # get the words and actual (unnormalized) bounding boxes\n",
71 | " #words = [word for word in ocr_df.text if str(word) != 'nan'])\n",
72 | " words = list(ocr_df.text)\n",
73 | " words = [str(w) for w in words]\n",
74 | " coordinates = ocr_df[['left', 'top', 'width', 'height']]\n",
75 | " actual_boxes = []\n",
76 | " for idx, row in coordinates.iterrows():\n",
77 | " x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format\n",
78 | " actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box \n",
79 | " actual_boxes.append(actual_box)\n",
80 | " \n",
81 | " # normalize the bounding boxes\n",
82 | " boxes = []\n",
83 | " for box in actual_boxes:\n",
84 | " boxes.append(normalize_box(box, width, height))\n",
85 | " \n",
86 | " # add as extra columns \n",
87 | " assert len(words) == len(boxes)\n",
88 | " example['words'] = words\n",
89 | " example['bbox'] = boxes\n",
90 | " return example\n"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "id": "mathematical-archives",
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "tokenizer = LayoutLMTokenizer.from_pretrained(\"microsoft/layoutlm-base-uncased\")\n",
101 | "\n",
102 | "def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]):\n",
103 | " words = example['words']\n",
104 | " normalized_word_boxes = example['bbox']\n",
105 | "\n",
106 | " assert len(words) == len(normalized_word_boxes)\n",
107 | "\n",
108 | " token_boxes = []\n",
109 | " for word, box in zip(words, normalized_word_boxes):\n",
110 | " word_tokens = tokenizer.tokenize(word)\n",
111 | " token_boxes.extend([box] * len(word_tokens))\n",
112 | " \n",
113 | " # Truncation of token_boxes\n",
114 | " special_tokens_count = 2 \n",
115 | " if len(token_boxes) > max_seq_length - special_tokens_count:\n",
116 | " token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]\n",
117 | " \n",
118 | " # add bounding boxes of cls + sep tokens\n",
119 | " token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]\n",
120 | " \n",
121 | " encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)\n",
122 | " # Padding of token_boxes up the bounding boxes to the sequence length.\n",
123 | " input_ids = tokenizer(' '.join(words), truncation=True)[\"input_ids\"]\n",
124 | " padding_length = max_seq_length - len(input_ids)\n",
125 | " token_boxes += [pad_token_box] * padding_length\n",
126 | " encoding['bbox'] = token_boxes\n",
127 | "\n",
128 | " assert len(encoding['input_ids']) == max_seq_length\n",
129 | " assert len(encoding['attention_mask']) == max_seq_length\n",
130 | " assert len(encoding['token_type_ids']) == max_seq_length\n",
131 | " assert len(encoding['bbox']) == max_seq_length\n",
132 | "\n",
133 | " return encoding"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 4,
139 | "id": "afraid-township",
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "# we need to define the features ourselves as the bbox of LayoutLM are an extra feature\n",
144 | "features = Features({\n",
145 | " 'input_ids': Sequence(feature=Value(dtype='int64')),\n",
146 | " 'bbox': Array2D(dtype=\"int64\", shape=(512, 4)),\n",
147 | " 'attention_mask': Sequence(Value(dtype='int64')),\n",
148 | " 'token_type_ids': Sequence(Value(dtype='int64')),\n",
149 | " 'image_path': Value(dtype='string'),\n",
150 | " 'words': Sequence(feature=Value(dtype='string')),\n",
151 | "})\n"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "id": "analyzed-legend",
158 | "metadata": {},
159 | "outputs": [],
160 | "source": []
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 5,
165 | "id": "intense-recall",
166 | "metadata": {},
167 | "outputs": [
168 | {
169 | "data": {
170 | "text/plain": [
171 | "LayoutLMForSequenceClassification(\n",
172 | " (layoutlm): LayoutLMModel(\n",
173 | " (embeddings): LayoutLMEmbeddings(\n",
174 | " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
175 | " (position_embeddings): Embedding(512, 768)\n",
176 | " (x_position_embeddings): Embedding(1024, 768)\n",
177 | " (y_position_embeddings): Embedding(1024, 768)\n",
178 | " (h_position_embeddings): Embedding(1024, 768)\n",
179 | " (w_position_embeddings): Embedding(1024, 768)\n",
180 | " (token_type_embeddings): Embedding(2, 768)\n",
181 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
182 | " (dropout): Dropout(p=0.1, inplace=False)\n",
183 | " )\n",
184 | " (encoder): LayoutLMEncoder(\n",
185 | " (layer): ModuleList(\n",
186 | " (0): LayoutLMLayer(\n",
187 | " (attention): LayoutLMAttention(\n",
188 | " (self): LayoutLMSelfAttention(\n",
189 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
190 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
191 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
192 | " (dropout): Dropout(p=0.1, inplace=False)\n",
193 | " )\n",
194 | " (output): LayoutLMSelfOutput(\n",
195 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
196 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
197 | " (dropout): Dropout(p=0.1, inplace=False)\n",
198 | " )\n",
199 | " )\n",
200 | " (intermediate): LayoutLMIntermediate(\n",
201 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
202 | " )\n",
203 | " (output): LayoutLMOutput(\n",
204 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
205 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
206 | " (dropout): Dropout(p=0.1, inplace=False)\n",
207 | " )\n",
208 | " )\n",
209 | " (1): LayoutLMLayer(\n",
210 | " (attention): LayoutLMAttention(\n",
211 | " (self): LayoutLMSelfAttention(\n",
212 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
213 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
214 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
215 | " (dropout): Dropout(p=0.1, inplace=False)\n",
216 | " )\n",
217 | " (output): LayoutLMSelfOutput(\n",
218 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
219 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
220 | " (dropout): Dropout(p=0.1, inplace=False)\n",
221 | " )\n",
222 | " )\n",
223 | " (intermediate): LayoutLMIntermediate(\n",
224 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
225 | " )\n",
226 | " (output): LayoutLMOutput(\n",
227 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
228 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
229 | " (dropout): Dropout(p=0.1, inplace=False)\n",
230 | " )\n",
231 | " )\n",
232 | " (2): LayoutLMLayer(\n",
233 | " (attention): LayoutLMAttention(\n",
234 | " (self): LayoutLMSelfAttention(\n",
235 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
236 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
237 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
238 | " (dropout): Dropout(p=0.1, inplace=False)\n",
239 | " )\n",
240 | " (output): LayoutLMSelfOutput(\n",
241 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
242 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
243 | " (dropout): Dropout(p=0.1, inplace=False)\n",
244 | " )\n",
245 | " )\n",
246 | " (intermediate): LayoutLMIntermediate(\n",
247 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
248 | " )\n",
249 | " (output): LayoutLMOutput(\n",
250 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
251 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
252 | " (dropout): Dropout(p=0.1, inplace=False)\n",
253 | " )\n",
254 | " )\n",
255 | " (3): LayoutLMLayer(\n",
256 | " (attention): LayoutLMAttention(\n",
257 | " (self): LayoutLMSelfAttention(\n",
258 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
259 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
260 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
261 | " (dropout): Dropout(p=0.1, inplace=False)\n",
262 | " )\n",
263 | " (output): LayoutLMSelfOutput(\n",
264 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
265 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
266 | " (dropout): Dropout(p=0.1, inplace=False)\n",
267 | " )\n",
268 | " )\n",
269 | " (intermediate): LayoutLMIntermediate(\n",
270 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
271 | " )\n",
272 | " (output): LayoutLMOutput(\n",
273 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
274 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
275 | " (dropout): Dropout(p=0.1, inplace=False)\n",
276 | " )\n",
277 | " )\n",
278 | " (4): LayoutLMLayer(\n",
279 | " (attention): LayoutLMAttention(\n",
280 | " (self): LayoutLMSelfAttention(\n",
281 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
282 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
283 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
284 | " (dropout): Dropout(p=0.1, inplace=False)\n",
285 | " )\n",
286 | " (output): LayoutLMSelfOutput(\n",
287 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
288 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
289 | " (dropout): Dropout(p=0.1, inplace=False)\n",
290 | " )\n",
291 | " )\n",
292 | " (intermediate): LayoutLMIntermediate(\n",
293 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
294 | " )\n",
295 | " (output): LayoutLMOutput(\n",
296 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
297 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
298 | " (dropout): Dropout(p=0.1, inplace=False)\n",
299 | " )\n",
300 | " )\n",
301 | " (5): LayoutLMLayer(\n",
302 | " (attention): LayoutLMAttention(\n",
303 | " (self): LayoutLMSelfAttention(\n",
304 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
305 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
306 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
307 | " (dropout): Dropout(p=0.1, inplace=False)\n",
308 | " )\n",
309 | " (output): LayoutLMSelfOutput(\n",
310 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
311 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
312 | " (dropout): Dropout(p=0.1, inplace=False)\n",
313 | " )\n",
314 | " )\n",
315 | " (intermediate): LayoutLMIntermediate(\n",
316 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
317 | " )\n",
318 | " (output): LayoutLMOutput(\n",
319 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
320 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
321 | " (dropout): Dropout(p=0.1, inplace=False)\n",
322 | " )\n",
323 | " )\n",
324 | " (6): LayoutLMLayer(\n",
325 | " (attention): LayoutLMAttention(\n",
326 | " (self): LayoutLMSelfAttention(\n",
327 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
328 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
329 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
330 | " (dropout): Dropout(p=0.1, inplace=False)\n",
331 | " )\n",
332 | " (output): LayoutLMSelfOutput(\n",
333 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
334 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
335 | " (dropout): Dropout(p=0.1, inplace=False)\n",
336 | " )\n",
337 | " )\n",
338 | " (intermediate): LayoutLMIntermediate(\n",
339 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
340 | " )\n",
341 | " (output): LayoutLMOutput(\n",
342 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
343 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
344 | " (dropout): Dropout(p=0.1, inplace=False)\n",
345 | " )\n",
346 | " )\n",
347 | " (7): LayoutLMLayer(\n",
348 | " (attention): LayoutLMAttention(\n",
349 | " (self): LayoutLMSelfAttention(\n",
350 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
351 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
352 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
353 | " (dropout): Dropout(p=0.1, inplace=False)\n",
354 | " )\n",
355 | " (output): LayoutLMSelfOutput(\n",
356 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
357 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
358 | " (dropout): Dropout(p=0.1, inplace=False)\n",
359 | " )\n",
360 | " )\n",
361 | " (intermediate): LayoutLMIntermediate(\n",
362 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
363 | " )\n",
364 | " (output): LayoutLMOutput(\n",
365 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
366 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
367 | " (dropout): Dropout(p=0.1, inplace=False)\n",
368 | " )\n",
369 | " )\n",
370 | " (8): LayoutLMLayer(\n",
371 | " (attention): LayoutLMAttention(\n",
372 | " (self): LayoutLMSelfAttention(\n",
373 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
374 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
375 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
376 | " (dropout): Dropout(p=0.1, inplace=False)\n",
377 | " )\n",
378 | " (output): LayoutLMSelfOutput(\n",
379 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
380 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
381 | " (dropout): Dropout(p=0.1, inplace=False)\n",
382 | " )\n",
383 | " )\n",
384 | " (intermediate): LayoutLMIntermediate(\n",
385 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
386 | " )\n",
387 | " (output): LayoutLMOutput(\n",
388 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
389 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
390 | " (dropout): Dropout(p=0.1, inplace=False)\n",
391 | " )\n",
392 | " )\n",
393 | " (9): LayoutLMLayer(\n",
394 | " (attention): LayoutLMAttention(\n",
395 | " (self): LayoutLMSelfAttention(\n",
396 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
397 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
398 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
399 | " (dropout): Dropout(p=0.1, inplace=False)\n",
400 | " )\n",
401 | " (output): LayoutLMSelfOutput(\n",
402 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
403 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
404 | " (dropout): Dropout(p=0.1, inplace=False)\n",
405 | " )\n",
406 | " )\n",
407 | " (intermediate): LayoutLMIntermediate(\n",
408 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
409 | " )\n",
410 | " (output): LayoutLMOutput(\n",
411 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
412 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
413 | " (dropout): Dropout(p=0.1, inplace=False)\n",
414 | " )\n",
415 | " )\n",
416 | " (10): LayoutLMLayer(\n",
417 | " (attention): LayoutLMAttention(\n",
418 | " (self): LayoutLMSelfAttention(\n",
419 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
420 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
421 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
422 | " (dropout): Dropout(p=0.1, inplace=False)\n",
423 | " )\n",
424 | " (output): LayoutLMSelfOutput(\n",
425 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
426 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
427 | " (dropout): Dropout(p=0.1, inplace=False)\n",
428 | " )\n",
429 | " )\n",
430 | " (intermediate): LayoutLMIntermediate(\n",
431 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
432 | " )\n",
433 | " (output): LayoutLMOutput(\n",
434 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
435 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
436 | " (dropout): Dropout(p=0.1, inplace=False)\n",
437 | " )\n",
438 | " )\n",
439 | " (11): LayoutLMLayer(\n",
440 | " (attention): LayoutLMAttention(\n",
441 | " (self): LayoutLMSelfAttention(\n",
442 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
443 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
444 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
445 | " (dropout): Dropout(p=0.1, inplace=False)\n",
446 | " )\n",
447 | " (output): LayoutLMSelfOutput(\n",
448 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
449 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
450 | " (dropout): Dropout(p=0.1, inplace=False)\n",
451 | " )\n",
452 | " )\n",
453 | " (intermediate): LayoutLMIntermediate(\n",
454 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
455 | " )\n",
456 | " (output): LayoutLMOutput(\n",
457 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
458 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
459 | " (dropout): Dropout(p=0.1, inplace=False)\n",
460 | " )\n",
461 | " )\n",
462 | " )\n",
463 | " )\n",
464 | " (pooler): LayoutLMPooler(\n",
465 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
466 | " (activation): Tanh()\n",
467 | " )\n",
468 | " )\n",
469 | " (dropout): Dropout(p=0.1, inplace=False)\n",
470 | " (classifier): Linear(in_features=768, out_features=5, bias=True)\n",
471 | ")"
472 | ]
473 | },
474 | "execution_count": 5,
475 | "metadata": {},
476 | "output_type": "execute_result"
477 | }
478 | ],
479 | "source": [
480 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
481 | "model = LayoutLMForSequenceClassification.from_pretrained(\"saved_model/run2\")\n",
482 | "model.to(device)"
483 | ]
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "id": "answering-credits",
488 | "metadata": {},
489 | "source": [
490 | "# Data Processing Flow"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": 11,
496 | "id": "involved-cycle",
497 | "metadata": {},
498 | "outputs": [
499 | {
500 | "name": "stdout",
501 | "output_type": "stream",
502 | "text": [
503 | "test_data [] ['audacious.jpg', 'Developer-564x804.png']\n",
504 | "['audacious.jpg', 'Developer-564x804.png']\n"
505 | ]
506 | },
507 | {
508 | "data": {
509 | "text/html": [
510 | "
\n",
511 | "\n",
524 | "
\n",
525 | " \n",
526 | " \n",
527 | " | \n",
528 | " image_path | \n",
529 | "
\n",
530 | " \n",
531 | " \n",
532 | " \n",
533 | " 0 | \n",
534 | " test_data/audacious.jpg | \n",
535 | "
\n",
536 | " \n",
537 | "
\n",
538 | "
"
539 | ],
540 | "text/plain": [
541 | " image_path\n",
542 | "0 test_data/audacious.jpg"
543 | ]
544 | },
545 | "execution_count": 11,
546 | "metadata": {},
547 | "output_type": "execute_result"
548 | }
549 | ],
550 | "source": [
551 | "images = []\n",
552 | "labels = []\n",
553 | "dataset_path = 'test_data'\n",
554 | "\n",
555 | "for label_folder, _, file_names in os.walk(dataset_path):\n",
556 | " print(label_folder, _, file_names)\n",
557 | " print(file_names)\n",
558 | " relative_image_names = []\n",
559 | " relative_image_names.append(dataset_path + \"/\" + file_names[0])\n",
560 | " images.extend(relative_image_names)\n",
561 | "test_data = pd.DataFrame.from_dict({'image_path': images})\n",
562 | "test_data.head()"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": 12,
568 | "id": "auburn-letter",
569 | "metadata": {},
570 | "outputs": [
571 | {
572 | "data": {
573 | "application/vnd.jupyter.widget-view+json": {
574 | "model_id": "66aba325525c4a6b86718dc6a7dd6fbb",
575 | "version_major": 2,
576 | "version_minor": 0
577 | },
578 | "text/plain": [
579 | " 0%| | 0/1 [00:00, ?ex/s]"
580 | ]
581 | },
582 | "metadata": {},
583 | "output_type": "display_data"
584 | }
585 | ],
586 | "source": [
587 | "test_dataset = Dataset.from_pandas(test_data)\n",
588 | "updated_test_dataset = test_dataset.map(apply_ocr)"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": 13,
594 | "id": "together-techno",
595 | "metadata": {},
596 | "outputs": [
597 | {
598 | "name": "stdout",
599 | "output_type": "stream",
600 | "text": [
601 | "578\n",
602 | "['Essay', 'Writing', 'in', '10', 'Words', '1.', 'Brief', 'Take', 'time', 'to', 'fully', 'understand', 'the', 'brief', 'you', 'have', 'been', 'given.', 'Read', 'it', 'several', 'times', 'and', 'put', 'it', 'into', 'your', 'own', 'words', 'before', 'starting.', 'Ensure', 'your', 'essay', 'states', 'the', 'brief', 'in', 'your', 'introduction,', 'addresses', 'it', 'in', 'your', 'main', 'body', 'and', 'answers', 'it', 'in', 'your', 'conclusion.', 'This', 'will', 'enable', 'your', 'reader', 'to', 'know', 'exactly', 'what', 'your', 'essay', 'is', 'about,', 'how', 'you', 'have', 'addressed', 'it', 'and', 'what', 'you', 'have', 'concluded', 'on', 'this', 'subject.', 'If', 'your', 'brief', 'has', 'several', 'parts', 'to', 'it,', 'divide', 'your', 'word', 'count', 'between', 'them', 'according', 'to', 'their', 'relative', 'importance', 'and', 'structure', 'your', 'main', 'body', 'so', 'that', 'you', 'address', 'each', 'part', 'in', 'turn.', '2.', 'Evidence', 'Collect', 'evidence', 'on', 'your', 'subject.', 'Every', 'claim', 'you', 'make', 'in', 'your', 'essay', 'should', 'be', 'backed', 'up', 'by', 'evidence.', 'Never', 'air', 'your', 'beliefs', 'or', 'opinions', 'without', 'evidence', 'unless', 'specifically', 'asked', 'to', 'in', 'your', 'brief.', 'Ensure', 'the', 'quality', 'and', 'relevance', 'of', 'your', 'evidence', 'is', 'as', 'high', 'as', 'possible.', 'Organise', 'your', 'evidence', 'conceptually', '(e.g.', 'use', 'a', 'mind', 'map).', 'Collect', 'about', 'twice', 'as', 'much', 'evidence', 'as', 'you', 'are', 'planning', 'to', 'use.', 'Prioritise', 'the', 'importance', 'of', 'your', 'evidence,', '.g.', 'the', 'top', '10%,', 'the', 'next', '40%,', 'and', 'the', 'other', '50%.', 'Spend', 'about', 'the', 'same', 'number', 'of', 'words', 'discussing', 'the', 'top', '10%', 'as', 'the', 'next', '40%', 'but', \"don't\", 'refer', 'to', 'the', 'other', '50%.', 'Where', 'possible,', 'cite', 'your', 'top', '10%', 'in', 'more', 'depth', 'in', 'one', 'place', 'rather', 'than', 'using', 'several', 'short', 'citations', 'so', 'that', 'your', 'reader', 'can', 'clearly', 'see', 'you', 'are', 'prioritising', 'this', 'evidence.', 'Avoid', 'regurgitating', 'evidence', 'by', 'evaluating', 'it:', \"don't\", 'ignore', 'conflicting', 'evidence', 'and', 'come', 'to', 'a', 'personal', 'view', 'based', 'on', 'the', 'evidence', 'with', 'an', 'appropriate', 'degree', 'of', 'confidence', 'for', 'each', 'area', 'you', 'cover.', '3.', 'Paragraph', 'Write', 'your', 'essay', 'In', 'paragraphs.', 'Separate', 'each', 'paragraph', 'with', 'a', 'blank', 'line.', 'The', 'average', 'paragraph', 'length', 'should', 'be', 'about', '125', 'words', 'and', '4', 'or', 'more', 'sentences.', 'Paragraphs', 'are', 'the', 'basic', 'building', 'blocks', 'of', 'essays.', 'If', 'you', 'can', 'write', 'good', 'paragraphs', 'you', 'are', 'more', 'than', 'half', 'way', 'to', 'becoming', 'a', 'good', 'essay', 'writer.', 'Short', 'paragraphs', 'indicate', 'ideas', 'which', 'are', 'not', 'fully', 'developed.', '4.', 'Topic', 'Each', 'of', 'your', 'paragraphs', 'should', 'introduce,', 'develop', 'and', 'conclude', 'a', 'single', 'topic.', 'Your', 'reader', 'should', 'clearly', 'be', 'able', 'to', 'understand', 'what', 'each', 'paragraph', 'is', 'about', 'by', 'reading', 'the', 'first', 'one', 'o', 'two', 'sentences.', 'If', 'not,', 'then', 'divide', 'your', 'paragraphs', 'up', 'to', 'avoid', 'rambling.', 'Successive', 'paragraph', 'topics', 'should', 'flow', 'in', 'a', 'logical', 'sequence.', '5.', 'Point', 'Each', 'paragraph', 'should', 'make', 'one', 'main', 'point', '(claim', 'backed', 'up', 'by', 'evidence).', 'When', 'introducing', 'and', 'describing', 'evidence', 'your', 'points', 'should', 'generally', 'come', 'at', 'the', 'beginning', 'of', 'your', 'paragraphs', 'but', 'when', 'discussing', 'ideas', 'in', 'more', 'depth', 'they', 'should', 'generally', 'come', 'at', 'the', 'end.', 'Your', 'points', 'are', 'very', 'important', 'as', 'they', 'are', 'the', 'essence', 'of', 'what', 'you', 'are', 'trying', 'to', 'say', 'to', 'your', 'reader.', '6.', 'Introduction', 'Start', 'your', 'essay', 'with', 'an', 'introduction.', 'Use', 'about', '10%', 'of', 'your', 'word', 'count.', 'Your', 'introduction', 'should', 'contain', 'three', 'stages:', '«', 'Territory', '—', 'establish', 'the', 'wider', 'context', 'of', 'your', 'essay', 'subject', 'by', 'making', 'some', 'simple,', 'generally', 'accepted', 'observations', 'which', 'a', 'non-specialist', 'reader', 'can', 'understand.', 'For', 'longer', 'essays,', 'then', 'start', 'to', 'focus', 'towards', 'the', 'essay', 'subject.', '«', 'Niche', '—', 'explain', 'what', 'your', 'essay', 'is', 'about', 'and', 'why', 'it', 'is', 'important/interesting', '«', 'Occupy', '-', 'explain', 'how', 'you', 'are', 'going', 'to', 'address', 'the', 'subject', 'of', 'your', 'essay', 'in', 'your', 'main', 'body', '(this', 'stage', 'also', 'signals', 'to', 'your', 'reader', 'that', 'your', 'introduction', 'is', 'ending', 'and', 'your', 'main', 'body', 'is', 'about', 'to', 'start).', 'This', 'is', 'saying', 'what', 'you', 'are', 'going', 'to', 'say.']\n"
603 | ]
604 | }
605 | ],
606 | "source": [
607 | "import pandas as pd\n",
608 | "df = pd.DataFrame.from_dict(updated_test_dataset)\n",
609 | "print(len(df[\"words\"][0]))\n",
610 | "print(df[\"words\"][0])"
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": 14,
616 | "id": "adjusted-magnitude",
617 | "metadata": {},
618 | "outputs": [
619 | {
620 | "data": {
621 | "application/vnd.jupyter.widget-view+json": {
622 | "model_id": "1c21942666604cccb4ec65905867fdf3",
623 | "version_major": 2,
624 | "version_minor": 0
625 | },
626 | "text/plain": [
627 | " 0%| | 0/1 [00:00, ?ex/s]"
628 | ]
629 | },
630 | "metadata": {},
631 | "output_type": "display_data"
632 | }
633 | ],
634 | "source": [
635 | "encoded_test_dataset = updated_test_dataset.map(lambda example: encode_example(example), \n",
636 | " features=features)"
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "execution_count": 15,
642 | "id": "fallen-hammer",
643 | "metadata": {},
644 | "outputs": [],
645 | "source": [
646 | "encoded_test_dataset.set_format(type='torch', columns=['input_ids', 'bbox', 'attention_mask', 'token_type_ids'])\n",
647 | "test_dataloader = torch.utils.data.DataLoader(encoded_test_dataset, batch_size=1, shuffle=True)\n",
648 | "test_batch = next(iter(test_dataloader))"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 16,
654 | "id": "accessory-thanksgiving",
655 | "metadata": {},
656 | "outputs": [
657 | {
658 | "name": "stdout",
659 | "output_type": "stream",
660 | "text": [
661 | "SequenceClassifierOutput(loss=None, logits=tensor([[-2.4451, -2.6868, 9.6513, -2.4914, -2.8206]], device='cuda:0',\n",
662 | " grad_fn=), hidden_states=None, attentions=None)\n"
663 | ]
664 | }
665 | ],
666 | "source": [
667 | "input_ids = test_batch[\"input_ids\"].to(device)\n",
668 | "bbox = test_batch[\"bbox\"].to(device)\n",
669 | "attention_mask = test_batch[\"attention_mask\"].to(device)\n",
670 | "token_type_ids = test_batch[\"token_type_ids\"].to(device)\n",
671 | "\n",
672 | "# forward pass\n",
673 | "outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, \n",
674 | " token_type_ids=token_type_ids)\n",
675 | "\n",
676 | "# prediction = int(torch.max(outputs.data, 1)[1].numpy())\n",
677 | "print(outputs)"
678 | ]
679 | },
680 | {
681 | "cell_type": "code",
682 | "execution_count": 17,
683 | "id": "documentary-hartford",
684 | "metadata": {},
685 | "outputs": [
686 | {
687 | "name": "stdout",
688 | "output_type": "stream",
689 | "text": [
690 | "bill: 0%\n",
691 | "invoice: 0%\n",
692 | "others: 100%\n",
693 | "Purchase_Order: 0%\n",
694 | "remittance: 0%\n"
695 | ]
696 | }
697 | ],
698 | "source": [
699 | "# import torch.nn.functional as F\n",
700 | "# pt_predictions = F.softmax(outputs[0], dim=-1)\n",
701 | "# pt_predictions\n",
702 | "\n",
703 | "classification_logits = outputs.logits\n",
704 | "classification_results = torch.softmax(classification_logits, dim=1).tolist()[0]\n",
705 | "for i in range(len(classes)):\n",
706 | " print(f\"{classes[i]}: {int(round(classification_results[i] * 100))}%\")"
707 | ]
708 | },
709 | {
710 | "cell_type": "code",
711 | "execution_count": 24,
712 | "id": "unsigned-luther",
713 | "metadata": {},
714 | "outputs": [
715 | {
716 | "name": "stdout",
717 | "output_type": "stream",
718 | "text": [
719 | "{'bill': '0%', 'invoice': '0%', 'others': '100%', 'Purchase_Order': '0%', 'remittance': '0%'}\n"
720 | ]
721 | }
722 | ],
723 | "source": [
724 | "thisdict ={}\n",
725 | "for i in range(len(classes)):\n",
726 | " thisdict[classes[i]] = str(int(round(classification_results[i] * 100))) + \"%\"\n",
727 | "print(thisdict)\n"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": 18,
733 | "id": "color-bones",
734 | "metadata": {},
735 | "outputs": [
736 | {
737 | "data": {
738 | "text/plain": [
739 | "tensor([[5.5799e-06, 4.3818e-06, 9.9998e-01, 5.3273e-06, 3.8329e-06]],\n",
740 | " device='cuda:0', grad_fn=)"
741 | ]
742 | },
743 | "execution_count": 18,
744 | "metadata": {},
745 | "output_type": "execute_result"
746 | }
747 | ],
748 | "source": [
749 | "import torch.nn.functional as F\n",
750 | "pt_predictions = F.softmax(outputs[0], dim=-1)\n",
751 | "pt_predictions"
752 | ]
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": 19,
757 | "id": "helpful-seventh",
758 | "metadata": {},
759 | "outputs": [
760 | {
761 | "data": {
762 | "text/plain": [
763 | "2"
764 | ]
765 | },
766 | "execution_count": 19,
767 | "metadata": {},
768 | "output_type": "execute_result"
769 | }
770 | ],
771 | "source": [
772 | "predictions = outputs.logits.argmax(-1).squeeze().tolist()\n",
773 | "predictions"
774 | ]
775 | },
776 | {
777 | "cell_type": "code",
778 | "execution_count": null,
779 | "id": "extra-workstation",
780 | "metadata": {},
781 | "outputs": [],
782 | "source": [
783 | "# NATIVE T5\n",
784 | "\n",
785 | "# generated_answer = model.generate(input_ids, attention_mask=attention_mask, \n",
786 | "# max_length=decoder_max_len, top_p=0.98, top_k=50)\n",
787 | "# decoded_answer = tokenizer.decode(generated_answer.numpy()[0])\n",
788 | "# print(\"Answer: \", decoded_answer)"
789 | ]
790 | },
791 | {
792 | "cell_type": "code",
793 | "execution_count": null,
794 | "id": "brilliant-uncertainty",
795 | "metadata": {},
796 | "outputs": [],
797 | "source": []
798 | }
799 | ],
800 | "metadata": {
801 | "kernelspec": {
802 | "display_name": "Python 3",
803 | "language": "python",
804 | "name": "python3"
805 | },
806 | "language_info": {
807 | "codemirror_mode": {
808 | "name": "ipython",
809 | "version": 3
810 | },
811 | "file_extension": ".py",
812 | "mimetype": "text/x-python",
813 | "name": "python",
814 | "nbconvert_exporter": "python",
815 | "pygments_lexer": "ipython3",
816 | "version": "3.7.10"
817 | }
818 | },
819 | "nbformat": 4,
820 | "nbformat_minor": 5
821 | }
822 |
--------------------------------------------------------------------------------
/Procfile:
--------------------------------------------------------------------------------
1 | web: sh setup.sh && streamlit run app.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Document-Classification-using-LayoutLM
2 | This PyTorch implementation of LayoutLM paper by Microsoft demonstrate the SequenceClassfication task using HuggingFaceTransformers to classify types of Documents.
3 |
4 |
5 | ## Star History
6 |
7 | [](https://star-history.com/#lucky-verma/Document-Classification-using-LayoutLM&Date)
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | allennlp==2.2.0
2 | altair==4.1.0
3 | anyio==2.2.0
4 | argon2-cffi==20.1.0
5 | astor==0.8.1
6 | async-generator==1.10
7 | atomicwrites==1.4.0
8 | attrs==20.3.0
9 | Babel==2.9.0
10 | backcall==0.2.0
11 | base58==2.1.0
12 | bleach==3.3.0
13 | blinker==1.4
14 | blis==0.4.1
15 | boto3==1.17.14
16 | botocore==1.20.14
17 | cached-property==1.5.2
18 | cachetools==4.2.1
19 | catalogue==1.0.0
20 | certifi==2020.12.5
21 | cffi==1.14.5
22 | chardet==4.0.0
23 | click==7.1.2
24 | colorama==0.4.4
25 | configparser==5.0.2
26 | cycler==0.10.0
27 | cymem==2.0.5
28 | datasets==1.5.0
29 | decorator==4.4.2
30 | defusedxml==0.7.1
31 | dill==0.3.3
32 | docker-pycreds==0.4.0
33 | entrypoints==0.3
34 | filelock==3.0.12
35 | fsspec==0.8.7
36 | gitdb==4.0.7
37 | GitPython==3.1.14
38 | h5py==3.1.0
39 | huggingface-hub==0.0.7
40 | idna==2.10
41 | importlib-metadata==3.10.0
42 | iniconfig==1.1.1
43 | ipykernel==5.5.3
44 | ipython==7.22.0
45 | ipython-genutils==0.2.0
46 | ipywidgets==7.6.3
47 | jedi==0.18.0
48 | Jinja2==2.11.3
49 | jmespath==0.10.0
50 | joblib==1.0.1
51 | json5==0.9.5
52 | jsonpickle==2.0.0
53 | jsonschema==3.2.0
54 | jupyter-client==6.1.12
55 | jupyter-core==4.7.1
56 | jupyter-packaging==0.7.12
57 | jupyter-server==1.5.1
58 | jupyterlab==3.0.12
59 | jupyterlab-pygments==0.1.2
60 | jupyterlab-server==2.4.0
61 | jupyterlab-widgets==1.0.0
62 | kiwisolver==1.3.1
63 | lmdb==1.1.1
64 | MarkupSafe==1.1.1
65 | matplotlib==3.4.1
66 | mistune==0.8.4
67 | more-itertools==8.7.0
68 | multiprocess==0.70.11.1
69 | murmurhash==1.0.5
70 | nbclassic==0.2.6
71 | nbclient==0.5.3
72 | nbconvert==6.0.7
73 | nbformat==5.1.2
74 | nest-asyncio==1.5.1
75 | nltk==3.5
76 | notebook==6.3.0
77 | numpy==1.20.2
78 | opencv-python==4.3.0.36
79 | overrides==3.1.0
80 | packaging==20.9
81 | pandas==1.0.5
82 | pandocfilters==1.4.3
83 | parso==0.8.1
84 | pathtools==0.1.2
85 | pickleshare==0.7.5
86 | Pillow==8.1.2
87 | plac==1.1.3
88 | plotly==4.14.3
89 | pluggy==0.13.1
90 | preshed==3.0.5
91 | prometheus-client==0.9.0
92 | promise==2.3
93 | prompt-toolkit==3.0.18
94 | protobuf==3.15.1
95 | psutil==5.8.0
96 | py==1.10.0
97 | pyarrow==3.0.0
98 | pycparser==2.20
99 | pydeck==0.6.1
100 | Pygments==2.8.1
101 | pyparsing==2.4.7
102 | pyrsistent==0.17.3
103 | pytesseract==0.3.7
104 | pytest==6.2.2
105 | python-dateutil==2.8.1
106 | pytz==2021.1
107 | pywin32==300
108 | pywinpty==0.5.7
109 | PyYAML==5.4.1
110 | pyzmq==22.0.3
111 | regex==2020.11.13
112 | requests==2.25.1
113 | retrying==1.3.3
114 | s3transfer==0.3.4
115 | sacremoses==0.0.43
116 | scikit-learn==0.24.1
117 | scipy==1.5.4
118 | Send2Trash==1.5.0
119 | sentencepiece==0.1.95
120 | sentry-sdk==1.0.0
121 | shortuuid==1.0.1
122 | six==1.15.0
123 | smmap==4.0.0
124 | sniffio==1.2.0
125 | spacy==2.2.4
126 | srsly==1.0.5
127 | streamlit==0.79.0
128 | subprocess32==3.5.4
129 | tabulate==0.8.7
130 | tensorboardX==2.1
131 | terminado==0.9.4
132 | testpath==0.4.4
133 | tez==0.1.2
134 | thinc==7.4.0
135 | threadpoolctl==2.1.0
136 | tokenizers==0.10.1
137 | toml==0.10.2
138 | toolz==0.11.1
139 | torch==1.8.1+cu102
140 | torchaudio==0.8.1
141 | torchtext==0.6.0
142 | torchvision==0.9.1+cu102
143 | tornado==6.1
144 | tqdm==4.59.0
145 | traitlets==5.0.5
146 | transformers==4.4.2
147 | typing-extensions==3.7.4.3
148 | tzlocal==2.1
149 | urllib3==1.26.4
150 | validators==0.18.2
151 | wandb==0.10.23
152 | wasabi==0.8.2
153 | watchdog==2.0.2
154 | wcwidth==0.2.5
155 | webencodings==0.5.1
156 | widgetsnbextension==3.5.1
157 | wincertstore==0.2
158 | xxhash==2.0.0
159 | zipp==3.4.1
160 |
--------------------------------------------------------------------------------
/saved_model/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "microsoft/layoutlm-base-uncased",
3 | "architectures": [
4 | "LayoutLMForSequenceClassification"
5 | ],
6 | "attention_probs_dropout_prob": 0.1,
7 | "gradient_checkpointing": false,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 768,
11 | "id2label": {
12 | "0": "LABEL_0",
13 | "1": "LABEL_1",
14 | "2": "LABEL_2",
15 | "3": "LABEL_3",
16 | "4": "LABEL_4"
17 | },
18 | "initializer_range": 0.02,
19 | "intermediate_size": 3072,
20 | "label2id": {
21 | "LABEL_0": 0,
22 | "LABEL_1": 1,
23 | "LABEL_2": 2,
24 | "LABEL_3": 3,
25 | "LABEL_4": 4
26 | },
27 | "layer_norm_eps": 1e-12,
28 | "max_2d_position_embeddings": 1024,
29 | "max_position_embeddings": 512,
30 | "model_type": "layoutlm",
31 | "num_attention_heads": 12,
32 | "num_hidden_layers": 12,
33 | "output_past": true,
34 | "pad_token_id": 0,
35 | "position_embedding_type": "absolute",
36 | "transformers_version": "4.4.2",
37 | "type_vocab_size": 2,
38 | "use_cache": true,
39 | "vocab_size": 30522
40 | }
41 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ~/.streamlit/
2 |
3 | echo "\
4 | [server]\n\
5 | headless = true\n\
6 | port = $PORT\n\
7 | enableCORS = false\n\
8 | \n\
9 | " > ~/.streamlit/config.toml
--------------------------------------------------------------------------------
/streamlit-app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from PIL import Image, ImageDraw, ImageFont
4 | from transformers import LayoutLMForSequenceClassification, LayoutLMTokenizer
5 | import torch
6 | import requests
7 | from torch.utils.data import Dataset, DataLoader
8 | import pytesseract
9 | from datasets import Features, Sequence, ClassLabel, Value, Array2D
10 | import numpy as np
11 | import streamlit as st
12 | from datasets import Dataset
13 | import plotly.figure_factory as ff
14 | import plotly.express as px
15 | from plotly.subplots import make_subplots
16 | import plotly.graph_objects as go
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | # Legacy method imports
21 |
22 | def normalize_box(box, width, height):
23 | return [
24 | int(1000 * (box[0] / width)),
25 | int(1000 * (box[1] / height)),
26 | int(1000 * (box[2] / width)),
27 | int(1000 * (box[3] / height)),
28 | ]
29 |
30 | def apply_ocr(example):
31 | # get the image
32 | image = Image.open(example['image_path'])
33 |
34 | width, height = image.size
35 |
36 | # apply ocr to the image
37 | ocr_df = pytesseract.image_to_data(image, output_type='data.frame')
38 | float_cols = ocr_df.select_dtypes('float').columns
39 | ocr_df = ocr_df.dropna().reset_index(drop=True)
40 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
41 | ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
42 | ocr_df = ocr_df.dropna().reset_index(drop=True)
43 |
44 | # get the words and actual (unnormalized) bounding boxes
45 | #words = [word for word in ocr_df.text if str(word) != 'nan'])
46 | words = list(ocr_df.text)
47 | words = [str(w) for w in words]
48 | coordinates = ocr_df[['left', 'top', 'width', 'height']]
49 | actual_boxes = []
50 | for idx, row in coordinates.iterrows():
51 | x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
52 | actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box
53 | actual_boxes.append(actual_box)
54 |
55 | # normalize the bounding boxes
56 | boxes = []
57 | for box in actual_boxes:
58 | boxes.append(normalize_box(box, width, height))
59 |
60 | # add as extra columns
61 | assert len(words) == len(boxes)
62 | example['words'] = words
63 | example['bbox'] = boxes
64 | return example
65 |
66 | tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
67 |
68 | def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]):
69 | words = example['words']
70 | normalized_word_boxes = example['bbox']
71 |
72 | assert len(words) == len(normalized_word_boxes)
73 |
74 | token_boxes = []
75 | for word, box in zip(words, normalized_word_boxes):
76 | word_tokens = tokenizer.tokenize(word)
77 | token_boxes.extend([box] * len(word_tokens))
78 |
79 | # Truncation of token_boxes
80 | special_tokens_count = 2
81 | if len(token_boxes) > max_seq_length - special_tokens_count:
82 | token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
83 |
84 | # add bounding boxes of cls + sep tokens
85 | token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
86 |
87 | encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)
88 | # Padding of token_boxes up the bounding boxes to the sequence length.
89 | input_ids = tokenizer(' '.join(words), truncation=True)["input_ids"]
90 | padding_length = max_seq_length - len(input_ids)
91 | token_boxes += [pad_token_box] * padding_length
92 | encoding['bbox'] = token_boxes
93 |
94 | assert len(encoding['input_ids']) == max_seq_length
95 | assert len(encoding['attention_mask']) == max_seq_length
96 | assert len(encoding['token_type_ids']) == max_seq_length
97 | assert len(encoding['bbox']) == max_seq_length
98 |
99 | return encoding
100 |
101 | # we need to define the features ourselves as the bbox of LayoutLM are an extra feature
102 | features = Features({
103 | 'input_ids': Sequence(feature=Value(dtype='int64')),
104 | 'bbox': Array2D(dtype="int64", shape=(512, 4)),
105 | 'attention_mask': Sequence(Value(dtype='int64')),
106 | 'token_type_ids': Sequence(Value(dtype='int64')),
107 | 'image_path': Value(dtype='string'),
108 | 'words': Sequence(feature=Value(dtype='string')),
109 | })
110 |
111 | classes = ["bill", "invoice", "others", "Purchase_Order", "remittance"]
112 |
113 |
114 | # Model Loading
115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116 | @st.cache(allow_output_mutation=True)
117 | def load_model():
118 | url = "https://vast-ml-models.s3-ap-southeast-2.amazonaws.com/Document-Classification-5-labels-final.bin"
119 | r = requests.get(url, allow_redirects=True)
120 | open('saved_model/pytorch_model.bin', 'wb').write(r.content)
121 | model = LayoutLMForSequenceClassification.from_pretrained("saved_model")
122 | return model
123 |
124 | load_model().to(device)
125 |
126 | # Data processing
127 |
128 | st.title('VAST: Document Classifier')
129 | st.header('Upload any document image')
130 | hide_streamlit_style = """
131 |
135 | """
136 | st.markdown(hide_streamlit_style, unsafe_allow_html=True)
137 |
138 |
139 | image = st.file_uploader('Upload here', type=['jpg', 'png', 'jpeg', 'webp'])
140 |
141 | if image is None:
142 | st.write("### Please upload your Invoice IMAGE")
143 | else:
144 | im = Image.open(image)
145 | rgb_im = im.convert('RGB')
146 | rgb_im.save('test_data/audacious.jpg')
147 | os.getcwd()
148 | test_data = pd.DataFrame.from_dict({'image_path': ['test_data/audacious.jpg']})
149 | st.image(image, caption='your_doc', use_column_width=True)
150 | if st.button("Process"):
151 | st.spinner()
152 | with st.spinner(text='In progress'):
153 | test_dataset = Dataset.from_pandas(test_data)
154 | updated_test_dataset = test_dataset.map(apply_ocr)
155 | st.success('OCR Done')
156 | encoded_test_dataset = updated_test_dataset.map(lambda example: encode_example(example),
157 | features=features)
158 | encoded_test_dataset.set_format(type='torch', columns=['input_ids', 'bbox', 'attention_mask', 'token_type_ids'])
159 | test_dataloader = torch.utils.data.DataLoader(encoded_test_dataset, batch_size=1, shuffle=True)
160 | test_batch = next(iter(test_dataloader))
161 | st.success('Encoding Data Done')
162 | input_ids = test_batch["input_ids"].to(device)
163 | bbox = test_batch["bbox"].to(device)
164 | attention_mask = test_batch["attention_mask"].to(device)
165 | token_type_ids = test_batch["token_type_ids"].to(device)
166 |
167 | # forward pass
168 | outputs = load_model()(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask,
169 | token_type_ids=token_type_ids)
170 |
171 | classification_logits = outputs.logits
172 | classification_results = torch.softmax(classification_logits, dim=1).tolist()[0]
173 |
174 | # Show JSON output
175 | thisdict ={}
176 | for i in range(len(classes)):
177 | thisdict[classes[i]] = str(int(round(classification_results[i] * 100))) + "%"
178 | st.json(thisdict)
179 |
180 | # Show a Plotly Graph
181 | res_list = []
182 | res_dict ={"Type of Document":["bill", "invoice", "others", "Purchase_Order", "remittance"],
183 | "Prediction Percent": res_list}
184 | for i in range(len(classes)):
185 | res_list.append(classification_results[i] * 100)
186 | res_dict[classes[i]] = int(round(classification_results[i] * 100))
187 | total_dataframe = pd.DataFrame(res_dict)
188 | state_total_graph = px.bar(
189 | total_dataframe,
190 | x='Type of Document',
191 | y='Prediction Percent',
192 | labels={'YOYO': 'Prediction Percent' }, color='Type of Document')
193 | st.plotly_chart(state_total_graph)
194 |
195 |
196 | st.success('Done')
197 | st.balloons()
198 |
199 |
--------------------------------------------------------------------------------