├── .gitignore ├── LICENSE ├── README.md ├── docs ├── images │ └── docformer-architecture.png └── index.md ├── examples └── docformer_pl │ ├── document_image_classification_on_rvl_cdip │ └── 4.Document-image-classification-with-docformer.ipynb │ ├── pre_training_task_on_idl_dataset │ ├── docformer_pre_training_1_preparing_idl_dataset.py │ ├── docformer_pre_training_2_preparing_idl_pytorch_dataset_for_docformer.py │ └── docformer_pre_training_3_modeling_for_docformer.py │ └── token_classification_on_funsd │ ├── Token_Classification_Part_1.ipynb │ ├── Token_Classification_Part_2.ipynb │ └── Token_Classification_Part_3.ipynb ├── images └── docformer-architecture.png ├── mkdocs.yml ├── scripts ├── .coveragerc ├── docs ├── env └── lint ├── setup.cfg ├── setup.py ├── src └── docformer │ ├── README.md │ ├── dataset.py │ ├── dataset_pytorch.py │ ├── modeling.py │ ├── modeling_pl.py │ ├── train_accelerator.py │ ├── train_accelerator_mlm_ir.py │ └── utils.py └── tests └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocFormer - PyTorch 2 | 3 | ![docformer architecture](images/docformer-architecture.png) 4 | 5 | Implementation of [DocFormer: End-to-End Transformer for Document Understanding](https://arxiv.org/abs/2106.11539), a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU) 📄📄📄. 6 | 7 | DocFormer is a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU). In addition, DocFormer is pre-trained in an unsupervised fashion using carefully designed tasks which encourage multi-modal interaction. DocFormer uses text, vision and spatial features and combines them using a novel multi-modal self-attention layer. DocFormer also shares learned spatial embeddings across modalities which makes it easy for the model to correlate text to visual tokens and vice versa. DocFormer is evaluated on 4 different datasets each with strong baselines. DocFormer achieves state-of-the-art results on all of them, sometimes beating models 4x its size (in no. of parameters). 8 | 9 | The official implementation was not released by the authors. 10 | 11 | ## NOTE: 12 | 13 | I tried to pre-train DocFormer on the task of MLM on a subset of [IDL Dataset](https://github.com/furkanbiten/idl_data). The weights are [here](https://www.kaggle.com/code/akarshu121/downloading-docformer-weights), and the associated kaggle notebook for fine-tuning on FUNSD is attached [here](https://www.kaggle.com/code/akarshu121/ckpt-docformer-for-token-classification-on-funsd/notebook?scriptVersionId=118952199) 14 | 15 | ## Install 16 | 17 | There might be some issues with the import of pytessaract, so in order to debug that, we need to write 18 | 19 | ```python 20 | pip install pytesseract 21 | sudo apt install tesseract-ocr 22 | ``` 23 | 24 | And then, 25 | 26 | ```python 27 | !git clone https://github.com/shabie/docformer.git 28 | 29 | 30 | ``` 31 | 32 | ## Usage 33 | 34 | ```python 35 | import sys 36 | sys.path.extend(['docformer/src/docformer/']) 37 | import modeling, dataset 38 | from transformers import BertTokenizerFast 39 | 40 | 41 | config = { 42 | "coordinate_size": 96, 43 | "hidden_dropout_prob": 0.1, 44 | "hidden_size": 768, 45 | "image_feature_pool_shape": [7, 7, 256], 46 | "intermediate_ff_size_factor": 4, 47 | "max_2d_position_embeddings": 1000, 48 | "max_position_embeddings": 512, 49 | "max_relative_positions": 8, 50 | "num_attention_heads": 12, 51 | "num_hidden_layers": 12, 52 | "pad_token_id": 0, 53 | "shape_size": 96, 54 | "vocab_size": 30522, 55 | "layer_norm_eps": 1e-12, 56 | } 57 | 58 | fp = "filepath/to/the/image.tif" 59 | 60 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 61 | encoding = dataset.create_features(fp, tokenizer, add_batch_dim=True) 62 | 63 | feature_extractor = modeling.ExtractFeatures(config) 64 | docformer = modeling.DocFormerEncoder(config) 65 | v_bar, t_bar, v_bar_s, t_bar_s = feature_extractor(encoding) 66 | output = docformer(v_bar, t_bar, v_bar_s, t_bar_s) # shape (1, 512, 768) 67 | ``` 68 | 69 | ## License 70 | 71 | MIT 72 | 73 | ## Maintainers 74 | 75 | - [uakarsh](https://github.com/uakarsh) 76 | - [shabie](https://github.com/shabie) 77 | 78 | ## Contribute 79 | 80 | 81 | ## Citations 82 | 83 | ```bibtex 84 | @InProceedings{Appalaraju_2021_ICCV, 85 | author = {Appalaraju, Srikar and Jasani, Bhavan and Kota, Bhargava Urala and Xie, Yusheng and Manmatha, R.}, 86 | title = {DocFormer: End-to-End Transformer for Document Understanding}, 87 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 88 | month = {October}, 89 | year = {2021}, 90 | pages = {993-1003} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /docs/images/docformer-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shabie/docformer/c7fcfdf71fb174784c3dba932b0f0daa6f05a92f/docs/images/docformer-architecture.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # DocFormer - PyTorch 2 | 3 | ![docformer architecture](images/docformer-architecture.png) 4 | 5 | Implementation of [DocFormer: End-to-End Transformer for Document Understanding](https://arxiv.org/abs/2106.11539), a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU) 📄📄📄. 6 | 7 | The official implementation was not released by the authors. 8 | 9 | ## Install 10 | 11 | 12 | ## Usage 13 | 14 | See `examples` for usage. 15 | 16 | 17 | ## Contribute 18 | 19 | 20 | ## Citations 21 | 22 | ```bibtex 23 | @misc{appalaraju2021docformer, 24 | title = {DocFormer: End-to-End Transformer for Document Understanding}, 25 | author = {Srikar Appalaraju and Bhavan Jasani and Bhargava Urala Kota and Yusheng Xie and R. Manmatha}, 26 | year = {2021}, 27 | eprint = {2106.11539}, 28 | archivePrefix = {arXiv}, 29 | primaryClass = {cs.CV} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /examples/docformer_pl/document_image_classification_on_rvl_cdip/4.Document-image-classification-with-docformer.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown","outputs":[],"execution_count":0},{"cell_type":"markdown","source":"\n\n\n## 1. Introduction: \n* This notebook is a tutorial to the multi-modal architecture DocFormer (mainly for the purpose of Document Understanding).\n* We would take in, the test-images of the RVL-CDIP Dataset, and then would train the model on a subset of the dataset\n* We would also be logging the metrics with the help of Weights and Biases","metadata":{}},{"cell_type":"markdown","source":"## A small Introduction about the Model:\n\n\n\nDocFormer is a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU). In addition, DocFormer is pre-trained in an unsupervised fashion using carefully designed tasks which encourage multi-modal interaction. DocFormer uses text, vision and spatial features and combines them using a novel multi-modal self-attention layer. DocFormer also shares learned spatial embeddings across modalities which makes it easy for the model to correlate text to visual tokens and vice versa. DocFormer is evaluated on 4 different datasets each with strong baselines. DocFormer achieves state-of-the-art results on all of them, sometimes beating models 4x its size (in no. of parameters).\n\nFor more understanding of the model and its code implementation, one can visit [here](https://github.com/uakarsh/docformer). So, let us go on to see what this model has to offer\n\nThe report for this entire run is attached [here](https://wandb.ai/iakarshu/RVL%20CDIP%20with%20DocFormer%20New%20Version/reports/Performance-of-DocFormer-with-RVL-CDIP-Test-Dataset--VmlldzoyMTI3NTM4)\n\n\n\n\n\n\n### An Interactive Demo for the same can be found on 🤗 space [here](https://huggingface.co/spaces/iakarshu/docformer_for_document_classification)\n\n### Installing the Libraries ⚙️:","metadata":{}},{"cell_type":"code","source":"## Installing the dependencies (might take some time)\n\n!pip install -q pytesseract\n!sudo apt install -q tesseract-ocr\n!pip install -q transformers\n!pip install -q pytorch-lightning\n!pip install -q einops\n!pip install -q tqdm\n!pip install -q 'Pillow==7.1.2'\n!pip install -q datasets\n!pip install wandb\n!pip install torchmetrics","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Cloning the repository\n!git clone https://github.com/uakarsh/docformer.git","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:22.079626Z","iopub.execute_input":"2022-06-07T07:08:22.080057Z","iopub.status.idle":"2022-06-07T07:08:24.114693Z","shell.execute_reply.started":"2022-06-07T07:08:22.080014Z","shell.execute_reply":"2022-06-07T07:08:24.113469Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Logging into wandb\n\nimport wandb\nfrom kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\nsecret_value_0 = user_secrets.get_secret(\"wandb_api\")\nwandb.login(key=secret_value_0)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:24.145099Z","iopub.execute_input":"2022-06-07T07:08:24.146015Z","iopub.status.idle":"2022-06-07T07:08:25.65243Z","shell.execute_reply.started":"2022-06-07T07:08:24.145939Z","shell.execute_reply":"2022-06-07T07:08:25.651299Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 2. Libraries 📘:","metadata":{}},{"cell_type":"code","source":"## Importing the libraries\n\nimport warnings\nwarnings.simplefilter(\"ignore\", UserWarning)\nwarnings.simplefilter(\"ignore\", RuntimeWarning)\n\nimport os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\nimport numpy as np\nimport pandas as pd\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import Dataset,DataLoader\n\nimport torch.nn.functional as F\nimport torchvision.models as models\n\n## Adding the path of docformer to system path\nimport sys\nsys.path.append('./docformer/src/docformer/')\n\n## Importing the functions from the DocFormer Repo\nfrom dataset import create_features\nfrom modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor\nfrom transformers import BertTokenizerFast","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:25.654629Z","iopub.execute_input":"2022-06-07T07:08:25.65569Z","iopub.status.idle":"2022-06-07T07:08:31.919208Z","shell.execute_reply.started":"2022-06-07T07:08:25.655654Z","shell.execute_reply":"2022-06-07T07:08:31.91827Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Hyperparameters\n\nseed = 42\ntarget_size = (500, 384)\n\n## Setting some hyperparameters\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n## One can change this configuration and try out new combination\nconfig = {\n \"coordinate_size\": 96, ## (768/8), 8 for each of the 8 coordinates of x, y\n \"hidden_dropout_prob\": 0.1,\n \"hidden_size\": 768,\n \"image_feature_pool_shape\": [7, 7, 256],\n \"intermediate_ff_size_factor\": 4,\n \"max_2d_position_embeddings\": 1024,\n \"max_position_embeddings\": 128,\n \"max_relative_positions\": 8,\n \"num_attention_heads\": 12,\n \"num_hidden_layers\": 12,\n \"pad_token_id\": 0,\n \"shape_size\": 96,\n \"vocab_size\": 30522,\n \"layer_norm_eps\": 1e-12,\n}","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:31.920835Z","iopub.execute_input":"2022-06-07T07:08:31.921618Z","iopub.status.idle":"2022-06-07T07:08:31.989757Z","shell.execute_reply.started":"2022-06-07T07:08:31.921575Z","shell.execute_reply":"2022-06-07T07:08:31.988868Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## A small note 🗒️: \nHere, for the purpose of Demo I would be using only 250 Images per class, and would train the model on it. Definintely for a data hungry model such as transformers, such a small data is not enough, but let us see what are the results on it.","metadata":{}},{"cell_type":"code","source":"from tqdm.auto import tqdm\n\n## For the purpose of prediction\nid2label = []\nlabel2id = {}\n\ncurr_class = 0\n## Preparing the Dataset\nbase_directory = '../input/the-rvlcdip-dataset-test/test'\ndict_of_img_labels = {'img':[], 'label':[]}\n\nmax_sample_per_class = 250\n\nfor label in tqdm(os.listdir(base_directory)):\n img_path = os.path.join(base_directory, label)\n \n count = 0\n if label not in label2id:\n label2id[label] = curr_class\n curr_class+=1\n id2label.append(label)\n \n for img in os.listdir(img_path):\n if count>max_sample_per_class:\n break\n \n curr_img_path = os.path.join(img_path, img)\n dict_of_img_labels['img'].append(curr_img_path)\n dict_of_img_labels['label'].append(label2id[label])\n count+=1","metadata":{"execution":{"iopub.status.busy":"2022-06-07T18:38:57.319504Z","iopub.execute_input":"2022-06-07T18:38:57.319852Z","iopub.status.idle":"2022-06-07T18:39:02.430787Z","shell.execute_reply.started":"2022-06-07T18:38:57.319827Z","shell.execute_reply":"2022-06-07T18:39:02.429826Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import pandas as pd\ndf = pd.DataFrame(dict_of_img_labels)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.612078Z","iopub.execute_input":"2022-06-07T07:08:36.612682Z","iopub.status.idle":"2022-06-07T07:08:36.621744Z","shell.execute_reply.started":"2022-06-07T07:08:36.612639Z","shell.execute_reply":"2022-06-07T07:08:36.620752Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from sklearn.model_selection import train_test_split as tts\ntrain_df, valid_df = tts(df, random_state = seed, stratify = df['label'], shuffle = True)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.623323Z","iopub.execute_input":"2022-06-07T07:08:36.623888Z","iopub.status.idle":"2022-06-07T07:08:36.973838Z","shell.execute_reply.started":"2022-06-07T07:08:36.623845Z","shell.execute_reply":"2022-06-07T07:08:36.972854Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_df = train_df.reset_index().drop(columns = ['index'], axis = 1)\nvalid_df = valid_df.reset_index().drop(columns = ['index'], axis = 1)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.976098Z","iopub.execute_input":"2022-06-07T07:08:36.976661Z","iopub.status.idle":"2022-06-07T07:08:36.990587Z","shell.execute_reply.started":"2022-06-07T07:08:36.976608Z","shell.execute_reply":"2022-06-07T07:08:36.989443Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 3. Making the dataset 💽:\n\nThe main idea behind making the dataset is, to pre-process the input into a given format, and then provide the input to the model. So, simply just the image path, and the other configurations, and boom 💥, you would get the desired pre-processed input","metadata":{}},{"cell_type":"code","source":"## Creating the dataset\n\nclass RVLCDIPData(Dataset):\n \n def __init__(self, image_list, label_list, target_size, tokenizer, max_len = 512, transform = None):\n \n self.image_list = image_list\n self.label_list = label_list\n self.target_size = target_size\n self.tokenizer = tokenizer\n self.max_len = max_len\n self.transform = transform\n \n def __len__(self):\n return len(self.image_list)\n \n def __getitem__(self, idx):\n img_path = self.image_list[idx]\n label = self.label_list[idx]\n \n ## More on this, in the repo mentioned previously\n final_encoding = create_features(\n img_path,\n self.tokenizer,\n add_batch_dim=False,\n target_size=self.target_size,\n max_seq_length=self.max_len,\n path_to_save=None,\n save_to_disk=False,\n apply_mask_for_mlm=False,\n extras_for_debugging=False,\n use_ocr = True\n )\n if self.transform is not None:\n ## Note that, ToTensor is already applied on the image\n final_encoding['resized_scaled_img'] = self.transform(final_encoding['resized_scaled_img'])\n \n \n keys_to_reshape = ['x_features', 'y_features', 'resized_and_aligned_bounding_boxes']\n for key in keys_to_reshape:\n final_encoding[key] = final_encoding[key][:self.max_len]\n \n final_encoding['label'] = torch.as_tensor(label).long()\n return final_encoding","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.992277Z","iopub.execute_input":"2022-06-07T07:08:36.993034Z","iopub.status.idle":"2022-06-07T07:08:37.006426Z","shell.execute_reply.started":"2022-06-07T07:08:36.992996Z","shell.execute_reply":"2022-06-07T07:08:37.005431Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Defining the tokenizer\ntokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:37.00823Z","iopub.execute_input":"2022-06-07T07:08:37.009227Z","iopub.status.idle":"2022-06-07T07:08:38.530076Z","shell.execute_reply.started":"2022-06-07T07:08:37.009184Z","shell.execute_reply":"2022-06-07T07:08:38.529039Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from torchvision import transforms\n\n## Normalization to these mean and std (I have seen some tutorials used this, and also in image reconstruction, so used it)\ntransform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n ","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.531653Z","iopub.execute_input":"2022-06-07T07:08:38.53226Z","iopub.status.idle":"2022-06-07T07:08:38.538266Z","shell.execute_reply.started":"2022-06-07T07:08:38.532219Z","shell.execute_reply":"2022-06-07T07:08:38.537005Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_ds = RVLCDIPData(train_df['img'].tolist(), train_df['label'].tolist(),\n target_size, tokenizer, config['max_position_embeddings'], transform)\nval_ds = RVLCDIPData(valid_df['img'].tolist(), valid_df['label'].tolist(),\n target_size, tokenizer,config['max_position_embeddings'], transform)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.540023Z","iopub.execute_input":"2022-06-07T07:08:38.540791Z","iopub.status.idle":"2022-06-07T07:08:38.550323Z","shell.execute_reply.started":"2022-06-07T07:08:38.540746Z","shell.execute_reply":"2022-06-07T07:08:38.549236Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"### Collate Function:\n\nDefinitely collate function is an amazing function for using the dataloader as per our wish. More on collate function can be known from [here](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)","metadata":{}},{"cell_type":"code","source":"def collate_fn(data_bunch):\n\n '''\n A function for the dataloader to return a batch dict of given keys\n\n data_bunch: List of dictionary\n '''\n\n dict_data_bunch = {}\n\n for i in data_bunch:\n for (key, value) in i.items():\n if key not in dict_data_bunch:\n dict_data_bunch[key] = []\n dict_data_bunch[key].append(value)\n\n for key in list(dict_data_bunch.keys()):\n dict_data_bunch[key] = torch.stack(dict_data_bunch[key], axis = 0)\n\n return dict_data_bunch","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.552109Z","iopub.execute_input":"2022-06-07T07:08:38.552819Z","iopub.status.idle":"2022-06-07T07:08:38.566185Z","shell.execute_reply.started":"2022-06-07T07:08:38.552777Z","shell.execute_reply":"2022-06-07T07:08:38.561295Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 4. Defining the DataModule 📖\n\n* A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:\n\n* A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and predict_dataloader(s) along with the matching transforms and data processing/downloads steps required.\n\n\n","metadata":{}},{"cell_type":"code","source":"import pytorch_lightning as pl\n\nclass DataModule(pl.LightningDataModule):\n\n def __init__(self, train_dataset, val_dataset, batch_size = 4):\n\n super(DataModule, self).__init__()\n self.train_dataset = train_dataset\n self.val_dataset = val_dataset\n self.batch_size = batch_size\n\n def train_dataloader(self):\n return DataLoader(self.train_dataset, batch_size = self.batch_size, \n collate_fn = collate_fn, shuffle = True)\n \n def val_dataloader(self):\n return DataLoader(self.val_dataset, batch_size = self.batch_size,\n collate_fn = collate_fn, shuffle = False)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.567907Z","iopub.execute_input":"2022-06-07T07:08:38.568576Z","iopub.status.idle":"2022-06-07T07:08:39.574703Z","shell.execute_reply.started":"2022-06-07T07:08:38.568533Z","shell.execute_reply":"2022-06-07T07:08:39.573807Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"datamodule = DataModule(train_ds, val_ds)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:39.576481Z","iopub.execute_input":"2022-06-07T07:08:39.577345Z","iopub.status.idle":"2022-06-07T07:08:39.582506Z","shell.execute_reply.started":"2022-06-07T07:08:39.577298Z","shell.execute_reply":"2022-06-07T07:08:39.581255Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 5. Modeling Part 🏎️\n\n1. Firstly, we would define the pytorch model with our configurations, in which the class labels would be ranging from 0 to 15\n2. Secondly, we would encode it in the PyTorch Lightening module, and boom 💥 our work of defining the model is done","metadata":{}},{"cell_type":"code","source":"class DocFormerForClassification(nn.Module):\n \n def __init__(self, config):\n super(DocFormerForClassification, self).__init__()\n\n self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])\n self.embeddings = DocFormerEmbeddings(config)\n self.lang_emb = LanguageFeatureExtractor()\n self.config = config\n self.dropout = nn.Dropout(config['hidden_dropout_prob'])\n self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = len(id2label)) ## Number of Classes\n self.encoder = DocFormerEncoder(config)\n\n def forward(self, batch_dict):\n\n x_feat = batch_dict['x_features']\n y_feat = batch_dict['y_features']\n\n token = batch_dict['input_ids']\n img = batch_dict['resized_scaled_img']\n\n v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)\n v_bar = self.resnet(img)\n t_bar = self.lang_emb(token)\n out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)\n out = self.linear_layer(out)\n out = out[:, 0, :]\n return out","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:39.58445Z","iopub.execute_input":"2022-06-07T07:08:39.584879Z","iopub.status.idle":"2022-06-07T07:08:39.597737Z","shell.execute_reply.started":"2022-06-07T07:08:39.584836Z","shell.execute_reply":"2022-06-07T07:08:39.596891Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Defining pytorch lightning model\nfrom sklearn.metrics import accuracy_score, confusion_matrix\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport numpy as np\nimport torchmetrics\n\nclass DocFormer(pl.LightningModule):\n\n def __init__(self, config , lr = 5e-5):\n super(DocFormer, self).__init__()\n \n self.save_hyperparameters()\n self.config = config\n self.docformer = DocFormerForClassification(config)\n \n self.num_classes = len(id2label)\n self.train_accuracy_metric = torchmetrics.Accuracy()\n self.val_accuracy_metric = torchmetrics.Accuracy()\n self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)\n self.precision_macro_metric = torchmetrics.Precision(\n average=\"macro\", num_classes=self.num_classes\n )\n self.recall_macro_metric = torchmetrics.Recall(\n average=\"macro\", num_classes=self.num_classes\n )\n self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n def forward(self, batch_dict):\n logits = self.docformer(batch_dict)\n return logits\n\n def training_step(self, batch, batch_idx):\n logits = self.forward(batch)\n\n loss = nn.CrossEntropyLoss()(logits, batch['label'])\n preds = torch.argmax(logits, 1)\n\n ## Calculating the accuracy score\n train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n\n ## Logging\n self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)\n self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)\n\n return loss\n \n def validation_step(self, batch, batch_idx):\n logits = self.forward(batch)\n loss = nn.CrossEntropyLoss()(logits, batch['label'])\n preds = torch.argmax(logits, 1)\n \n labels = batch['label']\n # Metrics\n valid_acc = self.val_accuracy_metric(preds, labels)\n precision_macro = self.precision_macro_metric(preds, labels)\n recall_macro = self.recall_macro_metric(preds, labels)\n precision_micro = self.precision_micro_metric(preds, labels)\n recall_micro = self.recall_micro_metric(preds, labels)\n f1 = self.f1_metric(preds, labels)\n\n # Logging metrics\n self.log(\"valid/loss\", loss, prog_bar=True, on_step=True, logger=True)\n self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n \n return {\"label\": batch['label'], \"logits\": logits}\n\n def validation_epoch_end(self, outputs):\n labels = torch.cat([x[\"label\"] for x in outputs])\n logits = torch.cat([x[\"logits\"] for x in outputs])\n preds = torch.argmax(logits, 1)\n\n wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})\n self.logger.experiment.log(\n {\"roc\": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}\n )\n \n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:12.786194Z","iopub.execute_input":"2022-06-07T07:09:12.786554Z","iopub.status.idle":"2022-06-07T07:09:12.869343Z","shell.execute_reply.started":"2022-06-07T07:09:12.786523Z","shell.execute_reply":"2022-06-07T07:09:12.86839Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 6. Summing it up and running the entire procedure 🏃","metadata":{}},{"cell_type":"code","source":"from pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\ndef main():\n datamodule = DataModule(train_ds, val_ds)\n docformer = DocFormer(config)\n\n checkpoint_callback = ModelCheckpoint(\n dirpath=\"./models\", monitor=\"valid/loss\", mode=\"min\"\n )\n early_stopping_callback = EarlyStopping(\n monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n )\n \n wandb.init(config=config, project=\"RVL CDIP with DocFormer New Version\")\n wandb_logger = WandbLogger(project=\"RVL CDIP with DocFormer New Version\", entity=\"iakarshu\")\n ## https://www.tutorialexample.com/implement-reproducibility-in-pytorch-lightning-pytorch-lightning-tutorial/\n pl.seed_everything(seed, workers=True)\n trainer = pl.Trainer(\n default_root_dir=\"logs\",\n gpus=(1 if torch.cuda.is_available() else 0),\n max_epochs=1,\n fast_dev_run=False,\n logger=wandb_logger,\n callbacks=[checkpoint_callback, early_stopping_callback],\n deterministic=True\n )\n trainer.fit(docformer, datamodule)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:21.086225Z","iopub.execute_input":"2022-06-07T07:09:21.086667Z","iopub.status.idle":"2022-06-07T07:09:21.097492Z","shell.execute_reply.started":"2022-06-07T07:09:21.086636Z","shell.execute_reply":"2022-06-07T07:09:21.096609Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"if __name__ == \"__main__\":\n main()","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:21.297147Z","iopub.execute_input":"2022-06-07T07:09:21.299276Z","iopub.status.idle":"2022-06-07T07:13:58.468599Z","shell.execute_reply.started":"2022-06-07T07:09:21.299227Z","shell.execute_reply":"2022-06-07T07:13:58.467592Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## References:\n\n1. [MLOps Repo](https://github.com/graviraja/MLOps-Basics) (For the integration of model and data with PyTorch Lightening) \n2. [PyTorch Lightening Docs](https://pytorch-lightning.readthedocs.io/en/stable/index.html) For all the doubts and bugs\n3. [My Repo](https://github.com/uakarsh/docformer) For downloading the model and pre-processing steps\n4. Unspash for Images\n5. Google for other stuffs","metadata":{}},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /examples/docformer_pl/pre_training_task_on_idl_dataset/docformer_pre_training_1_preparing_idl_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DocFormer Pre-training : 1. Preparing IDL Dataset 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/17G4Io-2BOLx5YwKymgIwO_g14HovwbBb 8 | """ 9 | 10 | ## Refer here for the dataset: https://github.com/furkanbiten/idl_data 11 | # (IDL dataset was also used in the pre-training of LaTr), might take time to download the dataset 12 | 13 | 14 | ## The below lines of code download the sample dataset 15 | # !wget http://datasets.cvc.uab.es/UCSF_IDL/Samples/ocr_imgs_sample.zip 16 | # !unzip /content/ocr_imgs_sample.zip 17 | # !rm /content/ocr_imgs_sample.zip 18 | 19 | # Commented out IPython magic to ensure Python compatibility. 20 | # ## Installing the dependencies (might take some time) 21 | # 22 | # %%capture 23 | # !pip install pytesseract 24 | # !sudo apt install tesseract-ocr 25 | # !pip install transformers 26 | # !pip install pytorch-lightning 27 | # !pip install einops 28 | # !pip install tqdm 29 | # !pip install 'Pillow==7.1.2' 30 | # !pip install PyPDF2 31 | 32 | ## Getting the JSON File 33 | 34 | import json 35 | 36 | ## For reading the PDFs 37 | from PyPDF2 import PdfReader 38 | import io 39 | from PIL import Image, ImageDraw 40 | 41 | ## Standard library 42 | import os 43 | 44 | pdf_path = "./sample/pdfs" 45 | ocr_path = "./sample/OCR" 46 | 47 | ## Image property 48 | 49 | resize_scale = (500, 500) 50 | 51 | from typing import List 52 | 53 | def normalize_box(box : List[int], width : int, height : int, size : tuple = resize_scale): 54 | """ 55 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 56 | just like calculating percentage except takes 1000 instead of 100. 57 | """ 58 | return [ 59 | int(size[0] * (box[0] / width)), 60 | int(size[1] * (box[1] / height)), 61 | int(size[0] * (box[2] / width)), 62 | int(size[1] * (box[3] / height)), 63 | ] 64 | 65 | ## Function to get the images from the PDFs as well as the OCRs for the corresponding images 66 | 67 | def get_image_ocrs_from_path(pdf_file_path : str, ocr_file_path : str, resize_scale = resize_scale): 68 | 69 | ## Getting the image list, since the pdfs can contain many image 70 | reader = PdfReader(pdf_file_path) 71 | img_list = [] 72 | for i in range(len(reader.pages)): 73 | page = reader.pages[i] 74 | for image_file_object in page.images: 75 | 76 | stream = io.BytesIO(image_file_object.data) 77 | img = Image.open(stream).convert("RGB") 78 | img_list.append(img) 79 | 80 | json_entry = json.load(open(ocr_file_path))[1] 81 | json_entry =[x for x in json_entry["Blocks"] if "Text" in x] 82 | 83 | pages = [x["Page"] for x in json_entry] 84 | ocrs = {pg : [] for pg in set(pages)} 85 | 86 | for entry in json_entry: 87 | bbox = entry["Geometry"]["BoundingBox"] 88 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 89 | bbox = [x, y, x + w, y + h] 90 | bbox = normalize_box(bbox, width = 1, height = 1, size = resize_scale) 91 | ocrs[entry["Page"]].append({"word" : entry["Text"], "bbox" : bbox}) 92 | 93 | return img_list, ocrs 94 | 95 | # sample_pdf_folder = os.path.join(pdf_path, sorted(os.listdir(pdf_path))[0]) 96 | # sample_ocr_folder = os.path.join(ocr_path, sorted(os.listdir(ocr_path))[0]) 97 | 98 | # sample_pdf = os.path.join(sample_pdf_folder, sample_pdf_folder.split("/")[-1] + ".pdf") 99 | # sample_ocr = os.path.join(sample_ocr_folder, os.listdir(sample_ocr_folder)[0]) 100 | 101 | # img_list, ocrs = get_image_ocrs_from_path(sample_pdf, sample_ocr) 102 | 103 | """## Preparing the Pytorch Dataset""" 104 | 105 | from tqdm.auto import tqdm 106 | 107 | img_list = [] 108 | ocr_list = [] 109 | 110 | pdf_files = sorted(os.listdir(pdf_path))[:30] ## Using only 30 since, google session gets crashed 111 | ocr_files = sorted(os.listdir(ocr_path))[:30] 112 | 113 | for pdf, ocr in tqdm(zip(pdf_files, ocr_files), total = len(pdf_files)): 114 | pdf = os.path.join(pdf_path, pdf, pdf + '.pdf') 115 | ocr = os.path.join(ocr_path, ocr) 116 | ocr = os.path.join(ocr, os.listdir(ocr)[0]) 117 | img, ocrs = get_image_ocrs_from_path(pdf, ocr) 118 | 119 | for i in range(len(img)): 120 | img_list.append(img[i]) 121 | ocr_list.append(ocrs[i+1]) ## Pages are 1, 2, 3 hence 0 + 1, 1 + 1, 2 + 1 122 | 123 | """## Visualizing the OCRs""" 124 | 125 | index = 43 126 | curr_img = img_list[index].resize(resize_scale) 127 | curr_ocr = ocr_list[index] 128 | 129 | # create rectangle image 130 | draw_on_img = ImageDraw.Draw(curr_img) 131 | 132 | for it in curr_ocr: 133 | box = it["bbox"] 134 | draw_on_img.rectangle(box, outline ="violet") 135 | 136 | curr_img 137 | 138 | -------------------------------------------------------------------------------- /examples/docformer_pl/pre_training_task_on_idl_dataset/docformer_pre_training_2_preparing_idl_pytorch_dataset_for_docformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DocFormer Pre-training : 2. Preparing IDL PyTorch Dataset for DocFormer 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1dDMABz_EunCdg3S1neJ5fMPkRSmTcAfd 8 | """ 9 | 10 | ## Refer here for the dataset: https://github.com/furkanbiten/idl_data 11 | # (IDL dataset was also used in the pre-training of LaTr), might take time to download the dataset 12 | 13 | # !wget http://datasets.cvc.uab.es/UCSF_IDL/Samples/ocr_imgs_sample.zip 14 | # !unzip /content/ocr_imgs_sample.zip 15 | # !rm /content/ocr_imgs_sample.zip 16 | 17 | # Commented out IPython magic to ensure Python compatibility. 18 | # ## Installing the dependencies (might take some time) 19 | # 20 | # %%capture 21 | # !pip install pytesseract 22 | # !sudo apt install tesseract-ocr 23 | # !pip install transformers 24 | # !pip install pytorch-lightning 25 | # !pip install einops 26 | # !pip install tqdm 27 | # !pip install 'Pillow==7.1.2' 28 | # !pip install PyPDF2 29 | 30 | ## Cloning the repository 31 | # !git clone https://github.com/uakarsh/docformer.git 32 | 33 | ## Getting the JSON File 34 | import json 35 | 36 | ## For reading the PDFs 37 | from PyPDF2 import PdfReader 38 | import io 39 | from PIL import Image, ImageDraw 40 | 41 | ## A bit of code taken from here : https://www.kaggle.com/code/akarshu121/docformer-for-token-classification-on-funsd/notebook 42 | import os 43 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 44 | 45 | ## PyTorch Libraries 46 | import torch 47 | from torchvision.transforms import ToTensor 48 | from torch.utils.data import Dataset, DataLoader 49 | import torch.nn.functional as F 50 | import torch.nn as nn 51 | 52 | ## Adding the path of docformer to system path 53 | import sys 54 | sys.path.append('./docformer/src/docformer/') 55 | 56 | ## Importing the functions from the DocFormer Repo 57 | from dataset import resize_align_bbox, get_centroid, get_pad_token_id_start_index, get_relative_distance 58 | from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor 59 | 60 | ## Transformer librarues 61 | from transformers import BertTokenizerFast 62 | 63 | ## PyTorch Lightning library 64 | import pytorch_lightning as pl 65 | from pytorch_lightning.callbacks import ModelCheckpoint 66 | 67 | pdf_path = "./sample/pdfs" 68 | ocr_path = "./sample/OCR" 69 | 70 | ## Image property 71 | 72 | resize_scale = (500, 500) 73 | 74 | from typing import List 75 | 76 | def normalize_box(box : List[int], width : int, height : int, size : tuple = resize_scale): 77 | """ 78 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 79 | just like calculating percentage except takes 1000 instead of 100. 80 | """ 81 | return [ 82 | int(size[0] * (box[0] / width)), 83 | int(size[1] * (box[1] / height)), 84 | int(size[0] * (box[2] / width)), 85 | int(size[1] * (box[3] / height)), 86 | ] 87 | 88 | ## Function to get the images from the PDFs as well as the OCRs for the corresponding images 89 | 90 | def get_image_ocrs_from_path(pdf_file_path : str, ocr_file_path : str, resize_scale = resize_scale): 91 | 92 | ## Getting the image list, since the pdfs can contain many image 93 | reader = PdfReader(pdf_file_path) 94 | img_list = [] 95 | for i in range(len(reader.pages)): 96 | page = reader.pages[i] 97 | for image_file_object in page.images: 98 | 99 | stream = io.BytesIO(image_file_object.data) 100 | img = Image.open(stream).convert("RGB").resize(resize_scale) 101 | img_list.append(img) 102 | 103 | json_entry = json.load(open(ocr_file_path))[1] 104 | json_entry =[x for x in json_entry["Blocks"] if "Text" in x] 105 | 106 | pages = [x["Page"] for x in json_entry] 107 | ocrs = {pg : [] for pg in set(pages)} 108 | 109 | for entry in json_entry: 110 | bbox = entry["Geometry"]["BoundingBox"] 111 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 112 | bbox = [x, y, x + w, y + h] 113 | bbox = normalize_box(bbox, width = 1, height = 1, size = resize_scale) 114 | ocrs[entry["Page"]].append({"word" : entry["Text"], "bbox" : bbox}) 115 | 116 | return img_list, ocrs 117 | 118 | # sample_pdf_folder = os.path.join(pdf_path, sorted(os.listdir(pdf_path))[0]) 119 | # sample_ocr_folder = os.path.join(ocr_path, sorted(os.listdir(ocr_path))[0]) 120 | 121 | # sample_pdf = os.path.join(sample_pdf_folder, sample_pdf_folder.split("/")[-1] + ".pdf") 122 | # sample_ocr = os.path.join(sample_ocr_folder, os.listdir(sample_ocr_folder)[0]) 123 | 124 | # img_list, ocrs = get_image_ocrs_from_path(sample_pdf, sample_ocr) 125 | 126 | """## Preparing the Pytorch Dataset""" 127 | 128 | from tqdm.auto import tqdm 129 | 130 | img_list = [] 131 | ocr_list = [] 132 | 133 | pdf_files = sorted(os.listdir(pdf_path))[:30] ## Using only 30 since, google session gets crashed 134 | ocr_files = sorted(os.listdir(ocr_path))[:30] 135 | 136 | for pdf, ocr in tqdm(zip(pdf_files, ocr_files), total = len(pdf_files)): 137 | pdf = os.path.join(pdf_path, pdf, pdf + '.pdf') 138 | ocr = os.path.join(ocr_path, ocr) 139 | ocr = os.path.join(ocr, os.listdir(ocr)[0]) 140 | img, ocrs = get_image_ocrs_from_path(pdf, ocr) 141 | 142 | for i in range(len(img)): 143 | img_list.append(img[i]) 144 | ocr_list.append(ocrs[i+1]) ## Pages are 1, 2, 3 hence 0 + 1, 1 + 1, 2 + 1 145 | 146 | """## Visualizing the OCRs""" 147 | 148 | index = 17 149 | curr_img = img_list[index] 150 | curr_ocr = ocr_list[index] 151 | 152 | # create rectangle image 153 | draw_on_img = ImageDraw.Draw(curr_img) 154 | 155 | for it in curr_ocr: 156 | box = it["bbox"] 157 | draw_on_img.rectangle(box, outline ="violet") 158 | 159 | curr_img 160 | 161 | """## Creating features for DocFormer""" 162 | 163 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 164 | 165 | def get_tokens_with_boxes(unnormalized_word_boxes, word_ids,max_seq_len = 512, pad_token_box = [0, 0, 0, 0]): 166 | 167 | # assert len(unnormalized_word_boxes) == len(word_ids), this should not be applied, since word_ids may have higher 168 | # length and the bbox corresponding to them may not exist 169 | 170 | unnormalized_token_boxes = [] 171 | 172 | i = 0 173 | for word_idx in word_ids: 174 | if word_idx is None: 175 | break 176 | unnormalized_token_boxes.append(unnormalized_word_boxes[word_idx]) 177 | i+=1 178 | 179 | # all remaining are padding tokens so why add them in a loop one by one 180 | num_pad_tokens = len(word_ids) - i - 1 181 | if num_pad_tokens > 0: 182 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens) 183 | 184 | 185 | if len(unnormalized_token_boxes) 0: 167 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens) 168 | 169 | 170 | if len(unnormalized_token_boxes)=0.3', 20 | 'torch>=1.6', 21 | 'torchvision', 22 | 'pytesseract', 23 | 'transformers', 24 | 'pytesseract>=0.3.8', 25 | ], 26 | classifiers=[ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 30 | 'License :: OSI Approved :: MIT License', 31 | 'Programming Language :: Python :: 3.7', 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /src/docformer/README.md: -------------------------------------------------------------------------------- 1 | # About the files: 2 | 3 | 4 | ```python 5 | 1. dataset.py 6 | ``` 7 | * This file contains the various functions, which are required for the pre-processing of the image as for an example, running the OCR for getting the bounding box of each word, getting the relative distance between the coordinates of the different word, performing the tokenization, and some optional arguments which are required for saving the different features on the disk, and many more!!! 8 | 9 | 10 | * The code can be modified as per the requirement, as for an example, if we have to perform a specific task of segmenting the image or classifying a document, with some modifications in the ```create_features``` function, it can be achieved. 11 | 12 | ```python 13 | 2. dataset_pytorch.py 14 | ``` 15 | * This file inherits the functions from the ```dataset.py```, however it creates a Dataset object of the file stored in the disk, and the function can be modified for some augmentations as well 16 | * The Dataset object is required for the purpose of training the model in PyTorch, however in TensorFlow, the numpy version would work instead of Dataset and DataLoader object 17 | 18 | ```python 19 | 3. modeling.py 20 | ``` 21 | * This file is the brain of everything in this repo, the file contains the various functions, which have been written with least approximation in mind, and as close to the paper, it contains the ```multi-head attention```, the various embedding functions, and a lot of stuffs, which are mentioned in the paper. In order, to understand this file properly, one of the suggestion is to, open the code and the paper side by side, and that would work. 22 | * And, for the task specific requirements, one can import ```DocFormerEncoder```, and attach one head for the task-specific requirement, however, the last function ```ShallowDecoder``` is a work in progress, which is for the Image Reconstruction purpose as mentioned in the paper (in the pre-training part) 23 | 24 | ```python 25 | 4. modeling_pl.py 26 | ``` 27 | * This file is basically, for the parallelization of the training and validation part, so that the utilization of multiple GPUs becomes easy 28 | * This file contains the integration of PyTorch Lightening, with the DocFormer model, so that the coding part becomes less and the task specific things can be done. 29 | * For task specific requirements, one can modify the ```Model``` class, with the modification being in the ```training``` and ```validation``` step, where the specific loss functions can be integrated and thats it!!! 30 | 31 | 32 | ```python 33 | 5. train_accelerator.py 34 | and 35 | 6. train_accelerator_mlm_ir.py 36 | ``` 37 | * These files are also codes for the purpose of training so that the coding requirements becomes less. 38 | * The only thing is that, these code inherits the ```Accelerator``` of "Hugging Face", for the purpose of Parallelization of the task to multiple GPUs 39 | * ```train_accelerator.py``` contains the function of running the code of Pre-training the model with ```MLM``` task 40 | * ```train_accelerator_mlm_ir.py``` contains the function of running the code of Pre-training the model with ```MLM and Image Reconstruction (IR)``` task, however we are thinking of making a file which contains the options of training according to specific task 41 | 42 | 43 | 44 | ```python 45 | 7. utils.py 46 | ``` 47 | 48 | * File, which contains the utility function for performing the unsupervised task of Text Describe Image (as mentioned in the paper: DocFormer) 49 | 50 | How to use it? 51 | * Let us assume, that all the entries of the dataset have been stored somewhere. 52 | * Now, we can get the length of the entries, and that has to be passed to the function ```labels_for_tdi```, which would give the arguments as well as the labels, now, iterate through each of the arguments, and for ith entry, create a new entry in the dicitionary (data format for docformer, refer to dataset.py, create_features function), and map it to the resized_scaled_img of arr[i], 53 | 54 | * i.e in terms of pseduocode, 55 | Assume, that 56 | ```python 57 | 58 | l-> list of dictionary format, data points of docformer 59 | d_arr, labels = labels_for_tdi(n) 60 | for i, j in enumerate(d_arr): 61 | l[i]['d_resized_scaled_img'] = l[j]['resized_scaled_img'] 62 | l[i]['label_for_tdi'] = labels[i] 63 | ``` 64 | 65 | And then, the rest follows by passing the argument `use_tdi`, for the model 66 | 67 | Using it with the model: 68 | For the purpose of integrating TDI with model, the following instruction would be helpful: 69 | 70 | Let us assume, we want to do MLM + IR + TDI: 71 | 72 | 1. As for the first and the third task, attach a head, and forward propagate the data, calculate the weighted loss of these two task, and store it 73 | 2. In case of the second task, you have to forward propagate it again with the same dataset, but with the argument use_tdi = True, and calculate the binary cross entropy loss with the `label_for_tdi` key, and add the weighted sum of it to the stored loss, and then backpropagate it 74 | -------------------------------------------------------------------------------- /src/docformer/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import pickle 4 | from functools import lru_cache 5 | import pytesseract 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | from torchvision.transforms import ToTensor 10 | 11 | PAD_TOKEN_BOX = [0, 0, 0, 0] 12 | GRID_SIZE = 1000 13 | 14 | 15 | def normalize_box(box, width, height, size=1000): 16 | """ 17 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 18 | just like calculating percentage except takes 1000 instead of 100. 19 | """ 20 | return [ 21 | int(size * (box[0] / width)), 22 | int(size * (box[1] / height)), 23 | int(size * (box[2] / width)), 24 | int(size * (box[3] / height)), 25 | ] 26 | 27 | 28 | @lru_cache(maxsize=10) 29 | def resize_align_bbox(bbox, orig_w, orig_h, target_w, target_h): 30 | x_scale = target_w / orig_w 31 | y_scale = target_h / orig_h 32 | orig_left, orig_top, orig_right, orig_bottom = bbox 33 | x = int(np.round(orig_left * x_scale)) 34 | y = int(np.round(orig_top * y_scale)) 35 | xmax = int(np.round(orig_right * x_scale)) 36 | ymax = int(np.round(orig_bottom * y_scale)) 37 | return [x, y, xmax, ymax] 38 | 39 | 40 | def get_topleft_bottomright_coordinates(df_row): 41 | left, top, width, height = df_row["left"], df_row["top"], df_row["width"], df_row["height"] 42 | return [left, top, left + width, top + height] 43 | 44 | 45 | def apply_ocr(image_fp): 46 | """ 47 | Returns words and its bounding boxes from an image 48 | """ 49 | image = Image.open(image_fp) 50 | width, height = image.size 51 | 52 | ocr_df = pytesseract.image_to_data(image, output_type="data.frame") 53 | ocr_df = ocr_df.dropna().reset_index(drop=True) 54 | float_cols = ocr_df.select_dtypes("float").columns 55 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) 56 | ocr_df = ocr_df.replace(r"^\s*$", np.nan, regex=True) 57 | ocr_df = ocr_df.dropna().reset_index(drop=True) 58 | words = list(ocr_df.text.apply(lambda x: str(x).strip())) 59 | actual_bboxes = ocr_df.apply(get_topleft_bottomright_coordinates, axis=1).values.tolist() 60 | 61 | # add as extra columns 62 | assert len(words) == len(actual_bboxes) 63 | return {"words": words, "bbox": actual_bboxes} 64 | 65 | def get_tokens_with_boxes(unnormalized_word_boxes, pad_token_box, word_ids,max_seq_len = 512): 66 | 67 | # assert len(unnormalized_word_boxes) == len(word_ids), this should not be applied, since word_ids may have higher 68 | # length and the bbox corresponding to them may not exist 69 | 70 | unnormalized_token_boxes = [] 71 | 72 | for i, word_idx in enumerate(word_ids): 73 | if word_idx is None: 74 | break 75 | unnormalized_token_boxes.append(unnormalized_word_boxes[word_idx]) 76 | 77 | # all remaining are padding tokens so why add them in a loop one by one 78 | num_pad_tokens = len(word_ids) - i - 1 79 | if num_pad_tokens > 0: 80 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens) 81 | 82 | 83 | if len(unnormalized_token_boxes)= pad_tokens_start_idx: 113 | a_rel_x.append([0] * 8) 114 | a_rel_y.append([0] * 8) 115 | continue 116 | 117 | curr = bboxes[i] 118 | next = bboxes[i+1] 119 | 120 | a_rel_x.append( 121 | [ 122 | curr[0], # top left x 123 | curr[2], # bottom right x 124 | curr[2] - curr[0], # width 125 | next[0] - curr[0], # diff top left x 126 | next[0] - curr[0], # diff bottom left x 127 | next[2] - curr[2], # diff top right x 128 | next[2] - curr[2], # diff bottom right x 129 | centroids[i+1][0] - centroids[i][0], 130 | ] 131 | ) 132 | 133 | a_rel_y.append( 134 | [ 135 | curr[1], # top left y 136 | curr[3], # bottom right y 137 | curr[3] - curr[1], # height 138 | next[1] - curr[1], # diff top left y 139 | next[3] - curr[3], # diff bottom left y 140 | next[1] - curr[1], # diff top right y 141 | next[3] - curr[3], # diff bottom right y 142 | centroids[i+1][1] - centroids[i][1], 143 | ] 144 | ) 145 | 146 | # For the last word 147 | 148 | a_rel_x.append([0]*8) 149 | a_rel_y.append([0]*8) 150 | 151 | 152 | return a_rel_x, a_rel_y 153 | 154 | 155 | 156 | def apply_mask(inputs, tokenizer): 157 | inputs = torch.as_tensor(inputs) 158 | rand = torch.rand(inputs.shape) 159 | # where the random array is less than 0.15, we set true 160 | mask_arr = (rand < 0.15) * (inputs != tokenizer.cls_token_id) * (inputs != tokenizer.pad_token_id) 161 | # create selection from mask_arr 162 | selection = torch.flatten(mask_arr.nonzero()).tolist() 163 | # apply selection pad_tokens_start_idx to inputs.input_ids, adding MASK tokens 164 | inputs[selection] = 103 165 | return inputs 166 | 167 | 168 | def read_image_and_extract_text(image): 169 | original_image = Image.open(image).convert("RGB") 170 | return apply_ocr(image) 171 | 172 | 173 | def create_features( 174 | image, 175 | tokenizer, 176 | add_batch_dim=False, 177 | target_size=(500,384), # This was the resolution used by the authors 178 | max_seq_length=512, 179 | path_to_save=None, 180 | save_to_disk=False, 181 | apply_mask_for_mlm=False, 182 | extras_for_debugging=False, 183 | use_ocr = True, 184 | bounding_box = None, 185 | words = None 186 | ): 187 | 188 | # step 1: read original image and extract OCR entries 189 | try: 190 | original_image = Image.open(image).convert("RGB") 191 | except: 192 | original_image = Image.new(mode = "RGB", size = ((500, 500)), color = (255, 255, 255)) 193 | 194 | if (use_ocr == False) and (bounding_box == None or words == None): 195 | raise Exception('Please provide the bounding box and words or pass the argument "use_ocr" = True') 196 | 197 | if use_ocr == True: 198 | entries = apply_ocr(image) 199 | bounding_box = entries["bbox"] 200 | words = entries["words"] 201 | 202 | CLS_TOKEN_BOX = [0, 0, *original_image.size] # Can be variable, but as per the paper, they have mentioned that it covers the whole image 203 | # step 2: resize image 204 | resized_image = original_image.resize(target_size) 205 | 206 | # step 3: normalize image to a grid of 1000 x 1000 (to avoid the problem of differently sized images) 207 | width, height = original_image.size 208 | normalized_word_boxes = [ 209 | normalize_box(bbox, width, height, GRID_SIZE) for bbox in bounding_box 210 | ] 211 | assert len(words) == len(normalized_word_boxes), "Length of words != Length of normalized words" 212 | 213 | # step 4: tokenize words and get their bounding boxes (one word may split into multiple tokens) 214 | encoding = tokenizer(words, 215 | padding="max_length", 216 | max_length=max_seq_length, 217 | is_split_into_words=True, 218 | truncation=True, 219 | add_special_tokens=False) 220 | 221 | unnormalized_token_boxes = get_tokens_with_boxes(bounding_box, 222 | PAD_TOKEN_BOX, 223 | encoding.word_ids()) 224 | 225 | # step 5: add special tokens and truncate seq. to maximum length 226 | unnormalized_token_boxes = [CLS_TOKEN_BOX] + unnormalized_token_boxes[:-1] 227 | # add CLS token manually to avoid autom. addition of SEP too (as in the paper) 228 | encoding["input_ids"] = [tokenizer.cls_token_id] + encoding["input_ids"][:-1] 229 | 230 | # step 6: Add bounding boxes to the encoding dict 231 | encoding["unnormalized_token_boxes"] = unnormalized_token_boxes 232 | 233 | # step 7: apply mask for the sake of pre-training 234 | if apply_mask_for_mlm: 235 | encoding["mlm_labels"] = encoding["input_ids"] 236 | encoding["input_ids"] = apply_mask(encoding["input_ids"], tokenizer) 237 | assert len(encoding["mlm_labels"]) == max_seq_length, "Length of mlm_labels != Length of max_seq_length" 238 | 239 | assert len(encoding["input_ids"]) == max_seq_length, "Length of input_ids != Length of max_seq_length" 240 | assert len(encoding["attention_mask"]) == max_seq_length, "Length of attention mask != Length of max_seq_length" 241 | assert len(encoding["token_type_ids"]) == max_seq_length, "Length of token type ids != Length of max_seq_length" 242 | 243 | # step 8: normalize the image 244 | encoding["resized_scaled_img"] = ToTensor()(resized_image) 245 | 246 | # step 9: apply mask for the sake of pre-training 247 | if apply_mask_for_mlm: 248 | encoding["mlm_labels"] = encoding["input_ids"] 249 | encoding["input_ids"] = apply_mask(encoding["input_ids"], tokenizer) 250 | 251 | # step 10: rescale and align the bounding boxes to match the resized image size (typically 224x224) 252 | resized_and_aligned_bboxes = [] 253 | 254 | for bbox in unnormalized_token_boxes: 255 | # performing the normalization of the bounding box 256 | resized_and_aligned_bboxes.append(resize_align_bbox(tuple(bbox), *original_image.size, *target_size)) 257 | 258 | encoding["resized_and_aligned_bounding_boxes"] = resized_and_aligned_bboxes 259 | 260 | # step 11: add the relative distances in the normalized grid 261 | bboxes_centroids = get_centroid(resized_and_aligned_bboxes) 262 | pad_token_start_index = get_pad_token_id_start_index(words, encoding, tokenizer) 263 | a_rel_x, a_rel_y = get_relative_distance(resized_and_aligned_bboxes, bboxes_centroids, pad_token_start_index) 264 | 265 | # step 12: convert all to tensors 266 | for k, v in encoding.items(): 267 | encoding[k] = torch.as_tensor(encoding[k]) 268 | 269 | encoding.update({ 270 | "x_features": torch.as_tensor(a_rel_x, dtype=torch.int32), 271 | "y_features": torch.as_tensor(a_rel_y, dtype=torch.int32), 272 | }) 273 | 274 | # step 13: add tokens for debugging 275 | if extras_for_debugging: 276 | input_ids = encoding["mlm_labels"] if apply_mask_for_mlm else encoding["input_ids"] 277 | encoding["tokens_without_padding"] = tokenizer.convert_ids_to_tokens(input_ids) 278 | encoding["words"] = words 279 | 280 | 281 | # step 14: add extra dim for batch 282 | if add_batch_dim: 283 | encoding["x_features"].unsqueeze_(0) 284 | encoding["y_features"].unsqueeze_(0) 285 | encoding["input_ids"].unsqueeze_(0) 286 | encoding["resized_scaled_img"].unsqueeze_(0) 287 | 288 | # step 15: save to disk 289 | if save_to_disk: 290 | os.makedirs(path_to_save, exist_ok=True) 291 | image_name = os.path.basename(image) 292 | with open(f"{path_to_save}{image_name}.pickle", "wb") as f: 293 | pickle.dump(encoding, f) 294 | 295 | # step 16: keys to keep, resized_and_aligned_bounding_boxes have been added for the purpose to test if the bounding boxes are drawn correctly or not, it maybe removed 296 | 297 | keys = ['resized_scaled_img', 'x_features','y_features','input_ids','resized_and_aligned_bounding_boxes'] 298 | 299 | if apply_mask_for_mlm: 300 | keys.append('mlm_labels') 301 | 302 | final_encoding = {k:encoding[k] for k in keys} 303 | 304 | del encoding 305 | return final_encoding 306 | -------------------------------------------------------------------------------- /src/docformer/dataset_pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.models as models 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split as tts 11 | from torch.autograd import Variable 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision.transforms import ToTensor 14 | from transformers import AutoModel, AutoTokenizer 15 | 16 | """## Base Dataset""" 17 | 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | 20 | class DocumentDatset(Dataset): 21 | def __init__(self, entries, pathToPickleFile,pretrain= True): 22 | self.pathToPickleFile = pathToPickleFile 23 | self.entries = entries 24 | self.pretrain = pretrain 25 | 26 | def __len__(self): 27 | return len(self.entries) 28 | 29 | def __getitem__(self, index): 30 | 31 | imageName = self.entries[index] 32 | encoding = os.path.join(self.pathToPickleFile, imageName) 33 | 34 | with open(encoding, "rb") as sample: 35 | encoding = pickle.load(sample) 36 | 37 | if self.pretrain: 38 | 39 | # If the model is used for the purpose of pretraining, then there is no need for the other entries, since there would be some errors, while training 40 | 41 | del encoding['category_labels'] # Error would be created, because category label cannot be stored in the pytorch tensor 42 | del encoding['numeric_labels'] # Removed it, but this can be used for the purpose of the segmenting (as for an image_fp, in the FUNSD Dataset) 43 | del encoding['target_bbox'] # For the purpose of segmenting the different text in the image 44 | del encoding['resized_and_aligned_target_bbox'] # Resized version of the above bounding box, for 224x224 image 45 | 46 | for i in list(encoding.keys()): 47 | encoding[i] = encoding[i].to(device) 48 | 49 | # Since, we had taken the absolute value of the relative distance, we don't need to add any offset, and hence we can proceed with the model training 50 | return encoding 51 | 52 | 53 | # pathToPickleFile = 'RVL-CDIP-PickleFiles/' 54 | # entries = os.listdir(pathToPickleFile) 55 | # train_entries,val_entries = tts(entries,test_size = 0.2) 56 | # train_dataset = DocumentDatset(train_entries,pathToPickleFile) 57 | # val_dataset = DocumentDatset(val_entries,pathToPickleFile) 58 | 59 | -------------------------------------------------------------------------------- /src/docformer/modeling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | from einops import rearrange 7 | from torch import Tensor 8 | 9 | class PositionalEncoding(nn.Module): 10 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 11 | super().__init__() 12 | self.dropout = nn.Dropout(p=dropout) 13 | self.max_len = max_len 14 | self.d_model = d_model 15 | position = torch.arange(max_len).unsqueeze(1) 16 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 17 | pe = torch.zeros(1, max_len, d_model) 18 | pe[0, :, 0::2] = torch.sin(position * div_term) 19 | pe[0, :, 1::2] = torch.cos(position * div_term) 20 | self.register_buffer("pe", pe) 21 | 22 | 23 | def forward(self) -> Tensor: 24 | x = self.pe[0, : self.max_len] 25 | return self.dropout(x).unsqueeze(0) 26 | 27 | 28 | class ResNetFeatureExtractor(nn.Module): 29 | def __init__(self, hidden_dim = 512): 30 | super().__init__() 31 | 32 | # Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction 33 | 34 | resnet50 = models.resnet50(pretrained=False) 35 | modules = list(resnet50.children())[:-2] 36 | self.resnet50 = nn.Sequential(*modules) 37 | 38 | # Applying convolution and linear layer 39 | 40 | self.conv1 = nn.Conv2d(2048, 768, 1) 41 | self.relu1 = F.relu 42 | self.linear1 = nn.Linear(192, hidden_dim) 43 | 44 | def forward(self, x): 45 | x = self.resnet50(x) 46 | x = self.conv1(x) 47 | x = self.relu1(x) 48 | x = rearrange(x, "b e w h -> b e (w h)") # b -> batch, e -> embedding dim, w -> width, h -> height 49 | x = self.linear1(x) 50 | x = rearrange(x, "b e s -> b s e") # b -> batch, e -> embedding dim, s -> sequence length 51 | return x 52 | 53 | class DocFormerEmbeddings(nn.Module): 54 | """Construct the embeddings from word, position and token_type embeddings.""" 55 | 56 | def __init__(self, config): 57 | super(DocFormerEmbeddings, self).__init__() 58 | 59 | self.config = config 60 | 61 | self.position_embeddings_v = PositionalEncoding( 62 | d_model=config["hidden_size"], 63 | dropout=0.1, 64 | max_len=config["max_position_embeddings"], 65 | ) 66 | 67 | self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 68 | self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 69 | self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) 70 | self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 71 | self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 72 | self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 73 | self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 74 | self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 75 | 76 | self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 77 | self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 78 | self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) 79 | self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 80 | self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 81 | self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 82 | self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 83 | self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 84 | 85 | self.position_embeddings_t = PositionalEncoding( 86 | d_model=config["hidden_size"], 87 | dropout=0.1, 88 | max_len=config["max_position_embeddings"], 89 | ) 90 | 91 | self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 92 | self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 93 | self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) 94 | self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"]) 95 | self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"]) 96 | self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 97 | self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 98 | self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 99 | 100 | self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 101 | self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) 102 | self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) 103 | self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 104 | self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 105 | self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 106 | self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 107 | self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) 108 | 109 | self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"]) 110 | self.dropout = nn.Dropout(config["hidden_dropout_prob"]) 111 | 112 | 113 | 114 | def forward(self, x_feature, y_feature): 115 | 116 | """ 117 | Arguments: 118 | x_features of shape, (batch size, seq_len, 8) 119 | y_features of shape, (batch size, seq_len, 8) 120 | Outputs: 121 | (V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768) 122 | What are the features: 123 | 0 -> top left x/y 124 | 1 -> bottom right x/y 125 | 2 -> width/height 126 | 3 -> diff top left x/y 127 | 4 -> diff bottom left x/y 128 | 5 -> diff top right x/y 129 | 6 -> diff bottom right x/y 130 | 7 -> centroids diff x/y 131 | """ 132 | 133 | 134 | batch, seq_len = x_feature.shape[:-1] 135 | hidden_size = self.config["hidden_size"] 136 | num_feat = x_feature.shape[-1] 137 | sub_dim = hidden_size // num_feat 138 | 139 | # Clamping and adding a bias for handling negative values 140 | x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"]) 141 | x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"] 142 | 143 | y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"]) 144 | y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"] 145 | 146 | x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0]) 147 | x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1]) 148 | w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2]) 149 | x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3]) 150 | x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4]) 151 | x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5]) 152 | x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6]) 153 | x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7]) 154 | 155 | x_calculated_embedding_v = torch.cat( 156 | [ 157 | x_topleft_position_embeddings_v, 158 | x_bottomright_position_embeddings_v, 159 | w_position_embeddings_v, 160 | x_topleft_distance_to_prev_embeddings_v, 161 | x_bottomleft_distance_to_prev_embeddings_v, 162 | x_topright_distance_to_prev_embeddings_v, 163 | x_bottomright_distance_to_prev_embeddings_v , 164 | x_centroid_distance_to_prev_embeddings_v 165 | ], 166 | dim = -1 167 | ) 168 | 169 | y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0]) 170 | y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1]) 171 | h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2]) 172 | y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3]) 173 | y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4]) 174 | y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5]) 175 | y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6]) 176 | y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7]) 177 | 178 | x_calculated_embedding_v = torch.cat( 179 | [ 180 | x_topleft_position_embeddings_v, 181 | x_bottomright_position_embeddings_v, 182 | w_position_embeddings_v, 183 | x_topleft_distance_to_prev_embeddings_v, 184 | x_bottomleft_distance_to_prev_embeddings_v, 185 | x_topright_distance_to_prev_embeddings_v, 186 | x_bottomright_distance_to_prev_embeddings_v , 187 | x_centroid_distance_to_prev_embeddings_v 188 | ], 189 | dim = -1 190 | ) 191 | 192 | y_calculated_embedding_v = torch.cat( 193 | [ 194 | y_topleft_position_embeddings_v, 195 | y_bottomright_position_embeddings_v, 196 | h_position_embeddings_v, 197 | y_topleft_distance_to_prev_embeddings_v, 198 | y_bottomleft_distance_to_prev_embeddings_v, 199 | y_topright_distance_to_prev_embeddings_v, 200 | y_bottomright_distance_to_prev_embeddings_v , 201 | y_centroid_distance_to_prev_embeddings_v 202 | ], 203 | dim = -1 204 | ) 205 | 206 | v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v() 207 | 208 | 209 | 210 | x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0]) 211 | x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1]) 212 | w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2]) 213 | x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3]) 214 | x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4]) 215 | x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5]) 216 | x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6]) 217 | x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7]) 218 | 219 | x_calculated_embedding_t = torch.cat( 220 | [ 221 | x_topleft_position_embeddings_t, 222 | x_bottomright_position_embeddings_t, 223 | w_position_embeddings_t, 224 | x_topleft_distance_to_prev_embeddings_t, 225 | x_bottomleft_distance_to_prev_embeddings_t, 226 | x_topright_distance_to_prev_embeddings_t, 227 | x_bottomright_distance_to_prev_embeddings_t , 228 | x_centroid_distance_to_prev_embeddings_t 229 | ], 230 | dim = -1 231 | ) 232 | 233 | y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0]) 234 | y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1]) 235 | h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2]) 236 | y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3]) 237 | y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4]) 238 | y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5]) 239 | y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6]) 240 | y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7]) 241 | 242 | x_calculated_embedding_t = torch.cat( 243 | [ 244 | x_topleft_position_embeddings_t, 245 | x_bottomright_position_embeddings_t, 246 | w_position_embeddings_t, 247 | x_topleft_distance_to_prev_embeddings_t, 248 | x_bottomleft_distance_to_prev_embeddings_t, 249 | x_topright_distance_to_prev_embeddings_t, 250 | x_bottomright_distance_to_prev_embeddings_t , 251 | x_centroid_distance_to_prev_embeddings_t 252 | ], 253 | dim = -1 254 | ) 255 | 256 | y_calculated_embedding_t = torch.cat( 257 | [ 258 | y_topleft_position_embeddings_t, 259 | y_bottomright_position_embeddings_t, 260 | h_position_embeddings_t, 261 | y_topleft_distance_to_prev_embeddings_t, 262 | y_bottomleft_distance_to_prev_embeddings_t, 263 | y_topright_distance_to_prev_embeddings_t, 264 | y_bottomright_distance_to_prev_embeddings_t , 265 | y_centroid_distance_to_prev_embeddings_t 266 | ], 267 | dim = -1 268 | ) 269 | 270 | t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t() 271 | 272 | return v_bar_s, t_bar_s 273 | 274 | 275 | 276 | # fmt: off 277 | class PreNorm(nn.Module): 278 | def __init__(self, dim, fn): 279 | # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf 280 | super().__init__() 281 | self.norm = nn.LayerNorm(dim) 282 | self.fn = fn 283 | 284 | def forward(self, x, **kwargs): 285 | return self.fn(self.norm(x), **kwargs) 286 | 287 | 288 | class PreNormAttn(nn.Module): 289 | def __init__(self, dim, fn): 290 | # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf 291 | super().__init__() 292 | 293 | self.norm_t_bar = nn.LayerNorm(dim) 294 | self.norm_v_bar = nn.LayerNorm(dim) 295 | self.norm_t_bar_s = nn.LayerNorm(dim) 296 | self.norm_v_bar_s = nn.LayerNorm(dim) 297 | self.fn = fn 298 | 299 | def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs): 300 | return self.fn(self.norm_t_bar(t_bar), 301 | self.norm_v_bar(v_bar), 302 | self.norm_t_bar_s(t_bar_s), 303 | self.norm_v_bar_s(v_bar_s), **kwargs) 304 | 305 | 306 | class FeedForward(nn.Module): 307 | def __init__(self, dim, hidden_dim, dropout=0.): 308 | super().__init__() 309 | self.net = nn.Sequential( 310 | nn.Linear(dim, hidden_dim), 311 | nn.GELU(), 312 | nn.Dropout(dropout), 313 | nn.Linear(hidden_dim, dim), 314 | nn.Dropout(dropout) 315 | ) 316 | 317 | def forward(self, x): 318 | return self.net(x) 319 | 320 | 321 | class RelativePosition(nn.Module): 322 | 323 | def __init__(self, num_units, max_relative_position, max_seq_length): 324 | super().__init__() 325 | self.num_units = num_units 326 | self.max_relative_position = max_relative_position 327 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) 328 | self.max_length = max_seq_length 329 | range_vec_q = torch.arange(max_seq_length) 330 | range_vec_k = torch.arange(max_seq_length) 331 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 332 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 333 | final_mat = distance_mat_clipped + self.max_relative_position 334 | self.final_mat = torch.LongTensor(final_mat) 335 | nn.init.xavier_uniform_(self.embeddings_table) 336 | 337 | def forward(self, length_q, length_k): 338 | embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]] 339 | return embeddings 340 | 341 | 342 | class MultiModalAttentionLayer(nn.Module): 343 | def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout): 344 | super().__init__() 345 | assert embed_dim % n_heads == 0 346 | 347 | self.embed_dim = embed_dim 348 | self.n_heads = n_heads 349 | self.head_dim = embed_dim // n_heads 350 | 351 | self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length) 352 | self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length) 353 | 354 | # text qkv embeddings 355 | self.fc_k_text = nn.Linear(embed_dim, embed_dim) 356 | self.fc_q_text = nn.Linear(embed_dim, embed_dim) 357 | self.fc_v_text = nn.Linear(embed_dim, embed_dim) 358 | 359 | # image qkv embeddings 360 | self.fc_k_img = nn.Linear(embed_dim, embed_dim) 361 | self.fc_q_img = nn.Linear(embed_dim, embed_dim) 362 | self.fc_v_img = nn.Linear(embed_dim, embed_dim) 363 | 364 | # spatial qk embeddings (shared for visual and text) 365 | self.fc_k_spatial = nn.Linear(embed_dim, embed_dim) 366 | self.fc_q_spatial = nn.Linear(embed_dim, embed_dim) 367 | 368 | self.dropout = nn.Dropout(dropout) 369 | 370 | self.to_out = nn.Sequential( 371 | nn.Linear(embed_dim, embed_dim), 372 | nn.Dropout(dropout) 373 | ) 374 | self.scale = embed_dim**0.5 375 | 376 | def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat): 377 | text_feat = text_feat 378 | img_feat = img_feat 379 | text_spatial_feat = text_spatial_feat 380 | img_spatial_feat = img_spatial_feat 381 | seq_length = text_feat.shape[1] 382 | 383 | # self attention of text 384 | # b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim. 385 | key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads) 386 | query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads) 387 | value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads) 388 | dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh) 389 | dots_text = dots_text/ self.scale 390 | 391 | # 1D relative positions (query, key) 392 | rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length) 393 | rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text) 394 | rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text) 395 | 396 | # shared spatial <-> text hidden features 397 | key_spatial_text = self.fc_k_spatial(text_spatial_feat) 398 | query_spatial_text = self.fc_q_spatial(text_spatial_feat) 399 | key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads) 400 | query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads) 401 | dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh) 402 | dots_text_spatial = dots_text_spatial/ self.scale 403 | 404 | # Line 38 of pseudo-code 405 | text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial 406 | 407 | # self-attention of image 408 | key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads) 409 | query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads) 410 | value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads) 411 | dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh) 412 | dots_img = dots_img/ self.scale 413 | 414 | # 1D relative positions (query, key) 415 | rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length) 416 | rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text) 417 | rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text) 418 | 419 | # shared spatial <-> image features 420 | key_spatial_img = self.fc_k_spatial(img_spatial_feat) 421 | query_spatial_img = self.fc_q_spatial(img_spatial_feat) 422 | key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads) 423 | query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads) 424 | dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh) 425 | dots_img_spatial = dots_img_spatial/ self.scale 426 | 427 | # Line 59 of pseudo-code 428 | img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial 429 | 430 | text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1)) 431 | img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1)) 432 | 433 | text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh) 434 | img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh) 435 | 436 | context = text_context + img_context 437 | 438 | embeddings = rearrange(context, 'head b t d -> b t (head d)') 439 | return self.to_out(embeddings) 440 | 441 | class DocFormerEncoder(nn.Module): 442 | def __init__(self, config): 443 | super().__init__() 444 | self.config = config 445 | self.layers = nn.ModuleList([]) 446 | for _ in range(config['num_hidden_layers']): 447 | encoder_block = nn.ModuleList([ 448 | PreNormAttn(config['hidden_size'], 449 | MultiModalAttentionLayer(config['hidden_size'], 450 | config['num_attention_heads'], 451 | config['max_relative_positions'], 452 | config['max_position_embeddings'], 453 | config['hidden_dropout_prob'], 454 | ) 455 | ), 456 | PreNorm(config['hidden_size'], 457 | FeedForward(config['hidden_size'], 458 | config['hidden_size'] * config['intermediate_ff_size_factor'], 459 | dropout=config['hidden_dropout_prob'])) 460 | ]) 461 | self.layers.append(encoder_block) 462 | 463 | def forward( 464 | self, 465 | text_feat, # text feat or output from last encoder block 466 | img_feat, 467 | text_spatial_feat, 468 | img_spatial_feat, 469 | ): 470 | # Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762 471 | # TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does) 472 | for attn, ff in self.layers: 473 | skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat 474 | x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip 475 | x = ff(x) + x 476 | text_feat = x 477 | return x 478 | 479 | 480 | class LanguageFeatureExtractor(nn.Module): 481 | def __init__(self): 482 | super().__init__() 483 | from transformers import LayoutLMForTokenClassification 484 | layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1) 485 | self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight) 486 | 487 | def forward(self, x): 488 | return self.embedding_vector(x) 489 | 490 | 491 | 492 | class ExtractFeatures(nn.Module): 493 | 494 | ''' 495 | Inputs: dictionary 496 | Output: v_bar, t_bar, v_bar_s, t_bar_s 497 | ''' 498 | 499 | def __init__(self, config): 500 | super().__init__() 501 | self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings']) 502 | self.language_feature = LanguageFeatureExtractor() 503 | self.spatial_feature = DocFormerEmbeddings(config) 504 | 505 | def forward(self, encoding): 506 | 507 | image = encoding['resized_scaled_img'] 508 | 509 | language = encoding['input_ids'] 510 | x_feature = encoding['x_features'] 511 | y_feature = encoding['y_features'] 512 | 513 | v_bar = self.visual_feature(image) 514 | t_bar = self.language_feature(language) 515 | 516 | v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature) 517 | 518 | return v_bar, t_bar, v_bar_s, t_bar_s 519 | 520 | 521 | 522 | class DocFormer(nn.Module): 523 | 524 | ''' 525 | Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function 526 | ''' 527 | def __init__(self, config): 528 | super().__init__() 529 | self.config = config 530 | self.extract_feature = ExtractFeatures(config) 531 | self.encoder = DocFormerEncoder(config) 532 | self.dropout = nn.Dropout(config['hidden_dropout_prob']) 533 | 534 | def forward(self, x ): 535 | v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi) 536 | features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s} 537 | output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s']) 538 | output = self.dropout(output) 539 | return output 540 | 541 | 542 | 543 | 544 | 545 | 546 | -------------------------------------------------------------------------------- /src/docformer/modeling_pl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.models as models 10 | from PIL import Image 11 | from sklearn.model_selection import train_test_split as tts 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader, Dataset 14 | from torchvision.transforms import ToTensor 15 | from transformers import AutoModel, AutoTokenizer 16 | from modelling import * 17 | 18 | """## Base Dataset""" 19 | 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | """## Base Model""" 23 | 24 | config = { 25 | "coordinate_size": 96, 26 | "hidden_dropout_prob": 0.1, 27 | "hidden_size": 768, 28 | "image_feature_pool_shape": [7, 7, 256], 29 | "intermediate_ff_size_factor": 3, # default ought to be 4 30 | "max_2d_position_embeddings": 1024, 31 | "max_position_embeddings": 512, 32 | "max_relative_positions": 8, 33 | "num_attention_heads": 12, 34 | "num_hidden_layers": 12, 35 | "pad_token_id": 0, 36 | "shape_size": 96, 37 | "vocab_size": 30522, 38 | "layer_norm_eps": 1e-12, 39 | "batch_size":9 40 | } 41 | 42 | 43 | class Model(pl.LightningModule): 44 | 45 | def __init__(self,config,num_classes,lr = 5e-5): 46 | 47 | super().__init__() 48 | self.save_hyperparameters() 49 | self.docformer = DocFormerForClassification(config,num_classes) 50 | 51 | def forward(self,x): 52 | return self.docformer(x) 53 | 54 | def training_step(self,batch,batch_idx): 55 | 56 | # For the purpose of pretraining, there could be multiple target outputs, so therefore we need to add additional loss function, as for an image_fp, if the MLM + IR is to be done 57 | # then, there could be a dictionary as an output, and then we need to define two criterion as CrossEntropy and L1 loss, and add the weighted sum of them as the total loss 58 | # and proceed forward, and for the whole process, only the final head of the DocFormer encoder needs to be changed, and thats it 59 | 60 | 61 | # Currently, we are performing only the MLM Part 62 | logits = self.forward(batch) 63 | criterion = torch.nn.CrossEntropyLoss() 64 | loss = criterion(logits.transpose(2,1), batch["mlm_labels"].long()) 65 | self.log("train_loss",loss,prog_bar = True) 66 | 67 | def validation_step(self, batch, batch_idx): 68 | 69 | logits = self.forward(batch) 70 | b,size,classes = logits.shape 71 | criterion = torch.nn.CrossEntropyLoss() 72 | loss = criterion(logits.transpose(2,1), batch["mlm_labels"].long()) 73 | val_acc = 100*(torch.argmax(logits,dim = -1)==batch["mlm_labels"].long()).float().sum()/(logits.shape[0]*logits.shape[1]) 74 | val_acc = torch.tensor(val_acc) 75 | self.log("val_loss", loss, prog_bar=True) 76 | self.log("val_acc", val_acc, prog_bar=True) 77 | 78 | def configure_optimizers(self): 79 | return torch.optim.AdamW(self.parameters(), lr=self.hparams["lr"]) 80 | 81 | """## Examples""" 82 | 83 | # pathToPickleFile = 'RVL-CDIP-PickleFiles/' 84 | # entries = os.listdir(pathToPickleFile) 85 | # data = DataModule(train_entries,val_entries,pathToPickleFile) 86 | # model = Model(config,num_classes= 30522).to(device) 87 | # trainer = pl.Trainer(gpus=(1 if torch.cuda.is_available() else 0), 88 | # max_epochs=10, 89 | # fast_dev_run=False, 90 | # logger=pl.loggers.TensorBoardLogger("logs/", name="rvl-cdip", version=1),) 91 | # trainer.fit(model,data) 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /src/docformer/train_accelerator.py: -------------------------------------------------------------------------------- 1 | ## Dependencies 2 | 3 | from accelerate import Accelerator 4 | import accelerate 5 | import pytesseract 6 | import torchmetrics 7 | import math 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | import pandas as pd 12 | from PIL import Image 13 | import json 14 | import numpy as np 15 | from tqdm.auto import tqdm 16 | from torchvision.transforms import ToTensor 17 | import torch.nn.functional as F 18 | import torch.nn as nn 19 | import torchvision.models as models 20 | from einops import rearrange 21 | from einops import rearrange as rearr 22 | from sklearn.model_selection import train_test_split as tts 23 | from torch.autograd import Variable 24 | from torch.utils.data import DataLoader, Dataset 25 | from torchvision.transforms import ToTensor 26 | from modeling import DocFormer 27 | 28 | batch_size = 9 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | 45 | @property 46 | def avg(self): 47 | return (self.sum / self.count) if self.count>0 else 0 48 | 49 | ## Loggers 50 | class Logger: 51 | def __init__(self, filename, format='csv'): 52 | self.filename = filename + '.' + format 53 | self._log = [] 54 | self.format = format 55 | 56 | def save(self, log, epoch=None): 57 | log['epoch'] = epoch + 1 58 | self._log.append(log) 59 | if self.format == 'json': 60 | with open(self.filename, 'w') as f: 61 | json.dump(self._log, f) 62 | else: 63 | pd.DataFrame(self._log).to_csv(self.filename, index=False) 64 | 65 | 66 | 67 | def train_fn(data_loader, model, criterion, optimizer, epoch, device, scheduler=None): 68 | model.train() 69 | accelerator = Accelerator() 70 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) 71 | loop = tqdm(data_loader, leave=True) 72 | log = None 73 | train_acc = torchmetrics.Accuracy() 74 | loop = tqdm(data_loader) 75 | 76 | for batch in loop: 77 | 78 | input_ids = batch["input_ids"].to(device) 79 | attention_mask = batch["attention_mask"].to(device) 80 | labels = batch["mlm_labels"].to(device) 81 | 82 | # process 83 | outputs = model(batch) 84 | ce_loss = criterion(outputs.transpose(1,2), labels) 85 | 86 | if log is None: 87 | log = {} 88 | log["ce_loss"] = AverageMeter() 89 | log['accuracy'] = AverageMeter() 90 | 91 | optimizer.zero_grad() 92 | accelerator.backward(ce_loss) 93 | optimizer.step() 94 | 95 | if scheduler is not None: 96 | scheduler.step() 97 | 98 | log['accuracy'].update(train_acc(labels.cpu(),torch.argmax(outputs,-1).cpu()).item(),batch_size) 99 | log['ce_loss'].update(ce_loss.item()) 100 | loop.set_postfix({k: v.avg for k, v in log.items()}) 101 | 102 | return log 103 | 104 | 105 | # Function for the validation data loader 106 | def eval_fn(data_loader, model, criterion, device): 107 | model.eval() 108 | log = None 109 | val_acc = torchmetrics.Accuracy() 110 | 111 | 112 | with torch.no_grad(): 113 | loop = tqdm(data_loader, total=len(data_loader), leave=True) 114 | for batch in loop: 115 | 116 | input_ids = batch["input_ids"].to(device) 117 | attention_mask = batch["attention_mask"].to(device) 118 | labels = batch["mlm_labels"].to(device) 119 | output = model(batch) 120 | ce_loss = criterion(output.transpose(1,2), labels) 121 | 122 | if log is None: 123 | log = {} 124 | log["ce_loss"] = AverageMeter() 125 | log['accuracy'] = AverageMeter() 126 | 127 | log['accuracy'].update(val_acc(labels.cpu(),torch.argmax(output,-1).cpu()).item(),batch_size) 128 | log['ce_loss'].update(ce_loss.item()) 129 | loop.set_postfix({k: v.avg for k, v in log.items()}) 130 | return log # ['total_loss'] 131 | 132 | date = '20Oct' 133 | 134 | 135 | def run(config,train_dataloader,val_dataloader,device,epochs,path,classes,lr = 5e-5): 136 | logger = Logger(f"{path}/logs") 137 | model = DocFormerForClassification(config,classes).to(device) 138 | criterion = nn.CrossEntropyLoss() 139 | criterion = criterion.to(device) 140 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr) 141 | best_val_loss = 1e9 142 | header_printed = False 143 | batch_size = config['batch_size'] 144 | for epoch in range(epochs): 145 | print("Training the model.....") 146 | train_log = train_fn( 147 | train_dataloader, model, criterion, optimizer, epoch, device, scheduler=None 148 | ) 149 | 150 | print("Validating the model.....") 151 | valid_log = eval_fn(val_dataloader, model, criterion, device) 152 | log = {k: v.avg for k, v in train_log.items()} 153 | log.update({"V/" + k: v.avg for k, v in valid_log.items()}) 154 | logger.save(log, epoch) 155 | keys = sorted(log.keys()) 156 | if not header_printed: 157 | print(" ".join(map(lambda k: f"{k[:8]:8}", keys))) 158 | header_printed = True 159 | print(" ".join(map(lambda k: f"{log[k]:8.3f}"[:8], keys))) 160 | if log["V/ce_loss"] < best_val_loss: 161 | best_val_loss = log["V/ce_loss"] 162 | print("Best model found at epoch {}".format(epoch + 1)) 163 | torch.save(model.state_dict(), f"{path}/docformer_best_{epoch}_{date}.pth") 164 | -------------------------------------------------------------------------------- /src/docformer/train_accelerator_mlm_ir.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | ## Dependencies 6 | 7 | import pickle 8 | import os 9 | from accelerate import Accelerator 10 | import accelerate 11 | import pytesseract 12 | import torchmetrics 13 | import math 14 | import numpy as np 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader 17 | import pandas as pd 18 | from PIL import Image 19 | import json 20 | import numpy as np 21 | from tqdm.auto import tqdm 22 | from torchvision.transforms import ToTensor 23 | import torch.nn.functional as F 24 | import torch.nn as nn 25 | import torchvision.models as models 26 | from einops import rearrange 27 | from einops import rearrange as rearr 28 | from PIL import Image 29 | from sklearn.model_selection import train_test_split as tts 30 | from torch.autograd import Variable 31 | from torch.utils.data import DataLoader, Dataset 32 | from torchvision.transforms import ToTensor 33 | from modeling import * 34 | 35 | batch_size = 9 36 | 37 | 38 | weights = {'mlm':5,'ir':1,'tdi':5} 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | def __init__(self): 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val, n=1): 51 | self.val = val 52 | self.sum += val * n 53 | self.count += n 54 | 55 | @property 56 | def avg(self): 57 | return (self.sum / self.count) if self.count>0 else 0 58 | 59 | ## Loggers 60 | class Logger: 61 | def __init__(self, filename, format='csv'): 62 | self.filename = filename + '.' + format 63 | self._log = [] 64 | self.format = format 65 | 66 | def save(self, log, epoch=None): 67 | log['epoch'] = epoch + 1 68 | self._log.append(log) 69 | if self.format == 'json': 70 | with open(self.filename, 'w') as f: 71 | json.dump(self._log, f) 72 | else: 73 | pd.DataFrame(self._log).to_csv(self.filename, index=False) 74 | 75 | 76 | 77 | def train_fn(data_loader, model, criterion1,criterion2, optimizer, epoch, device, scheduler=None,weights=weights): 78 | model.train() 79 | accelerator = Accelerator() 80 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) 81 | loop = tqdm(data_loader, leave=True) 82 | log = None 83 | train_acc = torchmetrics.Accuracy() 84 | loop = tqdm(data_loader) 85 | 86 | for batch in loop: 87 | 88 | input_ids = batch["input_ids"].to(device) 89 | attention_mask = batch["attention_mask"].to(device) 90 | labels1 = batch["mlm_labels"].to(device) 91 | labels2 = batch['resized_image'].to(device) 92 | 93 | # process 94 | outputs = model(batch) 95 | ce_loss = criterion1(outputs['mlm_labels'].transpose(1,2), labels1) 96 | ir_loss = criterion2(outputs['ir'],labels2) 97 | 98 | if log is None: 99 | log = {} 100 | log["ce_loss"] = AverageMeter() 101 | log['accuracy'] = AverageMeter() 102 | log['ir_loss'] = AverageMeter() 103 | log['total_loss'] = AverageMeter() 104 | 105 | total_loss = weights['mlm']*ce_loss + weights['ir']*ir_loss 106 | optimizer.zero_grad() 107 | accelerator.backward(total_loss) 108 | optimizer.step() 109 | 110 | if scheduler is not None: 111 | scheduler.step() 112 | 113 | log['accuracy'].update(train_acc(labels1.cpu(),torch.argmax(outputs['mlm_labels'],-1).cpu()).item(),batch_size) 114 | log['ce_loss'].update(ce_loss.item(),batch_size) 115 | log['ir_loss'].update(ir_loss.item(),batch_size) 116 | log['total_loss'].update(total_loss.item(),batch_size) 117 | loop.set_postfix({k: v.avg for k, v in log.items()}) 118 | 119 | return log 120 | 121 | 122 | # Function for the validation data loader 123 | def eval_fn(data_loader, model, criterion1,criterion2, device,weights=weights): 124 | model.eval() 125 | log = None 126 | val_acc = torchmetrics.Accuracy() 127 | 128 | 129 | with torch.no_grad(): 130 | loop = tqdm(data_loader, total=len(data_loader), leave=True) 131 | for batch in loop: 132 | 133 | input_ids = batch["input_ids"].to(device) 134 | attention_mask = batch["attention_mask"].to(device) 135 | labels1 = batch["mlm_labels"].to(device) 136 | labels2 = batch['resized_image'].to(device) 137 | 138 | 139 | output = model(batch) 140 | ce_loss = criterion1(output['mlm_labels'].transpose(1,2), labels1) 141 | ir_loss = criterion2(output['ir'],labels2) 142 | total_loss = weights['mlm']*ce_loss + weights['ir']*ir_loss 143 | if log is None: 144 | log = {} 145 | log["ce_loss"] = AverageMeter() 146 | log['accuracy'] = AverageMeter() 147 | log['ir_loss'] = AverageMeter() 148 | log['total_loss'] = AverageMeter() 149 | 150 | log['accuracy'].update(val_acc(labels1.cpu(),torch.argmax(output['mlm_labels'],-1).cpu()).item(),batch_size) 151 | log['ce_loss'].update(ce_loss.item(),batch_size) 152 | log['ir_loss'].update(ir_loss.item(),batch_size) 153 | log['total_loss'].update(total_loss.item(),batch_size) 154 | loop.set_postfix({k: v.avg for k, v in log.items()}) 155 | 156 | return log # ['total_loss'] 157 | 158 | date = '26Oct' 159 | 160 | 161 | def run(config,train_dataloader,val_dataloader,device,epochs,path,classes,lr = 5e-5,weights=weights): 162 | logger = Logger(f"{path}/logs") 163 | model = DocFormer_For_IR(config,classes).to(device) 164 | criterion1 = nn.CrossEntropyLoss().to(device) 165 | criterion2 = torch.nn.L1Loss().to(device) 166 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr) 167 | best_val_loss = 1e9 168 | header_printed = False 169 | batch_size = config['batch_size'] 170 | for epoch in range(epochs): 171 | print("Training the model.....") 172 | train_log = train_fn( 173 | train_dataloader, model, criterion1,criterion2, optimizer, epoch, device, scheduler=None 174 | ) 175 | 176 | print("Validating the model.....") 177 | valid_log = eval_fn(val_dataloader, model, criterion1,criterion2, device) 178 | log = {k: v.avg for k, v in train_log.items()} 179 | log.update({"V/" + k: v.avg for k, v in valid_log.items()}) 180 | logger.save(log, epoch) 181 | keys = sorted(log.keys()) 182 | if not header_printed: 183 | print(" ".join(map(lambda k: f"{k[:8]:8}", keys))) 184 | header_printed = True 185 | print(" ".join(map(lambda k: f"{log[k]:8.3f}"[:8], keys))) 186 | if log["V/total_loss"] < best_val_loss: 187 | best_val_loss = log["V/total_loss"] 188 | print("Best model found at epoch {}".format(epoch + 1)) 189 | torch.save(model.state_dict(), f"{path}/docformer_best_{epoch}_{date}.pth") 190 | -------------------------------------------------------------------------------- /src/docformer/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utility File 3 | 4 | Created basically for the purpose of defining the labels for the unsupervised task of "Text Describe Image", as in the paper DocFormer 5 | ''' 6 | 7 | import numpy as np 8 | import random 9 | import json 10 | import os 11 | 12 | def labels_for_tdi(length,sample_ratio=0.05): 13 | 14 | ''' 15 | 16 | Function for providing the labels for the task of 17 | Text Describes Image 18 | 19 | Input: 20 | * Length of the dataset, i.e the total number of data points 21 | * sample_ratio: The percentage of the total length, which you want to shuffle 22 | 23 | Output: 24 | * d_arr (dummy array): The array which contains the indexes of the images 25 | * Labels: Whether the image has been shuffled with some other image or not 26 | 27 | 28 | Example: 29 | d_arr,labels = labels_for_tdi(100) 30 | 31 | Explanation: 32 | Suppose, the array is [1,2,3,4,5], so, the steps are as follows: 33 | 34 | * Choose some arbitrary values, (refer the samples_to_be_changed variable) (let us suppose [2,4]) 35 | * Generate the permutation of the same, and replace the arbitary values with their permutations (one permutation can be [4,2], and hence the array becomes 36 | [1,4,3,2,5] 37 | * And then, if the original arr and the d_arr's arguments matches, put a label of 1, else put 0, hence the labels array becomes [1,0,1,0,1] 38 | 39 | 40 | The purpose of returning d_arr is, because the d_arr[i] would be the argument, which is responsible for becoming the d_resized_scaled_img of ith encoding dictinary vector 41 | 42 | i.e if d_arr[i] == i (means not shuffled), the d_resized_scaled_img of ith entry would be same else resized_scaled_img, 43 | else d_sized_scaled_img[i] = resized_scaled_img[d_arr[i]] 44 | 45 | ''' 46 | samples_to_be_changed = int(sample_ratio*length) 47 | arr = np.arange(length) 48 | d_arr = arr.copy() 49 | labels = np.ones(length) 50 | sample_id = np.array(random.sample(list(arr), samples_to_be_changed)) 51 | new_sample_id = np.random.permutation(sample_id) 52 | d_arr[sample_id]=new_sample_id 53 | labels = (arr==d_arr).astype(int) 54 | 55 | return d_arr,labels 56 | 57 | 58 | ## Purpose: Reading the json file from the path and return the dictionary 59 | def load_json_file(file_path): 60 | with open(file_path, 'r') as f: 61 | data = json.load(f) 62 | return data 63 | 64 | ## Purpose: Getting the address of specific file type, eg: .pdf, .tif, so and so 65 | def get_specific_file(path, last_entry = 'tif'): 66 | base_path = path 67 | for i in os.listdir(path): 68 | if i.endswith(last_entry): 69 | return os.path.join(base_path, i) 70 | 71 | return '-1' 72 | 73 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shabie/docformer/c7fcfdf71fb174784c3dba932b0f0daa6f05a92f/tests/__init__.py --------------------------------------------------------------------------------