├── .gitignore ├── LICENSE ├── README.md ├── _images ├── invoicenet.png ├── logo.png └── trainer.png ├── extractor.py ├── install.sh ├── invoicenet ├── __init__.py ├── acp │ ├── __init__.py │ ├── acp.py │ ├── data.py │ └── model.py ├── common │ ├── __init__.py │ ├── data.py │ ├── model.py │ ├── trainer.py │ └── util.py ├── gui │ ├── __init__.py │ ├── custom_widgets.py │ ├── extractor.py │ ├── help_box.py │ ├── trainer.py │ └── viewer.py └── parsing │ ├── __init__.py │ ├── amount_generator.py │ ├── data.py │ ├── date_generator.py │ ├── parser.py │ └── parsers.py ├── models └── parsers │ ├── amount │ ├── best.data-00000-of-00001 │ └── best.index │ └── date │ ├── best.data-00000-of-00001 │ └── best.index ├── predict.py ├── prepare_data.py ├── setup.py ├── train.py ├── train_parser.py ├── trainer.py └── widgets ├── begin.png ├── clear_all.png ├── clear_page.png ├── extract.png ├── first.png ├── fullscreen.png ├── help.png ├── help.txt ├── labels.png ├── last.png ├── logo.png ├── next.png ├── next_file.png ├── ocr.png ├── open_dir.png ├── open_dir_small.png ├── open_file.png ├── options.png ├── prepare.png ├── prev.png ├── prev_file.png ├── rotate.png ├── save_as.png ├── search.png ├── stop.png ├── text.png ├── text_focus.png ├── toolbar.png ├── zoomin.png └── zoomout.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | .DS_Store 4 | invoicenet/acp/data/ 5 | invoicenet/parsing/data/ 6 | */__pycache__/* 7 | env/ 8 | data/ 9 | predictions/ 10 | processed_data/ 11 | invoices/ 12 | 13 | **/.vscode/ 14 | *.swp 15 | **/*.egg-info/ 16 | venv-*/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sarthak Mittal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![InvoiceNet Logo](_images/logo.png) 2 | 3 | -------------------------------------------------------------------------------- 4 | 5 | Deep neural network to extract intelligent information from invoice documents. 6 | 7 | **TL;DR** 8 | 9 | * An easy to use UI to view PDF/JPG/PNG invoices and extract information. 10 | * Train custom models using the Trainer UI on your own dataset. 11 | * Add or remove invoice fields as per your convenience. 12 | * Save the extracted information into your system with the click of a button. 13 | 14 | :star: We appreciate your star, it helps! 15 | 16 | The InvoiceNet logo was designed by [Sidhant Tibrewal](https://www.linkedin.com/in/sidhant-tibrewal-864058148/). 17 | [Check out](https://www.behance.net/tiber_sid) his work for some more beautiful designs. 18 | 19 | --- 20 | 21 | ![InvoiceNet](_images/invoicenet.png) 22 | 23 | --- 24 | 25 | **DISCLAIMER**: 26 | 27 | Pre-trained models for some general invoice fields are not available right now but will soon be provided. 28 | The training GUI and data preparation scripts have been made available. 29 | 30 | Invoice documents contain sensitive information because of which collecting a sizable dataset has proven to be difficult. 31 | This makes it difficult for developers like us to train large-scale generalised models and make them available to the community. 32 | 33 | If you have a dataset of invoice documents that you are comfortable sharing with us, please reach out (). 34 | We have the tools to create the first publicly-available large-scale invoice dataset along with a software platform for structured information extraction. 35 | 36 | --- 37 | 38 | ## Installation 39 | 40 | #### Ubuntu 20.04 41 | 42 | InvoiceNet has been developed and tested on **Ubuntu 20.04** with **CUDA Version: 11.8**, **cuDNN version: 8.9.7**, and **Tensorflow v2.13.1**. 43 | 44 | To install InvoiceNet on Ubuntu, run the following commands: 45 | 46 | ```bash 47 | git clone https://github.com/naiveHobo/InvoiceNet.git 48 | cd InvoiceNet/ 49 | 50 | # Run installation script 51 | ./install.sh 52 | ``` 53 | 54 | The install.sh script will install all the dependencies, create a virtual environment, and install InvoiceNet in the virtual environment. 55 | 56 | To be able to use InvoiceNet, you need to source the virtual environment that the package was installed in. 57 | 58 | ```bash 59 | # Source virtual environment 60 | source env/bin/activate 61 | ``` 62 | 63 | #### Windows 10 64 | 65 | The recommended way is to install InvoiceNet along with its dependencies in an Anaconda environment: 66 | 67 | ```bash 68 | git clone https://github.com/naiveHobo/InvoiceNet.git 69 | cd InvoiceNet/ 70 | 71 | # Create conda environment and activate 72 | conda create --name invoicenet python=3.7 73 | conda activate invoicenet 74 | 75 | # Install InvoiceNet 76 | pip install . 77 | 78 | # Install poppler 79 | conda install -c conda-forge poppler 80 | ``` 81 | 82 | Some dependencies also need to be installed separately on Windows 10 before running InvoiceNet: 83 | 84 | - [Tesseract 5.0.0](https://github.com/UB-Mannheim/tesseract/wiki) 85 | - [ImageMagick 7.0.10](https://imagemagick.org/script/download.php#windows) 86 | - [Ghostscript 9.52](https://www.ghostscript.com/download/gsdnld.html) 87 | 88 | 89 | 90 | ## Data Preparation 91 | The training data must be arranged in a single directory. The invoice documents are expected be PDF files and each invoice is expected to have a corresponding JSON label file with the same name. Your training data should be in the following format: 92 | 93 | ``` 94 | train_data/ 95 | invoice1.pdf 96 | invoice1.json 97 | nike-invoice.pdf 98 |     nike-invoice.json 99 | 12345.pdf 100 | 12345.json 101 |     ... 102 | ``` 103 | 104 | The JSON labels should have the following format: 105 | ``` 106 | { 107 | "vendor_name":"Nike", 108 | "invoice_date":"12-01-2017", 109 | "invoice_number":"R0007546449", 110 | "total_amount":"137.51", 111 | ... other fields 112 | } 113 | ``` 114 | 115 | To begin the data preparation process, click on the "Prepare Data" button in the GUI or follow the instructions below if you're using the CLI. 116 | 117 | 118 | ## Add Your Own Fields 119 | To add your own fields to InvoiceNet, open **invoicenet/\_\_init\_\_.py**. 120 | 121 | There are 4 pre-defined field types: 122 | - **FIELD_TYPES["general"]** : General field like names, address, invoice number, etc. 123 | - **FIELD_TYPES["optional"]** : Optional fields that might not be present in all invoices. 124 | - **FIELD_TYPES["amount"]** : Fields that represent an amount. 125 | - **FIELD_TYPES["date"]** : Fields that represent a date. 126 | 127 | Choose the appropriate field type for the field and add the line mentioned below. 128 | 129 | ```python 130 | # Add the following line at the end of the file 131 | 132 | # For example, to add a field total_amount 133 | FIELDS["total_amount"] = FIELD_TYPES["amount"] 134 | 135 | # For example, to add a field invoice_date 136 | FIELDS["invoice_date"] = FIELD_TYPES["date"] 137 | 138 | # For example, to add a field tax_id (which might be optional) 139 | FIELDS["tax_id"] = FIELD_TYPES["optional"] 140 | 141 | # For example, to add a field vendor_name 142 | FIELDS["vendor_name"] = FIELD_TYPES["general"] 143 | ``` 144 | 145 | 146 | ## Using the GUI 147 | InvoiceNet provides you with a GUI to train a model on your data and extract information from invoice documents using this trained model 148 | 149 | ![Trainer](_images/trainer.png) 150 | 151 | 152 | Run the following command to run the trainer GUI: 153 | 154 | ```bash 155 | python trainer.py 156 | ``` 157 | 158 | Run the following command to run the extractor GUI: 159 | 160 | ```bash 161 | python extractor.py 162 | ``` 163 | 164 | You need to prepare the data for training first. 165 | You can do so by setting the **Data Folder** field to the directory containing your training data and the clicking the **Prepare Data** button. 166 | Once the data is prepared, you can start training by clicking the **Start** button. 167 | 168 | 169 | ## Using the CLI 170 | 171 | ### Training 172 | 173 | Prepare the data for training first by running the following command: 174 | ```bash 175 | python prepare_data.py --data_dir train_data/ 176 | ``` 177 | 178 | Train InvoiceNet using the following command: 179 | ```bash 180 | python train.py --field enter-field-here --batch_size 8 181 | 182 | # For example, for field 'total_amount' 183 | python train.py --field total_amount --batch_size 8 184 | ``` 185 | 186 | --- 187 | 188 | ### Prediction 189 | If you are trying to use different ocr, change the ocr_engine in this function before running predict.py [create_ngrams.py](https://github.com/naiveHobo/InvoiceNet/blob/e883158a690726afd1de5b76b5810287013577c6/invoicenet/common/util.py#L193) 190 | 191 | --- 192 | 193 | #### Single invoice 194 | To extract a field from a single invoice file, run the following command: 195 | ```bash 196 | python predict.py --field enter-field-here --invoice path-to-invoice-file 197 | 198 | # For example, to extract field total_amount from an invoice file invoices/1.pdf 199 | python predict.py --field total_amount --invoice invoices/1.pdf 200 | ``` 201 | 202 | --- 203 | 204 | #### Multiple invoices 205 | For extracting information using the trained InvoiceNet model, you just need to place the PDF invoice documents in one directory in the following format: 206 | 207 | ``` 208 | predict_data/ 209 | invoice1.pdf 210 | invoice2.pdf 211 |     ... 212 | ``` 213 | 214 | Run InvoiceNet using the following command: 215 | ```bash 216 | python predict.py --field enter-field-here --data_dir predict_data/ 217 | 218 | # For example, for field 'total_amount' 219 | python predict.py --field total_amount --data_dir predict_data/ 220 | ``` 221 | --- 222 | 223 | ## Reference 224 | This implementation is largely based on the work of R. Palm et al, who should be cited if this is used in a scientific publication (or the preceding conference papers): 225 | 226 | [1] Palm, Rasmus Berg, Florian Laws, and Ole Winther. **"Attend, Copy, Parse End-to-end information extraction from documents."** 2019 International Conference on Document Analysis and Recognition (ICDAR). IEEE, 2019. 227 | 228 | ```bibtex 229 | @inproceedings{palm2019attend, 230 | title={Attend, Copy, Parse End-to-end information extraction from documents}, 231 | author={Palm, Rasmus Berg and Laws, Florian and Winther, Ole}, 232 | booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)}, 233 | pages={329--336}, 234 | year={2019}, 235 | organization={IEEE} 236 | } 237 | ``` 238 | 239 | ### Note 240 | An implementation of an inferior (also slightly broken) invoice handling system based on the paper **"Cloudscan - A configuration-free invoice analysis system using recurrent neural networks."** is available [here](https://github.com/naiveHobo/InvoiceNet/tree/cloudscan). 241 | 242 | [2] Palm, Rasmus Berg, Ole Winther, and Florian Laws. **"Cloudscan - A configuration-free invoice analysis system using recurrent neural networks."** 2017 14th IAPR International Conference on Document Analysis and Recognition (ICDAR). Vol. 1. IEEE, 2017. 243 | 244 | ```bibtex 245 | @inproceedings{palm2017cloudscan, 246 | title={Cloudscan-a configuration-free invoice analysis system using recurrent neural networks}, 247 | author={Palm, Rasmus Berg and Winther, Ole and Laws, Florian}, 248 | booktitle={2017 14th IAPR International Conference on Document Analysis and Recognition (ICDAR)}, 249 | volume={1}, 250 | pages={406--413}, 251 | year={2017}, 252 | organization={IEEE} 253 | } 254 | ``` 255 | -------------------------------------------------------------------------------- /_images/invoicenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/_images/invoicenet.png -------------------------------------------------------------------------------- /_images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/_images/logo.png -------------------------------------------------------------------------------- /_images/trainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/_images/trainer.png -------------------------------------------------------------------------------- /extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from tkinter import Tk 22 | from invoicenet.gui.extractor import Extractor 23 | 24 | 25 | def main(): 26 | root = Tk() 27 | Extractor(root) 28 | root.mainloop() 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # install dependencies 4 | sudo apt install -y tesseract-ocr poppler-utils libxext-dev libsm-dev libxrender-dev 5 | 6 | # create virtual environment 7 | virtualenv env -p python3 8 | source env/bin/activate 9 | 10 | pip install . 11 | -------------------------------------------------------------------------------- /invoicenet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | FIELD_TYPES = { 22 | "general": 0, 23 | "optional": 1, 24 | "amount": 2, 25 | "date": 3 26 | } 27 | 28 | FIELDS = dict() 29 | 30 | FIELDS["invoice_number"] = FIELD_TYPES["general"] 31 | FIELDS["vendor_name"] = FIELD_TYPES["general"] 32 | 33 | FIELDS["invoice_date"] = FIELD_TYPES["date"] 34 | 35 | FIELDS["net_amount"] = FIELD_TYPES["amount"] 36 | FIELDS["tax_amount"] = FIELD_TYPES["amount"] 37 | FIELDS["total_amount"] = FIELD_TYPES["amount"] 38 | -------------------------------------------------------------------------------- /invoicenet/acp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | -------------------------------------------------------------------------------- /invoicenet/acp/acp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from ..common.model import Model 27 | from .data import InvoiceData 28 | from .model import AttendCopyParseModel 29 | from .. import FIELD_TYPES, FIELDS 30 | from ..parsing.parsers import DateParser, AmountParser, NoOpParser, OptionalParser 31 | 32 | 33 | class AttendCopyParse(Model): 34 | 35 | def __init__(self, field, restore=False): 36 | self.field = field 37 | 38 | self.restore_all_path = './models/invoicenet/{}/best'.format(self.field) if restore else None 39 | os.makedirs("./models/invoicenet", exist_ok=True) 40 | 41 | if FIELDS[field] == FIELD_TYPES["optional"]: 42 | noop_parser = NoOpParser() 43 | parser = OptionalParser(noop_parser, 128) 44 | elif FIELDS[field] == FIELD_TYPES["amount"]: 45 | parser = AmountParser() 46 | elif FIELDS[field] == FIELD_TYPES["date"]: 47 | parser = DateParser() 48 | else: 49 | parser = NoOpParser() 50 | 51 | restore = parser.restore() 52 | if restore is not None: 53 | print("Restoring %s parser %s..." % (self.field, restore)) 54 | tf.train.Checkpoint(model=parser).read(restore).expect_partial() 55 | 56 | self.model = AttendCopyParseModel(parser=parser) 57 | 58 | self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy( 59 | from_logits=True, 60 | reduction=tf.keras.losses.Reduction.NONE) 61 | 62 | self.optimizer = tf.keras.optimizers.Nadam(learning_rate=3e-4) 63 | 64 | self.model.compile(self.optimizer) 65 | 66 | self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model) 67 | 68 | if self.restore_all_path: 69 | if not os.path.exists('./models/invoicenet/{}'.format(self.field)): 70 | raise Exception("No trained model available for the field '{}'".format(self.field)) 71 | print("Restoring all " + self.restore_all_path + "...") 72 | self.checkpoint.read(self.restore_all_path).expect_partial() 73 | 74 | def loss_func(self, y_true, y_pred): 75 | mask = tf.cast(tf.logical_not(tf.equal(y_true, InvoiceData.pad_idx)), dtype=tf.float32) # (bs, seq) 76 | label_cross_entropy = tf.reduce_sum( 77 | self.loss_object(y_true, y_pred) * mask, axis=1) / tf.reduce_sum(mask, axis=1) 78 | field_loss = tf.reduce_mean(label_cross_entropy) 79 | loss = field_loss + sum(self.model.losses) 80 | return loss 81 | 82 | @tf.function 83 | def train_step(self, inputs): 84 | inputs, targets = inputs[:-1], inputs[-1] 85 | with tf.GradientTape() as tape: 86 | predictions = self.model(inputs, training=True) 87 | loss = self.loss_func(targets, predictions) 88 | gradients = tape.gradient(loss, self.model.trainable_variables) 89 | self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) 90 | return loss 91 | 92 | @tf.function 93 | def val_step(self, inputs): 94 | inputs, targets = inputs[:-1], inputs[-1] 95 | predictions = self.model(inputs, training=False) 96 | loss = self.loss_func(targets, predictions) 97 | return loss 98 | 99 | def predict(self, paths): 100 | data = InvoiceData(field=self.field) 101 | shapes, types = data.shapes()[:-1], data.types()[:-1] 102 | 103 | def _transform(i, v, s, *args): 104 | return (tf.SparseTensor(i, v, s),) + args 105 | 106 | dataset = tf.data.Dataset.from_generator( 107 | data.generate_test_data(paths), 108 | types, 109 | shapes 110 | ).map(_transform) \ 111 | .batch(batch_size=1, drop_remainder=False) 112 | 113 | predictions = [] 114 | for sample in dataset: 115 | try: 116 | logits = self.model(sample, training=False) 117 | chars = tf.argmax(logits, axis=2, output_type=tf.int32).numpy() 118 | predictions.extend(data.array_to_str(chars)) 119 | except tf.errors.OutOfRangeError: 120 | break 121 | 122 | return predictions 123 | 124 | def save(self, name): 125 | self.checkpoint.write(file_prefix="./models/invoicenet/%s/%s" % (self.field, name)) 126 | 127 | def load(self, name): 128 | self.checkpoint.read(name) 129 | -------------------------------------------------------------------------------- /invoicenet/acp/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import re 23 | import glob 24 | import hashlib 25 | import json 26 | import random 27 | import string 28 | import pdf2image 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | from PIL import Image 33 | from decimal import Decimal 34 | 35 | from .. import FIELDS, FIELD_TYPES 36 | from ..common import util 37 | from ..common.data import Data, UnkDict 38 | 39 | random.seed(0) 40 | 41 | 42 | class InvoiceData(Data): 43 | 44 | im_size = 128, 128 45 | chars = ['', '', ''] + list(string.printable) 46 | output_dict = UnkDict(chars) 47 | n_output = len(output_dict) 48 | pad_idx = 0 49 | eos_idx = 1 50 | unk_idx = 2 51 | word_hash_size = 2 ** 16 52 | pattern_hash_size = 2 ** 14 53 | seq_in = 4 * 32 54 | seq_amount = 16 55 | seq_date = 11 56 | seq_long = 128 57 | 58 | seq_out = { 59 | FIELD_TYPES["general"]: seq_long, 60 | FIELD_TYPES["optional"]: seq_long, 61 | FIELD_TYPES["amount"]: seq_amount, 62 | FIELD_TYPES["date"]: seq_date 63 | } 64 | 65 | n_memories = 4 66 | parses_idx = {'date': 0, 'amount': 1} 67 | 68 | def __init__(self, field, data_dir=None): 69 | self.field = field 70 | self.filenames = [] 71 | if data_dir: 72 | self.filenames = glob.glob(data_dir + "**/*.json", recursive=True) 73 | 74 | def types(self): 75 | return ( 76 | tf.int64, # i 77 | tf.float32, # v 78 | tf.int64, # s 79 | tf.float32, # pixels 80 | tf.int32, # word_indices 81 | tf.int32, # pattern_indices 82 | tf.int32, # char_indices 83 | tf.float32, # memory_mask 84 | tf.float32, # parses 85 | tf.int32 # target 86 | ) 87 | 88 | def shapes(self): 89 | return ( 90 | (None, 5), # i 91 | (None,), # v 92 | (None,), # s 93 | InvoiceData.im_size + (3,), # pixels 94 | InvoiceData.im_size, # word_indices 95 | InvoiceData.im_size, # pattern_indices 96 | InvoiceData.im_size, # char_indices 97 | InvoiceData.im_size, # memory_mask 98 | InvoiceData.im_size + (self.n_memories, 2), # parses 99 | (InvoiceData.seq_out[FIELDS[self.field]],) # target 100 | ) 101 | 102 | def _encode_ngrams(self, n_grams, height, width): 103 | v_ar = self.im_size[0] / height 104 | h_ar = self.im_size[1] / width 105 | 106 | max_v = self.im_size[0] - 1 107 | max_h = self.im_size[1] - 1 108 | 109 | pattern_indices = np.zeros(self.im_size) 110 | word_indices = np.zeros(self.im_size, np.int32) 111 | char_indices = np.zeros(self.im_size, np.int32) 112 | memory_mask = np.zeros(self.im_size, np.float32) 113 | 114 | parses = np.zeros(self.im_size + (self.n_memories, 2)) 115 | memory_indices = [] 116 | for n_gram in n_grams: 117 | words = n_gram["words"] 118 | m_idx = len(words) - 1 119 | word = words[0] 120 | 121 | left = min(round(word['left'] * h_ar), max_h) 122 | right = min(round(word['right'] * h_ar), max_h) 123 | top = min(round(word['top'] * v_ar), max_v) 124 | bottom = min(round(word['bottom'] * v_ar), max_v) 125 | 126 | for k, v in n_gram['parses'].items(): 127 | parses[top:bottom + 1, left:right + 1, m_idx, self.parses_idx[k]] = 1.0 128 | 129 | chars = " ".join([w['text'] for w in words])[:self.seq_in - 1] 130 | char_idx = [self.output_dict[c] for c in chars] + [self.eos_idx] 131 | char_pos = range(len(char_idx)) 132 | 133 | self.append_indices(top, bottom, left, right, m_idx, char_idx, char_pos, memory_indices) 134 | 135 | if len(words) == 1: 136 | text = words[0]['text'] 137 | memory_mask[top, left] = 1.0 138 | 139 | pattern = text 140 | pattern = re.sub(r"[A-Z]", "X", pattern) 141 | pattern = re.sub(r"[a-z]", "x", pattern) 142 | pattern = re.sub(r"[0-9]", "0", pattern) 143 | pattern = re.sub(r"[^Xx0]", "-", pattern) 144 | 145 | pattern_idx = (int(hashlib.md5(str.encode(pattern)).hexdigest(), 16) % (self.pattern_hash_size - 1)) + 1 146 | pattern_indices[top:bottom + 1, left:right + 1] = pattern_idx 147 | 148 | w_idx = (int(hashlib.md5(str.encode(text)).hexdigest(), 16) % (self.word_hash_size - 1)) + 1 149 | word_indices[top:bottom + 1, left:right + 1] = w_idx 150 | 151 | for cidx, p in zip(char_idx[:-1], np.linspace(left, right, len(char_idx[:-1]))): 152 | char_indices[top:bottom + 1, int(round(p))] = cidx 153 | 154 | assert len(memory_indices) > 0 155 | memory_values = [1.] * len(memory_indices) 156 | memory_dense_shape = self.im_size + (self.n_memories, self.seq_in, self.n_output) 157 | 158 | return ( 159 | word_indices, 160 | pattern_indices, 161 | char_indices, 162 | memory_mask, 163 | parses, 164 | memory_indices, 165 | memory_values, 166 | memory_dense_shape 167 | ) 168 | 169 | def append_indices(self, top, bottom, left, right, m_idx, char_idx, char_pos, indices): 170 | assert 0 <= m_idx < self.n_memories, m_idx 171 | 172 | assert top <= bottom, (top, bottom) 173 | assert left <= right, (left, right) 174 | assert 0 <= top < self.im_size[0], top 175 | assert 0 <= bottom < self.im_size[0], bottom 176 | assert 0 <= left < self.im_size[1], left 177 | assert 0 <= right < self.im_size[1], right 178 | 179 | for cp_idx, ci_idx in zip(char_pos, char_idx): 180 | assert 0 <= cp_idx < self.seq_in, cp_idx 181 | assert 0 <= ci_idx < self.n_output, ci_idx 182 | 183 | for cp_idx, ci_idx in zip(char_pos, char_idx): 184 | indices.append((top, left, m_idx, cp_idx, ci_idx)) 185 | 186 | def encode_image(self, page): 187 | im = Image.open(page["filename"]) 188 | im = im.convert('RGB').resize(self.im_size[::-1], Image.Resampling.LANCZOS) 189 | pixels = (np.asarray(im, np.float32) / 255. - 0.5) * 2. 190 | return pixels 191 | 192 | @staticmethod 193 | def _preprocess_amount(value): 194 | return '{:f}'.format(Decimal(value).normalize()) 195 | 196 | def _load_document(self, doc_id): 197 | with open(doc_id, encoding="utf8") as fp: 198 | page = json.load(fp) 199 | 200 | pixels = self.encode_image(page) 201 | n_grams = page['nGrams'] 202 | 203 | word_indices, pattern_indices, char_indices, memory_mask, parses, i, v, s = self._encode_ngrams(n_grams, 204 | page['height'], 205 | page['width']) 206 | 207 | target = page['fields'][self.field] 208 | if FIELDS[self.field] == FIELD_TYPES["amount"]: 209 | target = self._preprocess_amount(target) 210 | target = InvoiceData.encode_sequence(target, self.seq_out[FIELDS[self.field]]) 211 | 212 | return i, v, s, pixels, word_indices, pattern_indices, char_indices, memory_mask, parses, target 213 | 214 | def array_to_str(self, arr): 215 | """ 216 | :param arr: (bs, seq) int32 217 | """ 218 | strs = [] 219 | for r in arr: 220 | s = "" 221 | for c in r: 222 | if c == self.eos_idx: 223 | break 224 | else: 225 | s += self.output_dict.idx2key(c) 226 | strs.append(s) 227 | return strs 228 | 229 | def sample_generator(self): 230 | exceptions = 0 231 | np.random.seed(0) 232 | random.shuffle(self.filenames) 233 | 234 | for i, doc_id in enumerate(self.filenames): 235 | try: 236 | yield self._load_document(doc_id.strip()) 237 | except GeneratorExit: 238 | return 239 | except Exception as exp: 240 | print("Exception: {} : {}".format(doc_id, exp)) 241 | exceptions += 1 242 | 243 | def _process_pdf(self, path): 244 | pixels = pdf2image.convert_from_path(path)[0] 245 | height = pixels.size[1] 246 | width = pixels.size[0] 247 | 248 | ngrams = util.create_ngrams(pixels, height, width) 249 | for ngram in ngrams: 250 | if "amount" in ngram["parses"]: 251 | ngram["parses"]["amount"] = util.normalize(ngram["parses"]["amount"], key="amount") 252 | if "date" in ngram["parses"]: 253 | ngram["parses"]["date"] = util.normalize(ngram["parses"]["date"], key="date") 254 | 255 | page = { 256 | "nGrams": ngrams, 257 | "height": height, 258 | "width": width, 259 | "filename": path 260 | } 261 | 262 | pixels = pixels.convert('RGB').resize(self.im_size[::-1], Image.Resampling.LANCZOS) 263 | pixels = (np.asarray(pixels, np.float32) / 255. - 0.5) * 2. 264 | 265 | n_grams = page['nGrams'] 266 | 267 | word_indices, pattern_indices, char_indices, memory_mask, parses, i, v, s = self._encode_ngrams(n_grams, 268 | page['height'], 269 | page['width']) 270 | 271 | return i, v, s, pixels, word_indices, pattern_indices, char_indices, memory_mask, parses 272 | 273 | def generate_test_data(self, paths: list): 274 | if not isinstance(paths, list): 275 | raise Exception("This function assumes the input is a list of paths") 276 | 277 | def _generator(): 278 | exceptions = 0 279 | for idx, path in enumerate(paths): 280 | try: 281 | yield self._process_pdf(path) 282 | except Exception as exp: 283 | print("Exception: {} : {}".format(path, exp)) 284 | exceptions += 1 285 | 286 | return _generator 287 | 288 | @staticmethod 289 | def encode_sequence(value, max_len): 290 | encoded = [InvoiceData.output_dict[c] for c in list(value)[:max_len - 1]] + [InvoiceData.eos_idx] 291 | encoded += [InvoiceData.pad_idx] * (max_len - len(encoded)) 292 | return encoded 293 | 294 | @staticmethod 295 | def create_dataset(data_dir, field, batch_size): 296 | data = InvoiceData(field=field, data_dir=data_dir) 297 | shapes, types = data.shapes(), data.types() 298 | 299 | def _transform(i, v, s, *args): 300 | return (tf.SparseTensor(i, v, s),) + args 301 | 302 | return tf.data.Dataset.from_generator( 303 | data.sample_generator, 304 | types, 305 | shapes 306 | ).map(_transform) \ 307 | .repeat(-1) \ 308 | .batch(batch_size=batch_size, drop_remainder=True) \ 309 | .prefetch(2) 310 | -------------------------------------------------------------------------------- /invoicenet/acp/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import tensorflow as tf 23 | 24 | from .data import InvoiceData 25 | 26 | 27 | class DilatedConvBlock(tf.keras.layers.Layer): 28 | 29 | def __init__(self, filters): 30 | super(DilatedConvBlock, self).__init__() 31 | self.conv_layers = [ 32 | tf.keras.layers.Conv2D( 33 | filters=filters, 34 | kernel_size=3, 35 | padding='same', 36 | dilation_rate=rate, 37 | kernel_regularizer=tf.keras.regularizers.l2(1e-4) 38 | ) 39 | for rate in [1, 2, 4, 8] 40 | ] 41 | 42 | def call(self, inputs, training=None, mask=None): 43 | return tf.concat([conv(inputs) for conv in self.conv_layers], axis=3) 44 | 45 | 46 | class AttendBlock(tf.keras.layers.Layer): 47 | 48 | def __init__(self, embed_size=32, frac_ce_loss=0.0001): 49 | super(AttendBlock, self).__init__() 50 | 51 | self.frac_ce_loss = frac_ce_loss 52 | self.embed_size = embed_size 53 | self.word_embed = tf.keras.layers.Embedding( 54 | input_dim=InvoiceData.word_hash_size, 55 | output_dim=embed_size, 56 | input_length=InvoiceData.im_size[0] * InvoiceData.im_size[1], 57 | name="word_embeddings") 58 | self.pattern_embed = tf.keras.layers.Embedding( 59 | input_dim=InvoiceData.pattern_hash_size, 60 | output_dim=embed_size, 61 | input_length=InvoiceData.im_size[0] * InvoiceData.im_size[1], 62 | name="pattern_embeddings") 63 | self.char_embed = tf.keras.layers.Embedding( 64 | input_dim=InvoiceData.n_output, 65 | output_dim=embed_size, 66 | input_length=InvoiceData.im_size[0] * InvoiceData.im_size[1], 67 | name="char_embeddings") 68 | 69 | self.conv_block = tf.keras.Sequential() 70 | for _ in range(4): 71 | self.conv_block.add(DilatedConvBlock(embed_size)) 72 | 73 | self.dropout = tf.keras.layers.Dropout(0.5) 74 | self.conv_att = tf.keras.layers.Conv2D( 75 | filters=InvoiceData.n_memories, 76 | kernel_size=3, 77 | padding='same', 78 | kernel_regularizer=tf.keras.regularizers.l2(1e-4) 79 | ) 80 | 81 | def call(self, inputs, training=None, mask=None): 82 | pixels, word_indices, pattern_indices, char_indices, memory_mask, parses = inputs 83 | 84 | # pixels: (bs, h, w) 85 | # word_indices: (bs, h, w) 86 | # pattern_indices: (bs, h, w) 87 | # char_indices: (bs, h, w) 88 | # memory_mask: (bs, h, w, m, l, d) 89 | # parses: (bs, h, w, 4, 2) 90 | 91 | bs = tf.shape(pixels)[0] 92 | h, w = InvoiceData.im_size[0], InvoiceData.im_size[1] 93 | 94 | X, Y = tf.meshgrid(tf.linspace(0.0, 1.0, InvoiceData.im_size[1]), tf.linspace(0.0, 1.0, InvoiceData.im_size[0])) 95 | X = tf.tile(X[None, ..., None], (bs, 1, 1, 1)) 96 | Y = tf.tile(Y[None, ..., None], (bs, 1, 1, 1)) 97 | 98 | word_embeddings = tf.reshape( 99 | self.word_embed(tf.reshape(word_indices, (bs, -1))), 100 | (bs, h, w, self.embed_size) 101 | ) 102 | 103 | pattern_embeddings = tf.reshape( 104 | self.pattern_embed(tf.reshape(pattern_indices, (bs, -1))), 105 | (bs, h, w, self.embed_size) 106 | ) 107 | 108 | char_embeddings = tf.reshape( 109 | self.char_embed(tf.reshape(char_indices, (bs, -1))), 110 | (bs, h, w, self.embed_size) 111 | ) 112 | 113 | pixels = tf.reshape(pixels, (bs, h, w, 3)) 114 | parses = tf.reshape(parses, (bs, h, w, InvoiceData.n_memories * 2)) 115 | memory_mask = tf.reshape(memory_mask, (bs, h, w, 1)) 116 | x = tf.concat([pixels, word_embeddings, pattern_embeddings, char_embeddings, parses, X, Y, memory_mask], 117 | axis=3) 118 | 119 | x = self.conv_block(x) 120 | x = self.dropout(x, training=training) 121 | 122 | pre_att_logits = x 123 | att_logits = self.conv_att(x) # (bs, h, w, n_memories) 124 | att_logits = memory_mask * att_logits - ( 125 | 1.0 - memory_mask) * 1000 # TODO only sum the memory_mask idx, in the softmax 126 | 127 | logits = tf.reshape(att_logits, (bs, -1)) # (bs, h * w * n_memories) 128 | logits -= tf.reduce_max(logits, axis=1, keepdims=True) 129 | lp = tf.math.log_softmax(logits, axis=1) # (bs, h * w * n_memories) 130 | p = tf.math.softmax(logits, axis=1) # (bs, h * w * n_memories) 131 | 132 | spatial_attention = tf.reshape(p, (bs, h * w * InvoiceData.n_memories, 1, 1)) # (bs, h * w * n_memories, 1, 1) 133 | 134 | p_uniform = memory_mask / tf.reduce_sum(memory_mask, axis=(1, 2, 3), keepdims=True) 135 | cross_entropy_uniform = -tf.reduce_sum(p_uniform * tf.reshape(lp, (bs, h, w, InvoiceData.n_memories)), 136 | axis=(1, 2, 3)) # (bs, 1) 137 | 138 | cp = tf.reduce_sum(tf.reshape(p, (bs, h, w, InvoiceData.n_memories)), axis=3, keepdims=True) 139 | 140 | context = tf.reduce_sum(cp * pre_att_logits, axis=(1, 2)) # (bs, 4*n_hidden) 141 | 142 | self.add_loss(self.frac_ce_loss * tf.reduce_mean(cross_entropy_uniform)) 143 | 144 | return spatial_attention, context 145 | 146 | 147 | class AttendCopyParseModel(tf.keras.Model): 148 | """ 149 | You should pre-train this parser to parse dates otherwise it's hard to learn jointly. 150 | """ 151 | def __init__(self, parser): 152 | super(AttendCopyParseModel, self).__init__() 153 | self.parser = parser 154 | self.attend = AttendBlock(32) 155 | 156 | def call(self, inputs, training=None, mask=None): 157 | memories, pixels, word_indices, pattern_indices, char_indices, memory_mask, parses = inputs 158 | 159 | spatial_attention, context = self.attend(inputs=(pixels, 160 | word_indices, 161 | pattern_indices, 162 | char_indices, 163 | memory_mask, 164 | parses), 165 | training=training) 166 | 167 | # Copy 168 | memories = tf.sparse.reshape(memories, 169 | (-1, InvoiceData.im_size[0] * InvoiceData.im_size[1] * InvoiceData.n_memories, 170 | InvoiceData.seq_in, 171 | InvoiceData.n_output)) 172 | x = tf.reshape(tf.sparse.reduce_sum(spatial_attention * memories, axis=1), 173 | (-1, InvoiceData.seq_in, InvoiceData.n_output)) # (bs, seq_in, n_out) 174 | 175 | # Parse 176 | parsed = self.parser(inputs=(x, context), training=training) 177 | return parsed 178 | -------------------------------------------------------------------------------- /invoicenet/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | -------------------------------------------------------------------------------- /invoicenet/common/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | 23 | class Data: 24 | def sample_generator(self): 25 | raise NotImplementedError 26 | 27 | def types(self): 28 | raise NotImplementedError 29 | 30 | def shapes(self): 31 | raise NotImplementedError 32 | 33 | def array_to_str(self, arr): 34 | raise NotImplementedError 35 | 36 | 37 | class UnkDict: 38 | unk = '' 39 | 40 | def __init__(self, items): 41 | if self.unk not in items: 42 | raise ValueError("items must contain %s", self.unk) 43 | 44 | self.delegate = dict([(c, i) for i, c in enumerate(items)]) 45 | self.rdict = {i: c for c, i in self.delegate.items()} 46 | 47 | def __getitem__(self, item): 48 | if item in self.delegate: 49 | return self.delegate[item] 50 | else: 51 | return self.delegate[self.unk] 52 | 53 | def __len__(self): 54 | return len(self.delegate) 55 | 56 | def idx2key(self, idx): 57 | return self.rdict[idx] 58 | -------------------------------------------------------------------------------- /invoicenet/common/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | 23 | class Model: 24 | 25 | def train_step(self, inputs): 26 | raise NotImplementedError() 27 | 28 | def val_step(self, inputs): 29 | raise NotImplementedError 30 | 31 | def load(self, name): 32 | raise NotImplementedError 33 | 34 | def save(self, name): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /invoicenet/common/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import time 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from invoicenet.common.model import Model 27 | 28 | 29 | def train(model: Model, 30 | train_data: tf.data.Dataset, 31 | val_data: tf.data.Dataset, 32 | total_steps=50000, 33 | early_stop_steps=0): 34 | 35 | print_interval = 20 36 | no_improvement_steps = 0 37 | best = float("inf") 38 | 39 | train_iter = iter(train_data) 40 | val_iter = iter(val_data) 41 | 42 | start = time.time() 43 | for step in range(total_steps): 44 | try: 45 | train_loss = model.train_step(next(train_iter)) 46 | except StopIteration: 47 | print("Couldn't find any training data! Have you prepared your training data?") 48 | print("Terminating...") 49 | break 50 | 51 | if not np.isfinite(train_loss): 52 | raise ValueError("NaN loss") 53 | 54 | if step % print_interval == 0: 55 | took = time.time() - start 56 | 57 | try: 58 | val_loss = model.val_step(next(val_iter)) 59 | except StopIteration: 60 | print("Couldn't find any validation data! Have you prepared your training data?") 61 | print("Terminating...") 62 | break 63 | 64 | print("[%d/%d | %.2f steps/s]: train loss: %.4f val loss: %.4f" % ( 65 | step, total_steps, (step + 1) / took, train_loss, val_loss)) 66 | if not np.isfinite(val_loss): 67 | raise ValueError("NaN loss") 68 | if val_loss < best: 69 | no_improvement_steps = 0 70 | best = val_loss 71 | model.save("best") 72 | elif early_stop_steps > 0: 73 | no_improvement_steps += print_interval 74 | if no_improvement_steps >= early_stop_steps: 75 | print("Validation loss has not improved for {} steps, terminating!".format(no_improvement_steps)) 76 | return 77 | -------------------------------------------------------------------------------- /invoicenet/common/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import re 23 | import io 24 | import math 25 | import datetime 26 | import datefinder 27 | import pytesseract 28 | from pytesseract import Output 29 | 30 | import os 31 | from PIL import Image 32 | from googleapiclient.discovery import build 33 | import base64 34 | from google.cloud import vision 35 | 36 | ## API keys for google ocr 37 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="google_api_keys.json" 38 | 39 | 40 | class TextParser: 41 | 42 | def __init__(self): 43 | self.template = dict() 44 | self.template['amount'] = [r'\d+[,\d]*\.\d+'] 45 | self.template['date'] = [r'\d{1,2}[\/\\\.\,-]\d{1,2}[\/\\\.\,-]\d{2,4}', 46 | r'\d{2,4}[\/\\\.\,-]\d{1,2}[\/\\\.\,-]\d{1,2}'] 47 | 48 | def parse(self, text, key): 49 | if key == 'date': 50 | try: 51 | matches = [date for date in datefinder.find_dates(text) if date <= datetime.datetime.today()] 52 | if matches: 53 | return True 54 | else: 55 | return False 56 | except Exception: 57 | return False 58 | if key not in self.template: 59 | return False 60 | for regex in self.template[key]: 61 | if re.findall(regex, text): 62 | return True 63 | return False 64 | 65 | def find(self, text, key): 66 | if key == 'date': 67 | try: 68 | matches = [date for date in datefinder.find_dates(text) if date <= datetime.datetime.today()] 69 | if len(matches) > 0: 70 | return [match.strftime('%m-%d-%Y') for match in matches] 71 | else: 72 | return [] 73 | except Exception: 74 | return [] 75 | values = [] 76 | if key not in self.template: 77 | return values 78 | for regex in self.template[key]: 79 | values.extend(re.findall(regex, text)) 80 | values = list(set(values)) 81 | return values 82 | 83 | def replace(self, text, new, key): 84 | if key not in self.template: 85 | return text 86 | for regex in self.template[key]: 87 | text = re.sub(regex, new, text) 88 | while ' ' in text: 89 | text = text.replace(' ', ' ') 90 | return text 91 | 92 | 93 | def extract_words(img, height, width, ocr_engine='pytesseract'): 94 | if ocr_engine == 'pytesseract': 95 | data = pytesseract.image_to_data(img, output_type=Output.DICT) 96 | n_boxes = len(data['text']) 97 | words = [ 98 | { 99 | 'text': data['text'][i], 100 | 'left': data['left'][i], 101 | 'top': data['top'][i], 102 | 'right': data['left'][i] + data['width'][i], 103 | 'bottom': data['top'][i] + data['height'][i] 104 | } 105 | for i in range(n_boxes) if data['text'][i] 106 | ] 107 | return words 108 | 109 | 110 | elif ocr_engine=='google_ocr': 111 | img_byte_arr = io.BytesIO() 112 | img.save(img_byte_arr, format='PNG') 113 | img_byte_arr = img_byte_arr.getvalue() 114 | client = vision.ImageAnnotatorClient() 115 | content=img_byte_arr 116 | image_ = vision.Image(content=content) 117 | response = client.text_detection(image=image_) 118 | texts = response.text_annotations 119 | 120 | words=[] 121 | first=True 122 | for text in texts: 123 | if first: 124 | first=False 125 | continue 126 | data={} 127 | data['text']=text.description 128 | x_vert=[] 129 | y_vert=[] 130 | for vertex in text.bounding_poly.vertices: 131 | x_vert.append(vertex.x) 132 | y_vert.append(vertex.y) 133 | data['left']=min(x_vert) 134 | data['right']=max(x_vert) 135 | data['top']=min(y_vert) 136 | data['bottom']=max(y_vert) 137 | words.append(data) 138 | return words 139 | 140 | 141 | elif ocr_engine == 'aws_textract': 142 | 143 | import boto3 144 | 145 | # use aws textract 146 | client = boto3.client('textract') 147 | 148 | # convert PpmImageFile to byte 149 | img_byte_arr = io.BytesIO() 150 | img.save(img_byte_arr, format='PNG') 151 | img_byte_arr = img_byte_arr.getvalue() 152 | 153 | # call aws-textract API 154 | response = client.detect_document_text(Document={'Bytes': img_byte_arr}) 155 | 156 | # get image weight and height to convert normalized coordinate from response 157 | words = [ 158 | { 159 | 'text': data['Text'], 160 | 'left': math.floor((data['Geometry']['BoundingBox']['Left']) * width), 161 | 'top': math.floor((data['Geometry']['BoundingBox']['Top']) * height), 162 | 'right': math.ceil( 163 | (data['Geometry']['BoundingBox']['Left'] + data['Geometry']['BoundingBox']['Width']) * width), 164 | 'bottom': math.ceil( 165 | (data['Geometry']['BoundingBox']['Top'] + data['Geometry']['BoundingBox']['Height']) * height) 166 | } for data in response['Blocks'] if "Text" in data 167 | ] 168 | return words 169 | 170 | 171 | def divide_into_lines(words, height, width): 172 | cur = words[0] 173 | lines = [] 174 | line = [] 175 | for word in words: 176 | if ((word['top'] - cur['top']) / height) > 0.005: 177 | # if difference between y-coordinate of current word and previous word 178 | # is more than 0.5% of the height, consider the current word to be in the next line 179 | lines.append(line) 180 | line = [word] 181 | elif ((word['left'] - cur['right']) / width) > 0.05: 182 | # if difference between x-coordinate of current word and previous word 183 | # is more than 5% of the width, consider the current word to be in a different line 184 | lines.append(line) 185 | line = [word] 186 | else: 187 | line.append(word) 188 | cur = word 189 | lines.append(line) 190 | return lines 191 | 192 | 193 | def create_ngrams(img, height, width, length=4, ocr_engine='pytesseract'):##change ocr_engine here while running predict.py 194 | words = extract_words(img, height=height, width=width, ocr_engine=ocr_engine) 195 | lines = divide_into_lines(words, height=img.size[1], width=img.size[0]) 196 | tokens = [line[i:i + N] for line in lines for N in range(1, length + 1) for i in range(len(line) - N + 1)] 197 | ngrams = [] 198 | parser = TextParser() 199 | 200 | for token in tokens: 201 | text = ' '.join([word['text'] for word in token]) 202 | ngram = { 203 | "words": token, 204 | "parses": {} 205 | } 206 | if parser.parse(text=text, key='date'): 207 | ngram["parses"]["date"] = parser.find(text=text, key='date')[0] 208 | elif parser.parse(text=text, key='amount'): 209 | ngram["parses"]["amount"] = parser.find(text=text, key='amount')[0] 210 | ngrams.append(ngram) 211 | 212 | return ngrams 213 | 214 | 215 | def normalize(text, key): 216 | if key == 'amount': 217 | text = text.replace(",", '') 218 | splits = text.split('.') 219 | if len(splits) == 1: 220 | text += ".00" 221 | else: 222 | text = splits[0] + '.' + splits[1][:2] 223 | else: 224 | matches = [date for date in datefinder.find_dates(text) if date <= datetime.datetime.today()] 225 | if matches: 226 | text = matches[0].strftime('%m-%d-%Y') 227 | return text 228 | -------------------------------------------------------------------------------- /invoicenet/gui/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | -------------------------------------------------------------------------------- /invoicenet/gui/custom_widgets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import threading 22 | from tkinter import * 23 | from tkinter import scrolledtext 24 | from PIL import Image, ImageTk 25 | 26 | 27 | class StoppableThread(threading.Thread): 28 | 29 | def __init__(self, *args, **kwargs): 30 | super(StoppableThread, self).__init__(*args, **kwargs) 31 | self._stop_event = threading.Event() 32 | 33 | def stop(self): 34 | self._stop_event.set() 35 | 36 | def stopped(self): 37 | return self._stop_event.is_set() 38 | 39 | 40 | class ToolTip: 41 | 42 | def __init__(self, widget, text): 43 | self.waittime = 500 44 | self.wraplength = 180 45 | self.widget = widget 46 | self.text = text 47 | self.widget.bind("", self.enter) 48 | self.widget.bind("", self.leave) 49 | self.widget.bind("", self.leave) 50 | self.idx = None 51 | self.tw = None 52 | 53 | def enter(self, event=None): 54 | self.schedule() 55 | 56 | def leave(self, event=None): 57 | self.unschedule() 58 | self.hidetip() 59 | 60 | def schedule(self): 61 | self.unschedule() 62 | self.idx = self.widget.after(self.waittime, self.showtip) 63 | 64 | def unschedule(self): 65 | idx = self.idx 66 | self.idx = None 67 | if idx: 68 | self.widget.after_cancel(idx) 69 | 70 | def showtip(self, event=None): 71 | x, y, cx, cy = self.widget.bbox("insert") 72 | x += self.widget.winfo_rootx() + 25 73 | y += self.widget.winfo_rooty() + 20 74 | self.tw = Toplevel(self.widget) 75 | self.tw.wm_overrideredirect(True) 76 | self.tw.wm_geometry("+%d+%d" % (x, y)) 77 | label = Label(self.tw, text=self.text, justify='left', 78 | background="#ffffff", relief='solid', borderwidth=1, 79 | wraplength=self.wraplength) 80 | label.pack(ipadx=1) 81 | 82 | def hidetip(self): 83 | tw = self.tw 84 | self.tw = None 85 | if tw: 86 | tw.destroy() 87 | 88 | 89 | class MenuBox(Frame): 90 | 91 | def __init__(self, master=None, image_path=None, background='#303030', highlight='#558de8', **kw): 92 | Frame.__init__(self, master, **kw) 93 | 94 | self.menu_button = Menubutton(self, width=50, height=50, bg=background, bd=0, 95 | highlightthickness=0, activebackground=highlight) 96 | 97 | if image_path: 98 | self.image = ImageTk.PhotoImage(Image.open(image_path)) 99 | self.menu_button.configure(image=self.image) 100 | 101 | self.menu = Menu(self.menu_button, tearoff=False, bg=background, 102 | fg='white', bd=2, activebackground=highlight) 103 | 104 | self.menu_button.config(menu=self.menu) 105 | self.menu_button.pack(side=LEFT) 106 | 107 | self.menu_button.bind("", lambda e: self.menu_button.event_generate('<>')) 108 | 109 | def add_item(self, title, func, seperator=False): 110 | self.menu.add_command(label=title, command=func) 111 | if seperator: 112 | self.menu.add_separator() 113 | 114 | 115 | class HoverButton(Button): 116 | 117 | def __init__(self, master, tool_tip=None, image_path=None, **kw): 118 | Button.__init__(self, master=master, **kw) 119 | self.defaultBackground = self["background"] 120 | self.bind("", self.on_enter) 121 | self.bind("", self.on_leave) 122 | if image_path: 123 | self.image = ImageTk.PhotoImage(Image.open(image_path)) 124 | self.configure(image=self.image) 125 | if tool_tip: 126 | ToolTip(self, text=tool_tip) 127 | 128 | def on_click(self, e): 129 | if self['background'] == self.defaultBackground: 130 | self['background'] = self['activebackground'] 131 | else: 132 | self['background'] = self.defaultBackground 133 | 134 | def on_enter(self, e): 135 | self['background'] = self['activebackground'] 136 | 137 | def on_leave(self, e): 138 | self['background'] = self.defaultBackground 139 | 140 | 141 | class DisplayCanvas(Frame): 142 | 143 | def __init__(self, master, background='#404040', highlight='#558de8', **kw): 144 | Frame.__init__(self, master, **kw) 145 | self.x = self.y = 0 146 | 147 | self.rowconfigure(0, weight=1) 148 | self.rowconfigure(1, weight=0) 149 | self.columnconfigure(0, weight=1) 150 | self.columnconfigure(1, weight=0) 151 | 152 | self.canvas = Canvas(self, bg=background, highlightbackground='#353535') 153 | self.canvas.grid(row=0, column=0, sticky='news') 154 | 155 | self.sbarv = Scrollbar(self, orient=VERTICAL, bg=background, highlightbackground=highlight) 156 | self.sbarh = Scrollbar(self, orient=HORIZONTAL, bg=background, highlightbackground=highlight) 157 | self.sbarv.config(command=self.canvas.yview) 158 | self.sbarh.config(command=self.canvas.xview) 159 | 160 | self.canvas.config(yscrollcommand=self.sbarv.set) 161 | self.canvas.config(xscrollcommand=self.sbarh.set) 162 | 163 | self.sbarv.grid(row=0, column=1, stick=N+S) 164 | self.sbarh.grid(row=1, column=0, sticky=E+W) 165 | 166 | self.canvas.bind('', self.on_left) 167 | self.canvas.bind('', self.on_right) 168 | self.canvas.bind('', self.on_up) 169 | self.canvas.bind('', self.on_down) 170 | 171 | self.canvas.bind("", self.on_button_press) 172 | self.canvas.bind("", self.on_move_press) 173 | self.canvas.bind("", self.on_button_release) 174 | 175 | self.rect = None 176 | self.image = None 177 | self.image_obj = None 178 | self.pil_image = None 179 | self.draw = False 180 | 181 | self.start_x = None 182 | self.start_y = None 183 | 184 | self.canvas.config(scrollregion=self.canvas.bbox("all")) 185 | 186 | def on_button_press(self, event): 187 | self.canvas.focus_set() 188 | self.start_x = self.canvas.canvasx(event.x) 189 | self.start_y = self.canvas.canvasy(event.y) 190 | 191 | if not self.rect and self.draw: 192 | self.rect = self.canvas.create_rectangle(self.x, self.y, 1, 1, outline='red') 193 | 194 | def on_move_press(self, event): 195 | cur_x = self.canvas.canvasx(event.x) 196 | cur_y = self.canvas.canvasy(event.y) 197 | 198 | w, h = self.canvas.winfo_width(), self.canvas.winfo_height() 199 | if event.x > 0.9*w: 200 | self.on_right() 201 | elif event.x < 0.1*w: 202 | self.on_left() 203 | if event.y > 0.9*h: 204 | self.on_down() 205 | elif event.y < 0.1*h: 206 | self.on_up() 207 | 208 | if self.draw: 209 | self.canvas.coords(self.rect, self.start_x, self.start_y, cur_x, cur_y) 210 | 211 | def on_left(self, event=None): 212 | self.canvas.xview_scroll(-1, 'units') 213 | 214 | def on_right(self, event=None): 215 | self.canvas.xview_scroll(1, 'units') 216 | 217 | def on_up(self, event=None): 218 | self.canvas.yview_scroll(-1, 'units') 219 | 220 | def on_down(self, event=None): 221 | self.canvas.yview_scroll(1, 'units') 222 | 223 | def on_button_release(self, event): 224 | pass 225 | 226 | def update_image(self, image): 227 | self.draw = True 228 | self.pil_image = image 229 | self.image = ImageTk.PhotoImage(image) 230 | if self.image_obj is None: 231 | self.image_obj = self.canvas.create_image(1, 1, image=self.image, anchor=CENTER) 232 | else: 233 | self.canvas.itemconfig(self.image_obj, image=self.image) 234 | self.sbarv.config(command=self.canvas.yview) 235 | self.sbarh.config(command=self.canvas.xview) 236 | self.canvas.config(scrollregion=self.canvas.bbox("all")) 237 | self.canvas.xview_moveto(0.0) 238 | self.canvas.yview_moveto(0.0) 239 | 240 | def clear(self): 241 | self.canvas.delete("all") 242 | self.image_obj = self.canvas.create_image(1, 1, image=self.image, anchor=CENTER) 243 | self.sbarv.config(command=self.canvas.yview) 244 | self.sbarh.config(command=self.canvas.xview) 245 | self.canvas.config(yscrollcommand=self.sbarv.set) 246 | self.canvas.config(xscrollcommand=self.sbarh.set) 247 | self.canvas.config(scrollregion=self.canvas.bbox("all")) 248 | self.rect = None 249 | 250 | def reset(self): 251 | self.canvas.delete("all") 252 | self.rect = None 253 | self.image = None 254 | self.image_obj = None 255 | self.pil_image = None 256 | self.draw = False 257 | 258 | def get_rect(self): 259 | w, h = self.pil_image.size 260 | x0, y0 = self.canvas.coords(self.image_obj) 261 | minx = x0 - w / 2.0 262 | miny = y0 - h / 2.0 263 | if self.rect: 264 | rect = self.canvas.coords(self.rect) 265 | rect = [rect[0] + abs(minx), rect[1] + abs(miny), rect[2] + abs(minx), rect[3] + abs(miny)] 266 | return rect 267 | else: 268 | return None 269 | 270 | 271 | class Logger(Frame): 272 | 273 | def __init__(self, master=None, height=18, disable=True, **kw): 274 | Frame.__init__(self, master, **kw) 275 | self.text = scrolledtext.ScrolledText(self, height=height, 276 | bg='#002b36', fg='#eee8d5', insertbackground='#eee8d5') 277 | self.text.pack(expand=True, padx=50) 278 | self.disable = disable 279 | if self.disable: 280 | self.text.configure(state='disabled') 281 | 282 | def log(self, msg): 283 | self.text.configure(state='normal') 284 | self.text.insert(END, msg + '\n') 285 | if self.disable: 286 | self.text.configure(state='disabled') 287 | self.text.yview(END) 288 | 289 | def get(self): 290 | return self.text.get("1.0", END) 291 | 292 | def clear(self): 293 | self.text.configure(state='normal') 294 | self.text.delete('1.0', END) 295 | if self.disable: 296 | self.text.configure(state='disabled') 297 | -------------------------------------------------------------------------------- /invoicenet/gui/extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import io 23 | import subprocess 24 | import tempfile 25 | import PyPDF2 26 | import pdfplumber 27 | import pytesseract 28 | import simplejson 29 | from tkinter import * 30 | from tkinter import filedialog, messagebox 31 | from PIL import Image, ImageTk 32 | from wand.exceptions import WandException 33 | 34 | from .. import FIELDS 35 | from ..acp.acp import AttendCopyParse 36 | from .custom_widgets import MenuBox, HoverButton, Logger, StoppableThread 37 | from .help_box import HelpBox 38 | from .viewer import PDFViewer 39 | 40 | 41 | class Extractor(Frame): 42 | 43 | def __init__(self, master=None, **kw): 44 | Frame.__init__(self, master, **kw) 45 | self.background = '#303030' 46 | self.border_color = '#404040' 47 | self.checkbox_color = '#333333' 48 | self.highlight_color = '#558de8' 49 | self.pdf = None 50 | self.paths = list() 51 | self.pathidx = -1 52 | self.checkboxes = {} 53 | self.thread = None 54 | self.running = False 55 | self.save_dir = '.' 56 | self._init_ui() 57 | 58 | def _init_ui(self): 59 | ws = self.master.winfo_screenwidth() 60 | hs = self.master.winfo_screenheight() 61 | h = hs - 100 62 | w = (int(h / 1.414) + 100) * 2 63 | x = (ws / 2) - (w / 2) 64 | y = (hs / 2) - (h / 2) 65 | self.master.geometry('%dx%d+%d+%d' % (w, h, x, y)) 66 | self.master.maxsize(w, h) 67 | self.master.minsize(w, h) 68 | self.master.title("InvoiceNet") 69 | 70 | self.pack(fill=BOTH, expand=True) 71 | 72 | self.rowconfigure(0, weight=1) 73 | self.columnconfigure(0, weight=0) 74 | self.columnconfigure(1, weight=1) 75 | self.columnconfigure(2, weight=0) 76 | 77 | tool_frame = Frame(self, bg=self.background, bd=0, relief=SUNKEN) 78 | self.viewer = PDFViewer(self) 79 | interface = Frame(self, bg=self.background, bd=0, relief=SUNKEN, 80 | highlightbackground=self.border_color, highlightthickness=1) 81 | 82 | tool_frame.grid(row=0, column=0, sticky='news') 83 | self.viewer.grid(row=0, column=1, sticky='news') 84 | interface.grid(row=0, column=2, sticky='news') 85 | 86 | # Tool Frame 87 | tool_frame.columnconfigure(0, weight=1) 88 | tool_frame.rowconfigure(0, weight=0) 89 | tool_frame.rowconfigure(1, weight=1) 90 | tool_frame.rowconfigure(2, weight=0) 91 | tool_frame.rowconfigure(3, weight=2) 92 | 93 | options = MenuBox(tool_frame, image_path=r'widgets/options.png', bg=self.background) 94 | options.grid(row=0, column=0) 95 | 96 | options.add_item('Open Files...', self._open_file) 97 | options.add_item('Open Directory...', self._open_dir, seperator=True) 98 | options.add_item('Set Save Directory...', self._set_save_path, seperator=True) 99 | options.add_item('Next File', self._next_file) 100 | options.add_item('Previous File', self._prev_file, seperator=True) 101 | options.add_item('Clear Page', self.viewer.clear) 102 | options.add_item('Search Text', self.viewer.search_text) 103 | options.add_item('Extract Text', self.viewer.extract_text) 104 | options.add_item('Run OCR', self._run_ocr, seperator=True) 105 | options.add_item('Clear Invoice Queue', self._clear_queue, seperator=True) 106 | options.add_item('Help...', self._help, seperator=True) 107 | options.add_item('Exit', self.master.quit) 108 | 109 | tools = Frame(tool_frame, bg=self.background, bd=0, relief=SUNKEN) 110 | tools.grid(row=2, column=0) 111 | 112 | HoverButton(tools, image_path=r'widgets/open_file.png', command=self._open_file, 113 | width=50, height=50, bg=self.background, bd=0, tool_tip="Open Files", 114 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 115 | HoverButton(tools, image_path=r'widgets/open_dir.png', command=self._open_dir, 116 | width=50, height=50, bg=self.background, bd=0, tool_tip="Open Directory", 117 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 118 | HoverButton(tools, image_path=r'widgets/save_as.png', command=self._set_save_path, 119 | width=50, height=50, bg=self.background, bd=0, tool_tip="Set Save Directory", 120 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 121 | HoverButton(tools, image_path=r'widgets/clear_page.png', command=self.viewer.clear, 122 | width=50, height=50, bg=self.background, bd=0, tool_tip="Clear Page", 123 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 124 | HoverButton(tools, image_path=r'widgets/search.png', command=self.viewer.search_text, 125 | width=50, height=50, bg=self.background, bd=0, tool_tip="Search Text", 126 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 127 | HoverButton(tools, image_path=r'widgets/extract.png', command=self.viewer.extract_text, 128 | width=50, height=50, bg=self.background, bd=0, tool_tip="Extract Text", 129 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 130 | HoverButton(tools, image_path=r'widgets/ocr.png', command=self._run_ocr, 131 | width=50, height=50, bg=self.background, bd=0, tool_tip="Run OCR", 132 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 133 | HoverButton(tools, image_path=r'widgets/clear_all.png', command=self._clear_queue, 134 | width=50, height=50, bg=self.background, bd=0, tool_tip="Clear Invoice Queue", 135 | highlightthickness=0, activebackground=self.highlight_color).pack(pady=2) 136 | 137 | file_frame = Frame(tools, width=50, height=50, bg=self.background, bd=0, relief=SUNKEN) 138 | file_frame.pack(pady=2) 139 | 140 | file_frame.columnconfigure(0, weight=1) 141 | file_frame.columnconfigure(1, weight=1) 142 | 143 | HoverButton(file_frame, image_path=r'widgets/prev_file.png', command=self._prev_file, 144 | width=25, height=50, bg=self.background, bd=0, tool_tip="Previous File", 145 | highlightthickness=0, activebackground=self.highlight_color).grid(row=0, column=0) 146 | HoverButton(file_frame, image_path=r'widgets/next_file.png', command=self._next_file, 147 | width=25, height=50, bg=self.background, bd=0, tool_tip="Next File", 148 | highlightthickness=0, activebackground=self.highlight_color).grid(row=0, column=1) 149 | 150 | self.doc_label = Label(file_frame, bg=self.background, bd=0, fg='white', font=("Arial", 8), text="0 of 0") 151 | self.doc_label.grid(row=1, column=0, columnspan=2, pady=4, sticky='news') 152 | 153 | HoverButton(tool_frame, image_path=r'widgets/help.png', command=self._help, 154 | width=50, height=50, bg=self.background, bd=0, tool_tip="Help", 155 | highlightthickness=0, activebackground=self.highlight_color).grid(row=3, column=0, sticky='s') 156 | 157 | # Interface Frame 158 | interface.columnconfigure(0, weight=1) 159 | interface.rowconfigure(0, weight=0) 160 | interface.rowconfigure(1, weight=1) 161 | interface.rowconfigure(2, weight=1) 162 | 163 | logo_frame = Frame(interface, bg=self.background, bd=0, relief=SUNKEN, 164 | highlightbackground=self.border_color, highlightthickness=1) 165 | param_frame = Frame(interface, bg=self.background, bd=0, relief=SUNKEN, 166 | highlightbackground=self.border_color, highlightthickness=1) 167 | main_frame = Frame(interface, bg=self.background, bd=0, relief=SUNKEN, 168 | highlightbackground=self.border_color, highlightthickness=1) 169 | 170 | logo_frame.grid(row=0, column=0, sticky='news') 171 | param_frame.grid(row=1, column=0, sticky='news') 172 | main_frame.grid(row=2, column=0, sticky='news') 173 | 174 | # Logo Frame 175 | logo_frame.columnconfigure(0, weight=1) 176 | logo_frame.columnconfigure(1, weight=0) 177 | logo_frame.columnconfigure(2, weight=0) 178 | logo_frame.columnconfigure(3, weight=1) 179 | logo_frame.rowconfigure(0, weight=1) 180 | 181 | self.logo_img = ImageTk.PhotoImage(Image.open(r'widgets/logo.png')) 182 | Label(logo_frame, bg=self.background, image=self.logo_img).grid(row=0, column=1, sticky='news', pady=10) 183 | Label(logo_frame, text="InvoiceNet", bg=self.background, 184 | fg="white", font=("Arial", 24, "bold")).grid(row=0, column=2, sticky='news', padx=20, pady=10) 185 | 186 | # Param Frame 187 | param_frame.columnconfigure(0, weight=1) 188 | param_frame.columnconfigure(1, weight=2) 189 | param_frame.columnconfigure(2, weight=1) 190 | param_frame.rowconfigure(0, weight=1) 191 | param_frame.rowconfigure(1, weight=2) 192 | param_frame.rowconfigure(2, weight=1) 193 | 194 | self.start_button = HoverButton(param_frame, image_path=r'widgets/begin.png', command=self._start, 195 | text='Extract', compound='center', font=("Arial", 10, "bold"), bd=0, 196 | bg=self.background, highlightthickness=0, activebackground=self.background) 197 | self.start_button.grid(row=2, column=1, pady=20, sticky='news') 198 | 199 | field_param = Frame(param_frame, bg=self.background, bd=0, relief=SUNKEN, 200 | highlightbackground=self.border_color, highlightthickness=0) 201 | 202 | field_param.grid(row=1, column=1, sticky='news') 203 | 204 | # Field Checkboxes 205 | field_frame = Frame(field_param, bg=self.checkbox_color, bd=0, relief=SUNKEN, 206 | highlightbackground=self.border_color, highlightthickness=1) 207 | field_frame.pack(expand=True, fill=BOTH) 208 | 209 | Label(field_frame, text="Field:", width=30, bg=self.checkbox_color, 210 | anchor='w', fg="white", font=("Arial", 12, "bold")).pack(side=TOP, fill=X, padx=5, pady=5) 211 | 212 | checkbox_frame = Frame(field_frame, bg=self.checkbox_color, bd=0, relief=SUNKEN, 213 | highlightbackground=self.border_color, highlightthickness=1) 214 | checkbox_frame.pack(expand=True, fill=BOTH, side=BOTTOM) 215 | 216 | checkbox_frame.columnconfigure(0, weight=1) 217 | checkbox_frame.columnconfigure(1, weight=1) 218 | checkbox_frame.columnconfigure(2, weight=1) 219 | checkbox_frame.columnconfigure(3, weight=1) 220 | for i in range(len(FIELDS) // 2): 221 | checkbox_frame.rowconfigure(i, weight=1) 222 | for idx, key in enumerate(FIELDS): 223 | self.checkboxes[key] = BooleanVar(checkbox_frame, value=False) 224 | state = False 225 | if os.path.exists('./models/invoicenet/'): 226 | state = key in os.listdir('./models/invoicenet/') 227 | 228 | Checkbutton(checkbox_frame, fg="black", bg=self.checkbox_color, 229 | activebackground=self.checkbox_color, variable=self.checkboxes[key], 230 | state="normal" if state else "disabled", highlightthickness=0).grid(row=idx // 2, 231 | column=2 if idx % 2 else 0, 232 | sticky='news', padx=(10, 0)) 233 | Label(checkbox_frame, text=key, bg=self.checkbox_color, 234 | fg="white", font=("Arial", 12, "bold")).grid(row=idx // 2, column=3 if idx % 2 else 1, sticky='nws') 235 | 236 | # Main Frame 237 | main_frame.columnconfigure(0, weight=1) 238 | main_frame.columnconfigure(1, weight=1) 239 | main_frame.columnconfigure(2, weight=1) 240 | main_frame.rowconfigure(0, weight=1) 241 | main_frame.rowconfigure(1, weight=1) 242 | main_frame.rowconfigure(2, weight=1) 243 | main_frame.rowconfigure(3, weight=1) 244 | 245 | self.logger = Logger(main_frame, disable=False, height=18, bg=self.background, bd=0, relief=SUNKEN) 246 | self.logger.grid(row=1, column=1, sticky='news') 247 | 248 | button_frame = Frame(main_frame, bg=self.background, bd=0, relief=SUNKEN, 249 | highlightbackground=self.border_color, highlightthickness=0) 250 | button_frame.grid(row=2, column=1, sticky='news') 251 | 252 | button_frame.rowconfigure(0, weight=1) 253 | button_frame.columnconfigure(0, weight=1) 254 | button_frame.columnconfigure(1, weight=0) 255 | button_frame.columnconfigure(2, weight=0) 256 | button_frame.columnconfigure(3, weight=1) 257 | 258 | HoverButton(button_frame, image_path=r'widgets/labels.png', command=self._save_info, text='Save Information', 259 | compound='center', font=("Arial", 10, "bold"), bd=0, bg=self.background, 260 | highlightthickness=0, activebackground=self.background).grid(row=0, column=1, padx=10) 261 | 262 | HoverButton(button_frame, image_path=r'widgets/labels.png', command=self._load_labels, text='Load Labels', 263 | compound='center', font=("Arial", 10, "bold"), bd=0, bg=self.background, 264 | highlightthickness=0, activebackground=self.background).grid(row=0, column=2, padx=10) 265 | 266 | def _extract(self): 267 | path = self.paths[self.pathidx] 268 | 269 | self.logger.clear() 270 | self.logger.log("Extracting information from '{}'...\n".format(path)) 271 | 272 | temp = None 273 | if path.split('.')[-1].lower() in ['jpg', 'png']: 274 | image = Image.open(path) 275 | pdf = pytesseract.image_to_pdf_or_hocr(image, extension='pdf') 276 | temp = tempfile.NamedTemporaryFile(suffix='.pdf') 277 | temp.write(pdf) 278 | path = temp.name 279 | 280 | predictions = {} 281 | for key in FIELDS: 282 | if self.checkboxes[key].get(): 283 | model = AttendCopyParse(field=key, restore=True) 284 | predictions[key] = model.predict(paths=[path])[0] 285 | 286 | if temp is not None: 287 | temp.close() 288 | 289 | self.viewer.label(labels=predictions) 290 | self.logger.log(simplejson.dumps(predictions, indent=2, sort_keys=True)) 291 | self.start_button.configure(state='normal') 292 | self.running = False 293 | 294 | def _start(self): 295 | if not self.paths: 296 | messagebox.showerror("Error", "No PDF has been loaded!") 297 | return 298 | 299 | selected = False 300 | for key in FIELDS: 301 | if self.checkboxes[key].get(): 302 | selected = True 303 | break 304 | 305 | if not selected: 306 | messagebox.showerror("Error", "No fields were selected!") 307 | return 308 | 309 | if not self.running: 310 | self.running = True 311 | self.thread = StoppableThread(target=self._extract) 312 | self.thread.daemon = True 313 | self.thread.start() 314 | self.start_button.configure(state='disabled') 315 | 316 | def _load_labels(self): 317 | if self.pdf is None: 318 | messagebox.showerror("Error", "Load an invoice first!") 319 | return 320 | 321 | label_file = filedialog.askopenfile(filetypes=[('JSON files', '*.json'), ('all files', '.*')], 322 | initialdir=self.save_dir, title="Select label file") 323 | if label_file is None: 324 | return 325 | 326 | try: 327 | labels = simplejson.load(label_file) 328 | except simplejson.errors.JSONDecodeError: 329 | messagebox.showerror("Error", "JSON file is invalid!") 330 | return 331 | 332 | match = re.findall(r"[^{]*{([^}]+)}", self.logger.get()) 333 | if match: 334 | log_data = simplejson.loads('{' + ''.join(match) + '}') 335 | for key in log_data.keys(): 336 | if key not in labels: 337 | labels[key] = log_data[key] 338 | 339 | self.viewer.label(labels=labels) 340 | self.logger.clear() 341 | self.logger.log("Extracting information from '{}'...\n".format(self.paths[self.pathidx])) 342 | self.logger.log(simplejson.dumps(labels, indent=2, sort_keys=True)) 343 | self.logger.log("\nLoaded labels from '{}'".format(label_file.name)) 344 | 345 | def _save_info(self): 346 | if self.pdf is None: 347 | messagebox.showerror("Error", "Load an invoice first!") 348 | return 349 | 350 | match = re.findall(r"[^{]*{([^}]+)}", self.logger.get()) 351 | 352 | if not match: 353 | messagebox.showerror("Error", "Could not parse a valid dictionary!") 354 | return 355 | 356 | try: 357 | new_labels = simplejson.loads('{' + ''.join(match) + '}') 358 | except simplejson.errors.JSONDecodeError: 359 | messagebox.showerror("Error", "Could not parse a valid dictionary!") 360 | return 361 | 362 | path = os.path.join(self.save_dir, os.path.splitext(os.path.basename(self.paths[self.pathidx]))[0] + '.json') 363 | 364 | labels = {} 365 | if os.path.exists(path): 366 | with open(path, encoding="utf8") as fp: 367 | try: 368 | labels = simplejson.load(fp) 369 | self.logger.log("\n'{}' already exists, adding updated labels to this file".format(path)) 370 | except simplejson.errors.JSONDecodeError: 371 | pass 372 | 373 | for label in new_labels.keys(): 374 | labels[label] = new_labels[label] 375 | 376 | try: 377 | with open(path, 'w') as out: 378 | out.write(simplejson.dumps(labels, indent=2, sort_keys=True)) 379 | except simplejson.errors.JSONDecodeError: 380 | messagebox.showerror("Error", "Error occurred while writing JSON file!") 381 | return 382 | 383 | self.logger.log("\nWrote information to '{}'".format(path)) 384 | 385 | def _set_save_path(self): 386 | path = filedialog.askdirectory(title='Set Save Directory', initialdir=self.save_dir) 387 | if path == '' or not path: 388 | return 389 | if not os.path.exists(path): 390 | messagebox.showerror("Error", "Invalid directory!") 391 | return 392 | self.save_dir = path 393 | self.logger.log("Information will now be saved in '{}'".format(self.save_dir)) 394 | 395 | def _next_file(self): 396 | if self.pathidx == len(self.paths) - 1 or len(self.paths) == 0: 397 | return 398 | self.pathidx += 1 399 | self._load_file() 400 | 401 | def _prev_file(self): 402 | if self.pathidx == 0 or len(self.paths) == 0: 403 | return 404 | self.pathidx -= 1 405 | self._load_file() 406 | 407 | def _run_ocr(self): 408 | if self.pdf is None: 409 | return 410 | 411 | pdf_pages = list() 412 | for page in self.pdf.pages: 413 | image = page.to_image(resolution=100) 414 | pdf = pytesseract.image_to_pdf_or_hocr(image.original, extension='pdf') 415 | pdf_pages.append(pdf) 416 | 417 | pdf_writer = PyPDF2.PdfFileWriter() 418 | for page in pdf_pages: 419 | pdf = PyPDF2.PdfFileReader(io.BytesIO(page)) 420 | pdf_writer.addPage(pdf.getPage(0)) 421 | 422 | pdf = io.BytesIO() 423 | pdf_writer.write(pdf) 424 | 425 | self.pdf = pdfplumber.load(pdf) 426 | self.viewer.display_pdf(self.pdf) 427 | 428 | def _load_file(self): 429 | self.viewer.clear() 430 | path = self.paths[self.pathidx] 431 | filename = os.path.basename(path) 432 | try: 433 | if filename.split('.')[-1].lower() in ['jpg', 'png']: 434 | image = Image.open(path) 435 | pdf = io.BytesIO(pytesseract.image_to_pdf_or_hocr(image, extension='pdf')) 436 | self.pdf = pdfplumber.load(pdf) 437 | else: 438 | self.pdf = pdfplumber.open(path) 439 | self.viewer.display_pdf(self.pdf) 440 | self.doc_label.configure(text="{} of {}".format(self.pathidx + 1, len(self.paths))) 441 | self.logger.clear() 442 | self.logger.log("Showing invoice '{}'".format(path)) 443 | except WandException: 444 | result = messagebox.askokcancel("Error", 445 | "ImageMagick Policy Error! Should InvoiceNet try to fix the error?") 446 | if result: 447 | result = self._fix_policy_error() 448 | if result: 449 | messagebox.showinfo("Policy Fixed!", "ImageMagick Policy Error fixed! Restart InvoiceNet.") 450 | else: 451 | messagebox.showerror("ImageMagick Policy Error", 452 | "Coud not fix ImageMagick policy. Rejecting the current pdf file!") 453 | except (IndexError, IOError, TypeError): 454 | pass 455 | 456 | def _open_file(self): 457 | paths = filedialog.askopenfilenames(filetypes=[('PDF files', '*.pdf'), 458 | ('JPG files', '*.jpg'), 459 | ('PNG files', '*.png'), 460 | ('all files', '.*')], 461 | initialdir='.', 462 | title="Select files", multiple=True) 463 | if not paths or paths == '': 464 | return 465 | paths = [path for path in paths if os.path.basename(path).split('.')[-1].lower() in ['pdf', 'jpg', 'png']] 466 | self.paths = self.paths[:self.pathidx + 1] + paths + self.paths[self.pathidx + 1:] 467 | self.pathidx += 1 468 | self._load_file() 469 | 470 | def _open_dir(self): 471 | dir_name = filedialog.askdirectory(initialdir='.', title="Select Directory Containing Invoices") 472 | if not dir_name or dir_name == '': 473 | return 474 | paths = os.listdir(dir_name) 475 | paths = [os.path.join(dir_name, path) for path in paths 476 | if os.path.basename(path).split('.')[-1].lower() in ['pdf', 'jpg', 'png']] 477 | self.paths = self.paths[:self.pathidx + 1] + paths + self.paths[self.pathidx + 1:] 478 | if not self.paths: 479 | return 480 | self.pathidx += 1 481 | self._load_file() 482 | 483 | def _clear_queue(self): 484 | self.viewer.reset() 485 | self.logger.clear() 486 | self.pdf = None 487 | self.paths = list() 488 | self.pathidx = -1 489 | self.doc_label.configure(text="{} of {}".format(self.pathidx + 1, len(self.paths))) 490 | self.thread = None 491 | self.running = False 492 | 493 | def _help(self): 494 | ws = self.master.winfo_screenwidth() 495 | hs = self.master.winfo_screenheight() 496 | w, h = 600, 600 497 | x = (ws / 2) - (w / 2) 498 | y = (hs / 2) - (h / 2) 499 | help_frame = Toplevel(self) 500 | help_frame.title("Help") 501 | help_frame.configure(width=w, height=h, bg=self.background, relief=SUNKEN) 502 | help_frame.geometry('%dx%d+%d+%d' % (w, h, x, y)) 503 | help_frame.minsize(height=h, width=w) 504 | help_frame.maxsize(height=h, width=w) 505 | help_frame.rowconfigure(0, weight=1) 506 | help_frame.columnconfigure(0, weight=1) 507 | HelpBox(help_frame, width=w, height=h, bg=self.background, relief=SUNKEN).grid(row=0, column=0) 508 | 509 | @staticmethod 510 | def _fix_policy_error(): 511 | policy_path = "/etc/ImageMagick-6/policy.xml" 512 | 513 | if not os.path.isfile(policy_path): 514 | policy_path = "/etc/ImageMagick/policy.xml" 515 | 516 | if not os.path.exists(policy_path): 517 | return False 518 | 519 | try: 520 | with open(policy_path, 'r') as policy_file: 521 | data = policy_file.readlines() 522 | new_data = [] 523 | 524 | for line in data: 525 | if 'MVG' in line: 526 | line = '' 527 | elif 'PDF' in line: 528 | line = ' \n' 529 | elif '' in line: 530 | new_data.append(' \n') 531 | new_data.append(line) 532 | 533 | temp = tempfile.NamedTemporaryFile(mode='w', suffix='.xml') 534 | temp.writelines(new_data) 535 | subprocess.call(["sudo", "mv", temp.name, policy_path]) 536 | except (IndexError, IOError, TypeError): 537 | return False 538 | -------------------------------------------------------------------------------- /invoicenet/gui/help_box.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from tkinter import * 22 | from PIL import Image, ImageTk 23 | 24 | from.custom_widgets import HoverButton 25 | 26 | 27 | class HelpBox(Frame): 28 | 29 | def __init__(self, master=None, background='#303030', **kw): 30 | Frame.__init__(self, master, **kw) 31 | self.columnconfigure(0, weight=1) 32 | self.rowconfigure(0, weight=0) 33 | self.rowconfigure(1, weight=0) 34 | self.rowconfigure(2, weight=0) 35 | 36 | Label(self, text="Meet InvoiceNet!", anchor='nw', width=100, 37 | font="OpenSans 22 bold", fg='white', bg=background, bd=2).grid(row=0, column=0, padx=20, pady=20) 38 | 39 | Label(self, text="Made with ❤ by naiveHobo", anchor='nw', width=100, 40 | font="OpenSans 10 bold", fg='white', bg=background, bd=2).grid(row=2, column=0, padx=20, pady=20) 41 | 42 | text_frame = Frame(self, height=440, width=550, bg=background, bd=2, relief=SUNKEN) 43 | text_frame.grid(row=1, column=0) 44 | 45 | text_frame.grid_propagate(False) 46 | 47 | text_frame.grid_rowconfigure(0, weight=1) 48 | text_frame.grid_columnconfigure(0, weight=1) 49 | 50 | text_box = Text(text_frame, borderwidth=3, relief="sunken", bg=background, 51 | fg='white', font="OpenSans 12", wrap='word') 52 | 53 | with open(r'widgets/help.txt', 'r') as infile: 54 | texts = infile.read() 55 | 56 | texts = [text.strip() + '\n\n\n' for text in texts.split('---')] 57 | 58 | text_box.insert('1.0', texts[0]) 59 | texts = texts[1:] 60 | 61 | paths = ['open_file.png', 'open_dir.png', 'save_as.png', 'clear_page.png', 62 | 'search.png', 'extract.png', 'ocr.png', 'clear_all.png'] 63 | self.images = [ImageTk.PhotoImage(Image.open(r'widgets/' + path)) for path in paths] 64 | 65 | for text, image in zip(texts, self.images): 66 | text_box.image_create(END, image=image) 67 | text_box.insert(END, ' ' + text) 68 | 69 | self.images.extend([ImageTk.PhotoImage(Image.open(r'widgets/' + path)) 70 | for path in ['prev_file.png', 'next_file.png']]) 71 | 72 | text_box.image_create(END, image=self.images[-2]) 73 | text_box.image_create(END, image=self.images[-1]) 74 | text_box.insert(END, ' ' + texts[-5]) 75 | 76 | viewer_text = texts[-4].split('\n\n') 77 | text_box.insert(END, viewer_text[0] + '\n') 78 | self.images.append(ImageTk.PhotoImage(Image.open(r'widgets/toolbar.png'))) 79 | text_box.image_create(END, image=self.images[-1]) 80 | text_box.insert(END, '\n\n' + '\n\n'.join(viewer_text[1:])) 81 | 82 | for text, image in zip(texts[-3:], ['begin.png', 'labels.png', 'labels.png']): 83 | splits = text.strip().split('\n\n') 84 | btn = HoverButton(text_box, image_path=r'widgets/' + image, text=splits[0], 85 | compound='center', font=("Arial", 10, "bold"), bd=0, bg=background, 86 | highlightthickness=0, activebackground=background) 87 | text_box.window_create(END, window=btn) 88 | text_box.insert(END, '\n\n' + '\n\n'.join(splits[1:]) + '\n\n\n') 89 | 90 | text_box.config(state=DISABLED) 91 | text_box.grid(row=0, column=0, sticky="nsew", padx=2, pady=2) 92 | 93 | scroll_bar = Scrollbar(text_frame, command=text_box.yview, bg=background) 94 | scroll_bar.grid(row=0, column=1, sticky='nsew') 95 | 96 | text_box['yscrollcommand'] = scroll_bar.set 97 | -------------------------------------------------------------------------------- /invoicenet/gui/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import time 23 | import glob 24 | import random 25 | import pdf2image 26 | import simplejson 27 | import numpy as np 28 | from tqdm import tqdm 29 | from tkinter import * 30 | from tkinter import filedialog, messagebox 31 | from tkinter.ttk import Progressbar 32 | from PIL import Image, ImageTk 33 | 34 | from .. import FIELDS, FIELD_TYPES 35 | from ..common import util 36 | from ..acp.acp import AttendCopyParse 37 | from ..acp.data import InvoiceData 38 | from .custom_widgets import HoverButton, Logger, StoppableThread 39 | 40 | 41 | class Trainer(Frame): 42 | 43 | def __init__(self, master=None, **kw): 44 | Frame.__init__(self, master, **kw) 45 | self.background = '#303030' 46 | self.border_color = '#404040' 47 | self.args = { 48 | "data_dir": "", 49 | "prepared_data": "processed_data", 50 | "field": list(FIELDS.keys())[0], 51 | "batch_size": 4 52 | } 53 | self.textboxes = {} 54 | self.thread = None 55 | self.running = False 56 | self._init_ui() 57 | 58 | def _init_ui(self): 59 | ws = self.master.winfo_screenwidth() 60 | hs = self.master.winfo_screenheight() 61 | h = hs - 100 62 | w = int(h / 1.414) + 100 63 | x = (ws / 2) - (w / 2) 64 | y = (hs / 2) - (h / 2) 65 | self.master.geometry('%dx%d+%d+%d' % (w, h, x, y)) 66 | self.master.maxsize(w, h) 67 | self.master.minsize(w, h) 68 | self.master.title("InvoiceNet - Trainer") 69 | 70 | self.pack(fill=BOTH, expand=True) 71 | 72 | self.columnconfigure(0, weight=1) 73 | self.rowconfigure(0, weight=0) 74 | self.rowconfigure(1, weight=1) 75 | self.rowconfigure(2, weight=0) 76 | self.rowconfigure(3, weight=1) 77 | 78 | self.configure(bg=self.background, bd=0) 79 | 80 | logo_frame = Frame(self, bg=self.background, bd=0, relief=SUNKEN, 81 | highlightbackground=self.border_color, highlightthickness=1) 82 | param_frame = Frame(self, bg=self.background, bd=0, relief=SUNKEN, 83 | highlightbackground=self.border_color, highlightthickness=1) 84 | progress_frame = Frame(self, bg=self.background, bd=0, relief=SUNKEN, 85 | highlightbackground=self.border_color, highlightthickness=0) 86 | main_frame = Frame(self, bg=self.background, bd=0, relief=SUNKEN, 87 | highlightbackground=self.border_color, highlightthickness=1) 88 | 89 | logo_frame.grid(row=0, column=0, sticky='news') 90 | param_frame.grid(row=1, column=0, sticky='news') 91 | progress_frame.grid(row=2, column=0, sticky='news', padx=50, pady=(0, 20)) 92 | main_frame.grid(row=3, column=0, sticky='news') 93 | 94 | # Logo Frame 95 | logo_frame.columnconfigure(0, weight=1) 96 | logo_frame.columnconfigure(1, weight=0) 97 | logo_frame.columnconfigure(2, weight=0) 98 | logo_frame.columnconfigure(3, weight=1) 99 | logo_frame.rowconfigure(0, weight=1) 100 | 101 | self.logo_img = ImageTk.PhotoImage(Image.open(r'widgets/logo.png')) 102 | Label(logo_frame, bg=self.background, image=self.logo_img).grid(row=0, column=1, sticky='news', pady=10) 103 | Label(logo_frame, text="InvoiceNet", bg=self.background, 104 | fg="white", font=("Arial", 24, "bold")).grid(row=0, column=2, sticky='news', padx=20, pady=10) 105 | 106 | # Param Frame 107 | param_frame.columnconfigure(0, weight=1) 108 | param_frame.columnconfigure(1, weight=0) 109 | param_frame.columnconfigure(2, weight=0) 110 | param_frame.columnconfigure(3, weight=1) 111 | param_frame.rowconfigure(0, weight=1) 112 | param_frame.rowconfigure(1, weight=0) 113 | param_frame.rowconfigure(2, weight=0) 114 | param_frame.rowconfigure(3, weight=0) 115 | param_frame.rowconfigure(4, weight=1) 116 | 117 | data_param = Frame(param_frame, bg=self.background, bd=0, relief=SUNKEN, 118 | highlightbackground=self.border_color, highlightthickness=0) 119 | out_param = Frame(param_frame, bg=self.background, bd=0, relief=SUNKEN, 120 | highlightbackground=self.border_color, highlightthickness=0) 121 | field_param = Frame(param_frame, bg=self.background, bd=0, relief=SUNKEN, 122 | highlightbackground=self.border_color, highlightthickness=0) 123 | batch_param = Frame(param_frame, bg=self.background, bd=0, relief=SUNKEN, 124 | highlightbackground=self.border_color, highlightthickness=0) 125 | 126 | data_param.grid(row=1, column=1, pady=(0, 20), padx=20) 127 | out_param.grid(row=2, column=1, pady=20, padx=20) 128 | field_param.grid(row=1, column=2, pady=(0, 20), padx=20) 129 | batch_param.grid(row=2, column=2, pady=20, padx=20) 130 | 131 | df = Frame(data_param, bg=self.background, bd=0, relief=SUNKEN, 132 | highlightbackground=self.border_color, highlightthickness=0) 133 | df.pack(side=TOP, fill=BOTH) 134 | 135 | Label(df, text="Data Folder:", bg=self.background, 136 | fg="white", font=("Arial", 8, "bold"), anchor='w').pack(side=LEFT, fill=BOTH) 137 | HoverButton(df, image_path=r'widgets/open_dir_small.png', command=lambda: self._open_dir("data_dir"), 138 | width=18, height=18, bg=self.background, bd=0, 139 | highlightthickness=0, activebackground='#558de8').pack(side=RIGHT) 140 | 141 | self.textboxes["data_dir"] = Text(data_param, height=1, width=20) 142 | self.textboxes["data_dir"].insert('1.0', self.args["data_dir"]) 143 | self.textboxes["data_dir"].pack(side=BOTTOM) 144 | 145 | of = Frame(out_param, bg=self.background, bd=0, relief=SUNKEN, 146 | highlightbackground=self.border_color, highlightthickness=0) 147 | of.pack(side=TOP, fill=BOTH) 148 | 149 | Label(of, text="Processed Data Folder:", bg=self.background, 150 | anchor='w', fg="white", font=("Arial", 8, "bold")).pack(side=LEFT, fill=BOTH) 151 | HoverButton(of, image_path=r'widgets/open_dir_small.png', command=lambda: self._open_dir("prepared_data"), 152 | width=18, height=18, bg=self.background, bd=0, 153 | highlightthickness=0, activebackground='#558de8').pack(side=RIGHT) 154 | 155 | self.textboxes["prepared_data"] = Text(out_param, height=1, width=20) 156 | self.textboxes["prepared_data"].insert('1.0', self.args["prepared_data"]) 157 | self.textboxes["prepared_data"].pack(side=BOTTOM) 158 | 159 | Label(field_param, text="Field:", bg=self.background, 160 | anchor='w', fg="white", font=("Arial", 8, "bold")).pack(side=TOP, fill=BOTH) 161 | self.field_text = StringVar(field_param) 162 | self.field_text.set(list(FIELDS.keys())[0]) 163 | 164 | keys = list(FIELDS.keys()) 165 | field_list = OptionMenu(field_param, self.field_text, *keys) 166 | field_list.configure(highlightthickness=0, width=20, bg='#ffffff') 167 | field_list.pack(side=BOTTOM) 168 | 169 | for key in keys: 170 | field_list['menu'].entryconfigure(key, state="normal") 171 | 172 | Label(batch_param, text="Batch Size:", bg=self.background, 173 | anchor='w', fg="white", font=("Arial", 8, "bold")).pack(side=TOP, fill=BOTH) 174 | self.batch_text = StringVar(batch_param) 175 | self.batch_text.set("4") 176 | batch_list = OptionMenu(batch_param, self.batch_text, *[str(2 ** i) for i in range(8)]) 177 | batch_list.configure(highlightthickness=0, width=20, bg='#ffffff') 178 | batch_list.pack(side=BOTTOM) 179 | 180 | HoverButton(param_frame, image_path=r'widgets/prepare.png', command=self._prepare_data, 181 | text='Prepare Data', compound='center', font=("Arial", 10, "bold"), bg=self.background, 182 | bd=0, highlightthickness=0, activebackground=self.background).grid(row=3, column=1, columnspan=2, 183 | padx=20, pady=(20, 0), 184 | sticky='news') 185 | 186 | # Progress Frame 187 | self.progress_label = Label(progress_frame, text="Preparing data:", bg=self.background, 188 | anchor='w', fg="white", font=("Arial", 8, "bold"), bd=0, highlightthickness=0) 189 | self.progress_label.pack(side=TOP, expand=True, fill=X, pady=(10, 5)) 190 | self.progressbar = Progressbar(progress_frame, orient=HORIZONTAL, length=100, mode='determinate') 191 | self.progressbar.pack(side=BOTTOM, expand=True, fill=X) 192 | 193 | # Main Frame 194 | main_frame.columnconfigure(0, weight=1) 195 | main_frame.rowconfigure(0, weight=1) 196 | main_frame.rowconfigure(1, weight=1) 197 | 198 | button_frame = Frame(main_frame, bg=self.background, bd=0, relief=SUNKEN, 199 | highlightbackground=self.border_color, highlightthickness=0) 200 | button_frame.grid(row=0, column=0, sticky='news') 201 | 202 | button_frame.rowconfigure(0, weight=1) 203 | button_frame.columnconfigure(0, weight=1) 204 | button_frame.columnconfigure(1, weight=0) 205 | button_frame.columnconfigure(2, weight=1) 206 | 207 | self.start_button = HoverButton(button_frame, image_path=r'widgets/begin.png', command=self._start, 208 | text='Start', compound='center', font=("Arial", 10, "bold"), bg=self.background, 209 | bd=0, highlightthickness=0, activebackground=self.background) 210 | self.stop_button = HoverButton(button_frame, image_path=r'widgets/stop.png', command=self._stop, 211 | text='Stop', compound='center', font=("Arial", 10, "bold"), bg=self.background, 212 | bd=0, highlightthickness=0, activebackground=self.background) 213 | 214 | self.start_button.grid(row=0, column=1) 215 | self.stop_button.grid(row=0, column=1) 216 | self.stop_button.grid_forget() 217 | 218 | self.logger = Logger(main_frame, height=18, bg=self.background, bd=0, relief=SUNKEN) 219 | self.logger.grid(row=1, column=0, sticky='news') 220 | 221 | def _train(self): 222 | train_data = InvoiceData.create_dataset( 223 | field=self.args["field"], 224 | data_dir=os.path.join(self.args["prepared_data"], 'train/'), 225 | batch_size=self.args["batch_size"] 226 | ) 227 | val_data = InvoiceData.create_dataset( 228 | field=self.args["field"], 229 | data_dir=os.path.join(self.args["prepared_data"], 'val/'), 230 | batch_size=self.args["batch_size"] 231 | ) 232 | 233 | restore = None 234 | if os.path.exists(os.path.join('./models/invoicenet/', self.args["field"])): 235 | restore = messagebox.askyesno( 236 | title="Restore", 237 | message="A checkpoint was found! Do you want to restore checkpoint for training?") 238 | 239 | restore = True if restore else False 240 | 241 | model = AttendCopyParse(field=self.args["field"], restore=restore) 242 | 243 | print_interval = 20 244 | early_stop_steps = 0 245 | best = float("inf") 246 | 247 | train_iter = iter(train_data) 248 | val_iter = iter(val_data) 249 | 250 | self.logger.log("Initializing training!") 251 | start = time.time() 252 | step = 0 253 | 254 | while True: 255 | 256 | try: 257 | train_loss = model.train_step(next(train_iter)) 258 | except StopIteration: 259 | self.logger.log("Couldn't find any training data! Have you prepared your training data?") 260 | self.logger.log("Terminating...") 261 | self.thread.stop() 262 | break 263 | 264 | if not np.isfinite(train_loss): 265 | raise ValueError("NaN loss") 266 | 267 | if step % print_interval == 0: 268 | took = time.time() - start 269 | 270 | try: 271 | val_loss = model.val_step(next(val_iter)) 272 | except StopIteration: 273 | self.logger.log("Couldn't find any validation data! Have you prepared your training data?") 274 | self.logger.log("Terminating...") 275 | self.thread.stop() 276 | break 277 | 278 | self.logger.log("[step: %d | %.2f steps/s]: train loss: %.4f val loss: %.4f" % ( 279 | step, (step + 1) / took, train_loss, val_loss)) 280 | if not np.isfinite(val_loss): 281 | self.logger.log("ERROR: NaN loss") 282 | self.thread.stop() 283 | if val_loss < best: 284 | early_stop_steps = 0 285 | best = val_loss 286 | model.save("best") 287 | else: 288 | early_stop_steps += print_interval 289 | if early_stop_steps >= 500: 290 | self.logger.log("Validation loss has not improved for 500 steps") 291 | self.thread.stop() 292 | 293 | step += 1 294 | 295 | if self.thread.stopped(): 296 | self.logger.log("Training terminated!") 297 | break 298 | 299 | self.running = False 300 | self.stop_button.grid_forget() 301 | self.start_button.grid(row=0, column=1) 302 | 303 | def _get_inputs(self): 304 | self.args["field"] = self.field_text.get() 305 | self.args["batch_size"] = int(self.batch_text.get()) 306 | self.args["data_dir"] = self.textboxes["data_dir"].get("1.0", 'end-1c') 307 | self.args["prepared_data"] = self.textboxes["prepared_data"].get("1.0", 'end-1c') 308 | if not self.args["prepared_data"].endswith('/'): 309 | self.args["prepared_data"] += '/' 310 | if self.args["data_dir"] == '': 311 | return 312 | if not self.args["data_dir"].endswith('/'): 313 | self.args["data_dir"] += '/' 314 | 315 | def _start(self): 316 | self._get_inputs() 317 | 318 | if not os.path.exists(self.args["prepared_data"]): 319 | messagebox.showerror("Error", "Prepared data folder does not exist!") 320 | return 321 | 322 | files = glob.glob(self.args["prepared_data"] + "**/*.json", recursive=True) 323 | if not files: 324 | messagebox.showerror("Error", 325 | "Could not find processed data in \"{}\". Did you prepare training data?".format( 326 | self.args["prepared_data"])) 327 | return 328 | if not self.running: 329 | self.running = True 330 | self.thread = StoppableThread(target=self._train) 331 | self.thread.daemon = True 332 | self.thread.start() 333 | self.start_button.grid_forget() 334 | self.stop_button.grid(row=0, column=1) 335 | 336 | def _stop(self): 337 | if self.running: 338 | self.thread.stop() 339 | self.running = False 340 | self.logger.log("Stopping training...") 341 | 342 | def _open_dir(self, key): 343 | dir_name = filedialog.askdirectory(initialdir='.', title="Select Directory Containing Invoices") 344 | if not dir_name: 345 | return 346 | self.args[key] = dir_name 347 | self.textboxes[key].delete('1.0', END) 348 | self.textboxes[key].insert('1.0', self.args[key]) 349 | 350 | def _prepare_data(self): 351 | self._get_inputs() 352 | 353 | if self.args["data_dir"] == '': 354 | messagebox.showerror("Error", "Data folder does not exist!") 355 | return 356 | 357 | if not os.path.exists(self.args["data_dir"]): 358 | messagebox.showerror("Error", "Data folder does not exist!") 359 | return 360 | 361 | self.progressbar["value"] = 0 362 | self.progress_label.configure(text="Preparing Data:") 363 | 364 | os.makedirs(os.path.join(self.args["prepared_data"], 'train'), exist_ok=True) 365 | os.makedirs(os.path.join(self.args["prepared_data"], 'val'), exist_ok=True) 366 | 367 | filenames = [os.path.abspath(f) for f in glob.glob(self.args["data_dir"] + "**/*.pdf", recursive=True)] 368 | random.shuffle(filenames) 369 | 370 | idx = int(len(filenames) * 0.2) 371 | train_files = filenames[idx:] 372 | val_files = filenames[:idx] 373 | 374 | self.logger.log("Total: {}".format(len(filenames))) 375 | self.logger.log("Training: {}".format(len(train_files))) 376 | self.logger.log("Validation: {}".format(len(val_files))) 377 | 378 | if len(train_files) == 0 or len(val_files) == 0: 379 | messagebox.showwarning("Warning", "Training data is not enough to create training/validation splits") 380 | 381 | total_samples = len(filenames) 382 | sample_idx = 0 383 | for phase, filenames in [('train', train_files), ('val', val_files)]: 384 | self.logger.log("Preparing {} data...".format(phase)) 385 | for filename in tqdm(filenames): 386 | # try: 387 | page = pdf2image.convert_from_path(filename)[0] 388 | page.save(os.path.join(self.args["prepared_data"], phase, os.path.basename(filename)[:-3] + 'png')) 389 | 390 | height = page.size[1] 391 | width = page.size[0] 392 | 393 | ngrams = util.create_ngrams(page, height=height, width=width) 394 | for ngram in ngrams: 395 | if "amount" in ngram["parses"]: 396 | ngram["parses"]["amount"] = util.normalize(ngram["parses"]["amount"], key="amount") 397 | if "date" in ngram["parses"]: 398 | ngram["parses"]["date"] = util.normalize(ngram["parses"]["date"], key="date") 399 | 400 | with open(filename[:-3] + 'json', 'r') as fp: 401 | labels = simplejson.loads(fp.read()) 402 | 403 | fields = {} 404 | for field in FIELDS: 405 | if field in labels: 406 | if FIELDS[field] == FIELD_TYPES["amount"]: 407 | fields[field] = util.normalize(labels[field], key="amount") 408 | elif FIELDS[field] == FIELD_TYPES["date"]: 409 | fields[field] = util.normalize(labels[field], key="date") 410 | else: 411 | fields[field] = labels[field] 412 | else: 413 | fields[field] = '' 414 | 415 | data = { 416 | "fields": fields, 417 | "nGrams": ngrams, 418 | "height": height, 419 | "width": width, 420 | "filename": os.path.abspath( 421 | os.path.join(self.args["prepared_data"], phase, os.path.basename(filename)[:-3] + 'png')) 422 | } 423 | 424 | with open(os.path.join(self.args["prepared_data"], phase, os.path.basename(filename)[:-3] + 'json'), 425 | 'w') as fp: 426 | fp.write(simplejson.dumps(data, indent=2)) 427 | 428 | # except Exception as exp: 429 | # self.logger.log("Skipping {} : {}".format(filename, exp)) 430 | 431 | sample_idx += 1 432 | self.progress_label.configure(text="Preparing data [{}/{}]:".format(sample_idx, total_samples)) 433 | self.progressbar["value"] = (sample_idx / total_samples) * 100 434 | self.progressbar.update() 435 | 436 | self.progress_label.configure(text="Completed!") 437 | self.progressbar["value"] = 100 438 | self.progressbar.update() 439 | self.logger.log("Prepared data stored in '{}'".format(self.args["prepared_data"])) 440 | -------------------------------------------------------------------------------- /invoicenet/gui/viewer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import numpy as np 22 | import pdfplumber 23 | from tkinter import * 24 | from tkinter import simpledialog, messagebox 25 | 26 | from .. import FIELDS 27 | from .custom_widgets import HoverButton, DisplayCanvas 28 | 29 | 30 | class PDFViewer(Frame): 31 | 32 | def __init__(self, master=None, background='#303030', highlight='#558de8', **kw): 33 | Frame.__init__(self, master, **kw) 34 | self.background = background 35 | self.highlight = highlight 36 | self.pdf = None 37 | self.page = None 38 | self.total_pages = 0 39 | self.pageidx = 0 40 | self.scale = 1.0 41 | self.rotate = 0 42 | self.field_colors = {field: tuple(np.random.choice(range(256), size=3)) + (100,) for field in FIELDS.keys()} 43 | self._init_ui() 44 | 45 | def _init_ui(self): 46 | self.columnconfigure(0, weight=1) 47 | self.rowconfigure(0, weight=0) 48 | self.rowconfigure(1, weight=1) 49 | 50 | page_tools = Frame(self, bg=self.background, bd=0, relief=SUNKEN) 51 | canvas_frame = Frame(self, bg=self.background, bd=1, relief=SUNKEN) 52 | 53 | canvas_frame.grid(row=1, column=0, sticky='news') 54 | page_tools.grid(row=0, column=0, sticky='news') 55 | 56 | # Display Canvas 57 | self.canvas = DisplayCanvas(canvas_frame, cursor='cross') 58 | self.canvas.pack(fill=BOTH, expand=True) 59 | 60 | # Page Tools 61 | page_tools.rowconfigure(0, weight=1) 62 | page_tools.columnconfigure(0, weight=1) 63 | page_tools.columnconfigure(1, weight=0) 64 | page_tools.columnconfigure(2, weight=2) 65 | page_tools.columnconfigure(3, weight=0) 66 | page_tools.columnconfigure(4, weight=1) 67 | 68 | nav_frame = Frame(page_tools, bg=self.background, bd=0, relief=SUNKEN) 69 | nav_frame.grid(row=0, column=1, sticky='ns') 70 | 71 | HoverButton(nav_frame, image_path=r'widgets/first.png', 72 | command=self._first_page, bg=self.background, bd=0, 73 | highlightthickness=0, activebackground=self.highlight).pack(side=LEFT, expand=True) 74 | HoverButton(nav_frame, image_path=r'widgets/prev.png', 75 | command=self._prev_page, bg=self.background, bd=0, 76 | highlightthickness=0, activebackground=self.highlight).pack(side=LEFT, expand=True) 77 | 78 | self.page_label = Label(nav_frame, bg=self.background, bd=0, fg='white', font=('Arial', 8), 79 | text="Page {} of {}".format(self.pageidx, self.total_pages)) 80 | self.page_label.pack(side=LEFT, expand=True) 81 | 82 | HoverButton(nav_frame, image_path=r'widgets/next.png', 83 | command=self._next_page, bg=self.background, bd=0, 84 | highlightthickness=0, activebackground=self.highlight).pack(side=LEFT, expand=True) 85 | HoverButton(nav_frame, image_path=r'widgets/last.png', 86 | command=self._last_page, bg=self.background, bd=0, 87 | highlightthickness=0, activebackground=self.highlight).pack(side=LEFT, expand=True) 88 | 89 | zoom_frame = Frame(page_tools, bg=self.background, bd=0, relief=SUNKEN) 90 | zoom_frame.grid(row=0, column=3, sticky='ns') 91 | 92 | HoverButton(zoom_frame, image_path=r'widgets/rotate.png', 93 | command=self._rotate, bg=self.background, bd=0, 94 | highlightthickness=0, activebackground=self.highlight).pack(side=RIGHT, expand=True) 95 | HoverButton(zoom_frame, image_path=r'widgets/fullscreen.png', 96 | command=self._fit_to_screen, bg=self.background, bd=0, 97 | highlightthickness=0, activebackground=self.highlight).pack(side=RIGHT, expand=True) 98 | 99 | self.zoom_label = Label(zoom_frame, bg=self.background, bd=0, fg='white', font=('Arial', 8), 100 | text="Zoom {}%".format(int(self.scale * 100))) 101 | self.zoom_label.pack(side=RIGHT, expand=True) 102 | 103 | HoverButton(zoom_frame, image_path=r'widgets/zoomout.png', 104 | command=self._zoom_out, bg=self.background, bd=0, 105 | highlightthickness=0, activebackground=self.highlight).pack(side=RIGHT, expand=True) 106 | HoverButton(zoom_frame, image_path=r'widgets/zoomin.png', 107 | command=self._zoom_in, bg=self.background, bd=0, 108 | highlightthickness=0, activebackground=self.highlight).pack(side=RIGHT, expand=True) 109 | 110 | def reset(self): 111 | self.canvas.reset() 112 | self.pdf = None 113 | self.page = None 114 | self.total_pages = 0 115 | self.pageidx = 0 116 | self.scale = 1.0 117 | self.rotate = 0 118 | self.page_label.configure(text="Page {} of {}".format(self.pageidx, self.total_pages)) 119 | self.zoom_label.configure(text="Zoom {}%".format(int(self.scale * 100))) 120 | 121 | def _zoom_in(self): 122 | if self.pdf is None or self.scale == 2.5: 123 | return 124 | self.scale += 0.1 125 | self._update_page() 126 | 127 | def _zoom_out(self): 128 | if self.pdf is None or self.scale == 0.1: 129 | return 130 | self.scale -= 0.1 131 | self._update_page() 132 | 133 | def _fit_to_screen(self): 134 | if self.pdf is None or self.scale == 1.0: 135 | return 136 | self.scale = 1.0 137 | self._update_page() 138 | 139 | def _rotate(self): 140 | if self.pdf is None: 141 | return 142 | self.rotate = (self.rotate - 90) % 360 143 | self._update_page() 144 | 145 | def _next_page(self): 146 | if self.pdf is None or self.pageidx == self.total_pages: 147 | return 148 | self.pageidx += 1 149 | self._update_page() 150 | 151 | def _prev_page(self): 152 | if self.pdf is None or self.pageidx == 1: 153 | return 154 | self.pageidx -= 1 155 | self._update_page() 156 | 157 | def _last_page(self): 158 | if self.pdf is None or self.pageidx == self.total_pages: 159 | return 160 | self.pageidx = self.total_pages 161 | self._update_page() 162 | 163 | def _first_page(self): 164 | if self.pdf is None or self.pageidx == 1: 165 | return 166 | self.pageidx = 1 167 | self._update_page() 168 | 169 | def _update_page(self): 170 | page = self.pdf.pages[self.pageidx - 1] 171 | self.page = page.to_image(resolution=int(self.scale * 80)) 172 | image = self.page.original.rotate(self.rotate) 173 | self.canvas.update_image(image) 174 | self.page_label.configure(text="Page {} of {}".format(self.pageidx, self.total_pages)) 175 | self.zoom_label.configure(text="Zoom {}%".format(int(self.scale * 100))) 176 | 177 | def _reproject_bbox(self, bbox): 178 | bbox = [self.page.decimalize(x) for x in bbox] 179 | x0, y0, x1, y1 = bbox 180 | px0, py0 = self.page.page.bbox[:2] 181 | rx0, ry0 = self.page.root.bbox[:2] 182 | _x0 = (x0 / self.page.scale) - rx0 + px0 183 | _y0 = (y0 / self.page.scale) - ry0 + py0 184 | _x1 = (x1 / self.page.scale) - rx0 + px0 185 | _y1 = (y1 / self.page.scale) - ry0 + py0 186 | return [_x0, _y0, _x1, _y1] 187 | 188 | def display_pdf(self, pdf: pdfplumber.PDF): 189 | self.clear() 190 | try: 191 | self.pdf = pdf 192 | self.total_pages = len(self.pdf.pages) 193 | self.pageidx = 1 194 | self.scale = 1.0 195 | self.rotate = 0 196 | self._update_page() 197 | except (IndexError, IOError, TypeError): 198 | messagebox.showerror("Error", "Could not display PDF!") 199 | 200 | def search_text(self, text=None, fill=(0, 0, 255, 50)): 201 | if self.pdf is None: 202 | return 203 | 204 | if text is None: 205 | text = simpledialog.askstring('Search Text', 'Enter text to search:') 206 | if text == '' or text is None: 207 | return 208 | 209 | page = self.pdf.pages[self.pageidx - 1] 210 | image = page.to_image(resolution=int(self.scale * 80)) 211 | words = [w for w in page.extract_words() if text.lower() in w['text'].lower()] 212 | 213 | if words: 214 | image.draw_rects(words, fill=fill, stroke=(0, 0, 0, 200)) 215 | image = image.annotated.rotate(self.rotate) 216 | self.canvas.update_image(image) 217 | 218 | def label(self, labels=None): 219 | if self.pdf is None or labels is None: 220 | return 221 | 222 | page = self.pdf.pages[self.pageidx - 1] 223 | image = page.to_image(resolution=int(self.scale * 80)) 224 | 225 | for key in labels.keys(): 226 | if labels[key]: 227 | words = [w for w in page.extract_words() if labels[key].strip().lower() in w['text'].lower()] 228 | if words: 229 | image.draw_rects(words, fill=self.field_colors[key], stroke=(0, 0, 0, 200)) 230 | 231 | image = image.annotated.rotate(self.rotate) 232 | self.canvas.update_image(image) 233 | 234 | def extract_text(self): 235 | if self.pdf is None: 236 | return 237 | rect = self.canvas.get_rect() 238 | if rect is None: 239 | return 240 | self.clear() 241 | rect = self._reproject_bbox(rect) 242 | page = self.pdf.pages[self.pageidx - 1] 243 | words = page.extract_words() 244 | min_x = 1000000 245 | bbox = None 246 | for word in words: 247 | diff = abs(float(word['x0'] - rect[0])) + abs(float(word['top'] - rect[1])) \ 248 | + abs(float(word['x1'] - rect[2])) + abs(float(word['bottom'] - rect[3])) 249 | if diff < min_x: 250 | min_x = diff 251 | bbox = word 252 | 253 | if bbox is None: 254 | messagebox.showerror("Error", "Could not extract text! Try after running OCR on this invoice.") 255 | return 256 | 257 | image = page.to_image(resolution=int(self.scale * 80)) 258 | image.draw_rect(bbox) 259 | image = image.annotated.rotate(self.rotate) 260 | self.canvas.update_image(image) 261 | simpledialog.askstring("Extract Text", "Text Extracted:", initialvalue=bbox['text']) 262 | 263 | def clear(self): 264 | if self.pdf is None: 265 | return 266 | self.canvas.clear() 267 | self._update_page() 268 | -------------------------------------------------------------------------------- /invoicenet/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | -------------------------------------------------------------------------------- /invoicenet/parsing/amount_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import random 23 | from tqdm import tqdm 24 | 25 | 26 | def main(): 27 | train_numbers = random.sample(range(1000, 10000), 5000) 28 | train_numbers += random.sample(range(1000, 10000), 5000) 29 | train_numbers += random.sample(range(1000, 10000), 5000) 30 | train_numbers += random.sample(range(10000, 100000), 5000) 31 | train_numbers += random.sample(range(100000, 1000000), 5000) 32 | for idx in range(1, 16): 33 | train_numbers += random.sample(range(0, 1000), 1000) 34 | 35 | test_numbers = random.sample(range(0, 1000), 1000) 36 | test_numbers += random.sample(range(0, 1000), 1000) 37 | test_numbers += random.sample(range(0, 10000), 1000) 38 | test_numbers += random.sample(range(0, 100000), 500) 39 | test_numbers += random.sample(range(0, 1000000), 500) 40 | 41 | random.shuffle(train_numbers) 42 | random.shuffle(test_numbers) 43 | 44 | os.makedirs(os.path.join("data", "amount"), exist_ok=True) 45 | 46 | for phase, numbers in [("train", train_numbers), ("valid", test_numbers)]: 47 | 48 | with open(os.path.join("data", "amount", "{}.tsv".format(phase)), "w") as fp: 49 | 50 | for number in tqdm(numbers, total=len(numbers)): 51 | 52 | if random.sample(range(1, 10000), 1)[0] % 20 == 0: 53 | if random.sample(range(1, 10000), 1)[0] % 10 == 0: 54 | fp.write(str(number) + "\t" + str(number) + ".00" + "\n") 55 | else: 56 | exp = "{}".format(random.sample(range(0, 100), 1)[0]) 57 | if len(exp) == 1: 58 | exp = "0" + exp 59 | fp.write(str(number) + "." + exp + "\t" + str(number) + "." + exp + "\n") 60 | 61 | else: 62 | fnum = str(number) 63 | if len(fnum) > 3: 64 | fnum = fnum[:-3] + "," + fnum[-3:] 65 | if len(fnum) == 7: 66 | if random.sample(range(1, 10000), 1)[0] % 2 == 0: 67 | fnum = fnum[0] + ',' + fnum[1:] 68 | 69 | if random.sample(range(1, 10000), 1)[0] % 20 == 0: 70 | fp.write(fnum + "\t" + str(number) + ".00" + "\n") 71 | else: 72 | exp = "{}".format(random.sample(range(0, 100), 1)[0]) 73 | if len(exp) == 1: 74 | exp = "0" + exp 75 | fp.write(fnum + "." + exp + "\t" + str(number) + "." + exp + "\n") 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /invoicenet/parsing/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import random 23 | from decimal import Decimal 24 | 25 | import tensorflow as tf 26 | 27 | from ..common.data import Data, UnkDict 28 | from ..acp.data import InvoiceData 29 | 30 | 31 | class ParseData(Data): 32 | chars = InvoiceData.chars 33 | output_dict = UnkDict(chars) 34 | n_output = len(output_dict) 35 | pad_idx = InvoiceData.pad_idx 36 | eos_idx = InvoiceData.eos_idx 37 | unk_idx = InvoiceData.unk_idx 38 | input_length = InvoiceData.seq_in 39 | 40 | def __init__(self, samples_fname, output_length): 41 | self.samples_fname = samples_fname 42 | self.output_length = output_length 43 | 44 | def types(self): 45 | # source, target 46 | return tf.int32, tf.int32 47 | 48 | def shapes(self): 49 | # source, target 50 | return self.input_length, self.output_length 51 | 52 | def array_to_str(self, arr): 53 | """ 54 | :param arr: (bs, seq) int32 55 | """ 56 | strs = [] 57 | for r in arr: 58 | s = "" 59 | for c in r: 60 | if c == self.eos_idx: 61 | break 62 | else: 63 | s += self.output_dict.idx2key(c) 64 | strs.append(s) 65 | return strs 66 | 67 | def sample_generator(self): 68 | with open(self.samples_fname) as samples_file: 69 | samples = samples_file.readlines() 70 | 71 | while True: 72 | for s in random.sample(samples, len(samples)): 73 | source, target = s.strip().split("\t") 74 | yield (InvoiceData.encode_sequence(source, self.input_length), 75 | InvoiceData.encode_sequence(target, self.output_length)) 76 | 77 | @staticmethod 78 | def create_dataset(path, output_length, batch_size): 79 | data = ParseData(path, output_length) 80 | 81 | def _transform(inputs, targets): 82 | return ( 83 | (tf.one_hot(inputs, ParseData.n_output), 84 | tf.zeros( 85 | (128,), 86 | dtype=tf.float32, 87 | name="empty_context") 88 | ), targets) 89 | 90 | return tf.data.Dataset.from_generator( 91 | data.sample_generator, 92 | data.types(), 93 | data.shapes() 94 | ).map(_transform) \ 95 | .repeat(-1) \ 96 | .batch(batch_size) \ 97 | .prefetch(16) 98 | -------------------------------------------------------------------------------- /invoicenet/parsing/date_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import random 23 | import argparse 24 | import datetime 25 | from tqdm import tqdm 26 | 27 | years = list(range(2000, 2100)) 28 | months = list(range(1, 13)) 29 | days = list(range(1, 32)) 30 | 31 | year_formats = ['%y', '%Y'] 32 | month_formats = ['%m', '%b', '%B'] 33 | day_formats = ['%d'] 34 | 35 | date_formats = [ 36 | 'day/month/year', 37 | 'month/day/year', 38 | 'year/month/day', 39 | 'month/year/day', 40 | 'year/day/month', 41 | 'day/year/month', 42 | 'day-month-year', 43 | 'month-day-year', 44 | 'year-month-day', 45 | 'month-year-day', 46 | 'year-day-month', 47 | 'day-year-month', 48 | 'day.month.year', 49 | 'month.day.year', 50 | 'year.month.day', 51 | 'month.year.day', 52 | 'year.day.month', 53 | 'day.year.month', 54 | 'day,month,year', 55 | 'month,day,year', 56 | 'year,month,day', 57 | 'month,year,day', 58 | 'year,day,month', 59 | 'day,year,month', 60 | 'day.month,year', 61 | 'day,month.year', 62 | 'month.day,year', 63 | 'month,day.year', 64 | 'month.day,year', 65 | 'day\\month\\year', 66 | 'month\\day\\year', 67 | 'year\\month\\day', 68 | 'month\\year\\day', 69 | 'year\\day\\month', 70 | 'day\\year\\month', 71 | 'day month, year', 72 | 'month day, year', 73 | 'year, month day', 74 | 'year, day month', 75 | 'year day month', 76 | 'day month year', 77 | 'month day year' 78 | ] 79 | 80 | 81 | def main(): 82 | ap = argparse.ArgumentParser() 83 | 84 | ap.add_argument("--train_size", type=int, default=100000, 85 | help="number of training samples to generate") 86 | ap.add_argument("--test_size", type=int, default=20000, 87 | help="number of test samples to generate") 88 | 89 | args = ap.parse_args() 90 | 91 | train_days = random.choices(days, k=args.train_size) 92 | train_months = random.choices(months, k=args.train_size) 93 | train_years = random.choices(years, k=args.train_size) 94 | train_day_formats = random.choices(day_formats, k=args.train_size) 95 | train_month_formats = random.choices(month_formats, k=args.train_size) 96 | train_year_formats = random.choices(year_formats, k=args.train_size) 97 | train_date_formats = random.choices(date_formats, k=args.train_size) 98 | 99 | test_days = random.choices(days, k=args.test_size) 100 | test_months = random.choices(months, k=args.test_size) 101 | test_years = random.choices(years, k=args.test_size) 102 | test_day_formats = random.choices(day_formats, k=args.test_size) 103 | test_month_formats = random.choices(month_formats, k=args.test_size) 104 | test_year_formats = random.choices(year_formats, k=args.test_size) 105 | test_date_formats = random.choices(date_formats, k=args.test_size) 106 | 107 | os.makedirs(os.path.join("data", "date"), exist_ok=True) 108 | 109 | with open(os.path.join("data", "date", "train.tsv"), "w") as fp: 110 | for sample in tqdm(zip(train_date_formats, 111 | train_year_formats, train_month_formats, train_day_formats, 112 | train_years, train_months, train_days), total=args.train_size): 113 | if sample[5] not in [1, 3, 5, 7, 8, 10, 12]: 114 | if sample[5] == 2: 115 | day = random.sample(range(1, 29), k=1)[0] 116 | else: 117 | day = random.sample(range(1, 31), k=1)[0] 118 | else: 119 | day = sample[6] 120 | date = datetime.date(sample[4], sample[5], day) 121 | fp.write(date.strftime( 122 | sample[0].replace('year', sample[1]).replace('month', sample[2]).replace('day', sample[3]))) 123 | fp.write('\t') 124 | fp.write(date.strftime('%m-%d-%Y')) 125 | fp.write('\n') 126 | 127 | with open(os.path.join("data", "date", "valid.tsv"), "w") as fp: 128 | for sample in tqdm(zip(test_date_formats, 129 | test_year_formats, test_month_formats, test_day_formats, 130 | test_years, test_months, test_days), total=args.test_size): 131 | if sample[5] not in [1, 3, 5, 7, 8, 10, 12]: 132 | if sample[5] == 2: 133 | day = random.sample(range(1, 29), k=1)[0] 134 | else: 135 | day = random.sample(range(1, 31), k=1)[0] 136 | else: 137 | day = sample[6] 138 | date = datetime.date(sample[4], sample[5], day) 139 | fp.write(date.strftime( 140 | sample[0].replace('year', sample[1]).replace('month', sample[2]).replace('day', sample[3]))) 141 | fp.write('\t') 142 | fp.write(date.strftime('%m-%d-%Y')) 143 | fp.write('\n') 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /invoicenet/parsing/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import tensorflow as tf 23 | 24 | from ..common.model import Model 25 | from .parsers import DateParser, AmountParser 26 | from .data import ParseData 27 | 28 | 29 | class Parser(Model): 30 | 31 | def __init__(self, field, restore=False): 32 | self.type = field 33 | self.continue_from = './models/parsers/{}/best'.format(self.type) if restore else None 34 | 35 | self.parser = {'amount': AmountParser(), 'date': DateParser()}[self.type] 36 | 37 | self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy( 38 | from_logits=True, 39 | reduction=tf.keras.losses.Reduction.NONE) 40 | 41 | self.optimizer = tf.keras.optimizers.Nadam(learning_rate=1e-4) 42 | 43 | self.parser.compile(self.optimizer) 44 | 45 | self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.parser) 46 | 47 | if self.continue_from: 48 | print("Restoring " + self.continue_from + "...") 49 | self.checkpoint.read(self.continue_from) 50 | 51 | def loss_func(self, y_true, y_pred): 52 | mask = tf.logical_not(tf.equal(y_true, ParseData.pad_idx)) 53 | label_cross_entropy = tf.reduce_mean( 54 | self.loss_object(y_true, y_pred) * tf.cast(mask, dtype=tf.float32)) / tf.math.log(2.) 55 | return label_cross_entropy 56 | 57 | @tf.function 58 | def train_step(self, inputs): 59 | inputs, targets = inputs 60 | 61 | with tf.GradientTape() as tape: 62 | predictions = self.parser(inputs, training=True) 63 | loss = self.loss_func(targets, predictions) 64 | 65 | gradients = tape.gradient(loss, self.parser.trainable_variables) 66 | self.optimizer.apply_gradients(zip(gradients, self.parser.trainable_variables)) 67 | return loss 68 | 69 | @tf.function 70 | def val_step(self, inputs): 71 | inputs, targets = inputs 72 | predictions = self.parser(inputs, training=False) 73 | loss = self.loss_func(targets, predictions) 74 | return loss 75 | 76 | def save(self, name): 77 | self.checkpoint.write(file_prefix="./models/parsers/%s/%s" % (self.type, name)) 78 | 79 | def load(self, name): 80 | self.checkpoint.read(name) 81 | -------------------------------------------------------------------------------- /invoicenet/parsing/parsers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Tradeshift 2 | # Copyright (c) 2020 Sarthak Mittal 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | 22 | import os 23 | import tensorflow as tf 24 | 25 | from ..acp.data import InvoiceData 26 | 27 | 28 | class Parser(tf.keras.Model): 29 | def __init__(self): 30 | super(Parser, self).__init__() 31 | 32 | def restore(self): 33 | """ 34 | Must return a tuple of (scope, restore_file_path). 35 | """ 36 | raise NotImplementedError() 37 | 38 | 39 | class NoOpParser(Parser): 40 | def __init__(self): 41 | super(NoOpParser, self).__init__() 42 | 43 | def restore(self): 44 | return None 45 | 46 | def call(self, inputs, training=None, mask=None): 47 | x, context = inputs 48 | return x 49 | 50 | 51 | class OptionalParser(Parser): 52 | 53 | def __init__(self, delegate: Parser, seq_out): 54 | super(OptionalParser, self).__init__() 55 | self.seq_out = seq_out 56 | self.delegate = delegate 57 | self.dense_1 = tf.keras.layers.Dense(1) 58 | 59 | def restore(self): 60 | return self.delegate.restore() 61 | 62 | def call(self, inputs, training=None, mask=None): 63 | x, context = inputs 64 | parsed = self.delegate(inputs, training, mask) 65 | empty_answer = tf.fill([tf.shape(x)[0], self.seq_out], InvoiceData.eos_idx) 66 | empty_answer = tf.one_hot(empty_answer, InvoiceData.n_output) # (bs, seq_out, n_out) 67 | logit_empty = self.dense_1(context) # (bs, 1) 68 | return parsed + tf.expand_dims(logit_empty, axis=2) * empty_answer 69 | 70 | 71 | class AmountParser(Parser): 72 | """ 73 | You should pre-train this parser to parse amount otherwise it's hard to learn jointly. 74 | """ 75 | def __init__(self): 76 | super(AmountParser, self).__init__() 77 | os.makedirs(r"./models/parsers/amount", exist_ok=True) 78 | 79 | self.encoder = tf.keras.layers.Bidirectional( 80 | tf.keras.layers.LSTM(128, unit_forget_bias=True, return_sequences=True), 81 | name="encoder") 82 | self.decoder = tf.keras.layers.LSTM( 83 | 128, unit_forget_bias=True, return_sequences=True, name="decoder") 84 | 85 | self.encoder_dense = tf.keras.layers.Dense(128) 86 | self.decoder_dense = tf.keras.layers.Dense(128) 87 | self.attention_dense = tf.keras.layers.Dense(1) 88 | self.prob_dense = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid) 89 | self.gen_dense = tf.keras.layers.Dense(InvoiceData.n_output) 90 | 91 | def restore(self): 92 | return r"./models/parsers/amount/best" 93 | 94 | def call(self, inputs, training=None, mask=None): 95 | x, context = inputs 96 | 97 | # encoder 98 | h_in = self.encoder(x) 99 | h_in = tf.expand_dims(h_in, axis=2) # (bs, seq_in, 1, 256) 100 | 101 | # decoder 102 | out_input = tf.zeros((tf.shape(x)[0], InvoiceData.seq_amount, 1)) 103 | h_out = self.decoder(out_input) 104 | h_out = tf.expand_dims(h_out, axis=1) # (bs, 1, seq_out, 128) 105 | 106 | # Bahdanau attention 107 | att = tf.math.tanh(self.decoder_dense(h_out) + self.encoder_dense(h_in)) # (bs, seq_in, seq_out, 128) 108 | att = self.attention_dense(att) # (bs, seq_in, seq_out, 1) 109 | att = tf.math.softmax(att, axis=1) # (bs, seq_in, seq_out, 1) 110 | 111 | attended_h = tf.reduce_sum(att * h_in, axis=1) # (bs, seq_out, 128) 112 | 113 | p_gen = self.gen_dense(attended_h) # (bs, seq_out, 1) 114 | p_copy = (1 - p_gen) 115 | 116 | # Generate 117 | gen = self.gen_dense(attended_h) # (bs, seq_out, n_out) 118 | 119 | # Copy 120 | copy = tf.math.log(tf.reduce_sum(att * tf.expand_dims(x, axis=2), axis=1) + 1e-8) # (bs, seq_out, n_out) 121 | 122 | output_logits = p_copy * copy + p_gen * gen 123 | return output_logits 124 | 125 | 126 | class DateParser(Parser): 127 | """ 128 | You should pre-train this parser to parse dates otherwise it's hard to learn jointly. 129 | """ 130 | def __init__(self): 131 | super(DateParser, self).__init__() 132 | os.makedirs(r"./models/parsers/date", exist_ok=True) 133 | 134 | self.conv_block = tf.keras.Sequential() 135 | for _ in range(4): 136 | self.conv_block.add(tf.keras.layers.Conv1D(128, 3, padding='same', activation=tf.keras.activations.relu)) 137 | self.conv_block.add(tf.keras.layers.MaxPool1D(2, 2)) 138 | 139 | self.dense_block = tf.keras.Sequential() 140 | for _ in range(3): 141 | self.dense_block.add(tf.keras.layers.Dense(256, activation=tf.keras.activations.relu)) 142 | 143 | self.dropout = tf.keras.layers.Dropout(0.5) 144 | self.dense_out = tf.keras.layers.Dense(InvoiceData.seq_date * InvoiceData.n_output) 145 | 146 | def restore(self): 147 | return r"./models/parsers/date/best" 148 | 149 | def call(self, inputs, training=None, mask=None): 150 | x, context = inputs 151 | x = self.conv_block(x) # (bs, 8, 128) 152 | x = tf.reduce_sum(x, axis=1) # (bs, 128) 153 | x = tf.concat([x, context], axis=1) # (bs, 256) 154 | x = self.dense_block(x, 256) # (bs, 256) 155 | x = self.dropout(x, training=training) # (bs, 256) 156 | x = self.dense_out(x) # (bs, seq_out * n_out) 157 | return tf.reshape(x, (-1, InvoiceData.seq_date, InvoiceData.n_output)) 158 | -------------------------------------------------------------------------------- /models/parsers/amount/best.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/models/parsers/amount/best.data-00000-of-00001 -------------------------------------------------------------------------------- /models/parsers/amount/best.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/models/parsers/amount/best.index -------------------------------------------------------------------------------- /models/parsers/date/best.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/models/parsers/date/best.data-00000-of-00001 -------------------------------------------------------------------------------- /models/parsers/date/best.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/models/parsers/date/best.index -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import glob 23 | import json 24 | import argparse 25 | 26 | from invoicenet import FIELDS 27 | from invoicenet.acp.acp import AttendCopyParse 28 | 29 | 30 | def main(): 31 | ap = argparse.ArgumentParser() 32 | 33 | ap.add_argument("--field", nargs='+', type=str, required=True, choices=FIELDS.keys(), 34 | help="field to train parser for") 35 | ap.add_argument("--invoice", type=str, default=None, 36 | help="path to invoice pdf file") 37 | ap.add_argument("--data_dir", type=str, default='invoices/', 38 | help="path to directory containing invoice pdf files") 39 | ap.add_argument("--pred_dir", type=str, default='predictions/', 40 | help="path to directory where predictions should be stored") 41 | 42 | args = ap.parse_args() 43 | 44 | paths = [] 45 | fields = [] 46 | predictions = {} 47 | 48 | if args.invoice: 49 | if not os.path.exists(args.invoice): 50 | print("ERROR: Could not find file '{}'".format(args.invoice)) 51 | return 52 | if not args.invoice.endswith('.pdf'): 53 | print("ERROR: '{}' is not a PDF file".format(args.invoice)) 54 | return 55 | paths.append(args.invoice) 56 | else: 57 | paths = [os.path.abspath(f) for f in glob.glob(args.data_dir + "**/*.pdf", recursive=True)] 58 | 59 | if not os.path.exists('./models/invoicenet/'): 60 | print("Could not find any trained models!") 61 | return 62 | else: 63 | models = os.listdir('./models/invoicenet/') 64 | for field in args.field: 65 | if field in models: 66 | fields.append(field) 67 | else: 68 | print("Could not find a trained model for field '{}', skipping...".format(field)) 69 | 70 | for field in fields: 71 | print("\nExtracting field '{}' from {} invoices...\n".format(field, len(paths))) 72 | model = AttendCopyParse(field=field, restore=True) 73 | predictions[field] = model.predict(paths=paths) 74 | 75 | os.makedirs(args.pred_dir, exist_ok=True) 76 | for idx, filename in enumerate(paths): 77 | filename = os.path.basename(filename)[:-3] + 'json' 78 | labels = {} 79 | if os.path.exists(os.path.join(args.pred_dir, filename)): 80 | with open(os.path.join(args.pred_dir, filename), 'r') as fp: 81 | try: 82 | labels = json.load(fp) 83 | except: 84 | labels = {} 85 | with open(os.path.join(args.pred_dir, filename), 'w') as fp: 86 | print("\nFilename: {}".format(filename)) 87 | for field in predictions.keys(): 88 | labels[field] = predictions[field][idx] 89 | print(" {}: {}".format(field, labels[field])) 90 | fp.write(json.dumps(labels, indent=2)) 91 | print('\n') 92 | 93 | print("Predictions stored in '{}'".format(args.pred_dir)) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import argparse 22 | import glob 23 | import os 24 | import pdf2image 25 | import simplejson 26 | import tqdm 27 | import multiprocessing as mp 28 | 29 | from invoicenet import FIELDS, FIELD_TYPES 30 | from invoicenet.common import util 31 | 32 | 33 | def process_file(filename, out_dir, phase, ocr_engine): 34 | try: 35 | page = pdf2image.convert_from_path(filename)[0] 36 | page.save(os.path.join(out_dir, phase, os.path.basename(filename)[:-3] + 'png')) 37 | 38 | height = page.size[1] 39 | width = page.size[0] 40 | 41 | ngrams = util.create_ngrams(page, height=height, width=width, ocr_engine=ocr_engine) 42 | for ngram in ngrams: 43 | if "amount" in ngram["parses"]: 44 | ngram["parses"]["amount"] = util.normalize(ngram["parses"]["amount"], key="amount") 45 | if "date" in ngram["parses"]: 46 | ngram["parses"]["date"] = util.normalize(ngram["parses"]["date"], key="date") 47 | 48 | with open(filename[:-3] + 'json', 'r') as fp: 49 | labels = simplejson.loads(fp.read()) 50 | 51 | fields = {} 52 | for field in FIELDS: 53 | if field in labels: 54 | if FIELDS[field] == FIELD_TYPES["amount"]: 55 | fields[field] = util.normalize(labels[field], key="amount") 56 | elif FIELDS[field] == FIELD_TYPES["date"]: 57 | fields[field] = util.normalize(labels[field], key="date") 58 | else: 59 | fields[field] = labels[field] 60 | else: 61 | fields[field] = '' 62 | 63 | data = { 64 | "fields": fields, 65 | "nGrams": ngrams, 66 | "height": height, 67 | "width": width, 68 | "filename": os.path.abspath( 69 | os.path.join(out_dir, phase, os.path.basename(filename)[:-3] + 'png')) 70 | } 71 | 72 | with open(os.path.join(out_dir, phase, os.path.basename(filename)[:-3] + 'json'), 'w') as fp: 73 | fp.write(simplejson.dumps(data, indent=2)) 74 | return True 75 | 76 | except Exception as exp: 77 | print("Skipping {} : {}".format(filename, exp)) 78 | return False 79 | 80 | 81 | def main(): 82 | ap = argparse.ArgumentParser() 83 | 84 | ap.add_argument("--data_dir", type=str, required=True, 85 | help="path to directory containing invoice document images") 86 | ap.add_argument("--out_dir", type=str, default='processed_data/', 87 | help="path to save prepared data") 88 | ap.add_argument("--val_size", type=float, default=0.2, 89 | help="validation split ration") 90 | ap.add_argument("--cores", type=int, help='Number of virtual cores to parallelize over', 91 | default=max(1, (mp.cpu_count() - 2) // 2)) # To prevent IPC issues 92 | ap.add_argument("--ocr_engine", type=str, default='pytesseract', 93 | help='OCR used to extract text', choices=['pytesseract', 'aws_textract']) 94 | 95 | args = ap.parse_args() 96 | 97 | os.makedirs(os.path.join(args.out_dir, 'train'), exist_ok=True) 98 | os.makedirs(os.path.join(args.out_dir, 'val'), exist_ok=True) 99 | 100 | filenames = [os.path.abspath(f) for f in glob.glob(args.data_dir + "**/*.pdf", recursive=True)] 101 | 102 | idx = int(len(filenames) * args.val_size) 103 | train_files = filenames[idx:] 104 | val_files = filenames[:idx] 105 | 106 | print("Total: {}".format(len(filenames))) 107 | print("Training: {}".format(len(train_files))) 108 | print("Validation: {}".format(len(val_files))) 109 | 110 | for phase, filenames in [('train', train_files), ('val', val_files)]: 111 | print("Preparing {} data...".format(phase)) 112 | 113 | with tqdm.tqdm(total=len(filenames)) as pbar: 114 | pool = mp.Pool(args.cores) 115 | for filename in filenames: 116 | pool.apply_async(process_file, args=(filename, args.out_dir, phase, args.ocr_engine), 117 | callback=lambda _: pbar.update()) 118 | 119 | pool.close() 120 | pool.join() 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import setuptools 22 | 23 | from distutils import ccompiler 24 | from distutils import sysconfig 25 | import json 26 | import os 27 | import shutil 28 | import subprocess 29 | import sys 30 | import tempfile 31 | 32 | 33 | def find_in_path(filename, paths): 34 | for p in paths: 35 | fp = os.path.join(p, filename) 36 | if os.path.exists(fp): 37 | return os.path.abspath(fp) 38 | 39 | return os.path.join('usr', 'bin', filename) 40 | 41 | 42 | class InspectCudaException(Exception): 43 | pass 44 | 45 | 46 | def nvcc_compiler_settings(): 47 | """ Find nvcc and the CUDA installation """ 48 | 49 | search_paths = os.environ.get('PATH', '').split(os.pathsep) 50 | nvcc_path = find_in_path('nvcc', search_paths) 51 | default_cuda_path = os.path.join('usr', 'local', 'cuda') 52 | cuda_path = os.environ.get('CUDA_PATH', default_cuda_path) 53 | 54 | nvcc_found = os.path.exists(nvcc_path) 55 | cuda_path_found = os.path.exists(cuda_path) 56 | 57 | # Can't find either NVCC or some CUDA_PATH 58 | if not nvcc_found and not cuda_path_found: 59 | raise InspectCudaException("Neither nvcc '{}' " 60 | "or the CUDA_PATH '{}' were found!".format(nvcc_path, cuda_path)) 61 | 62 | # No NVCC, try find it in the CUDA_PATH 63 | if not nvcc_found: 64 | print("nvcc compiler not found at '{}'. " 65 | "Searching within the CUDA_PATH '{}'" 66 | .format(nvcc_path, cuda_path)) 67 | 68 | bin_dir = os.path.join(cuda_path, 'bin') 69 | nvcc_path = find_in_path('nvcc', bin_dir) 70 | nvcc_found = os.path.exists(nvcc_path) 71 | 72 | if not nvcc_found: 73 | raise InspectCudaException("nvcc not found in '{}' " 74 | "or under the CUDA_PATH at '{}' " 75 | .format(search_paths, cuda_path)) 76 | 77 | # No CUDA_PATH found, infer it from NVCC 78 | if not cuda_path_found: 79 | cuda_path = os.path.normpath( 80 | os.path.join(os.path.dirname(nvcc_path), "..")) 81 | 82 | print("CUDA_PATH not found, inferring it as '{}' " 83 | "from the nvcc location '{}'".format(cuda_path, nvcc_path)) 84 | 85 | cuda_path_found = True 86 | 87 | # Set up the compiler settings 88 | include_dirs = [] 89 | library_dirs = [] 90 | define_macros = [] 91 | 92 | if cuda_path_found: 93 | include_dirs.append(os.path.join(cuda_path, 'include')) 94 | if sys.platform == 'win32': 95 | library_dirs.append(os.path.join(cuda_path, 'bin')) 96 | library_dirs.append(os.path.join(cuda_path, 'lib', 'x64')) 97 | else: 98 | library_dirs.append(os.path.join(cuda_path, 'lib64')) 99 | library_dirs.append(os.path.join(cuda_path, 'lib')) 100 | if sys.platform == 'darwin': 101 | library_dirs.append(os.path.join(default_cuda_path, 'lib')) 102 | 103 | return { 104 | 'cuda_available': True, 105 | 'nvcc_path': nvcc_path, 106 | 'include_dirs': include_dirs, 107 | 'library_dirs': library_dirs, 108 | 'define_macros': define_macros, 109 | 'libraries': ['cudart', 'cuda'], 110 | 'language': 'c++', 111 | } 112 | 113 | 114 | def inspect_cuda_version_and_devices(compiler, settings): 115 | """ 116 | Poor mans deviceQuery. Returns CUDA_VERSION information and 117 | CUDA device information in JSON format 118 | """ 119 | try: 120 | output = build_and_run(compiler, ''' 121 | #include 122 | #include 123 | __device__ void test(int * in, int * out) 124 | { 125 | int tid = blockIdx.x*blockDim.x + threadIdx.x; 126 | out[tid] = in[tid]; 127 | } 128 | int main(int argc, char* argv[]) { 129 | printf("{\\n"); 130 | printf(" \\"cuda_version\\": %d,\\n", CUDA_VERSION); 131 | printf(" \\"devices\\": [\\n"); 132 | int nr_of_devices = 0; 133 | cudaGetDeviceCount(&nr_of_devices); 134 | for(int d=0; d < nr_of_devices; ++d) 135 | { 136 | cudaDeviceProp p; 137 | cudaGetDeviceProperties(&p, d); 138 | printf(" {\\n"); 139 | bool last = (d == nr_of_devices-1); 140 | printf(" \\"name\\": \\"%s\\",\\n", p.name); 141 | printf(" \\"major\\": %d,\\n", p.major); 142 | printf(" \\"minor\\": %d,\\n", p.minor); 143 | printf(" \\"memory\\": %lu\\n", p.totalGlobalMem); 144 | printf(" }%s\\n", last ? "" : ","); 145 | } 146 | printf(" ]\\n"); 147 | printf("}\\n"); 148 | return 0; 149 | } 150 | ''', 151 | filename='test.cu', 152 | include_dirs=settings['include_dirs'], 153 | library_dirs=settings['library_dirs'], 154 | libraries=settings['libraries']) 155 | 156 | except Exception as exp: 157 | msg = ("Running the CUDA device check " 158 | "stub failed\n{}".format(str(exp))) 159 | raise InspectCudaException(msg) 160 | 161 | return output 162 | 163 | 164 | def build_and_run(compiler, source, filename, libraries=(), 165 | include_dirs=(), library_dirs=()): 166 | temp_dir = tempfile.mkdtemp() 167 | 168 | try: 169 | fname = os.path.join(temp_dir, filename) 170 | with open(fname, 'w') as f: 171 | f.write(source) 172 | 173 | objects = compiler.compile([fname], output_dir=temp_dir, 174 | include_dirs=include_dirs) 175 | 176 | try: 177 | postargs = ['/MANIFEST'] if sys.platform == 'win32' else [] 178 | compiler.link_executable(objects, 179 | os.path.join(temp_dir, 'a'), 180 | libraries=libraries, 181 | library_dirs=library_dirs, 182 | extra_postargs=postargs, 183 | target_lang='c++') 184 | except Exception as exp: 185 | msg = ('Cannot build a stub file.\n' 186 | 'Original error: {0}'.format(exp)) 187 | raise InspectCudaException(msg) 188 | 189 | try: 190 | out = subprocess.check_output(os.path.join(temp_dir, 'a')) 191 | return out 192 | 193 | except Exception as exp: 194 | msg = ('Cannot execute a stub file.\n' 195 | 'Original error: {0}'.format(exp)) 196 | raise InspectCudaException(msg) 197 | 198 | finally: 199 | shutil.rmtree(temp_dir, ignore_errors=True) 200 | 201 | 202 | def customize_compiler_for_nvcc(compiler, nvcc_settings): 203 | """inject deep into distutils to customize gcc/nvcc dispatch """ 204 | 205 | # tell the compiler it can process .cu files 206 | compiler.src_extensions.append('.cu') 207 | 208 | # save references to the default compiler_so and _compile methods 209 | default_compiler_so = compiler.compiler_so 210 | default_compile = compiler._compile 211 | 212 | # now redefine the _compile method. This gets executed for each 213 | # object but distutils doesn't have the ability to change compilers 214 | # based on source extension: we add it. 215 | def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts): 216 | # Use NVCC for .cu files 217 | if os.path.splitext(src)[1] == '.cu': 218 | compiler.set_executable('compiler_so', nvcc_settings['nvcc_path']) 219 | 220 | default_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) 221 | # reset the default compiler_so, which we might have changed for cuda 222 | compiler.compiler_so = default_compiler_so 223 | 224 | # inject our redefined _compile method into the class 225 | compiler._compile = _compile 226 | 227 | 228 | def inspect_cuda(): 229 | """ Return cuda device information and nvcc/cuda setup """ 230 | nvcc_settings = nvcc_compiler_settings() 231 | sysconfig.get_config_vars() 232 | nvcc_compiler = ccompiler.new_compiler() 233 | sysconfig.customize_compiler(nvcc_compiler) 234 | customize_compiler_for_nvcc(nvcc_compiler, nvcc_settings) 235 | 236 | output = inspect_cuda_version_and_devices(nvcc_compiler, nvcc_settings) 237 | 238 | return json.loads(output), nvcc_settings 239 | 240 | 241 | tensorflow_package = 'tensorflow==2.13.1' 242 | numpy_package = 'numpy<=1.24.3,>=1.22' 243 | 244 | with open("README.md", "r") as fh: 245 | long_description = fh.read() 246 | 247 | setuptools.setup(name='InvoiceNet', 248 | version='0.1', 249 | description='A deep neural network to extract intelligent information from invoice documents', 250 | url='https://github.com/naiveHobo/InvoiceNet', 251 | author='Sarthak Mittal', 252 | author_email='sarthakmittal2608@gmail.com', 253 | license='MIT', 254 | long_description=long_description, 255 | long_description_content_type="text/markdown", 256 | packages=setuptools.find_packages(), 257 | install_requires=[ 258 | tensorflow_package, 259 | numpy_package, 260 | "six~=1.15.0", 261 | "datefinder==0.7.1", 262 | "opencv-python==4.5.1.48", 263 | "pdf2image==1.14.0", 264 | "pdfplumber==0.5.27", 265 | "PyPDF2==1.27.9", 266 | "pytesseract==0.3.7", 267 | "python-dateutil==2.8.1", 268 | "PyYAML==5.4.1", 269 | "simplejson==3.17.2", 270 | "tqdm==4.59.0", 271 | "google-api-python-client", 272 | "google-cloud-vision" 273 | ]) 274 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import os 22 | import argparse 23 | 24 | from invoicenet import FIELDS 25 | from invoicenet.common import trainer 26 | from invoicenet.acp.acp import AttendCopyParse 27 | from invoicenet.acp.data import InvoiceData 28 | 29 | 30 | def main(): 31 | ap = argparse.ArgumentParser() 32 | 33 | ap.add_argument("--field", type=str, choices=FIELDS.keys(), 34 | help="field to train parser for") 35 | ap.add_argument("--batch_size", type=int, default=8, 36 | help="batch size for training") 37 | ap.add_argument("--restore", action="store_true", 38 | help="restore from checkpoint") 39 | ap.add_argument("--data_dir", type=str, default='processed_data/', 40 | help="path to directory containing prepared data") 41 | ap.add_argument("--steps", type=int, default=50000, 42 | help="maximum number of training steps") 43 | ap.add_argument("--early_stop_steps", type=int, default=0, 44 | help="stop training if validation doesn't improve " 45 | "for a given number of steps, disabled when 0 (default)") 46 | 47 | args = ap.parse_args() 48 | 49 | train_data = InvoiceData.create_dataset(field=args.field, 50 | data_dir=os.path.join(args.data_dir, 'train/'), 51 | batch_size=args.batch_size) 52 | val_data = InvoiceData.create_dataset(field=args.field, 53 | data_dir=os.path.join(args.data_dir, 'val/'), 54 | batch_size=args.batch_size) 55 | 56 | print("Training...") 57 | trainer.train( 58 | model=AttendCopyParse(field=args.field, restore=args.restore), 59 | train_data=train_data, 60 | val_data=val_data, 61 | total_steps=args.steps, 62 | early_stop_steps=args.early_stop_steps 63 | ) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /train_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import argparse 22 | 23 | from invoicenet.common import trainer 24 | from invoicenet.parsing.parser import Parser 25 | from invoicenet.parsing.data import ParseData 26 | from invoicenet.acp.data import InvoiceData 27 | 28 | 29 | def main(): 30 | ap = argparse.ArgumentParser() 31 | 32 | ap.add_argument("--field", type=str, required=True, choices=["amount", "date"], 33 | help="field to train parser for") 34 | ap.add_argument("--batch_size", type=int, default=128, 35 | help="batch size for training") 36 | ap.add_argument("--restore", action="store_true", 37 | help="restore from checkpoint") 38 | ap.add_argument("--steps", type=int, default=50000, 39 | help="maximum number of training steps") 40 | ap.add_argument("--early_stop_steps", type=int, default=0, 41 | help="stop training if validation doesn't improve " 42 | "for a given number of steps, disabled when 0 (default)") 43 | 44 | args = ap.parse_args() 45 | 46 | output_length = {"date": InvoiceData.seq_date, "amount": InvoiceData.seq_amount}[args.field] 47 | 48 | train_data = ParseData.create_dataset( 49 | path='invoicenet/parsing/data/%s/train.tsv' % args.field, 50 | output_length=output_length, 51 | batch_size=args.batch_size) 52 | 53 | val_data = ParseData.create_dataset( 54 | path='invoicenet/parsing/data/%s/valid.tsv' % args.field, 55 | output_length=output_length, 56 | batch_size=args.batch_size) 57 | 58 | print("Training...") 59 | trainer.train( 60 | model=Parser(field=args.field, restore=args.restore), 61 | train_data=train_data, 62 | val_data=val_data, 63 | total_steps=args.steps, 64 | early_stop_steps=args.early_stop_steps 65 | ) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Sarthak Mittal 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | from tkinter import Tk 22 | from invoicenet.gui.trainer import Trainer 23 | 24 | 25 | def main(): 26 | root = Tk() 27 | Trainer(root) 28 | root.mainloop() 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /widgets/begin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/begin.png -------------------------------------------------------------------------------- /widgets/clear_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/clear_all.png -------------------------------------------------------------------------------- /widgets/clear_page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/clear_page.png -------------------------------------------------------------------------------- /widgets/extract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/extract.png -------------------------------------------------------------------------------- /widgets/first.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/first.png -------------------------------------------------------------------------------- /widgets/fullscreen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/fullscreen.png -------------------------------------------------------------------------------- /widgets/help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/help.png -------------------------------------------------------------------------------- /widgets/help.txt: -------------------------------------------------------------------------------- 1 | Welcome to InvoiceNet help! 2 | 3 | InvoiceNet is a deep neural network to extract intelligent information from invoice documents. 4 | 5 | InvoiceNet has a lot to offer. Here is a small description of some of the tools you will have to your disposal. 6 | 7 | --- 8 | 9 | Open Files: 10 | 11 | The 'Open Files' function allows you to open one or multiple files into InvoiceNet. 12 | 13 | On clicking the button, you're prompted with a dialog box where you can select all the files (PDF, PNG, JPG) you want to open. Once you're done, you can cycle through each document in the list and extract information from each document. 14 | 15 | PDF files are opened directly into the viewer. However, when an image file is opened, Google's OCR, Tesseract, is ran on the image in order to extract readable text. 16 | 17 | --- 18 | 19 | Open Directory: 20 | 21 | The 'Open Directory' function works in the same way as the 'Open Files' function, but it allows you to open all files inside a particular directory in a single shot. 22 | 23 | --- 24 | 25 | Set Save Directory: 26 | 27 | The 'Set Save Directory' function allows you to set the directory where extracted information will be saved when 'Save Information' button is clicked. 28 | 29 | --- 30 | 31 | Clear Page: 32 | 33 | The 'Clear Page' option lets you clear any modification you made to the current page. 34 | 35 | --- 36 | 37 | Search Text: 38 | 39 | The 'Search Text' option allows you to search for text in the current page. 40 | 41 | On clicking the button, you are prompted to enter the text you want to search. InvoiceNet then searches the entire page and marks any word that contains the text you searched for. 42 | 43 | If text search fails, it's possible that there is no extractable text in your invoice. In this case, try again after running the OCR on the selected invoice. 44 | 45 | --- 46 | 47 | Extract Text: 48 | 49 | The 'Extract Text' option allows you to draw a bounding box on the current page and get the text that you marked with the bounding box. 50 | 51 | On clicking the button, the viewer becomes drawable. You're expected to draw a bounding box around the text you want to extract from the page. Once done, click on the 'Extract Text' button again. 52 | 53 | On doing so, you will be prompted with the word that's closest to the bounding box you drew. 54 | 55 | If text extraction fails, it's possible that there is no extractable text in your invoice. In this case, try again after running the OCR on the selected invoice. 56 | 57 | --- 58 | 59 | Run OCR: 60 | 61 | The 'Run OCR' option allows you to run Google's Optical Character Reading (OCR) model, Tesseract, on the current invoice. 62 | 63 | This comes in handy when an image file is opened using InvoiceNet. 64 | 65 | Another scenario where you may have to use the OCR is when a PDF document contains embedded images and you need to extract text from the embedded images. 66 | 67 | --- 68 | 69 | Clear Invoice Queue: 70 | 71 | This option can be used to clear the list of invoices that have been loaded into InvoiceNet. 72 | 73 | --- 74 | 75 | Next/Previous File: 76 | 77 | The 'Next File' and 'Previous File' options allow you to cycle through the list of files that were opened using InvoiceNet. 78 | 79 | --- 80 | 81 | Viewer: 82 | 83 | The Page Tool Bar gives you functions to manipulate your documents. 84 | - Use the 'Next Page' and 'Previous Page' buttons to cycle through the different pages in your invoice. 85 | - Use the 'Last Page' and 'First Page' buttons to directly go to the pages on the extreme ends. 86 | - Use the 'Zoom In', 'Zoom Out' and 'Fit-To-Screen' buttons to make the current page bigger or smaller. 87 | - Use the 'Rotate' button to rotate pages. 88 | 89 | --- 90 | 91 | Extract: 92 | 93 | Before extracting information from an invoice, you need to select the fields that should be extracted. The field checkboxes become active automatically if InvoiceNet is able to find a trained model for a field. 94 | 95 | After selecting the fields to be extracted, click on the 'Extract' button and InvoiceNet will try to extract those fields from the current invoice and display the extracted information in the logging space. 96 | 97 | --- 98 | 99 | Save Information: 100 | 101 | After information extraction, the logging space will show you the extracted fields in the form of a python dictionary. You can edit the extracted fields here if there are any discrepancies. 102 | 103 | Once you are done, click on the 'Save Information' button and InvoiceNet will save your extracted information as a JSON file. If a JSON file with the same name already exists, the newly extracted fields will be added to this file. 104 | 105 | --- 106 | 107 | Load Labels: 108 | 109 | The 'Load Labels' button can be used to upload labels from an existing JSON file into the logging space. 110 | -------------------------------------------------------------------------------- /widgets/labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/labels.png -------------------------------------------------------------------------------- /widgets/last.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/last.png -------------------------------------------------------------------------------- /widgets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/logo.png -------------------------------------------------------------------------------- /widgets/next.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/next.png -------------------------------------------------------------------------------- /widgets/next_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/next_file.png -------------------------------------------------------------------------------- /widgets/ocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/ocr.png -------------------------------------------------------------------------------- /widgets/open_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/open_dir.png -------------------------------------------------------------------------------- /widgets/open_dir_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/open_dir_small.png -------------------------------------------------------------------------------- /widgets/open_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/open_file.png -------------------------------------------------------------------------------- /widgets/options.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/options.png -------------------------------------------------------------------------------- /widgets/prepare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/prepare.png -------------------------------------------------------------------------------- /widgets/prev.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/prev.png -------------------------------------------------------------------------------- /widgets/prev_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/prev_file.png -------------------------------------------------------------------------------- /widgets/rotate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/rotate.png -------------------------------------------------------------------------------- /widgets/save_as.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/save_as.png -------------------------------------------------------------------------------- /widgets/search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/search.png -------------------------------------------------------------------------------- /widgets/stop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/stop.png -------------------------------------------------------------------------------- /widgets/text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/text.png -------------------------------------------------------------------------------- /widgets/text_focus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/text_focus.png -------------------------------------------------------------------------------- /widgets/toolbar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/toolbar.png -------------------------------------------------------------------------------- /widgets/zoomin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/zoomin.png -------------------------------------------------------------------------------- /widgets/zoomout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naiveHobo/InvoiceNet/5765cde6a4b23a4855cf123d2b62e58f54b40e8b/widgets/zoomout.png --------------------------------------------------------------------------------