├── .gitignore
├── LICENSE
├── README.md
├── docs
├── images
│ └── docformer-architecture.png
└── index.md
├── examples
└── docformer_pl
│ ├── document_image_classification_on_rvl_cdip
│ └── 4.Document-image-classification-with-docformer.ipynb
│ ├── pre_training_task_on_idl_dataset
│ ├── docformer_pre_training_1_preparing_idl_dataset.py
│ ├── docformer_pre_training_2_preparing_idl_pytorch_dataset_for_docformer.py
│ └── docformer_pre_training_3_modeling_for_docformer.py
│ └── token_classification_on_funsd
│ ├── Token_Classification_Part_1.ipynb
│ ├── Token_Classification_Part_2.ipynb
│ └── Token_Classification_Part_3.ipynb
├── images
└── docformer-architecture.png
├── mkdocs.yml
├── scripts
├── .coveragerc
├── docs
├── env
└── lint
├── setup.cfg
├── setup.py
├── src
└── docformer
│ ├── README.md
│ ├── dataset.py
│ ├── dataset_pytorch.py
│ ├── modeling.py
│ ├── modeling_pl.py
│ ├── train_accelerator.py
│ ├── train_accelerator_mlm_ir.py
│ └── utils.py
└── tests
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .idea/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DocFormer - PyTorch
2 |
3 | 
4 |
5 | Implementation of [DocFormer: End-to-End Transformer for Document Understanding](https://arxiv.org/abs/2106.11539), a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU) 📄📄📄.
6 |
7 | DocFormer is a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU). In addition, DocFormer is pre-trained in an unsupervised fashion using carefully designed tasks which encourage multi-modal interaction. DocFormer uses text, vision and spatial features and combines them using a novel multi-modal self-attention layer. DocFormer also shares learned spatial embeddings across modalities which makes it easy for the model to correlate text to visual tokens and vice versa. DocFormer is evaluated on 4 different datasets each with strong baselines. DocFormer achieves state-of-the-art results on all of them, sometimes beating models 4x its size (in no. of parameters).
8 |
9 | The official implementation was not released by the authors.
10 |
11 | ## NOTE:
12 |
13 | I tried to pre-train DocFormer on the task of MLM on a subset of [IDL Dataset](https://github.com/furkanbiten/idl_data). The weights are [here](https://www.kaggle.com/code/akarshu121/downloading-docformer-weights), and the associated kaggle notebook for fine-tuning on FUNSD is attached [here](https://www.kaggle.com/code/akarshu121/ckpt-docformer-for-token-classification-on-funsd/notebook?scriptVersionId=118952199)
14 |
15 | ## Install
16 |
17 | There might be some issues with the import of pytessaract, so in order to debug that, we need to write
18 |
19 | ```python
20 | pip install pytesseract
21 | sudo apt install tesseract-ocr
22 | ```
23 |
24 | And then,
25 |
26 | ```python
27 | !git clone https://github.com/shabie/docformer.git
28 |
29 |
30 | ```
31 |
32 | ## Usage
33 |
34 | ```python
35 | import sys
36 | sys.path.extend(['docformer/src/docformer/'])
37 | import modeling, dataset
38 | from transformers import BertTokenizerFast
39 |
40 |
41 | config = {
42 | "coordinate_size": 96,
43 | "hidden_dropout_prob": 0.1,
44 | "hidden_size": 768,
45 | "image_feature_pool_shape": [7, 7, 256],
46 | "intermediate_ff_size_factor": 4,
47 | "max_2d_position_embeddings": 1000,
48 | "max_position_embeddings": 512,
49 | "max_relative_positions": 8,
50 | "num_attention_heads": 12,
51 | "num_hidden_layers": 12,
52 | "pad_token_id": 0,
53 | "shape_size": 96,
54 | "vocab_size": 30522,
55 | "layer_norm_eps": 1e-12,
56 | }
57 |
58 | fp = "filepath/to/the/image.tif"
59 |
60 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
61 | encoding = dataset.create_features(fp, tokenizer, add_batch_dim=True)
62 |
63 | feature_extractor = modeling.ExtractFeatures(config)
64 | docformer = modeling.DocFormerEncoder(config)
65 | v_bar, t_bar, v_bar_s, t_bar_s = feature_extractor(encoding)
66 | output = docformer(v_bar, t_bar, v_bar_s, t_bar_s) # shape (1, 512, 768)
67 | ```
68 |
69 | ## License
70 |
71 | MIT
72 |
73 | ## Maintainers
74 |
75 | - [uakarsh](https://github.com/uakarsh)
76 | - [shabie](https://github.com/shabie)
77 |
78 | ## Contribute
79 |
80 |
81 | ## Citations
82 |
83 | ```bibtex
84 | @InProceedings{Appalaraju_2021_ICCV,
85 | author = {Appalaraju, Srikar and Jasani, Bhavan and Kota, Bhargava Urala and Xie, Yusheng and Manmatha, R.},
86 | title = {DocFormer: End-to-End Transformer for Document Understanding},
87 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
88 | month = {October},
89 | year = {2021},
90 | pages = {993-1003}
91 | }
92 | ```
93 |
--------------------------------------------------------------------------------
/docs/images/docformer-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shabie/docformer/c7fcfdf71fb174784c3dba932b0f0daa6f05a92f/docs/images/docformer-architecture.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # DocFormer - PyTorch
2 |
3 | 
4 |
5 | Implementation of [DocFormer: End-to-End Transformer for Document Understanding](https://arxiv.org/abs/2106.11539), a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU) 📄📄📄.
6 |
7 | The official implementation was not released by the authors.
8 |
9 | ## Install
10 |
11 |
12 | ## Usage
13 |
14 | See `examples` for usage.
15 |
16 |
17 | ## Contribute
18 |
19 |
20 | ## Citations
21 |
22 | ```bibtex
23 | @misc{appalaraju2021docformer,
24 | title = {DocFormer: End-to-End Transformer for Document Understanding},
25 | author = {Srikar Appalaraju and Bhavan Jasani and Bhargava Urala Kota and Yusheng Xie and R. Manmatha},
26 | year = {2021},
27 | eprint = {2106.11539},
28 | archivePrefix = {arXiv},
29 | primaryClass = {cs.CV}
30 | }
31 | ```
32 |
--------------------------------------------------------------------------------
/examples/docformer_pl/document_image_classification_on_rvl_cdip/4.Document-image-classification-with-docformer.ipynb:
--------------------------------------------------------------------------------
1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"
","metadata":{},"cell_type":"markdown","outputs":[],"execution_count":0},{"cell_type":"markdown","source":"
\n\n\n## 1. Introduction: \n* This notebook is a tutorial to the multi-modal architecture DocFormer (mainly for the purpose of Document Understanding).\n* We would take in, the test-images of the RVL-CDIP Dataset, and then would train the model on a subset of the dataset\n* We would also be logging the metrics with the help of Weights and Biases","metadata":{}},{"cell_type":"markdown","source":"## A small Introduction about the Model:\n\n
\n\nDocFormer is a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU). In addition, DocFormer is pre-trained in an unsupervised fashion using carefully designed tasks which encourage multi-modal interaction. DocFormer uses text, vision and spatial features and combines them using a novel multi-modal self-attention layer. DocFormer also shares learned spatial embeddings across modalities which makes it easy for the model to correlate text to visual tokens and vice versa. DocFormer is evaluated on 4 different datasets each with strong baselines. DocFormer achieves state-of-the-art results on all of them, sometimes beating models 4x its size (in no. of parameters).\n\nFor more understanding of the model and its code implementation, one can visit [here](https://github.com/uakarsh/docformer). So, let us go on to see what this model has to offer\n\nThe report for this entire run is attached [here](https://wandb.ai/iakarshu/RVL%20CDIP%20with%20DocFormer%20New%20Version/reports/Performance-of-DocFormer-with-RVL-CDIP-Test-Dataset--VmlldzoyMTI3NTM4)\n\n
\n\n\n\n\n### An Interactive Demo for the same can be found on 🤗 space [here](https://huggingface.co/spaces/iakarshu/docformer_for_document_classification)\n\n### Installing the Libraries ⚙️:","metadata":{}},{"cell_type":"code","source":"## Installing the dependencies (might take some time)\n\n!pip install -q pytesseract\n!sudo apt install -q tesseract-ocr\n!pip install -q transformers\n!pip install -q pytorch-lightning\n!pip install -q einops\n!pip install -q tqdm\n!pip install -q 'Pillow==7.1.2'\n!pip install -q datasets\n!pip install wandb\n!pip install torchmetrics","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Cloning the repository\n!git clone https://github.com/uakarsh/docformer.git","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:22.079626Z","iopub.execute_input":"2022-06-07T07:08:22.080057Z","iopub.status.idle":"2022-06-07T07:08:24.114693Z","shell.execute_reply.started":"2022-06-07T07:08:22.080014Z","shell.execute_reply":"2022-06-07T07:08:24.113469Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Logging into wandb\n\nimport wandb\nfrom kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\nsecret_value_0 = user_secrets.get_secret(\"wandb_api\")\nwandb.login(key=secret_value_0)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:24.145099Z","iopub.execute_input":"2022-06-07T07:08:24.146015Z","iopub.status.idle":"2022-06-07T07:08:25.65243Z","shell.execute_reply.started":"2022-06-07T07:08:24.145939Z","shell.execute_reply":"2022-06-07T07:08:25.651299Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 2. Libraries 📘:","metadata":{}},{"cell_type":"code","source":"## Importing the libraries\n\nimport warnings\nwarnings.simplefilter(\"ignore\", UserWarning)\nwarnings.simplefilter(\"ignore\", RuntimeWarning)\n\nimport os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\nimport numpy as np\nimport pandas as pd\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import Dataset,DataLoader\n\nimport torch.nn.functional as F\nimport torchvision.models as models\n\n## Adding the path of docformer to system path\nimport sys\nsys.path.append('./docformer/src/docformer/')\n\n## Importing the functions from the DocFormer Repo\nfrom dataset import create_features\nfrom modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor\nfrom transformers import BertTokenizerFast","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:25.654629Z","iopub.execute_input":"2022-06-07T07:08:25.65569Z","iopub.status.idle":"2022-06-07T07:08:31.919208Z","shell.execute_reply.started":"2022-06-07T07:08:25.655654Z","shell.execute_reply":"2022-06-07T07:08:31.91827Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Hyperparameters\n\nseed = 42\ntarget_size = (500, 384)\n\n## Setting some hyperparameters\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n## One can change this configuration and try out new combination\nconfig = {\n \"coordinate_size\": 96, ## (768/8), 8 for each of the 8 coordinates of x, y\n \"hidden_dropout_prob\": 0.1,\n \"hidden_size\": 768,\n \"image_feature_pool_shape\": [7, 7, 256],\n \"intermediate_ff_size_factor\": 4,\n \"max_2d_position_embeddings\": 1024,\n \"max_position_embeddings\": 128,\n \"max_relative_positions\": 8,\n \"num_attention_heads\": 12,\n \"num_hidden_layers\": 12,\n \"pad_token_id\": 0,\n \"shape_size\": 96,\n \"vocab_size\": 30522,\n \"layer_norm_eps\": 1e-12,\n}","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:31.920835Z","iopub.execute_input":"2022-06-07T07:08:31.921618Z","iopub.status.idle":"2022-06-07T07:08:31.989757Z","shell.execute_reply.started":"2022-06-07T07:08:31.921575Z","shell.execute_reply":"2022-06-07T07:08:31.988868Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## A small note 🗒️: \nHere, for the purpose of Demo I would be using only 250 Images per class, and would train the model on it. Definintely for a data hungry model such as transformers, such a small data is not enough, but let us see what are the results on it.","metadata":{}},{"cell_type":"code","source":"from tqdm.auto import tqdm\n\n## For the purpose of prediction\nid2label = []\nlabel2id = {}\n\ncurr_class = 0\n## Preparing the Dataset\nbase_directory = '../input/the-rvlcdip-dataset-test/test'\ndict_of_img_labels = {'img':[], 'label':[]}\n\nmax_sample_per_class = 250\n\nfor label in tqdm(os.listdir(base_directory)):\n img_path = os.path.join(base_directory, label)\n \n count = 0\n if label not in label2id:\n label2id[label] = curr_class\n curr_class+=1\n id2label.append(label)\n \n for img in os.listdir(img_path):\n if count>max_sample_per_class:\n break\n \n curr_img_path = os.path.join(img_path, img)\n dict_of_img_labels['img'].append(curr_img_path)\n dict_of_img_labels['label'].append(label2id[label])\n count+=1","metadata":{"execution":{"iopub.status.busy":"2022-06-07T18:38:57.319504Z","iopub.execute_input":"2022-06-07T18:38:57.319852Z","iopub.status.idle":"2022-06-07T18:39:02.430787Z","shell.execute_reply.started":"2022-06-07T18:38:57.319827Z","shell.execute_reply":"2022-06-07T18:39:02.429826Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import pandas as pd\ndf = pd.DataFrame(dict_of_img_labels)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.612078Z","iopub.execute_input":"2022-06-07T07:08:36.612682Z","iopub.status.idle":"2022-06-07T07:08:36.621744Z","shell.execute_reply.started":"2022-06-07T07:08:36.612639Z","shell.execute_reply":"2022-06-07T07:08:36.620752Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from sklearn.model_selection import train_test_split as tts\ntrain_df, valid_df = tts(df, random_state = seed, stratify = df['label'], shuffle = True)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.623323Z","iopub.execute_input":"2022-06-07T07:08:36.623888Z","iopub.status.idle":"2022-06-07T07:08:36.973838Z","shell.execute_reply.started":"2022-06-07T07:08:36.623845Z","shell.execute_reply":"2022-06-07T07:08:36.972854Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_df = train_df.reset_index().drop(columns = ['index'], axis = 1)\nvalid_df = valid_df.reset_index().drop(columns = ['index'], axis = 1)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.976098Z","iopub.execute_input":"2022-06-07T07:08:36.976661Z","iopub.status.idle":"2022-06-07T07:08:36.990587Z","shell.execute_reply.started":"2022-06-07T07:08:36.976608Z","shell.execute_reply":"2022-06-07T07:08:36.989443Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 3. Making the dataset 💽:\n\nThe main idea behind making the dataset is, to pre-process the input into a given format, and then provide the input to the model. So, simply just the image path, and the other configurations, and boom 💥, you would get the desired pre-processed input","metadata":{}},{"cell_type":"code","source":"## Creating the dataset\n\nclass RVLCDIPData(Dataset):\n \n def __init__(self, image_list, label_list, target_size, tokenizer, max_len = 512, transform = None):\n \n self.image_list = image_list\n self.label_list = label_list\n self.target_size = target_size\n self.tokenizer = tokenizer\n self.max_len = max_len\n self.transform = transform\n \n def __len__(self):\n return len(self.image_list)\n \n def __getitem__(self, idx):\n img_path = self.image_list[idx]\n label = self.label_list[idx]\n \n ## More on this, in the repo mentioned previously\n final_encoding = create_features(\n img_path,\n self.tokenizer,\n add_batch_dim=False,\n target_size=self.target_size,\n max_seq_length=self.max_len,\n path_to_save=None,\n save_to_disk=False,\n apply_mask_for_mlm=False,\n extras_for_debugging=False,\n use_ocr = True\n )\n if self.transform is not None:\n ## Note that, ToTensor is already applied on the image\n final_encoding['resized_scaled_img'] = self.transform(final_encoding['resized_scaled_img'])\n \n \n keys_to_reshape = ['x_features', 'y_features', 'resized_and_aligned_bounding_boxes']\n for key in keys_to_reshape:\n final_encoding[key] = final_encoding[key][:self.max_len]\n \n final_encoding['label'] = torch.as_tensor(label).long()\n return final_encoding","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:36.992277Z","iopub.execute_input":"2022-06-07T07:08:36.993034Z","iopub.status.idle":"2022-06-07T07:08:37.006426Z","shell.execute_reply.started":"2022-06-07T07:08:36.992996Z","shell.execute_reply":"2022-06-07T07:08:37.005431Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Defining the tokenizer\ntokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:37.00823Z","iopub.execute_input":"2022-06-07T07:08:37.009227Z","iopub.status.idle":"2022-06-07T07:08:38.530076Z","shell.execute_reply.started":"2022-06-07T07:08:37.009184Z","shell.execute_reply":"2022-06-07T07:08:38.529039Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from torchvision import transforms\n\n## Normalization to these mean and std (I have seen some tutorials used this, and also in image reconstruction, so used it)\ntransform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n ","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.531653Z","iopub.execute_input":"2022-06-07T07:08:38.53226Z","iopub.status.idle":"2022-06-07T07:08:38.538266Z","shell.execute_reply.started":"2022-06-07T07:08:38.532219Z","shell.execute_reply":"2022-06-07T07:08:38.537005Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_ds = RVLCDIPData(train_df['img'].tolist(), train_df['label'].tolist(),\n target_size, tokenizer, config['max_position_embeddings'], transform)\nval_ds = RVLCDIPData(valid_df['img'].tolist(), valid_df['label'].tolist(),\n target_size, tokenizer,config['max_position_embeddings'], transform)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.540023Z","iopub.execute_input":"2022-06-07T07:08:38.540791Z","iopub.status.idle":"2022-06-07T07:08:38.550323Z","shell.execute_reply.started":"2022-06-07T07:08:38.540746Z","shell.execute_reply":"2022-06-07T07:08:38.549236Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"### Collate Function:\n\nDefinitely collate function is an amazing function for using the dataloader as per our wish. More on collate function can be known from [here](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)","metadata":{}},{"cell_type":"code","source":"def collate_fn(data_bunch):\n\n '''\n A function for the dataloader to return a batch dict of given keys\n\n data_bunch: List of dictionary\n '''\n\n dict_data_bunch = {}\n\n for i in data_bunch:\n for (key, value) in i.items():\n if key not in dict_data_bunch:\n dict_data_bunch[key] = []\n dict_data_bunch[key].append(value)\n\n for key in list(dict_data_bunch.keys()):\n dict_data_bunch[key] = torch.stack(dict_data_bunch[key], axis = 0)\n\n return dict_data_bunch","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.552109Z","iopub.execute_input":"2022-06-07T07:08:38.552819Z","iopub.status.idle":"2022-06-07T07:08:38.566185Z","shell.execute_reply.started":"2022-06-07T07:08:38.552777Z","shell.execute_reply":"2022-06-07T07:08:38.561295Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 4. Defining the DataModule 📖\n\n* A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:\n\n* A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and predict_dataloader(s) along with the matching transforms and data processing/downloads steps required.\n\n\n","metadata":{}},{"cell_type":"code","source":"import pytorch_lightning as pl\n\nclass DataModule(pl.LightningDataModule):\n\n def __init__(self, train_dataset, val_dataset, batch_size = 4):\n\n super(DataModule, self).__init__()\n self.train_dataset = train_dataset\n self.val_dataset = val_dataset\n self.batch_size = batch_size\n\n def train_dataloader(self):\n return DataLoader(self.train_dataset, batch_size = self.batch_size, \n collate_fn = collate_fn, shuffle = True)\n \n def val_dataloader(self):\n return DataLoader(self.val_dataset, batch_size = self.batch_size,\n collate_fn = collate_fn, shuffle = False)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:38.567907Z","iopub.execute_input":"2022-06-07T07:08:38.568576Z","iopub.status.idle":"2022-06-07T07:08:39.574703Z","shell.execute_reply.started":"2022-06-07T07:08:38.568533Z","shell.execute_reply":"2022-06-07T07:08:39.573807Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"datamodule = DataModule(train_ds, val_ds)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:39.576481Z","iopub.execute_input":"2022-06-07T07:08:39.577345Z","iopub.status.idle":"2022-06-07T07:08:39.582506Z","shell.execute_reply.started":"2022-06-07T07:08:39.577298Z","shell.execute_reply":"2022-06-07T07:08:39.581255Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 5. Modeling Part 🏎️\n\n1. Firstly, we would define the pytorch model with our configurations, in which the class labels would be ranging from 0 to 15\n2. Secondly, we would encode it in the PyTorch Lightening module, and boom 💥 our work of defining the model is done","metadata":{}},{"cell_type":"code","source":"class DocFormerForClassification(nn.Module):\n \n def __init__(self, config):\n super(DocFormerForClassification, self).__init__()\n\n self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])\n self.embeddings = DocFormerEmbeddings(config)\n self.lang_emb = LanguageFeatureExtractor()\n self.config = config\n self.dropout = nn.Dropout(config['hidden_dropout_prob'])\n self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = len(id2label)) ## Number of Classes\n self.encoder = DocFormerEncoder(config)\n\n def forward(self, batch_dict):\n\n x_feat = batch_dict['x_features']\n y_feat = batch_dict['y_features']\n\n token = batch_dict['input_ids']\n img = batch_dict['resized_scaled_img']\n\n v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)\n v_bar = self.resnet(img)\n t_bar = self.lang_emb(token)\n out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)\n out = self.linear_layer(out)\n out = out[:, 0, :]\n return out","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:08:39.58445Z","iopub.execute_input":"2022-06-07T07:08:39.584879Z","iopub.status.idle":"2022-06-07T07:08:39.597737Z","shell.execute_reply.started":"2022-06-07T07:08:39.584836Z","shell.execute_reply":"2022-06-07T07:08:39.596891Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## Defining pytorch lightning model\nfrom sklearn.metrics import accuracy_score, confusion_matrix\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport numpy as np\nimport torchmetrics\n\nclass DocFormer(pl.LightningModule):\n\n def __init__(self, config , lr = 5e-5):\n super(DocFormer, self).__init__()\n \n self.save_hyperparameters()\n self.config = config\n self.docformer = DocFormerForClassification(config)\n \n self.num_classes = len(id2label)\n self.train_accuracy_metric = torchmetrics.Accuracy()\n self.val_accuracy_metric = torchmetrics.Accuracy()\n self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)\n self.precision_macro_metric = torchmetrics.Precision(\n average=\"macro\", num_classes=self.num_classes\n )\n self.recall_macro_metric = torchmetrics.Recall(\n average=\"macro\", num_classes=self.num_classes\n )\n self.precision_micro_metric = torchmetrics.Precision(average=\"micro\")\n self.recall_micro_metric = torchmetrics.Recall(average=\"micro\")\n\n def forward(self, batch_dict):\n logits = self.docformer(batch_dict)\n return logits\n\n def training_step(self, batch, batch_idx):\n logits = self.forward(batch)\n\n loss = nn.CrossEntropyLoss()(logits, batch['label'])\n preds = torch.argmax(logits, 1)\n\n ## Calculating the accuracy score\n train_acc = self.train_accuracy_metric(preds, batch[\"label\"])\n\n ## Logging\n self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)\n self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)\n\n return loss\n \n def validation_step(self, batch, batch_idx):\n logits = self.forward(batch)\n loss = nn.CrossEntropyLoss()(logits, batch['label'])\n preds = torch.argmax(logits, 1)\n \n labels = batch['label']\n # Metrics\n valid_acc = self.val_accuracy_metric(preds, labels)\n precision_macro = self.precision_macro_metric(preds, labels)\n recall_macro = self.recall_macro_metric(preds, labels)\n precision_micro = self.precision_micro_metric(preds, labels)\n recall_micro = self.recall_micro_metric(preds, labels)\n f1 = self.f1_metric(preds, labels)\n\n # Logging metrics\n self.log(\"valid/loss\", loss, prog_bar=True, on_step=True, logger=True)\n self.log(\"valid/acc\", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/precision_macro\", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/recall_macro\", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/precision_micro\", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/recall_micro\", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)\n self.log(\"valid/f1\", f1, prog_bar=True, on_epoch=True)\n \n return {\"label\": batch['label'], \"logits\": logits}\n\n def validation_epoch_end(self, outputs):\n labels = torch.cat([x[\"label\"] for x in outputs])\n logits = torch.cat([x[\"logits\"] for x in outputs])\n preds = torch.argmax(logits, 1)\n\n wandb.log({\"cm\": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})\n self.logger.experiment.log(\n {\"roc\": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}\n )\n \n def configure_optimizers(self):\n return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:12.786194Z","iopub.execute_input":"2022-06-07T07:09:12.786554Z","iopub.status.idle":"2022-06-07T07:09:12.869343Z","shell.execute_reply.started":"2022-06-07T07:09:12.786523Z","shell.execute_reply":"2022-06-07T07:09:12.86839Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 6. Summing it up and running the entire procedure 🏃","metadata":{}},{"cell_type":"code","source":"from pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.callbacks.early_stopping import EarlyStopping\nfrom pytorch_lightning.loggers import WandbLogger\n\ndef main():\n datamodule = DataModule(train_ds, val_ds)\n docformer = DocFormer(config)\n\n checkpoint_callback = ModelCheckpoint(\n dirpath=\"./models\", monitor=\"valid/loss\", mode=\"min\"\n )\n early_stopping_callback = EarlyStopping(\n monitor=\"valid/loss\", patience=3, verbose=True, mode=\"min\"\n )\n \n wandb.init(config=config, project=\"RVL CDIP with DocFormer New Version\")\n wandb_logger = WandbLogger(project=\"RVL CDIP with DocFormer New Version\", entity=\"iakarshu\")\n ## https://www.tutorialexample.com/implement-reproducibility-in-pytorch-lightning-pytorch-lightning-tutorial/\n pl.seed_everything(seed, workers=True)\n trainer = pl.Trainer(\n default_root_dir=\"logs\",\n gpus=(1 if torch.cuda.is_available() else 0),\n max_epochs=1,\n fast_dev_run=False,\n logger=wandb_logger,\n callbacks=[checkpoint_callback, early_stopping_callback],\n deterministic=True\n )\n trainer.fit(docformer, datamodule)","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:21.086225Z","iopub.execute_input":"2022-06-07T07:09:21.086667Z","iopub.status.idle":"2022-06-07T07:09:21.097492Z","shell.execute_reply.started":"2022-06-07T07:09:21.086636Z","shell.execute_reply":"2022-06-07T07:09:21.096609Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"if __name__ == \"__main__\":\n main()","metadata":{"execution":{"iopub.status.busy":"2022-06-07T07:09:21.297147Z","iopub.execute_input":"2022-06-07T07:09:21.299276Z","iopub.status.idle":"2022-06-07T07:13:58.468599Z","shell.execute_reply.started":"2022-06-07T07:09:21.299227Z","shell.execute_reply":"2022-06-07T07:13:58.467592Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## References:\n\n1. [MLOps Repo](https://github.com/graviraja/MLOps-Basics) (For the integration of model and data with PyTorch Lightening) \n2. [PyTorch Lightening Docs](https://pytorch-lightning.readthedocs.io/en/stable/index.html) For all the doubts and bugs\n3. [My Repo](https://github.com/uakarsh/docformer) For downloading the model and pre-processing steps\n4. Unspash for Images\n5. Google for other stuffs","metadata":{}},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}
--------------------------------------------------------------------------------
/examples/docformer_pl/pre_training_task_on_idl_dataset/docformer_pre_training_1_preparing_idl_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """DocFormer Pre-training : 1. Preparing IDL Dataset
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/17G4Io-2BOLx5YwKymgIwO_g14HovwbBb
8 | """
9 |
10 | ## Refer here for the dataset: https://github.com/furkanbiten/idl_data
11 | # (IDL dataset was also used in the pre-training of LaTr), might take time to download the dataset
12 |
13 |
14 | ## The below lines of code download the sample dataset
15 | # !wget http://datasets.cvc.uab.es/UCSF_IDL/Samples/ocr_imgs_sample.zip
16 | # !unzip /content/ocr_imgs_sample.zip
17 | # !rm /content/ocr_imgs_sample.zip
18 |
19 | # Commented out IPython magic to ensure Python compatibility.
20 | # ## Installing the dependencies (might take some time)
21 | #
22 | # %%capture
23 | # !pip install pytesseract
24 | # !sudo apt install tesseract-ocr
25 | # !pip install transformers
26 | # !pip install pytorch-lightning
27 | # !pip install einops
28 | # !pip install tqdm
29 | # !pip install 'Pillow==7.1.2'
30 | # !pip install PyPDF2
31 |
32 | ## Getting the JSON File
33 |
34 | import json
35 |
36 | ## For reading the PDFs
37 | from PyPDF2 import PdfReader
38 | import io
39 | from PIL import Image, ImageDraw
40 |
41 | ## Standard library
42 | import os
43 |
44 | pdf_path = "./sample/pdfs"
45 | ocr_path = "./sample/OCR"
46 |
47 | ## Image property
48 |
49 | resize_scale = (500, 500)
50 |
51 | from typing import List
52 |
53 | def normalize_box(box : List[int], width : int, height : int, size : tuple = resize_scale):
54 | """
55 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is
56 | just like calculating percentage except takes 1000 instead of 100.
57 | """
58 | return [
59 | int(size[0] * (box[0] / width)),
60 | int(size[1] * (box[1] / height)),
61 | int(size[0] * (box[2] / width)),
62 | int(size[1] * (box[3] / height)),
63 | ]
64 |
65 | ## Function to get the images from the PDFs as well as the OCRs for the corresponding images
66 |
67 | def get_image_ocrs_from_path(pdf_file_path : str, ocr_file_path : str, resize_scale = resize_scale):
68 |
69 | ## Getting the image list, since the pdfs can contain many image
70 | reader = PdfReader(pdf_file_path)
71 | img_list = []
72 | for i in range(len(reader.pages)):
73 | page = reader.pages[i]
74 | for image_file_object in page.images:
75 |
76 | stream = io.BytesIO(image_file_object.data)
77 | img = Image.open(stream).convert("RGB")
78 | img_list.append(img)
79 |
80 | json_entry = json.load(open(ocr_file_path))[1]
81 | json_entry =[x for x in json_entry["Blocks"] if "Text" in x]
82 |
83 | pages = [x["Page"] for x in json_entry]
84 | ocrs = {pg : [] for pg in set(pages)}
85 |
86 | for entry in json_entry:
87 | bbox = entry["Geometry"]["BoundingBox"]
88 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"]
89 | bbox = [x, y, x + w, y + h]
90 | bbox = normalize_box(bbox, width = 1, height = 1, size = resize_scale)
91 | ocrs[entry["Page"]].append({"word" : entry["Text"], "bbox" : bbox})
92 |
93 | return img_list, ocrs
94 |
95 | # sample_pdf_folder = os.path.join(pdf_path, sorted(os.listdir(pdf_path))[0])
96 | # sample_ocr_folder = os.path.join(ocr_path, sorted(os.listdir(ocr_path))[0])
97 |
98 | # sample_pdf = os.path.join(sample_pdf_folder, sample_pdf_folder.split("/")[-1] + ".pdf")
99 | # sample_ocr = os.path.join(sample_ocr_folder, os.listdir(sample_ocr_folder)[0])
100 |
101 | # img_list, ocrs = get_image_ocrs_from_path(sample_pdf, sample_ocr)
102 |
103 | """## Preparing the Pytorch Dataset"""
104 |
105 | from tqdm.auto import tqdm
106 |
107 | img_list = []
108 | ocr_list = []
109 |
110 | pdf_files = sorted(os.listdir(pdf_path))[:30] ## Using only 30 since, google session gets crashed
111 | ocr_files = sorted(os.listdir(ocr_path))[:30]
112 |
113 | for pdf, ocr in tqdm(zip(pdf_files, ocr_files), total = len(pdf_files)):
114 | pdf = os.path.join(pdf_path, pdf, pdf + '.pdf')
115 | ocr = os.path.join(ocr_path, ocr)
116 | ocr = os.path.join(ocr, os.listdir(ocr)[0])
117 | img, ocrs = get_image_ocrs_from_path(pdf, ocr)
118 |
119 | for i in range(len(img)):
120 | img_list.append(img[i])
121 | ocr_list.append(ocrs[i+1]) ## Pages are 1, 2, 3 hence 0 + 1, 1 + 1, 2 + 1
122 |
123 | """## Visualizing the OCRs"""
124 |
125 | index = 43
126 | curr_img = img_list[index].resize(resize_scale)
127 | curr_ocr = ocr_list[index]
128 |
129 | # create rectangle image
130 | draw_on_img = ImageDraw.Draw(curr_img)
131 |
132 | for it in curr_ocr:
133 | box = it["bbox"]
134 | draw_on_img.rectangle(box, outline ="violet")
135 |
136 | curr_img
137 |
138 |
--------------------------------------------------------------------------------
/examples/docformer_pl/pre_training_task_on_idl_dataset/docformer_pre_training_2_preparing_idl_pytorch_dataset_for_docformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """DocFormer Pre-training : 2. Preparing IDL PyTorch Dataset for DocFormer
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1dDMABz_EunCdg3S1neJ5fMPkRSmTcAfd
8 | """
9 |
10 | ## Refer here for the dataset: https://github.com/furkanbiten/idl_data
11 | # (IDL dataset was also used in the pre-training of LaTr), might take time to download the dataset
12 |
13 | # !wget http://datasets.cvc.uab.es/UCSF_IDL/Samples/ocr_imgs_sample.zip
14 | # !unzip /content/ocr_imgs_sample.zip
15 | # !rm /content/ocr_imgs_sample.zip
16 |
17 | # Commented out IPython magic to ensure Python compatibility.
18 | # ## Installing the dependencies (might take some time)
19 | #
20 | # %%capture
21 | # !pip install pytesseract
22 | # !sudo apt install tesseract-ocr
23 | # !pip install transformers
24 | # !pip install pytorch-lightning
25 | # !pip install einops
26 | # !pip install tqdm
27 | # !pip install 'Pillow==7.1.2'
28 | # !pip install PyPDF2
29 |
30 | ## Cloning the repository
31 | # !git clone https://github.com/uakarsh/docformer.git
32 |
33 | ## Getting the JSON File
34 | import json
35 |
36 | ## For reading the PDFs
37 | from PyPDF2 import PdfReader
38 | import io
39 | from PIL import Image, ImageDraw
40 |
41 | ## A bit of code taken from here : https://www.kaggle.com/code/akarshu121/docformer-for-token-classification-on-funsd/notebook
42 | import os
43 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
44 |
45 | ## PyTorch Libraries
46 | import torch
47 | from torchvision.transforms import ToTensor
48 | from torch.utils.data import Dataset, DataLoader
49 | import torch.nn.functional as F
50 | import torch.nn as nn
51 |
52 | ## Adding the path of docformer to system path
53 | import sys
54 | sys.path.append('./docformer/src/docformer/')
55 |
56 | ## Importing the functions from the DocFormer Repo
57 | from dataset import resize_align_bbox, get_centroid, get_pad_token_id_start_index, get_relative_distance
58 | from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
59 |
60 | ## Transformer librarues
61 | from transformers import BertTokenizerFast
62 |
63 | ## PyTorch Lightning library
64 | import pytorch_lightning as pl
65 | from pytorch_lightning.callbacks import ModelCheckpoint
66 |
67 | pdf_path = "./sample/pdfs"
68 | ocr_path = "./sample/OCR"
69 |
70 | ## Image property
71 |
72 | resize_scale = (500, 500)
73 |
74 | from typing import List
75 |
76 | def normalize_box(box : List[int], width : int, height : int, size : tuple = resize_scale):
77 | """
78 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is
79 | just like calculating percentage except takes 1000 instead of 100.
80 | """
81 | return [
82 | int(size[0] * (box[0] / width)),
83 | int(size[1] * (box[1] / height)),
84 | int(size[0] * (box[2] / width)),
85 | int(size[1] * (box[3] / height)),
86 | ]
87 |
88 | ## Function to get the images from the PDFs as well as the OCRs for the corresponding images
89 |
90 | def get_image_ocrs_from_path(pdf_file_path : str, ocr_file_path : str, resize_scale = resize_scale):
91 |
92 | ## Getting the image list, since the pdfs can contain many image
93 | reader = PdfReader(pdf_file_path)
94 | img_list = []
95 | for i in range(len(reader.pages)):
96 | page = reader.pages[i]
97 | for image_file_object in page.images:
98 |
99 | stream = io.BytesIO(image_file_object.data)
100 | img = Image.open(stream).convert("RGB").resize(resize_scale)
101 | img_list.append(img)
102 |
103 | json_entry = json.load(open(ocr_file_path))[1]
104 | json_entry =[x for x in json_entry["Blocks"] if "Text" in x]
105 |
106 | pages = [x["Page"] for x in json_entry]
107 | ocrs = {pg : [] for pg in set(pages)}
108 |
109 | for entry in json_entry:
110 | bbox = entry["Geometry"]["BoundingBox"]
111 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"]
112 | bbox = [x, y, x + w, y + h]
113 | bbox = normalize_box(bbox, width = 1, height = 1, size = resize_scale)
114 | ocrs[entry["Page"]].append({"word" : entry["Text"], "bbox" : bbox})
115 |
116 | return img_list, ocrs
117 |
118 | # sample_pdf_folder = os.path.join(pdf_path, sorted(os.listdir(pdf_path))[0])
119 | # sample_ocr_folder = os.path.join(ocr_path, sorted(os.listdir(ocr_path))[0])
120 |
121 | # sample_pdf = os.path.join(sample_pdf_folder, sample_pdf_folder.split("/")[-1] + ".pdf")
122 | # sample_ocr = os.path.join(sample_ocr_folder, os.listdir(sample_ocr_folder)[0])
123 |
124 | # img_list, ocrs = get_image_ocrs_from_path(sample_pdf, sample_ocr)
125 |
126 | """## Preparing the Pytorch Dataset"""
127 |
128 | from tqdm.auto import tqdm
129 |
130 | img_list = []
131 | ocr_list = []
132 |
133 | pdf_files = sorted(os.listdir(pdf_path))[:30] ## Using only 30 since, google session gets crashed
134 | ocr_files = sorted(os.listdir(ocr_path))[:30]
135 |
136 | for pdf, ocr in tqdm(zip(pdf_files, ocr_files), total = len(pdf_files)):
137 | pdf = os.path.join(pdf_path, pdf, pdf + '.pdf')
138 | ocr = os.path.join(ocr_path, ocr)
139 | ocr = os.path.join(ocr, os.listdir(ocr)[0])
140 | img, ocrs = get_image_ocrs_from_path(pdf, ocr)
141 |
142 | for i in range(len(img)):
143 | img_list.append(img[i])
144 | ocr_list.append(ocrs[i+1]) ## Pages are 1, 2, 3 hence 0 + 1, 1 + 1, 2 + 1
145 |
146 | """## Visualizing the OCRs"""
147 |
148 | index = 17
149 | curr_img = img_list[index]
150 | curr_ocr = ocr_list[index]
151 |
152 | # create rectangle image
153 | draw_on_img = ImageDraw.Draw(curr_img)
154 |
155 | for it in curr_ocr:
156 | box = it["bbox"]
157 | draw_on_img.rectangle(box, outline ="violet")
158 |
159 | curr_img
160 |
161 | """## Creating features for DocFormer"""
162 |
163 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
164 |
165 | def get_tokens_with_boxes(unnormalized_word_boxes, word_ids,max_seq_len = 512, pad_token_box = [0, 0, 0, 0]):
166 |
167 | # assert len(unnormalized_word_boxes) == len(word_ids), this should not be applied, since word_ids may have higher
168 | # length and the bbox corresponding to them may not exist
169 |
170 | unnormalized_token_boxes = []
171 |
172 | i = 0
173 | for word_idx in word_ids:
174 | if word_idx is None:
175 | break
176 | unnormalized_token_boxes.append(unnormalized_word_boxes[word_idx])
177 | i+=1
178 |
179 | # all remaining are padding tokens so why add them in a loop one by one
180 | num_pad_tokens = len(word_ids) - i - 1
181 | if num_pad_tokens > 0:
182 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens)
183 |
184 |
185 | if len(unnormalized_token_boxes) 0:
167 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens)
168 |
169 |
170 | if len(unnormalized_token_boxes)=0.3',
20 | 'torch>=1.6',
21 | 'torchvision',
22 | 'pytesseract',
23 | 'transformers',
24 | 'pytesseract>=0.3.8',
25 | ],
26 | classifiers=[
27 | 'Development Status :: 4 - Beta',
28 | 'Intended Audience :: Developers',
29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
30 | 'License :: OSI Approved :: MIT License',
31 | 'Programming Language :: Python :: 3.7',
32 | ],
33 | )
34 |
--------------------------------------------------------------------------------
/src/docformer/README.md:
--------------------------------------------------------------------------------
1 | # About the files:
2 |
3 |
4 | ```python
5 | 1. dataset.py
6 | ```
7 | * This file contains the various functions, which are required for the pre-processing of the image as for an example, running the OCR for getting the bounding box of each word, getting the relative distance between the coordinates of the different word, performing the tokenization, and some optional arguments which are required for saving the different features on the disk, and many more!!!
8 |
9 |
10 | * The code can be modified as per the requirement, as for an example, if we have to perform a specific task of segmenting the image or classifying a document, with some modifications in the ```create_features``` function, it can be achieved.
11 |
12 | ```python
13 | 2. dataset_pytorch.py
14 | ```
15 | * This file inherits the functions from the ```dataset.py```, however it creates a Dataset object of the file stored in the disk, and the function can be modified for some augmentations as well
16 | * The Dataset object is required for the purpose of training the model in PyTorch, however in TensorFlow, the numpy version would work instead of Dataset and DataLoader object
17 |
18 | ```python
19 | 3. modeling.py
20 | ```
21 | * This file is the brain of everything in this repo, the file contains the various functions, which have been written with least approximation in mind, and as close to the paper, it contains the ```multi-head attention```, the various embedding functions, and a lot of stuffs, which are mentioned in the paper. In order, to understand this file properly, one of the suggestion is to, open the code and the paper side by side, and that would work.
22 | * And, for the task specific requirements, one can import ```DocFormerEncoder```, and attach one head for the task-specific requirement, however, the last function ```ShallowDecoder``` is a work in progress, which is for the Image Reconstruction purpose as mentioned in the paper (in the pre-training part)
23 |
24 | ```python
25 | 4. modeling_pl.py
26 | ```
27 | * This file is basically, for the parallelization of the training and validation part, so that the utilization of multiple GPUs becomes easy
28 | * This file contains the integration of PyTorch Lightening, with the DocFormer model, so that the coding part becomes less and the task specific things can be done.
29 | * For task specific requirements, one can modify the ```Model``` class, with the modification being in the ```training``` and ```validation``` step, where the specific loss functions can be integrated and thats it!!!
30 |
31 |
32 | ```python
33 | 5. train_accelerator.py
34 | and
35 | 6. train_accelerator_mlm_ir.py
36 | ```
37 | * These files are also codes for the purpose of training so that the coding requirements becomes less.
38 | * The only thing is that, these code inherits the ```Accelerator``` of "Hugging Face", for the purpose of Parallelization of the task to multiple GPUs
39 | * ```train_accelerator.py``` contains the function of running the code of Pre-training the model with ```MLM``` task
40 | * ```train_accelerator_mlm_ir.py``` contains the function of running the code of Pre-training the model with ```MLM and Image Reconstruction (IR)``` task, however we are thinking of making a file which contains the options of training according to specific task
41 |
42 |
43 |
44 | ```python
45 | 7. utils.py
46 | ```
47 |
48 | * File, which contains the utility function for performing the unsupervised task of Text Describe Image (as mentioned in the paper: DocFormer)
49 |
50 | How to use it?
51 | * Let us assume, that all the entries of the dataset have been stored somewhere.
52 | * Now, we can get the length of the entries, and that has to be passed to the function ```labels_for_tdi```, which would give the arguments as well as the labels, now, iterate through each of the arguments, and for ith entry, create a new entry in the dicitionary (data format for docformer, refer to dataset.py, create_features function), and map it to the resized_scaled_img of arr[i],
53 |
54 | * i.e in terms of pseduocode,
55 | Assume, that
56 | ```python
57 |
58 | l-> list of dictionary format, data points of docformer
59 | d_arr, labels = labels_for_tdi(n)
60 | for i, j in enumerate(d_arr):
61 | l[i]['d_resized_scaled_img'] = l[j]['resized_scaled_img']
62 | l[i]['label_for_tdi'] = labels[i]
63 | ```
64 |
65 | And then, the rest follows by passing the argument `use_tdi`, for the model
66 |
67 | Using it with the model:
68 | For the purpose of integrating TDI with model, the following instruction would be helpful:
69 |
70 | Let us assume, we want to do MLM + IR + TDI:
71 |
72 | 1. As for the first and the third task, attach a head, and forward propagate the data, calculate the weighted loss of these two task, and store it
73 | 2. In case of the second task, you have to forward propagate it again with the same dataset, but with the argument use_tdi = True, and calculate the binary cross entropy loss with the `label_for_tdi` key, and add the weighted sum of it to the stored loss, and then backpropagate it
74 |
--------------------------------------------------------------------------------
/src/docformer/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import pickle
4 | from functools import lru_cache
5 | import pytesseract
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 | from torchvision.transforms import ToTensor
10 |
11 | PAD_TOKEN_BOX = [0, 0, 0, 0]
12 | GRID_SIZE = 1000
13 |
14 |
15 | def normalize_box(box, width, height, size=1000):
16 | """
17 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is
18 | just like calculating percentage except takes 1000 instead of 100.
19 | """
20 | return [
21 | int(size * (box[0] / width)),
22 | int(size * (box[1] / height)),
23 | int(size * (box[2] / width)),
24 | int(size * (box[3] / height)),
25 | ]
26 |
27 |
28 | @lru_cache(maxsize=10)
29 | def resize_align_bbox(bbox, orig_w, orig_h, target_w, target_h):
30 | x_scale = target_w / orig_w
31 | y_scale = target_h / orig_h
32 | orig_left, orig_top, orig_right, orig_bottom = bbox
33 | x = int(np.round(orig_left * x_scale))
34 | y = int(np.round(orig_top * y_scale))
35 | xmax = int(np.round(orig_right * x_scale))
36 | ymax = int(np.round(orig_bottom * y_scale))
37 | return [x, y, xmax, ymax]
38 |
39 |
40 | def get_topleft_bottomright_coordinates(df_row):
41 | left, top, width, height = df_row["left"], df_row["top"], df_row["width"], df_row["height"]
42 | return [left, top, left + width, top + height]
43 |
44 |
45 | def apply_ocr(image_fp):
46 | """
47 | Returns words and its bounding boxes from an image
48 | """
49 | image = Image.open(image_fp)
50 | width, height = image.size
51 |
52 | ocr_df = pytesseract.image_to_data(image, output_type="data.frame")
53 | ocr_df = ocr_df.dropna().reset_index(drop=True)
54 | float_cols = ocr_df.select_dtypes("float").columns
55 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
56 | ocr_df = ocr_df.replace(r"^\s*$", np.nan, regex=True)
57 | ocr_df = ocr_df.dropna().reset_index(drop=True)
58 | words = list(ocr_df.text.apply(lambda x: str(x).strip()))
59 | actual_bboxes = ocr_df.apply(get_topleft_bottomright_coordinates, axis=1).values.tolist()
60 |
61 | # add as extra columns
62 | assert len(words) == len(actual_bboxes)
63 | return {"words": words, "bbox": actual_bboxes}
64 |
65 | def get_tokens_with_boxes(unnormalized_word_boxes, pad_token_box, word_ids,max_seq_len = 512):
66 |
67 | # assert len(unnormalized_word_boxes) == len(word_ids), this should not be applied, since word_ids may have higher
68 | # length and the bbox corresponding to them may not exist
69 |
70 | unnormalized_token_boxes = []
71 |
72 | for i, word_idx in enumerate(word_ids):
73 | if word_idx is None:
74 | break
75 | unnormalized_token_boxes.append(unnormalized_word_boxes[word_idx])
76 |
77 | # all remaining are padding tokens so why add them in a loop one by one
78 | num_pad_tokens = len(word_ids) - i - 1
79 | if num_pad_tokens > 0:
80 | unnormalized_token_boxes.extend([pad_token_box] * num_pad_tokens)
81 |
82 |
83 | if len(unnormalized_token_boxes)= pad_tokens_start_idx:
113 | a_rel_x.append([0] * 8)
114 | a_rel_y.append([0] * 8)
115 | continue
116 |
117 | curr = bboxes[i]
118 | next = bboxes[i+1]
119 |
120 | a_rel_x.append(
121 | [
122 | curr[0], # top left x
123 | curr[2], # bottom right x
124 | curr[2] - curr[0], # width
125 | next[0] - curr[0], # diff top left x
126 | next[0] - curr[0], # diff bottom left x
127 | next[2] - curr[2], # diff top right x
128 | next[2] - curr[2], # diff bottom right x
129 | centroids[i+1][0] - centroids[i][0],
130 | ]
131 | )
132 |
133 | a_rel_y.append(
134 | [
135 | curr[1], # top left y
136 | curr[3], # bottom right y
137 | curr[3] - curr[1], # height
138 | next[1] - curr[1], # diff top left y
139 | next[3] - curr[3], # diff bottom left y
140 | next[1] - curr[1], # diff top right y
141 | next[3] - curr[3], # diff bottom right y
142 | centroids[i+1][1] - centroids[i][1],
143 | ]
144 | )
145 |
146 | # For the last word
147 |
148 | a_rel_x.append([0]*8)
149 | a_rel_y.append([0]*8)
150 |
151 |
152 | return a_rel_x, a_rel_y
153 |
154 |
155 |
156 | def apply_mask(inputs, tokenizer):
157 | inputs = torch.as_tensor(inputs)
158 | rand = torch.rand(inputs.shape)
159 | # where the random array is less than 0.15, we set true
160 | mask_arr = (rand < 0.15) * (inputs != tokenizer.cls_token_id) * (inputs != tokenizer.pad_token_id)
161 | # create selection from mask_arr
162 | selection = torch.flatten(mask_arr.nonzero()).tolist()
163 | # apply selection pad_tokens_start_idx to inputs.input_ids, adding MASK tokens
164 | inputs[selection] = 103
165 | return inputs
166 |
167 |
168 | def read_image_and_extract_text(image):
169 | original_image = Image.open(image).convert("RGB")
170 | return apply_ocr(image)
171 |
172 |
173 | def create_features(
174 | image,
175 | tokenizer,
176 | add_batch_dim=False,
177 | target_size=(500,384), # This was the resolution used by the authors
178 | max_seq_length=512,
179 | path_to_save=None,
180 | save_to_disk=False,
181 | apply_mask_for_mlm=False,
182 | extras_for_debugging=False,
183 | use_ocr = True,
184 | bounding_box = None,
185 | words = None
186 | ):
187 |
188 | # step 1: read original image and extract OCR entries
189 | try:
190 | original_image = Image.open(image).convert("RGB")
191 | except:
192 | original_image = Image.new(mode = "RGB", size = ((500, 500)), color = (255, 255, 255))
193 |
194 | if (use_ocr == False) and (bounding_box == None or words == None):
195 | raise Exception('Please provide the bounding box and words or pass the argument "use_ocr" = True')
196 |
197 | if use_ocr == True:
198 | entries = apply_ocr(image)
199 | bounding_box = entries["bbox"]
200 | words = entries["words"]
201 |
202 | CLS_TOKEN_BOX = [0, 0, *original_image.size] # Can be variable, but as per the paper, they have mentioned that it covers the whole image
203 | # step 2: resize image
204 | resized_image = original_image.resize(target_size)
205 |
206 | # step 3: normalize image to a grid of 1000 x 1000 (to avoid the problem of differently sized images)
207 | width, height = original_image.size
208 | normalized_word_boxes = [
209 | normalize_box(bbox, width, height, GRID_SIZE) for bbox in bounding_box
210 | ]
211 | assert len(words) == len(normalized_word_boxes), "Length of words != Length of normalized words"
212 |
213 | # step 4: tokenize words and get their bounding boxes (one word may split into multiple tokens)
214 | encoding = tokenizer(words,
215 | padding="max_length",
216 | max_length=max_seq_length,
217 | is_split_into_words=True,
218 | truncation=True,
219 | add_special_tokens=False)
220 |
221 | unnormalized_token_boxes = get_tokens_with_boxes(bounding_box,
222 | PAD_TOKEN_BOX,
223 | encoding.word_ids())
224 |
225 | # step 5: add special tokens and truncate seq. to maximum length
226 | unnormalized_token_boxes = [CLS_TOKEN_BOX] + unnormalized_token_boxes[:-1]
227 | # add CLS token manually to avoid autom. addition of SEP too (as in the paper)
228 | encoding["input_ids"] = [tokenizer.cls_token_id] + encoding["input_ids"][:-1]
229 |
230 | # step 6: Add bounding boxes to the encoding dict
231 | encoding["unnormalized_token_boxes"] = unnormalized_token_boxes
232 |
233 | # step 7: apply mask for the sake of pre-training
234 | if apply_mask_for_mlm:
235 | encoding["mlm_labels"] = encoding["input_ids"]
236 | encoding["input_ids"] = apply_mask(encoding["input_ids"], tokenizer)
237 | assert len(encoding["mlm_labels"]) == max_seq_length, "Length of mlm_labels != Length of max_seq_length"
238 |
239 | assert len(encoding["input_ids"]) == max_seq_length, "Length of input_ids != Length of max_seq_length"
240 | assert len(encoding["attention_mask"]) == max_seq_length, "Length of attention mask != Length of max_seq_length"
241 | assert len(encoding["token_type_ids"]) == max_seq_length, "Length of token type ids != Length of max_seq_length"
242 |
243 | # step 8: normalize the image
244 | encoding["resized_scaled_img"] = ToTensor()(resized_image)
245 |
246 | # step 9: apply mask for the sake of pre-training
247 | if apply_mask_for_mlm:
248 | encoding["mlm_labels"] = encoding["input_ids"]
249 | encoding["input_ids"] = apply_mask(encoding["input_ids"], tokenizer)
250 |
251 | # step 10: rescale and align the bounding boxes to match the resized image size (typically 224x224)
252 | resized_and_aligned_bboxes = []
253 |
254 | for bbox in unnormalized_token_boxes:
255 | # performing the normalization of the bounding box
256 | resized_and_aligned_bboxes.append(resize_align_bbox(tuple(bbox), *original_image.size, *target_size))
257 |
258 | encoding["resized_and_aligned_bounding_boxes"] = resized_and_aligned_bboxes
259 |
260 | # step 11: add the relative distances in the normalized grid
261 | bboxes_centroids = get_centroid(resized_and_aligned_bboxes)
262 | pad_token_start_index = get_pad_token_id_start_index(words, encoding, tokenizer)
263 | a_rel_x, a_rel_y = get_relative_distance(resized_and_aligned_bboxes, bboxes_centroids, pad_token_start_index)
264 |
265 | # step 12: convert all to tensors
266 | for k, v in encoding.items():
267 | encoding[k] = torch.as_tensor(encoding[k])
268 |
269 | encoding.update({
270 | "x_features": torch.as_tensor(a_rel_x, dtype=torch.int32),
271 | "y_features": torch.as_tensor(a_rel_y, dtype=torch.int32),
272 | })
273 |
274 | # step 13: add tokens for debugging
275 | if extras_for_debugging:
276 | input_ids = encoding["mlm_labels"] if apply_mask_for_mlm else encoding["input_ids"]
277 | encoding["tokens_without_padding"] = tokenizer.convert_ids_to_tokens(input_ids)
278 | encoding["words"] = words
279 |
280 |
281 | # step 14: add extra dim for batch
282 | if add_batch_dim:
283 | encoding["x_features"].unsqueeze_(0)
284 | encoding["y_features"].unsqueeze_(0)
285 | encoding["input_ids"].unsqueeze_(0)
286 | encoding["resized_scaled_img"].unsqueeze_(0)
287 |
288 | # step 15: save to disk
289 | if save_to_disk:
290 | os.makedirs(path_to_save, exist_ok=True)
291 | image_name = os.path.basename(image)
292 | with open(f"{path_to_save}{image_name}.pickle", "wb") as f:
293 | pickle.dump(encoding, f)
294 |
295 | # step 16: keys to keep, resized_and_aligned_bounding_boxes have been added for the purpose to test if the bounding boxes are drawn correctly or not, it maybe removed
296 |
297 | keys = ['resized_scaled_img', 'x_features','y_features','input_ids','resized_and_aligned_bounding_boxes']
298 |
299 | if apply_mask_for_mlm:
300 | keys.append('mlm_labels')
301 |
302 | final_encoding = {k:encoding[k] for k in keys}
303 |
304 | del encoding
305 | return final_encoding
306 |
--------------------------------------------------------------------------------
/src/docformer/dataset_pytorch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision.models as models
9 | from PIL import Image
10 | from sklearn.model_selection import train_test_split as tts
11 | from torch.autograd import Variable
12 | from torch.utils.data import DataLoader, Dataset
13 | from torchvision.transforms import ToTensor
14 | from transformers import AutoModel, AutoTokenizer
15 |
16 | """## Base Dataset"""
17 |
18 | device = "cuda" if torch.cuda.is_available() else "cpu"
19 |
20 | class DocumentDatset(Dataset):
21 | def __init__(self, entries, pathToPickleFile,pretrain= True):
22 | self.pathToPickleFile = pathToPickleFile
23 | self.entries = entries
24 | self.pretrain = pretrain
25 |
26 | def __len__(self):
27 | return len(self.entries)
28 |
29 | def __getitem__(self, index):
30 |
31 | imageName = self.entries[index]
32 | encoding = os.path.join(self.pathToPickleFile, imageName)
33 |
34 | with open(encoding, "rb") as sample:
35 | encoding = pickle.load(sample)
36 |
37 | if self.pretrain:
38 |
39 | # If the model is used for the purpose of pretraining, then there is no need for the other entries, since there would be some errors, while training
40 |
41 | del encoding['category_labels'] # Error would be created, because category label cannot be stored in the pytorch tensor
42 | del encoding['numeric_labels'] # Removed it, but this can be used for the purpose of the segmenting (as for an image_fp, in the FUNSD Dataset)
43 | del encoding['target_bbox'] # For the purpose of segmenting the different text in the image
44 | del encoding['resized_and_aligned_target_bbox'] # Resized version of the above bounding box, for 224x224 image
45 |
46 | for i in list(encoding.keys()):
47 | encoding[i] = encoding[i].to(device)
48 |
49 | # Since, we had taken the absolute value of the relative distance, we don't need to add any offset, and hence we can proceed with the model training
50 | return encoding
51 |
52 |
53 | # pathToPickleFile = 'RVL-CDIP-PickleFiles/'
54 | # entries = os.listdir(pathToPickleFile)
55 | # train_entries,val_entries = tts(entries,test_size = 0.2)
56 | # train_dataset = DocumentDatset(train_entries,pathToPickleFile)
57 | # val_dataset = DocumentDatset(val_entries,pathToPickleFile)
58 |
59 |
--------------------------------------------------------------------------------
/src/docformer/modeling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.models as models
6 | from einops import rearrange
7 | from torch import Tensor
8 |
9 | class PositionalEncoding(nn.Module):
10 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
11 | super().__init__()
12 | self.dropout = nn.Dropout(p=dropout)
13 | self.max_len = max_len
14 | self.d_model = d_model
15 | position = torch.arange(max_len).unsqueeze(1)
16 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
17 | pe = torch.zeros(1, max_len, d_model)
18 | pe[0, :, 0::2] = torch.sin(position * div_term)
19 | pe[0, :, 1::2] = torch.cos(position * div_term)
20 | self.register_buffer("pe", pe)
21 |
22 |
23 | def forward(self) -> Tensor:
24 | x = self.pe[0, : self.max_len]
25 | return self.dropout(x).unsqueeze(0)
26 |
27 |
28 | class ResNetFeatureExtractor(nn.Module):
29 | def __init__(self, hidden_dim = 512):
30 | super().__init__()
31 |
32 | # Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction
33 |
34 | resnet50 = models.resnet50(pretrained=False)
35 | modules = list(resnet50.children())[:-2]
36 | self.resnet50 = nn.Sequential(*modules)
37 |
38 | # Applying convolution and linear layer
39 |
40 | self.conv1 = nn.Conv2d(2048, 768, 1)
41 | self.relu1 = F.relu
42 | self.linear1 = nn.Linear(192, hidden_dim)
43 |
44 | def forward(self, x):
45 | x = self.resnet50(x)
46 | x = self.conv1(x)
47 | x = self.relu1(x)
48 | x = rearrange(x, "b e w h -> b e (w h)") # b -> batch, e -> embedding dim, w -> width, h -> height
49 | x = self.linear1(x)
50 | x = rearrange(x, "b e s -> b s e") # b -> batch, e -> embedding dim, s -> sequence length
51 | return x
52 |
53 | class DocFormerEmbeddings(nn.Module):
54 | """Construct the embeddings from word, position and token_type embeddings."""
55 |
56 | def __init__(self, config):
57 | super(DocFormerEmbeddings, self).__init__()
58 |
59 | self.config = config
60 |
61 | self.position_embeddings_v = PositionalEncoding(
62 | d_model=config["hidden_size"],
63 | dropout=0.1,
64 | max_len=config["max_position_embeddings"],
65 | )
66 |
67 | self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
68 | self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
69 | self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
70 | self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
71 | self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
72 | self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
73 | self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
74 | self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
75 |
76 | self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
77 | self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
78 | self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
79 | self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
80 | self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
81 | self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
82 | self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
83 | self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
84 |
85 | self.position_embeddings_t = PositionalEncoding(
86 | d_model=config["hidden_size"],
87 | dropout=0.1,
88 | max_len=config["max_position_embeddings"],
89 | )
90 |
91 | self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
92 | self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
93 | self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
94 | self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
95 | self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
96 | self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
97 | self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
98 | self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
99 |
100 | self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
101 | self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
102 | self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
103 | self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
104 | self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
105 | self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
106 | self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
107 | self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
108 |
109 | self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
110 | self.dropout = nn.Dropout(config["hidden_dropout_prob"])
111 |
112 |
113 |
114 | def forward(self, x_feature, y_feature):
115 |
116 | """
117 | Arguments:
118 | x_features of shape, (batch size, seq_len, 8)
119 | y_features of shape, (batch size, seq_len, 8)
120 | Outputs:
121 | (V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768)
122 | What are the features:
123 | 0 -> top left x/y
124 | 1 -> bottom right x/y
125 | 2 -> width/height
126 | 3 -> diff top left x/y
127 | 4 -> diff bottom left x/y
128 | 5 -> diff top right x/y
129 | 6 -> diff bottom right x/y
130 | 7 -> centroids diff x/y
131 | """
132 |
133 |
134 | batch, seq_len = x_feature.shape[:-1]
135 | hidden_size = self.config["hidden_size"]
136 | num_feat = x_feature.shape[-1]
137 | sub_dim = hidden_size // num_feat
138 |
139 | # Clamping and adding a bias for handling negative values
140 | x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
141 | x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
142 |
143 | y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
144 | y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
145 |
146 | x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0])
147 | x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1])
148 | w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2])
149 | x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3])
150 | x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4])
151 | x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5])
152 | x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6])
153 | x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7])
154 |
155 | x_calculated_embedding_v = torch.cat(
156 | [
157 | x_topleft_position_embeddings_v,
158 | x_bottomright_position_embeddings_v,
159 | w_position_embeddings_v,
160 | x_topleft_distance_to_prev_embeddings_v,
161 | x_bottomleft_distance_to_prev_embeddings_v,
162 | x_topright_distance_to_prev_embeddings_v,
163 | x_bottomright_distance_to_prev_embeddings_v ,
164 | x_centroid_distance_to_prev_embeddings_v
165 | ],
166 | dim = -1
167 | )
168 |
169 | y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0])
170 | y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1])
171 | h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2])
172 | y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3])
173 | y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4])
174 | y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5])
175 | y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6])
176 | y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7])
177 |
178 | x_calculated_embedding_v = torch.cat(
179 | [
180 | x_topleft_position_embeddings_v,
181 | x_bottomright_position_embeddings_v,
182 | w_position_embeddings_v,
183 | x_topleft_distance_to_prev_embeddings_v,
184 | x_bottomleft_distance_to_prev_embeddings_v,
185 | x_topright_distance_to_prev_embeddings_v,
186 | x_bottomright_distance_to_prev_embeddings_v ,
187 | x_centroid_distance_to_prev_embeddings_v
188 | ],
189 | dim = -1
190 | )
191 |
192 | y_calculated_embedding_v = torch.cat(
193 | [
194 | y_topleft_position_embeddings_v,
195 | y_bottomright_position_embeddings_v,
196 | h_position_embeddings_v,
197 | y_topleft_distance_to_prev_embeddings_v,
198 | y_bottomleft_distance_to_prev_embeddings_v,
199 | y_topright_distance_to_prev_embeddings_v,
200 | y_bottomright_distance_to_prev_embeddings_v ,
201 | y_centroid_distance_to_prev_embeddings_v
202 | ],
203 | dim = -1
204 | )
205 |
206 | v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v()
207 |
208 |
209 |
210 | x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0])
211 | x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1])
212 | w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2])
213 | x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3])
214 | x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4])
215 | x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5])
216 | x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6])
217 | x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7])
218 |
219 | x_calculated_embedding_t = torch.cat(
220 | [
221 | x_topleft_position_embeddings_t,
222 | x_bottomright_position_embeddings_t,
223 | w_position_embeddings_t,
224 | x_topleft_distance_to_prev_embeddings_t,
225 | x_bottomleft_distance_to_prev_embeddings_t,
226 | x_topright_distance_to_prev_embeddings_t,
227 | x_bottomright_distance_to_prev_embeddings_t ,
228 | x_centroid_distance_to_prev_embeddings_t
229 | ],
230 | dim = -1
231 | )
232 |
233 | y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0])
234 | y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1])
235 | h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2])
236 | y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3])
237 | y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4])
238 | y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5])
239 | y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6])
240 | y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7])
241 |
242 | x_calculated_embedding_t = torch.cat(
243 | [
244 | x_topleft_position_embeddings_t,
245 | x_bottomright_position_embeddings_t,
246 | w_position_embeddings_t,
247 | x_topleft_distance_to_prev_embeddings_t,
248 | x_bottomleft_distance_to_prev_embeddings_t,
249 | x_topright_distance_to_prev_embeddings_t,
250 | x_bottomright_distance_to_prev_embeddings_t ,
251 | x_centroid_distance_to_prev_embeddings_t
252 | ],
253 | dim = -1
254 | )
255 |
256 | y_calculated_embedding_t = torch.cat(
257 | [
258 | y_topleft_position_embeddings_t,
259 | y_bottomright_position_embeddings_t,
260 | h_position_embeddings_t,
261 | y_topleft_distance_to_prev_embeddings_t,
262 | y_bottomleft_distance_to_prev_embeddings_t,
263 | y_topright_distance_to_prev_embeddings_t,
264 | y_bottomright_distance_to_prev_embeddings_t ,
265 | y_centroid_distance_to_prev_embeddings_t
266 | ],
267 | dim = -1
268 | )
269 |
270 | t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t()
271 |
272 | return v_bar_s, t_bar_s
273 |
274 |
275 |
276 | # fmt: off
277 | class PreNorm(nn.Module):
278 | def __init__(self, dim, fn):
279 | # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
280 | super().__init__()
281 | self.norm = nn.LayerNorm(dim)
282 | self.fn = fn
283 |
284 | def forward(self, x, **kwargs):
285 | return self.fn(self.norm(x), **kwargs)
286 |
287 |
288 | class PreNormAttn(nn.Module):
289 | def __init__(self, dim, fn):
290 | # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
291 | super().__init__()
292 |
293 | self.norm_t_bar = nn.LayerNorm(dim)
294 | self.norm_v_bar = nn.LayerNorm(dim)
295 | self.norm_t_bar_s = nn.LayerNorm(dim)
296 | self.norm_v_bar_s = nn.LayerNorm(dim)
297 | self.fn = fn
298 |
299 | def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs):
300 | return self.fn(self.norm_t_bar(t_bar),
301 | self.norm_v_bar(v_bar),
302 | self.norm_t_bar_s(t_bar_s),
303 | self.norm_v_bar_s(v_bar_s), **kwargs)
304 |
305 |
306 | class FeedForward(nn.Module):
307 | def __init__(self, dim, hidden_dim, dropout=0.):
308 | super().__init__()
309 | self.net = nn.Sequential(
310 | nn.Linear(dim, hidden_dim),
311 | nn.GELU(),
312 | nn.Dropout(dropout),
313 | nn.Linear(hidden_dim, dim),
314 | nn.Dropout(dropout)
315 | )
316 |
317 | def forward(self, x):
318 | return self.net(x)
319 |
320 |
321 | class RelativePosition(nn.Module):
322 |
323 | def __init__(self, num_units, max_relative_position, max_seq_length):
324 | super().__init__()
325 | self.num_units = num_units
326 | self.max_relative_position = max_relative_position
327 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
328 | self.max_length = max_seq_length
329 | range_vec_q = torch.arange(max_seq_length)
330 | range_vec_k = torch.arange(max_seq_length)
331 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
332 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
333 | final_mat = distance_mat_clipped + self.max_relative_position
334 | self.final_mat = torch.LongTensor(final_mat)
335 | nn.init.xavier_uniform_(self.embeddings_table)
336 |
337 | def forward(self, length_q, length_k):
338 | embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]]
339 | return embeddings
340 |
341 |
342 | class MultiModalAttentionLayer(nn.Module):
343 | def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout):
344 | super().__init__()
345 | assert embed_dim % n_heads == 0
346 |
347 | self.embed_dim = embed_dim
348 | self.n_heads = n_heads
349 | self.head_dim = embed_dim // n_heads
350 |
351 | self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
352 | self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
353 |
354 | # text qkv embeddings
355 | self.fc_k_text = nn.Linear(embed_dim, embed_dim)
356 | self.fc_q_text = nn.Linear(embed_dim, embed_dim)
357 | self.fc_v_text = nn.Linear(embed_dim, embed_dim)
358 |
359 | # image qkv embeddings
360 | self.fc_k_img = nn.Linear(embed_dim, embed_dim)
361 | self.fc_q_img = nn.Linear(embed_dim, embed_dim)
362 | self.fc_v_img = nn.Linear(embed_dim, embed_dim)
363 |
364 | # spatial qk embeddings (shared for visual and text)
365 | self.fc_k_spatial = nn.Linear(embed_dim, embed_dim)
366 | self.fc_q_spatial = nn.Linear(embed_dim, embed_dim)
367 |
368 | self.dropout = nn.Dropout(dropout)
369 |
370 | self.to_out = nn.Sequential(
371 | nn.Linear(embed_dim, embed_dim),
372 | nn.Dropout(dropout)
373 | )
374 | self.scale = embed_dim**0.5
375 |
376 | def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
377 | text_feat = text_feat
378 | img_feat = img_feat
379 | text_spatial_feat = text_spatial_feat
380 | img_spatial_feat = img_spatial_feat
381 | seq_length = text_feat.shape[1]
382 |
383 | # self attention of text
384 | # b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
385 | key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
386 | query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads)
387 | value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
388 | dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh)
389 | dots_text = dots_text/ self.scale
390 |
391 | # 1D relative positions (query, key)
392 | rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length)
393 | rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text)
394 | rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text)
395 |
396 | # shared spatial <-> text hidden features
397 | key_spatial_text = self.fc_k_spatial(text_spatial_feat)
398 | query_spatial_text = self.fc_q_spatial(text_spatial_feat)
399 | key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads)
400 | query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads)
401 | dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh)
402 | dots_text_spatial = dots_text_spatial/ self.scale
403 |
404 | # Line 38 of pseudo-code
405 | text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial
406 |
407 | # self-attention of image
408 | key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
409 | query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads)
410 | value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
411 | dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh)
412 | dots_img = dots_img/ self.scale
413 |
414 | # 1D relative positions (query, key)
415 | rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length)
416 | rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text)
417 | rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text)
418 |
419 | # shared spatial <-> image features
420 | key_spatial_img = self.fc_k_spatial(img_spatial_feat)
421 | query_spatial_img = self.fc_q_spatial(img_spatial_feat)
422 | key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads)
423 | query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads)
424 | dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh)
425 | dots_img_spatial = dots_img_spatial/ self.scale
426 |
427 | # Line 59 of pseudo-code
428 | img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial
429 |
430 | text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1))
431 | img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1))
432 |
433 | text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh)
434 | img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh)
435 |
436 | context = text_context + img_context
437 |
438 | embeddings = rearrange(context, 'head b t d -> b t (head d)')
439 | return self.to_out(embeddings)
440 |
441 | class DocFormerEncoder(nn.Module):
442 | def __init__(self, config):
443 | super().__init__()
444 | self.config = config
445 | self.layers = nn.ModuleList([])
446 | for _ in range(config['num_hidden_layers']):
447 | encoder_block = nn.ModuleList([
448 | PreNormAttn(config['hidden_size'],
449 | MultiModalAttentionLayer(config['hidden_size'],
450 | config['num_attention_heads'],
451 | config['max_relative_positions'],
452 | config['max_position_embeddings'],
453 | config['hidden_dropout_prob'],
454 | )
455 | ),
456 | PreNorm(config['hidden_size'],
457 | FeedForward(config['hidden_size'],
458 | config['hidden_size'] * config['intermediate_ff_size_factor'],
459 | dropout=config['hidden_dropout_prob']))
460 | ])
461 | self.layers.append(encoder_block)
462 |
463 | def forward(
464 | self,
465 | text_feat, # text feat or output from last encoder block
466 | img_feat,
467 | text_spatial_feat,
468 | img_spatial_feat,
469 | ):
470 | # Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762
471 | # TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does)
472 | for attn, ff in self.layers:
473 | skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat
474 | x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip
475 | x = ff(x) + x
476 | text_feat = x
477 | return x
478 |
479 |
480 | class LanguageFeatureExtractor(nn.Module):
481 | def __init__(self):
482 | super().__init__()
483 | from transformers import LayoutLMForTokenClassification
484 | layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1)
485 | self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight)
486 |
487 | def forward(self, x):
488 | return self.embedding_vector(x)
489 |
490 |
491 |
492 | class ExtractFeatures(nn.Module):
493 |
494 | '''
495 | Inputs: dictionary
496 | Output: v_bar, t_bar, v_bar_s, t_bar_s
497 | '''
498 |
499 | def __init__(self, config):
500 | super().__init__()
501 | self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
502 | self.language_feature = LanguageFeatureExtractor()
503 | self.spatial_feature = DocFormerEmbeddings(config)
504 |
505 | def forward(self, encoding):
506 |
507 | image = encoding['resized_scaled_img']
508 |
509 | language = encoding['input_ids']
510 | x_feature = encoding['x_features']
511 | y_feature = encoding['y_features']
512 |
513 | v_bar = self.visual_feature(image)
514 | t_bar = self.language_feature(language)
515 |
516 | v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature)
517 |
518 | return v_bar, t_bar, v_bar_s, t_bar_s
519 |
520 |
521 |
522 | class DocFormer(nn.Module):
523 |
524 | '''
525 | Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function
526 | '''
527 | def __init__(self, config):
528 | super().__init__()
529 | self.config = config
530 | self.extract_feature = ExtractFeatures(config)
531 | self.encoder = DocFormerEncoder(config)
532 | self.dropout = nn.Dropout(config['hidden_dropout_prob'])
533 |
534 | def forward(self, x ):
535 | v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi)
536 | features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s}
537 | output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s'])
538 | output = self.dropout(output)
539 | return output
540 |
541 |
542 |
543 |
544 |
545 |
546 |
--------------------------------------------------------------------------------
/src/docformer/modeling_pl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision.models as models
10 | from PIL import Image
11 | from sklearn.model_selection import train_test_split as tts
12 | from torch.autograd import Variable
13 | from torch.utils.data import DataLoader, Dataset
14 | from torchvision.transforms import ToTensor
15 | from transformers import AutoModel, AutoTokenizer
16 | from modelling import *
17 |
18 | """## Base Dataset"""
19 |
20 | device = "cuda" if torch.cuda.is_available() else "cpu"
21 |
22 | """## Base Model"""
23 |
24 | config = {
25 | "coordinate_size": 96,
26 | "hidden_dropout_prob": 0.1,
27 | "hidden_size": 768,
28 | "image_feature_pool_shape": [7, 7, 256],
29 | "intermediate_ff_size_factor": 3, # default ought to be 4
30 | "max_2d_position_embeddings": 1024,
31 | "max_position_embeddings": 512,
32 | "max_relative_positions": 8,
33 | "num_attention_heads": 12,
34 | "num_hidden_layers": 12,
35 | "pad_token_id": 0,
36 | "shape_size": 96,
37 | "vocab_size": 30522,
38 | "layer_norm_eps": 1e-12,
39 | "batch_size":9
40 | }
41 |
42 |
43 | class Model(pl.LightningModule):
44 |
45 | def __init__(self,config,num_classes,lr = 5e-5):
46 |
47 | super().__init__()
48 | self.save_hyperparameters()
49 | self.docformer = DocFormerForClassification(config,num_classes)
50 |
51 | def forward(self,x):
52 | return self.docformer(x)
53 |
54 | def training_step(self,batch,batch_idx):
55 |
56 | # For the purpose of pretraining, there could be multiple target outputs, so therefore we need to add additional loss function, as for an image_fp, if the MLM + IR is to be done
57 | # then, there could be a dictionary as an output, and then we need to define two criterion as CrossEntropy and L1 loss, and add the weighted sum of them as the total loss
58 | # and proceed forward, and for the whole process, only the final head of the DocFormer encoder needs to be changed, and thats it
59 |
60 |
61 | # Currently, we are performing only the MLM Part
62 | logits = self.forward(batch)
63 | criterion = torch.nn.CrossEntropyLoss()
64 | loss = criterion(logits.transpose(2,1), batch["mlm_labels"].long())
65 | self.log("train_loss",loss,prog_bar = True)
66 |
67 | def validation_step(self, batch, batch_idx):
68 |
69 | logits = self.forward(batch)
70 | b,size,classes = logits.shape
71 | criterion = torch.nn.CrossEntropyLoss()
72 | loss = criterion(logits.transpose(2,1), batch["mlm_labels"].long())
73 | val_acc = 100*(torch.argmax(logits,dim = -1)==batch["mlm_labels"].long()).float().sum()/(logits.shape[0]*logits.shape[1])
74 | val_acc = torch.tensor(val_acc)
75 | self.log("val_loss", loss, prog_bar=True)
76 | self.log("val_acc", val_acc, prog_bar=True)
77 |
78 | def configure_optimizers(self):
79 | return torch.optim.AdamW(self.parameters(), lr=self.hparams["lr"])
80 |
81 | """## Examples"""
82 |
83 | # pathToPickleFile = 'RVL-CDIP-PickleFiles/'
84 | # entries = os.listdir(pathToPickleFile)
85 | # data = DataModule(train_entries,val_entries,pathToPickleFile)
86 | # model = Model(config,num_classes= 30522).to(device)
87 | # trainer = pl.Trainer(gpus=(1 if torch.cuda.is_available() else 0),
88 | # max_epochs=10,
89 | # fast_dev_run=False,
90 | # logger=pl.loggers.TensorBoardLogger("logs/", name="rvl-cdip", version=1),)
91 | # trainer.fit(model,data)
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src/docformer/train_accelerator.py:
--------------------------------------------------------------------------------
1 | ## Dependencies
2 |
3 | from accelerate import Accelerator
4 | import accelerate
5 | import pytesseract
6 | import torchmetrics
7 | import math
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset, DataLoader
11 | import pandas as pd
12 | from PIL import Image
13 | import json
14 | import numpy as np
15 | from tqdm.auto import tqdm
16 | from torchvision.transforms import ToTensor
17 | import torch.nn.functional as F
18 | import torch.nn as nn
19 | import torchvision.models as models
20 | from einops import rearrange
21 | from einops import rearrange as rearr
22 | from sklearn.model_selection import train_test_split as tts
23 | from torch.autograd import Variable
24 | from torch.utils.data import DataLoader, Dataset
25 | from torchvision.transforms import ToTensor
26 | from modeling import DocFormer
27 |
28 | batch_size = 9
29 |
30 | class AverageMeter(object):
31 | """Computes and stores the average and current value"""
32 | def __init__(self):
33 | self.reset()
34 |
35 | def reset(self):
36 | self.val = 0
37 | self.sum = 0
38 | self.count = 0
39 |
40 | def update(self, val, n=1):
41 | self.val = val
42 | self.sum += val * n
43 | self.count += n
44 |
45 | @property
46 | def avg(self):
47 | return (self.sum / self.count) if self.count>0 else 0
48 |
49 | ## Loggers
50 | class Logger:
51 | def __init__(self, filename, format='csv'):
52 | self.filename = filename + '.' + format
53 | self._log = []
54 | self.format = format
55 |
56 | def save(self, log, epoch=None):
57 | log['epoch'] = epoch + 1
58 | self._log.append(log)
59 | if self.format == 'json':
60 | with open(self.filename, 'w') as f:
61 | json.dump(self._log, f)
62 | else:
63 | pd.DataFrame(self._log).to_csv(self.filename, index=False)
64 |
65 |
66 |
67 | def train_fn(data_loader, model, criterion, optimizer, epoch, device, scheduler=None):
68 | model.train()
69 | accelerator = Accelerator()
70 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
71 | loop = tqdm(data_loader, leave=True)
72 | log = None
73 | train_acc = torchmetrics.Accuracy()
74 | loop = tqdm(data_loader)
75 |
76 | for batch in loop:
77 |
78 | input_ids = batch["input_ids"].to(device)
79 | attention_mask = batch["attention_mask"].to(device)
80 | labels = batch["mlm_labels"].to(device)
81 |
82 | # process
83 | outputs = model(batch)
84 | ce_loss = criterion(outputs.transpose(1,2), labels)
85 |
86 | if log is None:
87 | log = {}
88 | log["ce_loss"] = AverageMeter()
89 | log['accuracy'] = AverageMeter()
90 |
91 | optimizer.zero_grad()
92 | accelerator.backward(ce_loss)
93 | optimizer.step()
94 |
95 | if scheduler is not None:
96 | scheduler.step()
97 |
98 | log['accuracy'].update(train_acc(labels.cpu(),torch.argmax(outputs,-1).cpu()).item(),batch_size)
99 | log['ce_loss'].update(ce_loss.item())
100 | loop.set_postfix({k: v.avg for k, v in log.items()})
101 |
102 | return log
103 |
104 |
105 | # Function for the validation data loader
106 | def eval_fn(data_loader, model, criterion, device):
107 | model.eval()
108 | log = None
109 | val_acc = torchmetrics.Accuracy()
110 |
111 |
112 | with torch.no_grad():
113 | loop = tqdm(data_loader, total=len(data_loader), leave=True)
114 | for batch in loop:
115 |
116 | input_ids = batch["input_ids"].to(device)
117 | attention_mask = batch["attention_mask"].to(device)
118 | labels = batch["mlm_labels"].to(device)
119 | output = model(batch)
120 | ce_loss = criterion(output.transpose(1,2), labels)
121 |
122 | if log is None:
123 | log = {}
124 | log["ce_loss"] = AverageMeter()
125 | log['accuracy'] = AverageMeter()
126 |
127 | log['accuracy'].update(val_acc(labels.cpu(),torch.argmax(output,-1).cpu()).item(),batch_size)
128 | log['ce_loss'].update(ce_loss.item())
129 | loop.set_postfix({k: v.avg for k, v in log.items()})
130 | return log # ['total_loss']
131 |
132 | date = '20Oct'
133 |
134 |
135 | def run(config,train_dataloader,val_dataloader,device,epochs,path,classes,lr = 5e-5):
136 | logger = Logger(f"{path}/logs")
137 | model = DocFormerForClassification(config,classes).to(device)
138 | criterion = nn.CrossEntropyLoss()
139 | criterion = criterion.to(device)
140 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
141 | best_val_loss = 1e9
142 | header_printed = False
143 | batch_size = config['batch_size']
144 | for epoch in range(epochs):
145 | print("Training the model.....")
146 | train_log = train_fn(
147 | train_dataloader, model, criterion, optimizer, epoch, device, scheduler=None
148 | )
149 |
150 | print("Validating the model.....")
151 | valid_log = eval_fn(val_dataloader, model, criterion, device)
152 | log = {k: v.avg for k, v in train_log.items()}
153 | log.update({"V/" + k: v.avg for k, v in valid_log.items()})
154 | logger.save(log, epoch)
155 | keys = sorted(log.keys())
156 | if not header_printed:
157 | print(" ".join(map(lambda k: f"{k[:8]:8}", keys)))
158 | header_printed = True
159 | print(" ".join(map(lambda k: f"{log[k]:8.3f}"[:8], keys)))
160 | if log["V/ce_loss"] < best_val_loss:
161 | best_val_loss = log["V/ce_loss"]
162 | print("Best model found at epoch {}".format(epoch + 1))
163 | torch.save(model.state_dict(), f"{path}/docformer_best_{epoch}_{date}.pth")
164 |
--------------------------------------------------------------------------------
/src/docformer/train_accelerator_mlm_ir.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 |
5 | ## Dependencies
6 |
7 | import pickle
8 | import os
9 | from accelerate import Accelerator
10 | import accelerate
11 | import pytesseract
12 | import torchmetrics
13 | import math
14 | import numpy as np
15 | import torch
16 | from torch.utils.data import Dataset, DataLoader
17 | import pandas as pd
18 | from PIL import Image
19 | import json
20 | import numpy as np
21 | from tqdm.auto import tqdm
22 | from torchvision.transforms import ToTensor
23 | import torch.nn.functional as F
24 | import torch.nn as nn
25 | import torchvision.models as models
26 | from einops import rearrange
27 | from einops import rearrange as rearr
28 | from PIL import Image
29 | from sklearn.model_selection import train_test_split as tts
30 | from torch.autograd import Variable
31 | from torch.utils.data import DataLoader, Dataset
32 | from torchvision.transforms import ToTensor
33 | from modeling import *
34 |
35 | batch_size = 9
36 |
37 |
38 | weights = {'mlm':5,'ir':1,'tdi':5}
39 |
40 | class AverageMeter(object):
41 | """Computes and stores the average and current value"""
42 | def __init__(self):
43 | self.reset()
44 |
45 | def reset(self):
46 | self.val = 0
47 | self.sum = 0
48 | self.count = 0
49 |
50 | def update(self, val, n=1):
51 | self.val = val
52 | self.sum += val * n
53 | self.count += n
54 |
55 | @property
56 | def avg(self):
57 | return (self.sum / self.count) if self.count>0 else 0
58 |
59 | ## Loggers
60 | class Logger:
61 | def __init__(self, filename, format='csv'):
62 | self.filename = filename + '.' + format
63 | self._log = []
64 | self.format = format
65 |
66 | def save(self, log, epoch=None):
67 | log['epoch'] = epoch + 1
68 | self._log.append(log)
69 | if self.format == 'json':
70 | with open(self.filename, 'w') as f:
71 | json.dump(self._log, f)
72 | else:
73 | pd.DataFrame(self._log).to_csv(self.filename, index=False)
74 |
75 |
76 |
77 | def train_fn(data_loader, model, criterion1,criterion2, optimizer, epoch, device, scheduler=None,weights=weights):
78 | model.train()
79 | accelerator = Accelerator()
80 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
81 | loop = tqdm(data_loader, leave=True)
82 | log = None
83 | train_acc = torchmetrics.Accuracy()
84 | loop = tqdm(data_loader)
85 |
86 | for batch in loop:
87 |
88 | input_ids = batch["input_ids"].to(device)
89 | attention_mask = batch["attention_mask"].to(device)
90 | labels1 = batch["mlm_labels"].to(device)
91 | labels2 = batch['resized_image'].to(device)
92 |
93 | # process
94 | outputs = model(batch)
95 | ce_loss = criterion1(outputs['mlm_labels'].transpose(1,2), labels1)
96 | ir_loss = criterion2(outputs['ir'],labels2)
97 |
98 | if log is None:
99 | log = {}
100 | log["ce_loss"] = AverageMeter()
101 | log['accuracy'] = AverageMeter()
102 | log['ir_loss'] = AverageMeter()
103 | log['total_loss'] = AverageMeter()
104 |
105 | total_loss = weights['mlm']*ce_loss + weights['ir']*ir_loss
106 | optimizer.zero_grad()
107 | accelerator.backward(total_loss)
108 | optimizer.step()
109 |
110 | if scheduler is not None:
111 | scheduler.step()
112 |
113 | log['accuracy'].update(train_acc(labels1.cpu(),torch.argmax(outputs['mlm_labels'],-1).cpu()).item(),batch_size)
114 | log['ce_loss'].update(ce_loss.item(),batch_size)
115 | log['ir_loss'].update(ir_loss.item(),batch_size)
116 | log['total_loss'].update(total_loss.item(),batch_size)
117 | loop.set_postfix({k: v.avg for k, v in log.items()})
118 |
119 | return log
120 |
121 |
122 | # Function for the validation data loader
123 | def eval_fn(data_loader, model, criterion1,criterion2, device,weights=weights):
124 | model.eval()
125 | log = None
126 | val_acc = torchmetrics.Accuracy()
127 |
128 |
129 | with torch.no_grad():
130 | loop = tqdm(data_loader, total=len(data_loader), leave=True)
131 | for batch in loop:
132 |
133 | input_ids = batch["input_ids"].to(device)
134 | attention_mask = batch["attention_mask"].to(device)
135 | labels1 = batch["mlm_labels"].to(device)
136 | labels2 = batch['resized_image'].to(device)
137 |
138 |
139 | output = model(batch)
140 | ce_loss = criterion1(output['mlm_labels'].transpose(1,2), labels1)
141 | ir_loss = criterion2(output['ir'],labels2)
142 | total_loss = weights['mlm']*ce_loss + weights['ir']*ir_loss
143 | if log is None:
144 | log = {}
145 | log["ce_loss"] = AverageMeter()
146 | log['accuracy'] = AverageMeter()
147 | log['ir_loss'] = AverageMeter()
148 | log['total_loss'] = AverageMeter()
149 |
150 | log['accuracy'].update(val_acc(labels1.cpu(),torch.argmax(output['mlm_labels'],-1).cpu()).item(),batch_size)
151 | log['ce_loss'].update(ce_loss.item(),batch_size)
152 | log['ir_loss'].update(ir_loss.item(),batch_size)
153 | log['total_loss'].update(total_loss.item(),batch_size)
154 | loop.set_postfix({k: v.avg for k, v in log.items()})
155 |
156 | return log # ['total_loss']
157 |
158 | date = '26Oct'
159 |
160 |
161 | def run(config,train_dataloader,val_dataloader,device,epochs,path,classes,lr = 5e-5,weights=weights):
162 | logger = Logger(f"{path}/logs")
163 | model = DocFormer_For_IR(config,classes).to(device)
164 | criterion1 = nn.CrossEntropyLoss().to(device)
165 | criterion2 = torch.nn.L1Loss().to(device)
166 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
167 | best_val_loss = 1e9
168 | header_printed = False
169 | batch_size = config['batch_size']
170 | for epoch in range(epochs):
171 | print("Training the model.....")
172 | train_log = train_fn(
173 | train_dataloader, model, criterion1,criterion2, optimizer, epoch, device, scheduler=None
174 | )
175 |
176 | print("Validating the model.....")
177 | valid_log = eval_fn(val_dataloader, model, criterion1,criterion2, device)
178 | log = {k: v.avg for k, v in train_log.items()}
179 | log.update({"V/" + k: v.avg for k, v in valid_log.items()})
180 | logger.save(log, epoch)
181 | keys = sorted(log.keys())
182 | if not header_printed:
183 | print(" ".join(map(lambda k: f"{k[:8]:8}", keys)))
184 | header_printed = True
185 | print(" ".join(map(lambda k: f"{log[k]:8.3f}"[:8], keys)))
186 | if log["V/total_loss"] < best_val_loss:
187 | best_val_loss = log["V/total_loss"]
188 | print("Best model found at epoch {}".format(epoch + 1))
189 | torch.save(model.state_dict(), f"{path}/docformer_best_{epoch}_{date}.pth")
190 |
--------------------------------------------------------------------------------
/src/docformer/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Utility File
3 |
4 | Created basically for the purpose of defining the labels for the unsupervised task of "Text Describe Image", as in the paper DocFormer
5 | '''
6 |
7 | import numpy as np
8 | import random
9 | import json
10 | import os
11 |
12 | def labels_for_tdi(length,sample_ratio=0.05):
13 |
14 | '''
15 |
16 | Function for providing the labels for the task of
17 | Text Describes Image
18 |
19 | Input:
20 | * Length of the dataset, i.e the total number of data points
21 | * sample_ratio: The percentage of the total length, which you want to shuffle
22 |
23 | Output:
24 | * d_arr (dummy array): The array which contains the indexes of the images
25 | * Labels: Whether the image has been shuffled with some other image or not
26 |
27 |
28 | Example:
29 | d_arr,labels = labels_for_tdi(100)
30 |
31 | Explanation:
32 | Suppose, the array is [1,2,3,4,5], so, the steps are as follows:
33 |
34 | * Choose some arbitrary values, (refer the samples_to_be_changed variable) (let us suppose [2,4])
35 | * Generate the permutation of the same, and replace the arbitary values with their permutations (one permutation can be [4,2], and hence the array becomes
36 | [1,4,3,2,5]
37 | * And then, if the original arr and the d_arr's arguments matches, put a label of 1, else put 0, hence the labels array becomes [1,0,1,0,1]
38 |
39 |
40 | The purpose of returning d_arr is, because the d_arr[i] would be the argument, which is responsible for becoming the d_resized_scaled_img of ith encoding dictinary vector
41 |
42 | i.e if d_arr[i] == i (means not shuffled), the d_resized_scaled_img of ith entry would be same else resized_scaled_img,
43 | else d_sized_scaled_img[i] = resized_scaled_img[d_arr[i]]
44 |
45 | '''
46 | samples_to_be_changed = int(sample_ratio*length)
47 | arr = np.arange(length)
48 | d_arr = arr.copy()
49 | labels = np.ones(length)
50 | sample_id = np.array(random.sample(list(arr), samples_to_be_changed))
51 | new_sample_id = np.random.permutation(sample_id)
52 | d_arr[sample_id]=new_sample_id
53 | labels = (arr==d_arr).astype(int)
54 |
55 | return d_arr,labels
56 |
57 |
58 | ## Purpose: Reading the json file from the path and return the dictionary
59 | def load_json_file(file_path):
60 | with open(file_path, 'r') as f:
61 | data = json.load(f)
62 | return data
63 |
64 | ## Purpose: Getting the address of specific file type, eg: .pdf, .tif, so and so
65 | def get_specific_file(path, last_entry = 'tif'):
66 | base_path = path
67 | for i in os.listdir(path):
68 | if i.endswith(last_entry):
69 | return os.path.join(base_path, i)
70 |
71 | return '-1'
72 |
73 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shabie/docformer/c7fcfdf71fb174784c3dba932b0f0daa6f05a92f/tests/__init__.py
--------------------------------------------------------------------------------