├── .gitignore ├── LICENCE ├── README.md ├── configs ├── base.yaml ├── models │ ├── e2e.yaml │ ├── edge.yaml │ └── gcn.yaml └── train.yaml ├── model.png ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── download.py │ ├── feature_builder.py │ ├── graph_builder.py │ ├── preprocessing.py │ └── utils.py ├── inference.py ├── main.py ├── models │ ├── __init__.py │ ├── graphs.py │ └── unet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py ├── paths.py ├── training │ ├── __init__.py │ ├── funsd.py │ ├── pau.py │ └── utils.py └── utils.py └── tutorial └── kie.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # system 2 | *.egg-info/ 3 | *__pycache__/ 4 | *.py[cod] 5 | *.eggs 6 | build 7 | .vscode 8 | cc.* 9 | *.DS_Store 10 | 11 | #user-defined 12 | DATA 13 | _old 14 | *.env 15 | outputs 16 | FUDGE 17 | configs/preprocessing.yaml 18 | esempi 19 | examples 20 | src/models/yolov5 21 | src/models/checkpoints 22 | prova* 23 | inference -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | Copyright (c) 2022 4 | Andrea Gemelli 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

`Doc2Graph`

2 | 3 | ![model](model.png) 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/doc2graph-a-task-agnostic-document/entity-linking-on-funsd)](https://paperswithcode.com/sota/entity-linking-on-funsd?p=doc2graph-a-task-agnostic-document) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/doc2graph-a-task-agnostic-document/semantic-entity-labeling-on-funsd)](https://paperswithcode.com/sota/semantic-entity-labeling-on-funsd?p=doc2graph-a-task-agnostic-document) 6 | 7 | ![Python](https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54) ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) 8 | 9 | This library is the implementation of the paper [Doc2Graph: a Task Agnostic Document Understanding Framework based on Graph Neural Networks](https://arxiv.org/abs/2208.11168), accepted at [TiE @ ECCV 2022](https://sites.google.com/view/tie-eccv2022/accepted-papers?authuser=0). 10 | 11 | The model and pipeline aims at being task-agnostic on the domain of Document Understanding. It is an ongoing project, these are the steps already achieved and the ones we would like to implement in the future: 12 | 13 | - [x] Build a model based on GNNs to spot key-value relationships on forms 14 | - [x] Publish the preliminary results and the code 15 | - [x] Extend the framework to other document-related tasks 16 | - [x] Business documents Layout Analysis 17 | - [x] Table Detection 18 | - [ ] Let the user train Doc2Graph over private / other datasets using our dataloader 19 | - [ ] Transform Doc2Graph into a PyPI package 20 | 21 | Index: 22 | - [`Doc2Graph`](#doc2graph) 23 | - [News!](#news) 24 | - [Environment Setup](#environment-setup) 25 | - [Training](#training) 26 | - [Testing](#testing) 27 | - [Cite this project](#cite-this-project) 28 | 29 | --- 30 | ## News! 31 | - 🔥 Added **inference** method: you can now use Doc2Graph directly on your documents simply passing a path to them!
This call will output an image with the connected entities and a json / dictionary with all the useful information you need! 🤗 32 | ``` 33 | python src/main.py -addG -addT -addE -addV --gpu 0 --weights e2e-funsd-best.pt --inference --docs /path/to/document 34 | ``` 35 | 36 | - 🔥 Added **tutorial** folder: get to know how to use Doc2Graph from the tutorial notebooks! 37 | 38 | ## Environment Setup 39 | Setup the initial conda environment 40 | 41 | ``` 42 | conda create -n doc2graph python=3.9 ipython cudatoolkit=11.3 -c anaconda && 43 | conda activate doc2graph && 44 | cd doc2graph 45 | ``` 46 | 47 | Then, install [setuptools-git-versioning](https://pypi.org/project/setuptools-git-versioning/) and doc2graph package itself. The following has been tested only on linux: for different OS installations refer directly to [PyTorch](https://pytorch.org/get-started/previous-versions/) and [DGL](https://www.dgl.ai/pages/start.html) original documentation. 48 | 49 | ``` 50 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url && 51 | https://download.pytorch.org/whl/cu113 && 52 | pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html && 53 | pip install setuptools-git-versioning && pip install -e . && 54 | pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.3.0/en_core_web_lg-3.3.0.tar.gz 55 | ``` 56 | 57 | Finally, create the project folder structure and download data: 58 | 59 | ``` 60 | python src/main.py --init 61 | ``` 62 | The script will download and setup: 63 | - FUNSD and the 'adjusted_annotations' for FUNSD[^1] are given by the work of[^3]. 64 | - The yolo detection bbox described in the paper (If you want to use YOLOv5-small to detect entities, script in `notebooks/YOLO.ipynb`, refer to [their github](https://github.com/ultralytics/yolov5) for the installation. Clone the repository into `src/models/yolov5`). 65 | - The Pau Riba's[^2] dataset with our train / test split. 66 | 67 | [^1]: G. Jaume et al., FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents, ICDARW 2019 68 | [^2]: P. Riba et al, Table Detection in Invoice Documents by Graph Neural Networks, ICDAR 2019 69 | [^3]: Hieu M. Vu et al., REVISING FUNSD DATASET FOR KEY-VALUE DETECTION IN DOCUMENT IMAGES, arXiv preprint 2020 70 | 71 | **Checkpoints** 72 | You can download our model checkpoints [here](https://drive.google.com/file/d/15jKWYLTcb8VwE7XY_jcRvZTAmqy_ElJ_/view?usp=sharing). Place them into `src/models/checkpoints`. 73 | 74 | --- 75 | ## Training 76 | 1. To train our **Doc2Graph** model (using CPU) use: 77 | ``` 78 | python src/main.py [SETTINGS] 79 | ``` 80 | 2. Instead, to test a trained **Doc2Graph** model (using GPU) [weights can be one or more file]: 81 | ``` 82 | python src/main.py [SETTINGS] --gpu 0 --test --weights *.pt 83 | ``` 84 | The project can be customized either changing directly `configs/base.yaml` file or providing these flags when calling `src/main.py`. 85 | 86 | **Features** 87 | - --add-geom: bool (to add positional features to graph nodes) 88 | - --add-embs: bool (to add textual features to graph nodes) 89 | - --add-hist: bool (to add visual features to graph nodes) 90 | - --add-visual: bool (to add visual features to graph nodes) 91 | - --add-eweights: bool (to add polar relative coordinates between nodes to graph edges) 92 | 93 | **Data** 94 | - --src-data: string [FUNSD, PAU or CUSTOM] (CUSTOM still under dev) 95 | - --src-type: string [img, pdf] (if src_data is CUSTOM, still under dev) 96 | 97 | **Graphs** 98 | - --edge-type: string [fully, knn] (to change the kind of connectivity) 99 | - --node-granularity: string [gt, yolo, ocr] (choose the granularity of nodes to be used, gt (if given), ocr (words) or yolo (entities)) 100 | - --num-polar-bins: int [Default 8] (number of bins into which discretize the space for edge polar features. It must be a power of 2) 101 | 102 | **Inference (only for KiE)** 103 | - --inference: bool (run inference on given document/s path/s) 104 | - --docs: list (list your absolute path to your document) 105 | 106 | Change directly `configs/train.yaml` for training settings or pass these flags to `src/main.py`. To create your own model (changing hyperparams) copy `configs/models/*.yaml`. 107 | 108 | **Training/Testing** 109 | - --model: string [e2e, edge, gcn] (which model to use, which yaml file to load) 110 | - --gpu: int [Default -1] (which GPU to use. Set -1 to use CPU( 111 | - --test: true / false (skip training if true) 112 | - --weights: strin(s) (provide weight file(s) relative path(s), if testing) 113 | 114 | ## Testing 115 | 116 | You can use our pretrained models over the test sets of FUNSD[^1] and Pau Riba's[^2] datasets. 117 | 118 | 1. On FUNSD we were able to perform both Semantic Entity Labeling and Entity Linking: 119 | 120 | **E2E-FUNSD-GT**: 121 | ``` 122 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-funsd-best.pt 123 | ``` 124 | 125 | **E2E-FUNSD-YOLO**: 126 | ``` 127 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-funsd-best.pt --node-granularity yolo 128 | ``` 129 | 130 | 2. on Pau Riba's dataset, we were able to perform both Layout Analysis and Table Detection 131 | 132 | **E2E-PAU**: 133 | ``` 134 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-pau-best.pt --src-data PAU --edge-type knn 135 | ``` 136 | 137 | --- 138 | ## Cite this project 139 | If you want to use our code in your project(s), please cite us: 140 | ``` 141 | @InProceedings{10.1007/978-3-031-25069-9_22, 142 | author="Gemelli, Andrea 143 | and Biswas, Sanket 144 | and Civitelli, Enrico 145 | and Llad{\'o}s, Josep 146 | and Marinai, Simone", 147 | editor="Karlinsky, Leonid 148 | and Michaeli, Tomer 149 | and Nishino, Ko", 150 | title="Doc2Graph: A Task Agnostic Document Understanding Framework Based on Graph Neural Networks", 151 | booktitle="Computer Vision -- ECCV 2022 Workshops", 152 | year="2023", 153 | publisher="Springer Nature Switzerland", 154 | address="Cham", 155 | pages="329--344", 156 | abstract="Geometric Deep Learning has recently attracted significant interest in a wide range of machine learning fields, including document analysis. The application of Graph Neural Networks (GNNs) has become crucial in various document-related tasks since they can unravel important structural patterns, fundamental in key information extraction processes. Previous works in the literature propose task-driven models and do not take into account the full power of graphs. We propose Doc2Graph, a task-agnostic document understanding framework based on a GNN model, to solve different tasks given different types of documents. We evaluated our approach on two challenging datasets for key information extraction in form understanding, invoice layout analysis and table detection. Our code is freely accessible on https://github.com/andreagemelli/doc2graph.", 157 | isbn="978-3-031-25069-9" 158 | } 159 | ``` 160 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | # THIS YAML SETTING IS USED AS A BASE TO PRODUCE THE 'PREPROCESSING.YAML' ONE. 2 | # THIS IT IS MEANT TO BE USED AS A STARTING POINT AND DOCUMENTATION ONE (see all the different choises a user has). 3 | 4 | LOADER: 5 | src_data: 6 | - FUNSD 7 | # - PAU 8 | # - CUSTOM 9 | FEATURES: 10 | add_embs: true 11 | add_eweights: true 12 | add_fudge: false 13 | add_geom: true 14 | add_hist: false 15 | add_visual: true 16 | GRAPHS: 17 | edge_type: 18 | - fully 19 | # - knn 20 | data_type: 21 | - img 22 | # - pdf 23 | node_granularity: 24 | - gt 25 | # - ocr 26 | # - yolo 27 | -------------------------------------------------------------------------------- /configs/models/e2e.yaml: -------------------------------------------------------------------------------- 1 | name: E2E 2 | dropout: 0.2 3 | hidden_dim: 300 4 | out_chunks: 300 5 | num_layers: 1 6 | doProject: True -------------------------------------------------------------------------------- /configs/models/edge.yaml: -------------------------------------------------------------------------------- 1 | name: EDGE 2 | dropout: 0.2 3 | num_layers: 1 4 | hidden_dim: 300 5 | out_chunks: 300 6 | doProject: True -------------------------------------------------------------------------------- /configs/models/gcn.yaml: -------------------------------------------------------------------------------- 1 | name: GCN 2 | dropout: 0.0 3 | hidden_dim: 64 4 | num_layers: 3 5 | attn: False 6 | out_chunks: 100 -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 15 2 | epochs: 100000 3 | lr: 1e-3 4 | weight_decay: 1e-4 5 | val_size : 0.1 6 | optimizer: Adam # or AdamW, SGD - NOT YET IMPLEMENTED 7 | scheduler: 8 | - ReduceLROnPlateau # CosineAnnealingLR, None- NOT YET IMPLEMENTED 9 | stopper_metric: acc # loss or acc 10 | seed: 42 11 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict==2.0.1 2 | itsdangerous==2.0.1 3 | numpy>=1.22 4 | Pillow>=9.1.1 5 | PyYAML==6.0 6 | scikit_learn==1.0.2 7 | setuptools==65.5.1 8 | setuptools_git_versioning==1.9.2 9 | spacy==3.3.0 10 | pydantic==1.8.2 11 | python-dotenv==0.20.0 12 | gitpython>=3.1.30 13 | gitdb2==3.0.1 14 | matplotlib==3.5.2 15 | segmentation-models-pytorch==0.2.1 16 | fasttext==0.9.2 17 | networkx==2.8.4 18 | tensorboard==2.9.1 19 | opencv-python==4.6.0.66 20 | pytesseract==0.3.9 21 | wget==3.2 22 | easyocr==1.6.2 23 | pandas 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from setuptools_git_versioning import version_from_git 3 | import os 4 | 5 | HERE = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | def parse_requirements(file_content): 8 | lines = file_content.splitlines() 9 | return [line.strip() for line in lines if line and not line.startswith("#")] 10 | 11 | with open(os.path.join(HERE, "requirements.txt")) as f: 12 | requirements = parse_requirements(f.read()) 13 | 14 | with open("src/root.env", "w") as f: 15 | f.write(f"ROOT = '{HERE}'") 16 | 17 | setup( 18 | name='doc2graph', 19 | version=version_from_git(), 20 | packages=['doc2graph'], 21 | package_dir={'doc2graph': 'src'}, 22 | description='Repo to transform Documents to Graphs, performing several tasks on them', 23 | author='Andrea Gemelli', 24 | license='MIT', 25 | keywords="document analysis graph", 26 | setuptools_git_versioning={ 27 | "enabled": True, 28 | }, 29 | python_requires=">=3.7", 30 | setup_requires=["setuptools_git_versioning"], 31 | install_requires=requirements, 32 | ) 33 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.training import * 2 | from src.data import * 3 | from src.utils import * -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | from typing import Tuple 3 | import torch 4 | import torch.utils.data as data 5 | import os 6 | import numpy as np 7 | import dgl 8 | from PIL import Image, ImageDraw 9 | 10 | from src.data.feature_builder import FeatureBuilder 11 | from src.data.graph_builder import GraphBuilder 12 | from src.utils import get_config 13 | 14 | class Document2Graph(data.Dataset): 15 | """This class convert documents (both images or pdfs) into graph structures. 16 | """ 17 | 18 | def __init__(self, name : str, src_path : str, device = str, output_dir = str): 19 | """ 20 | Args: 21 | name (str): should be one of the following: ['gt', 'img', 'pdf'] 22 | src_path (str): path to folder containing documents 23 | device (str): device to use. can be 'cpu' or 'cuda:n' 24 | output_dir (str): where to save printed graphs examples 25 | """ 26 | 27 | # initialize class 28 | if not os.path.isdir(src_path): raise Exception(f'src_path {src_path} does not exists\n -> please provide an existing path') 29 | 30 | self.name = name 31 | self.src_path = src_path 32 | self.cfg_preprocessing = get_config('preprocessing') 33 | self.src_data = self.cfg_preprocessing.LOADER.src_data 34 | self.GB = GraphBuilder() 35 | self.FB = FeatureBuilder(device) 36 | self.output_dir = output_dir 37 | 38 | # TODO: DO A DIFFERENT FILE 39 | self.COLORS = {'invoice_info': (150, 75, 0), 'receiver':(0,100,0), 'other':(128, 128, 128), 'supplier': (255, 0, 255), 'positions':(255,140,0), 'total':(0, 255, 255)} 40 | 41 | # get graphs 42 | self.graphs, self.node_labels, self.edge_labels, self.paths = self.__docs2graphs() 43 | 44 | # LABELS to numeric value 45 | # NODES 46 | if self.node_labels: 47 | self.node_unique_labels = np.unique(np.array([l for nl in self.node_labels for l in nl])) 48 | self.node_num_classes = len(self.node_unique_labels) 49 | self.node_num_features = self.graphs[0].ndata['feat'].shape[1] 50 | 51 | for idx, labels in enumerate(self.node_labels): 52 | self.graphs[idx].ndata['label'] = torch.tensor([np.where(target == self.node_unique_labels)[0][0] for target in labels], dtype=torch.int64) 53 | 54 | # EDGES 55 | if self.edge_labels: 56 | self.edge_unique_labels = np.unique(self.edge_labels[0]) 57 | self.edge_num_classes = len(self.edge_unique_labels) 58 | try: 59 | # TODO to be changed 60 | self.edge_num_features = self.graphs[0].edata['feat'].shape[1] 61 | except: 62 | self.edge_num_features = 0 63 | 64 | for idx, labels in enumerate(self.edge_labels): 65 | self.graphs[idx].edata['label'] = torch.tensor([np.where(target == self.edge_unique_labels)[0][0] for target in labels], dtype=torch.int64) 66 | 67 | def __getitem__(self, index: int) -> dgl.DGLGraph: 68 | """ Returns item (graph), given index 69 | 70 | Args: 71 | index (int): index of the item to be taken. 72 | """ 73 | return self.graphs[index] 74 | 75 | def __len__(self) -> int: 76 | """ Returns data length 77 | """ 78 | return len(self.graphs) 79 | 80 | def __docs2graphs(self) -> Tuple[list, list, list, list]: 81 | """It uses GraphBuilder and FeaturesBuilder objects to get graphs (and lables, if any) from source data. 82 | 83 | Returns: 84 | tuple (lists): DGLGraphs, nodes and edges label names, paths per each file 85 | """ 86 | graphs, node_labels, edge_labels, features = self.GB.get_graph(self.src_path, self.src_data) 87 | self.feature_chunks, self.num_mods = self.FB.add_features(graphs, features) 88 | return graphs, node_labels, edge_labels, features['paths'] 89 | 90 | def label2class(self, label : str, node=True) -> int: 91 | """ Transform a label (str) into its class number. 92 | 93 | Args: 94 | label (str): node or edge string label 95 | node (bool): either if taking node (true) ord edge (false) class number 96 | 97 | Returns: 98 | int: class number 99 | """ 100 | if node: 101 | return self.node_unique_labels[label] 102 | else: 103 | return self.edge_unique_labels[label] 104 | 105 | def get_info(self, num_graph=0) -> None: 106 | """ Print information regarding the data uploaded 107 | 108 | Args: 109 | num_graph (int): give one index to print one example graph information 110 | """ 111 | print(f"\n{self.name.upper()} dataset:\n-> graphs: {len(self.graphs)}\n-> node labels: {self.node_unique_labels}\n-> edge labels: {self.edge_unique_labels}\n-> node features: {self.node_num_features}") 112 | self.GB.get_info() 113 | self.FB.get_info() 114 | print(f"-> graph example: {self.graphs[num_graph]}") 115 | return 116 | 117 | def balance(self, cls = 'none', indices = None) -> None: 118 | """ Calls balance_edges() of GraphBuilder. 119 | 120 | Args: 121 | cls (str): 122 | """ 123 | 124 | cls = int(np.where(cls == self.edge_unique_labels)[0][0]) 125 | if indices is None: 126 | for i, g in enumerate(self.graphs): 127 | self.graphs[i] = self.GB.balance_edges(g, self.edge_num_classes, cls = cls) 128 | else: 129 | for id in indices: 130 | self.graphs[id] = self.GB.balance_edges(self.graphs[id], self.edge_num_classes, cls = cls) 131 | 132 | return 133 | 134 | def get_chunks(self) -> list: 135 | """ get feature_chunks, meaning the length of different modalities (features) contributions inside nodes. 136 | 137 | Returns: 138 | feature_chunks (list) : list of feature chunks 139 | """ 140 | if len(self.feature_chunks) != self.num_mods: self.feature_chunks.pop(0) 141 | return self.feature_chunks 142 | 143 | def print_graph(self, num=None, node_labels=None, labels_ids=None, name='doc_graph', bidirect=True, regions=[], preds=None) -> Image: 144 | """ Print a given graph over its image document. 145 | 146 | Args: 147 | num (int): which graph / document to print 148 | node_labels (list): list of node labels 149 | labels_ids (list): list of labels ids to print 150 | name (str): name to give to output file 151 | bidirect (bool): either to print the graph bidirected or not 152 | regions (list): debug purposes for layout anaylis, if any available it prints it 153 | preds (list): if any, it prints model predictions 154 | 155 | Returns: 156 | graph_img (Image) : the drawn graph over the document 157 | """ 158 | if num is None: num = randint(0, self.__len__()-1) 159 | graph = self.graphs[num] 160 | graph_path = self.paths[num] 161 | graph_img = Image.open(graph_path).convert('RGB') 162 | if labels_ids is None: labels_ids = graph.edata['label'].nonzero().flatten().tolist() 163 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2) 164 | graph_draw = ImageDraw.Draw(graph_img) 165 | w, h = graph_img.size 166 | boxs = graph.ndata['geom'][:, :4].tolist() 167 | boxs = [[box[0]*w, box[1]*h, box[2]*w, box[3]*h] for box in boxs] 168 | 169 | if node_labels is not None: 170 | for b, box in enumerate(boxs): 171 | label = self.node_unique_labels[node_labels[b]] 172 | graph_draw.rectangle(box, outline=self.COLORS[label], width=2) 173 | else: 174 | for box in boxs: 175 | graph_draw.rectangle(box, outline='blue', width=2) 176 | 177 | for region in regions: 178 | color = self.COLORS[region[0]] 179 | graph_draw.rectangle(region[1], outline=color, width=4) 180 | 181 | if preds is not None: 182 | graph_draw.rectangle(preds, outline='green', width=4) 183 | 184 | u,v = graph.edges() 185 | for id in labels_ids: 186 | sc = center(boxs[u[id]]) 187 | ec = center(boxs[v[id]]) 188 | graph_draw.line((sc,ec), fill='violet', width=3) 189 | if bidirect: 190 | graph_draw.ellipse([(sc[0]-4,sc[1]-4), (sc[0]+4,sc[1]+4)], fill = 'green', outline='black') 191 | graph_draw.ellipse([(ec[0]-4,ec[1]-4), (ec[0]+4,ec[1]+4)], fill = 'red', outline='black') 192 | 193 | graph_img.save(self.output_dir / f'{name}.png') 194 | return graph_img 195 | -------------------------------------------------------------------------------- /src/data/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import requests 4 | import zipfile 5 | import wget 6 | from src.utils import create_folder 7 | 8 | from src.paths import CHECKPOINTS, DATA 9 | 10 | def download_url(url, save_path, chunk_size=128): 11 | r = requests.get(url, stream=True) 12 | with open(save_path, 'wb') as fd: 13 | for chunk in r.iter_content(chunk_size=chunk_size): 14 | fd.write(chunk) 15 | 16 | def funsd(): 17 | print("Downloading FUNSD") 18 | 19 | dlz = DATA / 'funsd.zip' 20 | download_url("https://guillaumejaume.github.io/FUNSD/dataset.zip", dlz) 21 | with zipfile.ZipFile(dlz, 'r') as zip_ref: 22 | zip_ref.extractall(DATA) 23 | os.remove(dlz) 24 | os.rename(DATA / 'dataset', DATA / 'FUNSD') 25 | 26 | # adjusted annotations 27 | aa_train = os.path.join(DATA / 'adjusted_annotations.zip') 28 | wget.download(url="https://docs.google.com/uc?export=download&id=1cQE2dnLGh93u3xMUeGQRXM81tF2l0Zut", out=aa_train) 29 | with zipfile.ZipFile(aa_train, 'r') as zip_ref: 30 | zip_ref.extractall(DATA / 'FUNSD/training_data') 31 | os.remove(aa_train) 32 | 33 | aa_test = os.path.join(DATA / 'adjusted_annotations.zip') 34 | wget.download(url="https://docs.google.com/uc?export=download&id=18LXbRhjnkdsAvWBFhUdr7_44bfaBo0-v", out=aa_test) 35 | with zipfile.ZipFile(aa_test, 'r') as zip_ref: 36 | zip_ref.extractall(DATA / 'FUNSD/testing_data') 37 | os.remove(aa_test) 38 | 39 | # yolo_bbox 40 | yolo_train = os.path.join(DATA / 'yolo_bbox.zip') 41 | 42 | 43 | wget.download(url="https://docs.google.com/uc?export=download&id=1UzL5tYtBWDXk_nXj4KtoDMyBt7S3j-aS", out=yolo_train) 44 | with zipfile.ZipFile(yolo_train, 'r') as zip_ref: 45 | zip_ref.extractall(DATA / 'FUNSD/training_data') 46 | os.remove(yolo_train) 47 | 48 | yolo_test = os.path.join(DATA / 'yolo_bbox.zip') 49 | wget.download(url="https://docs.google.com/uc?export=download&id=1fWwbhfvINYFQmoPHpwlH8olyTyJPIe_-", out=yolo_test) 50 | with zipfile.ZipFile(yolo_test, 'r') as zip_ref: 51 | zip_ref.extractall(DATA / 'FUNSD/testing_data') 52 | os.remove(yolo_test) 53 | 54 | return 55 | 56 | def pau(): 57 | 58 | print("Downloading PAU") 59 | 60 | PAU = DATA / 'PAU' 61 | spl = os.path.join(DATA, 'pau_split.zip') 62 | wget.download(url="https://docs.google.com/uc?export=download&id=1NKlME13tRIDraSid7r3v_huTqFdHgo-G", out=spl) 63 | with zipfile.ZipFile(spl, 'r') as zip_ref: 64 | zip_ref.extractall(PAU) 65 | 66 | dlz = DATA / 'pau.zip' 67 | download_url("https://zenodo.org/record/3257319/files/dataset.zip", dlz) 68 | with zipfile.ZipFile(dlz, 'r') as zip_ref: 69 | zip_ref.extractall(PAU) 70 | 71 | create_folder(PAU / 'train') 72 | create_folder(PAU / 'test') 73 | create_folder(PAU / 'outliers') 74 | 75 | for split in ['train.txt', 'test.txt', 'outliers.txt']: 76 | with open(PAU / split, mode='r') as f: 77 | folder = split.split(".")[0] 78 | lines = f.readlines() 79 | for l in lines: 80 | l = l.rstrip('\n') 81 | src = PAU / l 82 | dst = PAU / '{}/{}'.format(folder, l) 83 | shutil.move(src, dst) 84 | 85 | os.remove(spl) 86 | os.remove(dlz) 87 | return 88 | 89 | def get_data(): 90 | funsd() 91 | pau() 92 | return 93 | 94 | -------------------------------------------------------------------------------- /src/data/feature_builder.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | import spacy 4 | import torch 5 | import torchvision 6 | from tqdm import tqdm 7 | from PIL import Image, ImageDraw 8 | import torchvision.transforms.functional as tvF 9 | 10 | from src.paths import CHECKPOINTS 11 | from src.models.unet import Unet 12 | from src.data.utils import to_bin 13 | from src.data.utils import polar, get_histogram 14 | from src.utils import get_config 15 | 16 | class FeatureBuilder(): 17 | 18 | def __init__(self, d : int = 'cpu'): 19 | """FeatureBuilder constructor 20 | 21 | Args: 22 | d (int): device number, if any (cpu or cuda:n) 23 | """ 24 | self.cfg_preprocessing = get_config('preprocessing') 25 | self.device = d 26 | self.add_geom = self.cfg_preprocessing.FEATURES.add_geom 27 | self.add_embs = self.cfg_preprocessing.FEATURES.add_embs 28 | self.add_hist = self.cfg_preprocessing.FEATURES.add_hist 29 | self.add_visual = self.cfg_preprocessing.FEATURES.add_visual 30 | self.add_eweights = self.cfg_preprocessing.FEATURES.add_eweights 31 | self.add_fudge = self.cfg_preprocessing.FEATURES.add_fudge 32 | self.num_polar_bins = self.cfg_preprocessing.FEATURES.num_polar_bins 33 | 34 | if self.add_embs: 35 | self.text_embedder = spacy.load('en_core_web_lg') 36 | 37 | if self.add_visual: 38 | self.visual_embedder = Unet(encoder_name="mobilenet_v2", encoder_weights=None, in_channels=1, classes=4) 39 | self.visual_embedder.load_state_dict(torch.load(CHECKPOINTS / 'backbone_unet.pth')['weights']) 40 | self.visual_embedder = self.visual_embedder.encoder 41 | self.visual_embedder.to(d) 42 | 43 | self.sg = lambda rect, s : [rect[0]/s[0], rect[1]/s[1], rect[2]/s[0], rect[3]/s[1]] # scaling by img width and height 44 | 45 | def add_features(self, graphs : list, features : list) -> Tuple[list, int]: 46 | """ Add features to provided graphs 47 | 48 | Args: 49 | graphs (list) : list of DGLGraphs 50 | features (list) : list of features "sources", like text, positions and images 51 | 52 | Returns: 53 | chunks list and its lenght 54 | """ 55 | 56 | for id, g in enumerate(tqdm(graphs, desc='adding features')): 57 | 58 | # positional features 59 | size = Image.open(features['paths'][id]).size 60 | feats = [[] for _ in range(len(features['boxs'][id]))] 61 | geom = [self.sg(box, size) for box in features['boxs'][id]] 62 | chunks = [] 63 | 64 | # 'geometrical' features 65 | if self.add_geom: 66 | 67 | # TODO add 2d encoding like "LayoutLM*" 68 | [feats[idx].extend(self.sg(box, size)) for idx, box in enumerate(features['boxs'][id])] 69 | chunks.append(4) 70 | 71 | # HISTOGRAM OF TEXT 72 | if self.add_hist: 73 | 74 | [feats[idx].extend(hist) for idx, hist in enumerate(get_histogram(features['texts'][id]))] 75 | chunks.append(4) 76 | 77 | # textual features 78 | if self.add_embs: 79 | 80 | # LANGUAGE MODEL (SPACY) 81 | [feats[idx].extend(self.text_embedder(features['texts'][id][idx]).vector) for idx, _ in enumerate(feats)] 82 | chunks.append(len(self.text_embedder(features['texts'][id][0]).vector)) 83 | 84 | # visual features 85 | # https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi 86 | if self.add_visual: 87 | img = Image.open(features['paths'][id]) 88 | visual_emb = self.visual_embedder(tvF.to_tensor(img).unsqueeze_(0).to(self.device)) # output [batch, channels, dim1, dim2] 89 | bboxs = [torch.Tensor(b) for b in features['boxs'][id]] 90 | bboxs = [torch.stack(bboxs, dim=0).to(self.device)] 91 | h = [torchvision.ops.roi_align(input=ve, boxes=bboxs, spatial_scale=1/ min(size[1] / ve.shape[2] , size[0] / ve.shape[3]), output_size=1) for ve in visual_emb[1:]] 92 | h = torch.cat(h, dim=1) 93 | 94 | # VISUAL FEATURES (RESNET-IMAGENET) 95 | [feats[idx].extend(torch.flatten(h[idx]).tolist()) for idx, _ in enumerate(feats)] 96 | chunks.append(len(torch.flatten(h[0]).tolist())) 97 | 98 | if self.add_eweights: 99 | u, v = g.edges() 100 | srcs, dsts = u.tolist(), v.tolist() 101 | distances = [] 102 | angles = [] 103 | 104 | # TODO CHOOSE WHICH DISTANCE NORMALIZATION TO APPLY 105 | #! with fully connected simply normalized with max distance between distances 106 | # m = sqrt((size[0]*size[0] + size[1]*size[1])) 107 | # parable = lambda x : (-x+1)**4 108 | 109 | for pair in zip(srcs, dsts): 110 | dist, angle = polar(features['boxs'][id][pair[0]], features['boxs'][id][pair[1]]) 111 | distances.append(dist) 112 | angles.append(angle) 113 | 114 | m = max(distances) 115 | polar_coordinates = to_bin(distances, angles, self.num_polar_bins) 116 | g.edata['feat'] = polar_coordinates 117 | 118 | else: 119 | distances = ([0.0 for _ in range(g.number_of_edges())]) 120 | m = 1 121 | 122 | g.ndata['geom'] = torch.tensor(geom, dtype=torch.float32) 123 | g.ndata['feat'] = torch.tensor(feats, dtype=torch.float32) 124 | 125 | distances = torch.tensor([(1-d/m) for d in distances], dtype=torch.float32) 126 | tresh_dist = torch.where(distances > 0.9, torch.full_like(distances, 0.1), torch.zeros_like(distances)) 127 | g.edata['weights'] = tresh_dist 128 | 129 | norm = [] 130 | num_nodes = len(features['boxs'][id]) - 1 131 | for n in range(num_nodes + 1): 132 | neigs = torch.count_nonzero(tresh_dist[n*num_nodes:(n+1)*num_nodes]).tolist() 133 | try: norm.append([1. / neigs]) 134 | except: norm.append([1.]) 135 | g.ndata['norm'] = torch.tensor(norm, dtype=torch.float32) 136 | 137 | #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET 138 | if False: 139 | if id == rand_id and self.add_eweights: 140 | print("\n\n### EXAMPLE ###") 141 | 142 | img_path = features['paths'][id] 143 | img = Image.open(img_path).convert('RGB') 144 | draw = ImageDraw.Draw(img) 145 | 146 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2) 147 | select = [random.randint(0, len(srcs)) for _ in range(10)] 148 | for p, pair in enumerate(zip(srcs, dsts)): 149 | if p in select: 150 | sc = center(features['boxs'][id][pair[0]]) 151 | ec = center(features['boxs'][id][pair[1]]) 152 | draw.line((sc, ec), fill='grey', width=3) 153 | middle_point = ((sc[0] + ec[0])/2,(sc[1] + ec[1])/2) 154 | draw.text(middle_point, str(angles[p]), fill='black') 155 | draw.rectangle(features['boxs'][id][pair[0]], fill='red') 156 | draw.rectangle(features['boxs'][id][pair[1]], fill='blue') 157 | 158 | img.save(f'esempi/FUNSD/edges.png') 159 | 160 | return chunks, len(chunks) 161 | 162 | def get_info(self): 163 | print(f"-> textual feats: {self.add_embs}\n-> visual feats: {self.add_visual}\n-> edge feats: {self.add_eweights}") 164 | 165 | 166 | -------------------------------------------------------------------------------- /src/data/graph_builder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from PIL import Image, ImageDraw 4 | from typing import Tuple 5 | import torch 6 | import dgl 7 | import random 8 | import numpy as np 9 | from tqdm import tqdm 10 | import xml.etree.ElementTree as ET 11 | import easyocr 12 | 13 | from src.data.preprocessing import load_predictions 14 | from src.data.utils import polar 15 | from src.paths import DATA, FUNSD_TEST 16 | from src.utils import get_config 17 | 18 | 19 | class GraphBuilder(): 20 | 21 | def __init__(self): 22 | self.cfg_preprocessing = get_config('preprocessing') 23 | self.edge_type = self.cfg_preprocessing.GRAPHS.edge_type 24 | self.data_type = self.cfg_preprocessing.GRAPHS.data_type 25 | self.node_granularity = self.cfg_preprocessing.GRAPHS.node_granularity 26 | random.seed = 42 27 | return 28 | 29 | def get_graph(self, src_path : str, src_data : str) -> Tuple[list, list, list, list]: 30 | """ Given the source, it returns a graph 31 | 32 | Args: 33 | src_path (str) : path to source data 34 | src_data (str) : either FUNSD, PAU or CUSTOM 35 | 36 | Returns: 37 | tuple (lists) : graphs, nodes and edge labels, features 38 | """ 39 | 40 | if src_data == 'FUNSD': 41 | return self.__fromFUNSD(src_path) 42 | elif src_data == 'PAU': 43 | return self.__fromPAU(src_path) 44 | elif src_data == 'CUSTOM': 45 | if self.data_type == 'img': 46 | return self.__fromIMG(src_path) 47 | elif self.data_type == 'pdf': 48 | return self.__fromPDF() 49 | else: 50 | raise Exception('GraphBuilder exception: data type invalid. Choose from ["img", "pdf"]') 51 | else: 52 | raise Exception('GraphBuilder exception: source data invalid. Choose from ["FUNSD", "PAU", "CUSTOM"]') 53 | 54 | def balance_edges(self, g : dgl.DGLGraph, cls=None ) -> dgl.DGLGraph: 55 | """ if cls (class) is not None, but an integer instead, balance that class to be equal to the sum of the other classes 56 | 57 | Args: 58 | g (DGLGraph) : a DGL graph 59 | cls (int) : class number, if any 60 | 61 | Returns: 62 | g (DGLGraph) : the new balanced graph 63 | """ 64 | 65 | edge_targets = g.edata['label'] 66 | u, v = g.all_edges(form='uv') 67 | edges_list = list() 68 | for e in zip(u.tolist(), v.tolist()): 69 | edges_list.append([e[0], e[1]]) 70 | 71 | if type(cls) is int: 72 | to_remove = (edge_targets == cls) 73 | indices_to_remove = to_remove.nonzero().flatten().tolist() 74 | 75 | for _ in range(int((edge_targets != cls).sum()/2)): 76 | indeces_to_save = [random.choice(indices_to_remove)] 77 | edge = edges_list[indeces_to_save[0]] 78 | 79 | for index in sorted(indeces_to_save, reverse=True): 80 | del indices_to_remove[indices_to_remove.index(index)] 81 | 82 | indices_to_remove = torch.flatten(torch.tensor(indices_to_remove, dtype=torch.int32)) 83 | g = dgl.remove_edges(g, indices_to_remove) 84 | return g 85 | 86 | else: 87 | raise Exception("Select a class to balance (an integer ranging from 0 to num_edge_classes).") 88 | 89 | def get_info(self): 90 | """ returns graph information 91 | """ 92 | print(f"-> edge type: {self.edge_type}") 93 | 94 | def fully_connected(self, ids : list) -> Tuple[list, list]: 95 | """ create fully connected graph 96 | 97 | Args: 98 | ids (list) : list of node indices 99 | 100 | Returns: 101 | u, v (lists) : lists of indices 102 | """ 103 | u, v = list(), list() 104 | for id in ids: 105 | u.extend([id for i in range(len(ids)) if i != id]) 106 | v.extend([i for i in range(len(ids)) if i != id]) 107 | return u, v 108 | 109 | def knn_connection(self, size : tuple, bboxs : list, k = 10) -> Tuple[list, list]: 110 | """ Given a list of bounding boxes, find for each of them their k nearest ones. 111 | 112 | Args: 113 | size (tuple) : width and height of the image 114 | bboxs (list) : list of bounding box coordinates 115 | k (int) : k of the knn algorithm 116 | 117 | Returns: 118 | u, v (lists) : lists of indices 119 | """ 120 | 121 | edges = [] 122 | width, height = size[0], size[1] 123 | 124 | # creating projections 125 | vertical_projections = [[] for i in range(width)] 126 | horizontal_projections = [[] for i in range(height)] 127 | for node_index, bbox in enumerate(bboxs): 128 | for hp in range(bbox[0], bbox[2]): 129 | if hp >= width: hp = width - 1 130 | vertical_projections[hp].append(node_index) 131 | for vp in range(bbox[1], bbox[3]): 132 | if vp >= height: vp = height - 1 133 | horizontal_projections[vp].append(node_index) 134 | 135 | def bound(a, ori=''): 136 | if a < 0 : return 0 137 | elif ori == 'h' and a > height: return height 138 | elif ori == 'w' and a > width: return width 139 | else: return a 140 | 141 | for node_index, node_bbox in enumerate(bboxs): 142 | neighbors = [] # collect list of neighbors 143 | window_multiplier = 2 # how much to look around bbox 144 | wider = (node_bbox[2] - node_bbox[0]) > (node_bbox[3] - node_bbox[1]) # if bbox wider than taller 145 | 146 | ### finding neighbors ### 147 | while(len(neighbors) < k and window_multiplier < 100): # keep enlarging the window until at least k bboxs are found or window too big 148 | vertical_bboxs = [] 149 | horizontal_bboxs = [] 150 | neighbors = [] 151 | 152 | if wider: 153 | h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier/4) 154 | v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier) 155 | else: 156 | h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier) 157 | v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier/4) 158 | 159 | window = [bound(node_bbox[0] - h_offset), 160 | bound(node_bbox[1] - v_offset), 161 | bound(node_bbox[2] + h_offset, 'w'), 162 | bound(node_bbox[3] + v_offset, 'h')] 163 | 164 | [vertical_bboxs.extend(d) for d in vertical_projections[window[0]:window[2]]] 165 | [horizontal_bboxs.extend(d) for d in horizontal_projections[window[1]:window[3]]] 166 | 167 | for v in set(vertical_bboxs): 168 | for h in set(horizontal_bboxs): 169 | if v == h: neighbors.append(v) 170 | 171 | window_multiplier += 1 # enlarge the window 172 | 173 | ### finding k nearest neighbors ### 174 | neighbors = list(set(neighbors)) 175 | if node_index in neighbors: 176 | neighbors.remove(node_index) 177 | neighbors_distances = [polar(node_bbox, bboxs[n])[0] for n in neighbors] 178 | for sd_num, sd_idx in enumerate(np.argsort(neighbors_distances)): 179 | if sd_num < k: 180 | if [node_index, neighbors[sd_idx]] not in edges and [neighbors[sd_idx], node_index] not in edges: 181 | edges.append([neighbors[sd_idx], node_index]) 182 | edges.append([node_index, neighbors[sd_idx]]) 183 | else: break 184 | 185 | return [e[0] for e in edges], [e[1] for e in edges] 186 | 187 | def __fromIMG(self, paths : list): 188 | 189 | graphs, node_labels, edge_labels = list(), list(), list() 190 | features = {'paths': paths, 'texts': [], 'boxs': []} 191 | 192 | for path in paths: 193 | reader = easyocr.Reader(['en']) #! TODO: in the future, handle multilanguage! 194 | result = reader.readtext(path, paragraph=True) 195 | img = Image.open(path).convert('RGB') 196 | draw = ImageDraw.Draw(img) 197 | boxs, texts = list(), list() 198 | 199 | for r in result: 200 | box = [int(r[0][0][0]), int(r[0][0][1]), int(r[0][2][0]), int(r[0][2][1])] 201 | draw.rectangle(box, outline='red', width=3) 202 | boxs.append(box) 203 | texts.append(r[1]) 204 | 205 | features['boxs'].append(boxs) 206 | features['texts'].append(texts) 207 | img.save('prova.png') 208 | 209 | if self.edge_type == 'fully': 210 | u, v = self.fully_connected(range(len(boxs))) 211 | elif self.edge_type == 'knn': 212 | u,v = self.knn_connection(Image.open(path).size, boxs) 213 | else: 214 | raise Exception('Other edge types still under development.') 215 | 216 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32) 217 | graphs.append(g) 218 | 219 | return graphs, node_labels, edge_labels, features 220 | 221 | def __fromPDF(): 222 | #TODO: dev from PDF import of graphs 223 | return 224 | 225 | def __fromPAU(self, src: str) -> Tuple[list, list, list, list]: 226 | """ build graphs from Pau Riba's dataset 227 | 228 | Args: 229 | src (str) : path to where data is stored 230 | 231 | Returns: 232 | tuple (lists) : graphs, nodes and edge labels, features 233 | """ 234 | 235 | graphs, node_labels, edge_labels = list(), list(), list() 236 | features = {'paths': [], 'texts': [], 'boxs': []} 237 | 238 | for image in tqdm(os.listdir(src), desc='Creating graphs'): 239 | if not image.endswith('tif'): continue 240 | 241 | img_name = image.split('.')[0] 242 | file_gt = img_name + '_gt.xml' 243 | file_ocr = img_name + '_ocr.xml' 244 | 245 | if not os.path.isfile(os.path.join(src, file_gt)) or not os.path.isfile(os.path.join(src, file_ocr)): continue 246 | features['paths'].append(os.path.join(src, image)) 247 | 248 | # DOCUMENT REGIONS 249 | root = ET.parse(os.path.join(src, file_gt)).getroot() 250 | regions = [] 251 | for parent in root: 252 | if parent.tag.split("}")[1] == 'Page': 253 | for child in parent: 254 | region_label = child[0].attrib['value'] 255 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]), 256 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]), 257 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]), 258 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])] 259 | regions.append([region_label, region_bbox]) 260 | 261 | # DOCUMENT TOKENS 262 | root = ET.parse(os.path.join(src, file_ocr)).getroot() 263 | tokens_bbox = [] 264 | tokens_text = [] 265 | nl = [] 266 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2) 267 | for parent in root: 268 | if parent.tag.split("}")[1] == 'Page': 269 | for child in parent: 270 | if child.tag.split("}")[1] == 'TextRegion': 271 | for elem in child: 272 | if elem.tag.split("}")[1] == 'TextLine': 273 | for word in elem: 274 | if word.tag.split("}")[1] == 'Word': 275 | word_bbox = [int(word[0].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]), 276 | int(word[0].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]), 277 | int(word[0].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]), 278 | int(word[0].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])] 279 | word_text = word[1][0].text 280 | c = center(word_bbox) 281 | for reg in regions: 282 | r = reg[1] 283 | if r[0] < c[0] < r[2] and r[1] < c[1] < r[3]: 284 | word_label = reg[0] 285 | break 286 | tokens_bbox.append(word_bbox) 287 | tokens_text.append(word_text) 288 | nl.append(word_label) 289 | 290 | features['boxs'].append(tokens_bbox) 291 | features['texts'].append(tokens_text) 292 | node_labels.append(nl) 293 | 294 | # getting edges 295 | if self.edge_type == 'fully': 296 | u, v = self.fully_connected(range(len(tokens_bbox))) 297 | elif self.edge_type == 'knn': 298 | u,v = self.knn_connection(Image.open(os.path.join(src, image)).size, tokens_bbox) 299 | else: 300 | raise Exception('Other edge types still under development.') 301 | 302 | el = list() 303 | for e in zip(u, v): 304 | if (nl[e[0]] == nl[e[1]]) and (nl[e[0]] == 'positions' or nl[e[0]] == 'total'): 305 | el.append('table') 306 | else: el.append('none') 307 | edge_labels.append(el) 308 | 309 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(tokens_bbox), idtype=torch.int32) 310 | graphs.append(g) 311 | 312 | return graphs, node_labels, edge_labels, features 313 | 314 | def __fromFUNSD(self, src : str) -> Tuple[list, list, list, list]: 315 | """Parsing FUNSD annotation files 316 | 317 | Args: 318 | src (str) : path to where data is stored 319 | 320 | Returns: 321 | tuple (lists) : graphs, nodes and edge labels, features 322 | """ 323 | 324 | graphs, node_labels, edge_labels = list(), list(), list() 325 | features = {'paths': [], 'texts': [], 'boxs': []} 326 | # justOne = random.choice(os.listdir(os.path.join(src, 'adjusted_annotations'))).split(".")[0] 327 | 328 | if self.node_granularity == 'gt': 329 | for file in tqdm(os.listdir(os.path.join(src, 'adjusted_annotations')), desc='Creating graphs - GT'): 330 | 331 | img_name = f'{file.split(".")[0]}.png' 332 | img_path = os.path.join(src, 'images', img_name) 333 | features['paths'].append(img_path) 334 | 335 | with open(os.path.join(src, 'adjusted_annotations', file), 'r') as f: 336 | form = json.load(f)['form'] 337 | 338 | # getting infos 339 | boxs, texts, ids, nl = list(), list(), list(), list() 340 | pair_labels = list() 341 | 342 | for elem in form: 343 | boxs.append(elem['box']) 344 | texts.append(elem['text']) 345 | nl.append(elem['label']) 346 | ids.append(elem['id']) 347 | [pair_labels.append(pair) for pair in elem['linking']] 348 | 349 | for p, pair in enumerate(pair_labels): 350 | pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])] 351 | 352 | node_labels.append(nl) 353 | features['texts'].append(texts) 354 | features['boxs'].append(boxs) 355 | 356 | # getting edges 357 | if self.edge_type == 'fully': 358 | u, v = self.fully_connected(range(len(boxs))) 359 | elif self.edge_type == 'knn': 360 | u,v = self.knn_connection(Image.open(img_path).size, boxs) 361 | else: 362 | raise Exception('GraphBuilder exception: Other edge types still under development.') 363 | 364 | el = list() 365 | for e in zip(u, v): 366 | edge = [e[0], e[1]] 367 | if edge in pair_labels: el.append('pair') 368 | else: el.append('none') 369 | edge_labels.append(el) 370 | 371 | # creating graph 372 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32) 373 | graphs.append(g) 374 | 375 | #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET 376 | if False: 377 | if justOne == file.split(".")[0]: 378 | print("\n\n### EXAMPLE ###") 379 | print("Savin example:", img_name) 380 | 381 | edge_unique_labels = np.unique(el) 382 | g.edata['label'] = torch.tensor([np.where(target == edge_unique_labels)[0][0] for target in el]) 383 | 384 | g = self.balance_edges(g, 3, int(np.where('none' == edge_unique_labels)[0][0])) 385 | 386 | img_removed = Image.open(img_path).convert('RGB') 387 | draw_removed = ImageDraw.Draw(img_removed) 388 | 389 | for b, box in enumerate(boxs): 390 | if nl[b] == 'header': 391 | color = 'yellow' 392 | elif nl[b] == 'question': 393 | color = 'blue' 394 | elif nl[b] == 'answer': 395 | color = 'green' 396 | else: 397 | color = 'gray' 398 | draw_removed.rectangle(box, outline=color, width=3) 399 | 400 | u, v = g.all_edges() 401 | labels = g.edata['label'].tolist() 402 | u, v = u.tolist(), v.tolist() 403 | 404 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2) 405 | 406 | num_pair = 0 407 | num_none = 0 408 | 409 | for p, pair in enumerate(zip(u,v)): 410 | sc = center(boxs[pair[0]]) 411 | ec = center(boxs[pair[1]]) 412 | if labels[p] == int(np.where('pair' == edge_unique_labels)[0][0]): 413 | num_pair += 1 414 | color = 'violet' 415 | draw_removed.ellipse([(sc[0]-4,sc[1]-4), (sc[0]+4,sc[1]+4)], fill = 'green', outline='black') 416 | draw_removed.ellipse([(ec[0]-4,ec[1]-4), (ec[0]+4,ec[1]+4)], fill = 'red', outline='black') 417 | else: 418 | num_none += 1 419 | color='gray' 420 | draw_removed.line((sc,ec), fill=color, width=3) 421 | 422 | print("Balanced Links: None {} | Key-Value {}".format(num_none, num_pair)) 423 | img_removed.save(f'esempi/FUNSD/{img_name}_removed_graph.png') 424 | 425 | elif self.node_granularity == 'yolo': 426 | path_preds = os.path.join(src, 'yolo_bbox') 427 | path_images = os.path.join(src, 'images') 428 | path_gts = os.path.join(src, 'adjusted_annotations') 429 | all_paths, all_preds, all_links, all_labels, all_texts = load_predictions(path_preds, path_gts, path_images) 430 | for f, img_path in enumerate(tqdm(all_paths, desc='Creating graphs - YOLO')): 431 | 432 | features['paths'].append(img_path) 433 | features['boxs'].append(all_preds[f]) 434 | features['texts'].append(all_texts[f]) 435 | node_labels.append(all_labels[f]) 436 | pair_labels = all_links[f] 437 | 438 | # getting edges 439 | if self.edge_type == 'fully': 440 | u, v = self.fully_connected(range(len(features['boxs'][f]))) 441 | elif self.edge_type == 'knn': 442 | u,v = self.knn_connection(Image.open(img_path).size, features['boxs'][f]) 443 | else: 444 | raise Exception('GraphBuilder exception: Other edge types still under development.') 445 | 446 | el = list() 447 | for e in zip(u, v): 448 | edge = [e[0], e[1]] 449 | if edge in pair_labels: el.append('pair') 450 | else: el.append('none') 451 | edge_labels.append(el) 452 | 453 | # creating graph 454 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(features['boxs'][f]), idtype=torch.int32) 455 | graphs.append(g) 456 | else: 457 | #TODO develop OCR too 458 | raise Exception('GraphBuilder Exception: only \'gt\' or \'yolo\' available for FUNSD.') 459 | 460 | 461 | return graphs, node_labels, edge_labels, features -------------------------------------------------------------------------------- /src/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from scipy.optimize import linprog 5 | import os 6 | from PIL import ImageDraw, Image 7 | import json 8 | import pytesseract 9 | from pytesseract import Output 10 | 11 | from src.paths import DATA, FUNSD_TEST 12 | 13 | 14 | def scale_back(r, w, h): return [int(r[0]*w), 15 | int(r[1]*h), int(r[2]*w), int(r[3]*h)] 16 | 17 | 18 | def center(r): return ((r[0] + r[2]) / 2, (r[1] + r[3]) / 2) 19 | 20 | 21 | def isIn(c, r): 22 | if c[0] < r[0] or c[0] > r[2]: 23 | return False 24 | elif c[1] < r[1] or c[1] > r[3]: 25 | return False 26 | else: 27 | return True 28 | 29 | 30 | def match_pred_w_gt(bbox_preds: torch.Tensor, bbox_gts: torch.Tensor, links_pair: list): 31 | bbox_iou = torchvision.ops.box_iou(boxes1=bbox_preds, boxes2=bbox_gts) 32 | bbox_iou = bbox_iou.numpy() 33 | 34 | A_ub = np.zeros(shape=( 35 | bbox_iou.shape[0] + bbox_iou.shape[1], bbox_iou.shape[0] * bbox_iou.shape[1])) 36 | for r in range(bbox_iou.shape[0]): 37 | st = r * bbox_iou.shape[1] 38 | A_ub[r, st:st + bbox_iou.shape[1]] = 1 39 | for j in range(bbox_iou.shape[1]): 40 | r = j + bbox_iou.shape[0] 41 | A_ub[r, j::bbox_iou.shape[1]] = 1 42 | b_ub = np.ones(shape=A_ub.shape[0]) 43 | 44 | assignaments_score = linprog( 45 | c=-bbox_iou.reshape(-1), A_ub=A_ub, b_ub=b_ub, bounds=(0, 1), method="highs-ds") 46 | #assignaments_score = linprog(c=-bbox_iou.reshape(-1), bounds=(0, 1), method="highs-ds") 47 | # print(assignaments_score) 48 | if not assignaments_score.success: 49 | print("Optimization FAILED") 50 | assignaments_score = assignaments_score.x.reshape(bbox_iou.shape) 51 | assignaments_ids = assignaments_score.argmax(axis=1) 52 | 53 | # matched 54 | opt_assignaments = {} 55 | for idx in range(assignaments_score.shape[0]): 56 | if (bbox_iou[idx, assignaments_ids[idx]] > 0.5) and (assignaments_score[idx, assignaments_ids[idx]] > 0.9): 57 | opt_assignaments[idx] = assignaments_ids[idx] 58 | # unmatched predictions 59 | false_positive = [idx for idx in range( 60 | bbox_preds.shape[0]) if idx not in opt_assignaments] 61 | # unmatched gts 62 | false_negative = [idx for idx in range( 63 | bbox_gts.shape[0]) if idx not in opt_assignaments.values()] 64 | 65 | gt2pred = {v: k for k, v in opt_assignaments.items()} 66 | link_false_neg = [] 67 | for link in links_pair: 68 | if link[0] in false_negative or link[1] in false_negative: 69 | link_false_neg.append(link) 70 | 71 | if len(links_pair) != 0: 72 | rate = len(link_false_neg) / len(links_pair) 73 | else: 74 | rate = 0 75 | return {"pred2gt": opt_assignaments, "gt2pred": gt2pred, "false_positive": false_positive, "false_negative": false_negative, "n_link_fn": int(len(link_false_neg) / 2), "link_loss": rate, "entity_loss": len(false_positive) / (len(false_positive) + len(opt_assignaments.keys()))} 76 | 77 | 78 | def get_objects(path, mode): 79 | # TODO given a document, apply OCR or Yolo to detect either words or entities. 80 | return 81 | 82 | 83 | def load_predictions(path_preds, path_gts, path_images, debug=False): 84 | # TODO read txt file and pass bounding box to the other function. 85 | 86 | boxs_preds = [] 87 | boxs_gts = [] 88 | links_gts = [] 89 | labels_gts = [] 90 | texts_ocr = [] 91 | all_paths = [] 92 | 93 | for img in os.listdir(path_images): 94 | all_paths.append(os.path.join(path_images, img)) 95 | w, h = Image.open(os.path.join(path_images, img)).size 96 | texts = pytesseract.image_to_data(Image.open( 97 | os.path.join(path_images, img)), output_type=Output.DICT) 98 | tp = [] 99 | n_elements = len(texts['level']) 100 | for t in range(n_elements): 101 | if int(texts['conf'][t]) > 50 and texts['text'][t] != ' ': 102 | b = [texts['left'][t], texts['top'][t], texts['left'][t] + 103 | texts['width'][t], texts['top'][t] + texts['height'][t]] 104 | tp.append([b, texts['text'][t]]) 105 | texts_ocr.append(tp) 106 | preds_name = img.split(".")[0] + '.txt' 107 | with open(os.path.join(path_preds, preds_name), 'r') as preds: 108 | lines = preds.readlines() 109 | boxs = list() 110 | for line in lines: 111 | scaled = scale_back([float(c) 112 | for c in line[:-1].split(" ")[1:]], w, h) 113 | sw, sh = scaled[2] / 2, scaled[3] / 2 114 | boxs.append([scaled[0] - sw, scaled[1] - sh, 115 | scaled[0] + sw, scaled[1] + sh]) 116 | boxs_preds.append(boxs) 117 | 118 | gts_name = img.split(".")[0] + '.json' 119 | with open(os.path.join(path_gts, gts_name), 'r') as f: 120 | form = json.load(f)['form'] 121 | boxs = list() 122 | pair_labels = [] 123 | ids = [] 124 | labels = [] 125 | for elem in form: 126 | boxs.append([float(e) for e in elem['box']]) 127 | ids.append(elem['id']) 128 | labels.append(elem['label']) 129 | [pair_labels.append(pair) for pair in elem['linking']] 130 | 131 | for p, pair in enumerate(pair_labels): 132 | pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])] 133 | 134 | boxs_gts.append(boxs) 135 | links_gts.append(pair_labels) 136 | labels_gts.append(labels) 137 | 138 | all_links = [] 139 | all_preds = [] 140 | all_labels = [] 141 | all_texts = [] 142 | dropped_links = 0 143 | dropped_entity = 0 144 | 145 | for p in range(len(boxs_preds)): 146 | d = match_pred_w_gt(torch.tensor( 147 | boxs_preds[p]), torch.tensor(boxs_gts[p]), links_gts[p]) 148 | dropped_links += d['link_loss'] 149 | dropped_entity += d['entity_loss'] 150 | links = list() 151 | 152 | for link in links_gts[p]: 153 | if link[0] in d['false_negative'] or link[1] in d['false_negative']: 154 | continue 155 | else: 156 | links.append([d['gt2pred'][link[0]], d['gt2pred'][link[1]]]) 157 | all_links.append(links) 158 | 159 | preds = [] 160 | labels = [] 161 | texts = [] 162 | for b, box in enumerate(boxs_preds[p]): 163 | if b in d['false_positive']: 164 | preds.append(box) 165 | labels.append('other') 166 | else: 167 | gt_id = d['pred2gt'][b] 168 | preds.append(box) 169 | labels.append(labels_gts[p][gt_id]) 170 | 171 | text = '' 172 | for tocr in texts_ocr[p]: 173 | if isIn(center(tocr[0]), box): 174 | text += tocr[1] + ' ' 175 | 176 | texts.append(text) 177 | 178 | all_preds.append(preds) 179 | all_labels.append(labels) 180 | all_texts.append(texts) 181 | print(dropped_links / len(boxs_preds), dropped_entity / len(boxs_preds)) 182 | 183 | if debug: 184 | # random.seed(35) 185 | # rand_idx = random.randint(0, len(os.listdir(path_images))) 186 | print(all_texts[0]) 187 | rand_idx = 0 188 | img = Image.open(os.path.join(path_images, os.listdir( 189 | path_images)[rand_idx])).convert('RGB') 190 | draw = ImageDraw.Draw(img) 191 | 192 | rand_boxs_preds = boxs_preds[rand_idx] 193 | rand_boxs_gts = boxs_gts[rand_idx] 194 | 195 | for box in rand_boxs_gts: 196 | draw.rectangle(box, outline='blue', width=3) 197 | for box in rand_boxs_preds: 198 | draw.rectangle(box, outline='red', width=3) 199 | 200 | d = match_pred_w_gt(torch.tensor(rand_boxs_preds), 201 | torch.tensor(rand_boxs_gts), links_gts[rand_idx]) 202 | print(d) 203 | for idx in d['pred2gt'].keys(): 204 | draw.rectangle(rand_boxs_preds[idx], outline='green', width=3) 205 | 206 | link_true_pos = list() 207 | link_false_neg = list() 208 | for link in links_gts[rand_idx]: 209 | if link[0] in d['false_negative'] or link[1] in d['false_negative']: 210 | link_false_neg.append(link) 211 | start = rand_boxs_gts[link[0]] 212 | end = rand_boxs_gts[link[1]] 213 | draw.line((center(start), center(end)), fill='red', width=3) 214 | else: 215 | link_true_pos.append(link) 216 | start = rand_boxs_preds[d['gt2pred'][link[0]]] 217 | end = rand_boxs_preds[d['gt2pred'][link[1]]] 218 | draw.line((center(start), center(end)), fill='green', width=3) 219 | 220 | precision = 0 221 | recall = 0 222 | for idx, gt in enumerate(boxs_gts): 223 | d = match_pred_w_gt(torch.tensor( 224 | boxs_preds[idx]), torch.tensor(gt), links_gts[rand_idx]) 225 | bbox_true_positive = len(d["pred2gt"]) 226 | p = bbox_true_positive / \ 227 | (bbox_true_positive + len(d["false_positive"])) 228 | r = bbox_true_positive / \ 229 | (bbox_true_positive + len(d["false_negative"])) 230 | # f1 += (2 * p * r) / (p + r) 231 | precision += p 232 | recall += r 233 | 234 | precision = precision / len(boxs_gts) 235 | recall = recall / len(boxs_gts) 236 | f1 = (2 * precision * recall) / (precision + recall) 237 | # print(f1, precision, recall) 238 | 239 | img.save('prova.png') 240 | 241 | return all_paths, all_preds, all_links, all_labels, all_texts 242 | 243 | 244 | def save_results(): 245 | # TODO output json of matching and check with visualization of images. 246 | return 247 | 248 | 249 | if __name__ == "__main__": 250 | path_preds = DATA / 'FUNSD' / 'test_bbox' 251 | path_images = FUNSD_TEST / 'images' 252 | path_gts = FUNSD_TEST / 'adjusted_annotations' 253 | load_predictions(path_preds, path_gts, path_images, debug=True) 254 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import Tuple 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import math 7 | 8 | def polar(rect_src : list, rect_dst : list) -> Tuple[int, int]: 9 | """Compute distance and angle from src to dst bounding boxes (poolar coordinates considering the src as the center) 10 | Args: 11 | rect_src (list) : source rectangle coordinates 12 | rect_dst (list) : destination rectangle coordinates 13 | 14 | Returns: 15 | tuple (ints): distance and angle 16 | """ 17 | 18 | # check relative position 19 | left = (rect_dst[2] - rect_src[0]) <= 0 20 | bottom = (rect_src[3] - rect_dst[1]) <= 0 21 | right = (rect_src[2] - rect_dst[0]) <= 0 22 | top = (rect_dst[3] - rect_src[1]) <= 0 23 | 24 | vp_intersect = (rect_src[0] <= rect_dst[2] and rect_dst[0] <= rect_src[2]) # True if two rects "see" each other vertically, above or under 25 | hp_intersect = (rect_src[1] <= rect_dst[3] and rect_dst[1] <= rect_src[3]) # True if two rects "see" each other horizontally, right or left 26 | rect_intersect = vp_intersect and hp_intersect 27 | 28 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2) 29 | 30 | # evaluate reciprocal position 31 | sc = center(rect_src) 32 | ec = center(rect_dst) 33 | new_ec = (ec[0] - sc[0], ec[1] - sc[1]) 34 | angle = int(math.degrees(math.atan2(new_ec[1], new_ec[0])) % 360) 35 | 36 | if rect_intersect: 37 | return 0, angle 38 | elif top and left: 39 | a, b = (rect_dst[2] - rect_src[0]), (rect_dst[3] - rect_src[1]) 40 | return int(sqrt(a**2 + b**2)), angle 41 | elif left and bottom: 42 | a, b = (rect_dst[2] - rect_src[0]), (rect_dst[1] - rect_src[3]) 43 | return int(sqrt(a**2 + b**2)), angle 44 | elif bottom and right: 45 | a, b = (rect_dst[0] - rect_src[2]), (rect_dst[1] - rect_src[3]) 46 | return int(sqrt(a**2 + b**2)), angle 47 | elif right and top: 48 | a, b = (rect_dst[0] - rect_src[2]), (rect_dst[3] - rect_src[1]) 49 | return int(sqrt(a**2 + b**2)), angle 50 | elif left: 51 | return (rect_src[0] - rect_dst[2]), angle 52 | elif right: 53 | return (rect_dst[0] - rect_src[2]), angle 54 | elif bottom: 55 | return (rect_dst[1] - rect_src[3]), angle 56 | elif top: 57 | return (rect_src[1] - rect_dst[3]), angle 58 | 59 | def transform_image(img_path : str, scale_image=1.0) -> torch.Tensor: 60 | """ Transform image to torch.Tensor 61 | 62 | Args: 63 | img_path (str) : where the image is stored 64 | scale_image (float) : how much scale the image 65 | """ 66 | 67 | np_img = cv2.imread(img_path, cv2.IMREAD_COLOR) 68 | width = int(np_img.shape[1] * scale_image) 69 | height = int(np_img.shape[0] * scale_image) 70 | new_size = (width, height) 71 | np_img = cv2.resize(np_img,new_size) 72 | img = cv2.cvtColor(np_img, cv2.COLOR_BGR2GRAY) 73 | img = img[None,None,:,:] 74 | img = img.astype(np.float32) 75 | img = torch.from_numpy(img) 76 | img = 1.0 - img / 128.0 77 | 78 | return img 79 | 80 | def get_histogram(contents : list) -> list: 81 | """Create histogram of content given a text. 82 | 83 | Args; 84 | contents (list) 85 | 86 | Returns: 87 | list of [x, y, z] - 3-dimension list with float values summing up to 1 where: 88 | - x is the % of literals inside the text 89 | - y is the % of numbers inside the text 90 | - z is the % of other symbols i.e. @, #, .., inside the text 91 | """ 92 | 93 | c_histograms = list() 94 | 95 | for token in contents: 96 | num_symbols = 0 # all 97 | num_literals = 0 # A, B etc. 98 | num_figures = 0 # 1, 2, etc. 99 | num_others = 0 # !, @, etc. 100 | 101 | histogram = [0.0000, 0.0000, 0.0000, 0.0000] 102 | 103 | for symbol in token.replace(" ", ""): 104 | if symbol.isalpha(): 105 | num_literals += 1 106 | elif symbol.isdigit(): 107 | num_figures += 1 108 | else: 109 | num_others += 1 110 | num_symbols += 1 111 | 112 | if num_symbols != 0: 113 | histogram[0] = num_literals / num_symbols 114 | histogram[1] = num_figures / num_symbols 115 | histogram[2] = num_others / num_symbols 116 | 117 | # keep sum 1 after truncate 118 | if sum(histogram) != 1.0: 119 | diff = 1.0 - sum(histogram) 120 | m = max(histogram) + diff 121 | histogram[histogram.index(max(histogram))] = m 122 | 123 | # if symbols not recognized at all or empty, sum everything at 1 in the last 124 | if histogram[0:3] == [0.0,0.0,0.0]: 125 | histogram[3] = 1.0 126 | 127 | c_histograms.append(histogram) 128 | 129 | return c_histograms 130 | 131 | def to_bin(dist :int, angle : int, b=8) -> torch.Tensor: 132 | """ Discretize the space into equal "bins": return a distance and angle into a number between 0 and 1. 133 | 134 | Args: 135 | dist (int): distance in terms of pixel, given by "polar()" util function 136 | angle (int): angle between 0 and 360, given by "polar()" util function 137 | b (int): number of bins, MUST be power of 2 138 | 139 | Returns: 140 | torch.Tensor: new distance and angle (binary encoded) 141 | 142 | """ 143 | def isPowerOfTwo(x): 144 | return (x and (not(x & (x - 1))) ) 145 | 146 | # dist 147 | assert isPowerOfTwo(b) 148 | m = max(dist) / b 149 | new_dist = [] 150 | for d in dist: 151 | bin = int(d / m) 152 | if bin >= b: bin = b - 1 153 | bin = [int(x) for x in list('{0:0b}'.format(bin))] 154 | while len(bin) < sqrt(b): bin.insert(0, 0) 155 | new_dist.append(bin) 156 | 157 | # angle 158 | amplitude = 360 / b 159 | new_angle = [] 160 | for a in angle: 161 | bin = (a - amplitude / 2) 162 | bin = int(bin / amplitude) 163 | bin = [int(x) for x in list('{0:0b}'.format(bin))] 164 | while len(bin) < sqrt(b): bin.insert(0, 0) 165 | new_angle.append(bin) 166 | 167 | return torch.cat([torch.tensor(new_dist, dtype=torch.float32), torch.tensor(new_angle, dtype=torch.float32)], dim=1) 168 | 169 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import os 4 | from PIL import Image, ImageDraw 5 | import json 6 | 7 | from src.data.feature_builder import FeatureBuilder 8 | from src.data.graph_builder import GraphBuilder 9 | from src.data.preprocessing import center 10 | from src.models.graphs import SetModel 11 | from src.paths import CHECKPOINTS, INFERENCE 12 | from src.training.utils import get_device 13 | from src.utils import create_folder 14 | 15 | pretrain = { 16 | 'funsd': {'node_num_classes': 4, 'edge_num_classes': 2}, 17 | 'pau': {'node_num_classes': 5, 'edge_num_classes': 2} 18 | } 19 | 20 | def inference(weights, paths, device=-1): 21 | # create a graph per each file 22 | print("Doc2Graph Inference:") 23 | device = get_device(device) 24 | 25 | print("-> Creating graphs ...") 26 | gb = GraphBuilder() 27 | graphs, _, _, features = gb.get_graph(paths, 'CUSTOM') 28 | 29 | # add embedded visual, text, layout etc. features to the graphs 30 | print("-> Creating features ...") 31 | fb = FeatureBuilder(d=device) 32 | chunks, _ = fb.add_features(graphs, features) 33 | 34 | # create the model 35 | print("-> Creating model ...") 36 | model = weights[0].split("-")[0] 37 | pre = weights[0].split("-")[1] 38 | sm = SetModel(name=model, device=device) 39 | info = pretrain[pre] 40 | model = sm.get_model(info['node_num_classes'], info['edge_num_classes'], chunks, False) 41 | model.load_state_dict(torch.load(CHECKPOINTS / weights[0])) 42 | model.eval() 43 | 44 | # predict on graphs 45 | print("Predicting:") 46 | with torch.no_grad(): 47 | for num, graph in enumerate(graphs): 48 | _, name = os.path.split(paths[num]) 49 | name = name.split(".")[0] 50 | print(f" -> {name}") 51 | 52 | n, e = model(graph.to(device), graph.ndata['feat'].to(device)) 53 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1) 54 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1) 55 | 56 | # save results 57 | links = (epreds == 1).nonzero(as_tuple=True)[0].tolist() 58 | u, v = graph.edges() 59 | entities = features['boxs'][num] 60 | contents = features['texts'][num] 61 | 62 | graph_img = Image.open(paths[num]).convert('RGB') 63 | graph_draw = ImageDraw.Draw(graph_img) 64 | 65 | result = [] 66 | for i, idx in enumerate(links): 67 | pair = {'key': {'text': contents[u[idx]], 'box': entities[u[idx]]}, 68 | 'value': {'text': contents[v[idx]], 'box': entities[v[idx]]}} 69 | result.append(pair) 70 | 71 | key_center = center(entities[u[idx]]) 72 | value_center = center(entities[v[idx]]) 73 | graph_draw.line((key_center, value_center), fill='violet', width=3) 74 | graph_draw.ellipse([(key_center[0]-4,key_center[1]-4), (key_center[0]+4,key_center[1]+4)], fill = 'green', outline='black') 75 | graph_draw.ellipse([(value_center[0]-4,value_center[1]-4), (value_center[0]+4,value_center[1]+4)], fill = 'red', outline='black') 76 | 77 | graph_img.save(INFERENCE / f'{name}.png') 78 | with open(INFERENCE / f'{name}.json', "w") as outfile: 79 | json.dump(result, outfile) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from src.data.download import get_data 4 | from src.inference import inference 5 | from src.training.funsd import train_funsd 6 | from src.utils import create_folder, project_tree, set_preprocessing 7 | from src.training.pau import train_pau 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Training') 11 | 12 | # init 13 | parser.add_argument('--init', action="store_true", 14 | help="download data and prepare folders") 15 | 16 | # features 17 | parser.add_argument('--add-geom', '-addG', action="store_true", 18 | help="add geometrical features to nodes") 19 | parser.add_argument('--add-embs', '-addT', action="store_true", 20 | help="add textual embeddings to nodes") 21 | parser.add_argument('--add-hist', '-addH', action="store_true", 22 | help="add histogram of contents to nodes") 23 | parser.add_argument('--add-visual', '-addV', action="store_true", 24 | help="add visual features to nodes") 25 | parser.add_argument('--add-eweights', '-addE', action="store_true", 26 | help="add edge features to graphs") 27 | # data 28 | parser.add_argument("--src-data", type=str, default='FUNSD', 29 | help="which data source to use. It can be FUNSD, PAU or CUSTOM") 30 | parser.add_argument("--data-type", type=str, default='img', 31 | help="if src-data is CUSTOM, define the data source type: img or pdf.") 32 | # graphs 33 | parser.add_argument("--edge-type", type=str, default='fully', 34 | help="choose the kind of connectivity in the graph. It can be: fully or knn.") 35 | parser.add_argument("--node-granularity", type=str, default='gt', 36 | help="choose the granularity of nodes to be used. It can be: gt (if given), ocr (words) or yolo (entities).") 37 | parser.add_argument("--num-polar-bins", type=int, default=8, 38 | help="number of bins into which discretize the space for edge polar features. It must be a power of 2: Default 8.") 39 | 40 | # training 41 | parser.add_argument("--model", type=str, default='e2e', 42 | help="which model to use, which yaml file to load: e2e, edge or gcn") 43 | parser.add_argument("--gpu", type=int, default=-1, 44 | help="which GPU to use. Set -1 to use CPU.") 45 | parser.add_argument('--test', action="store_true", 46 | help="skip training") 47 | parser.add_argument('--weights', '-w', nargs='+', type=str, default=None, 48 | help="provide a weights file relative path if testing") 49 | 50 | # inference 51 | parser.add_argument('--inference', action="store_true", 52 | help="use the model to predict on new, unseen examples") 53 | parser.add_argument('--docs', nargs='+', type=str, default=None, 54 | help="provide documents to do inference on them") 55 | 56 | args = parser.parse_args() 57 | print(args) 58 | 59 | if args.init: 60 | project_tree() 61 | get_data() 62 | print("Initialization completed!") 63 | 64 | else: 65 | set_preprocessing(args) 66 | if args.inference: 67 | create_folder('inference') 68 | inference(args.weights, args.docs, args.gpu) 69 | elif args.src_data == 'FUNSD': 70 | if args.test and args.weights == None: 71 | raise Exception("Main exception: Provide a weights file relative path! Or train a model first.") 72 | train_funsd(args) 73 | elif args.src_data == 'PAU': 74 | train_pau(args) 75 | elif args.src_data == 'CUSTOM': 76 | #TODO develop custom data preprocessing 77 | raise Exception('Main exception: "CUSTOM" source data still under development') 78 | else: 79 | raise Exception('Main exception: source data invalid. Choose from ["FUNSD", "PAU", "CUSTOM"]') 80 | 81 | return 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/graphs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import dgl.function as fn 4 | import math 5 | import torch.nn.functional as F 6 | 7 | from src.paths import CFGM 8 | from src.utils import get_config 9 | 10 | class SetModel(): 11 | def __init__(self, name='e2e', device = 'cpu'): 12 | """ Create a SetModel object, that handles dinamically different version of Doc2Graph Model. Default "end-to-end" (e2e) 13 | 14 | Args: 15 | name (str) : Which model to train / test. Default: e2e [gcn, edge]. 16 | 17 | Returns: 18 | SetModel object. 19 | """ 20 | 21 | self.cfg_model = get_config(CFGM / name) 22 | self.name = self.cfg_model.name 23 | self.total_params = 0 24 | self.device = device 25 | 26 | def get_name(self) -> str: 27 | """ Returns model name. 28 | """ 29 | return self.name 30 | 31 | def get_total_params(self) -> int: 32 | """ Returns number of model parameteres. 33 | """ 34 | return self.total_params 35 | 36 | def get_model(self, nodes : int, edges : int, chunks : list, verbatim : bool = True) -> nn.Module: 37 | """Return the DGL model defined in the setting file 38 | 39 | Args: 40 | nodes (int) : number of nodes target class 41 | edges (int) : number of edges target class 42 | chunks (list) : list of indeces of chunks 43 | 44 | Returns: 45 | A PyTorch nn.Module, your DGL model. 46 | """ 47 | print("\n### MODEL ###") 48 | print(f"-> Using {self.name}") 49 | 50 | if self.name == 'GCN': 51 | m = NodeClassifier(chunks, self.cfg_model.out_chunks, nodes, self.cfg_model.num_layers, F.relu, False, self.device) 52 | 53 | elif self.name == 'EDGE': 54 | m = EdgeClassifier(edges, self.cfg_model.num_layers, self.cfg_model.dropout, chunks, self.cfg_model.out_chunks, self.cfg_model.hidden_dim, self.device, self.cfg_model.doProject) 55 | 56 | elif self.name == 'E2E': 57 | edge_pred_features = int((math.log2(get_config('preprocessing').FEATURES.num_polar_bins) + nodes)*2) 58 | m = E2E(nodes, edges, self.cfg_model.num_layers, self.cfg_model.dropout, chunks, self.cfg_model.out_chunks, self.cfg_model.hidden_dim, self.device, edge_pred_features, self.cfg_model.doProject) 59 | 60 | else: 61 | raise Exception(f"Error! Model {self.name} do not exists.") 62 | 63 | m.to(self.device) 64 | self.total_params = sum(p.numel() for p in m.parameters() if p.requires_grad) 65 | print(f"-> Total params: {self.total_params}") 66 | print("-> Device: " + str(next(m.parameters()).is_cuda) + "\n") 67 | if verbatim: print(m) 68 | 69 | return m 70 | 71 | ################ 72 | ##### GCNS ##### 73 | 74 | class NodeClassifier(nn.Module): 75 | def __init__(self, 76 | in_chunks, 77 | out_chunks, 78 | n_classes, 79 | n_layers, 80 | activation, 81 | dropout=0, 82 | use_pp=False, 83 | device='cuda:0'): 84 | super(NodeClassifier, self).__init__() 85 | 86 | self.projector = InputProjector(in_chunks, out_chunks, device) 87 | self.layers = nn.ModuleList() 88 | # self.dropout = nn.Dropout(dropout) 89 | self.n_layers = n_layers 90 | 91 | n_hidden = self.projector.get_out_lenght() 92 | 93 | # mp layers 94 | for i in range(0, n_layers - 1): 95 | self.layers.append(GcnSAGELayer(n_hidden, n_hidden, activation=activation, 96 | dropout=dropout, use_pp=False, use_lynorm=True)) 97 | 98 | self.layers.append(GcnSAGELayer(n_hidden, n_classes, activation=None, 99 | dropout=False, use_pp=False, use_lynorm=False)) 100 | 101 | def forward(self, g, h): 102 | 103 | h = self.projector(h) 104 | 105 | for l in range(self.n_layers): 106 | h = self.layers[l](g, h) 107 | 108 | return h 109 | 110 | ################ 111 | ##### EDGE ##### 112 | 113 | class EdgeClassifier(nn.Module): 114 | 115 | def __init__(self, edge_classes, m_layers, dropout, in_chunks, out_chunks, hidden_dim, device, doProject=True): 116 | super().__init__() 117 | 118 | # Project inputs into higher space 119 | self.projector = InputProjector(in_chunks, out_chunks, device, doProject) 120 | 121 | # Perform message passing 122 | m_hidden = self.projector.get_out_lenght() 123 | self.message_passing = nn.ModuleList() 124 | self.m_layers = m_layers 125 | for l in range(m_layers): 126 | self.message_passing.append(GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.)) 127 | 128 | # Define edge predictori layer 129 | self.edge_pred = MLPPredictor(m_hidden, hidden_dim, edge_classes, dropout) 130 | 131 | def forward(self, g, h): 132 | 133 | h = self.projector(h) 134 | 135 | for l in range(self.m_layers): 136 | h = self.message_passing[l](g, h) 137 | 138 | e = self.edge_pred(g, h) 139 | 140 | return e 141 | 142 | ################ 143 | ###### E2E ##### 144 | 145 | class E2E(nn.Module): 146 | def __init__(self, node_classes, 147 | edge_classes, 148 | m_layers, 149 | dropout, 150 | in_chunks, 151 | out_chunks, 152 | hidden_dim, 153 | device, 154 | edge_pred_features, 155 | doProject=True): 156 | 157 | super().__init__() 158 | 159 | # Project inputs into higher space 160 | self.projector = InputProjector(in_chunks, out_chunks, device, doProject) 161 | 162 | # Perform message passing 163 | m_hidden = self.projector.get_out_lenght() 164 | self.message_passing = nn.ModuleList() 165 | # self.m_layers = m_layers 166 | # for l in range(m_layers): 167 | # self.message_passing.append(GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.)) 168 | self.message_passing = GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.) 169 | 170 | # Define edge predictor layer 171 | self.edge_pred = MLPPredictor_E2E(m_hidden, hidden_dim, edge_classes, dropout, edge_pred_features) 172 | 173 | # Define node predictor layer 174 | node_pred = [] 175 | node_pred.append(nn.Linear(m_hidden, node_classes)) 176 | node_pred.append(nn.LayerNorm(node_classes)) 177 | self.node_pred = nn.Sequential(*node_pred) 178 | 179 | def forward(self, g, h): 180 | 181 | h = self.projector(h) 182 | # for l in range(self.m_layers): 183 | # h = self.message_passing[l](g, h) 184 | h = self.message_passing(g,h) 185 | n = self.node_pred(h) 186 | e = self.edge_pred(g, h, n) 187 | 188 | return n, e 189 | 190 | ################ 191 | ##### LYRS ##### 192 | 193 | class GcnSAGELayer(nn.Module): 194 | def __init__(self, 195 | in_feats, 196 | out_feats, 197 | activation, 198 | dropout, 199 | bias=True, 200 | use_pp=False, 201 | use_lynorm=True): 202 | super(GcnSAGELayer, self).__init__() 203 | self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias) 204 | self.activation = activation 205 | self.use_pp = use_pp 206 | if dropout: 207 | self.dropout = nn.Dropout(p=dropout) 208 | else: 209 | self.dropout = 0. 210 | if use_lynorm: 211 | self.lynorm = nn.LayerNorm(out_feats, elementwise_affine=True) 212 | else: 213 | self.lynorm = lambda x: x 214 | self.reset_parameters() 215 | 216 | def reset_parameters(self): 217 | stdv = 1. / math.sqrt(self.linear.weight.size(1)) 218 | self.linear.weight.data.uniform_(-stdv, stdv) 219 | if self.linear.bias is not None: 220 | self.linear.bias.data.uniform_(-stdv, stdv) 221 | 222 | def forward(self, g, h): 223 | g = g.local_var() 224 | 225 | if not self.use_pp: 226 | # norm = self.get_norm(g) 227 | norm = g.ndata['norm'] 228 | g.ndata['h'] = h 229 | g.update_all(fn.u_mul_e('h', 'weights', 'm'), 230 | fn.sum(msg='m', out='h')) 231 | ah = g.ndata.pop('h') 232 | h = self.concat(h, ah, norm) 233 | 234 | if self.dropout: 235 | h = self.dropout(h) 236 | 237 | h = self.linear(h) 238 | h = self.lynorm(h) 239 | if self.activation: 240 | h = self.activation(h) 241 | return h 242 | 243 | def concat(self, h, ah, norm): 244 | ah = ah * norm 245 | h = torch.cat((h, ah), dim=1) 246 | return h 247 | 248 | def get_norm(self, g): 249 | norm = 1. / g.in_degrees().float().unsqueeze(1) 250 | norm[torch.isinf(norm)] = 0 251 | norm = norm.to(self.linear.weight.device) 252 | return norm 253 | 254 | class InputProjector(nn.Module): 255 | def __init__(self, in_chunks : list, out_chunks : int, device, doIt = True) -> None: 256 | super().__init__() 257 | 258 | if not doIt: 259 | self.output_length = sum(in_chunks) 260 | self.doIt = doIt 261 | return 262 | 263 | self.output_length = len(in_chunks)*out_chunks 264 | self.doIt = doIt 265 | self.chunks = in_chunks 266 | modules = [] 267 | self.device = device 268 | 269 | for chunk in in_chunks: 270 | chunk_module = [] 271 | chunk_module.append(nn.Linear(chunk, out_chunks)) 272 | chunk_module.append(nn.LayerNorm(out_chunks)) 273 | chunk_module.append(nn.ReLU()) 274 | modules.append(nn.Sequential(*chunk_module)) 275 | 276 | self.modalities = nn.Sequential(*modules) 277 | self.chunks.insert(0, 0) 278 | 279 | def get_out_lenght(self): 280 | return self.output_length 281 | 282 | def forward(self, h): 283 | 284 | if not self.doIt: 285 | return h 286 | 287 | mid = [] 288 | 289 | for name, module in self.modalities.named_children(): 290 | num = int(name) 291 | if num + 1 == len(self.chunks): break 292 | start = self.chunks[num] + sum(self.chunks[:num]) 293 | end = start + self.chunks[num+1] 294 | input = h[:, start:end].to(self.device) 295 | mid.append(module(input)) 296 | 297 | return torch.cat(mid, dim=1) 298 | 299 | class MLPPredictor(nn.Module): 300 | def __init__(self, in_features, hidden_dim, out_classes, dropout): 301 | super().__init__() 302 | self.out = out_classes 303 | self.W1 = nn.Linear(in_features*2, hidden_dim) 304 | self.norm = nn.LayerNorm(hidden_dim) 305 | self.W2 = nn.Linear(hidden_dim + 6, out_classes) 306 | self.drop = nn.Dropout(dropout) 307 | 308 | def apply_edges(self, edges): 309 | h_u = edges.src['h'] 310 | h_v = edges.dst['h'] 311 | polar = edges.data['feat'] 312 | 313 | x = F.relu(self.norm(self.W1(torch.cat((h_u, h_v), dim=1)))) 314 | x = torch.cat((x, polar), dim=1) 315 | score = self.drop(self.W2(x)) 316 | 317 | return {'score': score} 318 | 319 | def forward(self, graph, h): 320 | # h contains the node representations computed from the GNN defined 321 | # in the node classification section (Section 5.1). 322 | with graph.local_scope(): 323 | graph.ndata['h'] = h 324 | graph.apply_edges(self.apply_edges) 325 | return graph.edata['score'] 326 | 327 | class MLPPredictor_E2E(nn.Module): 328 | def __init__(self, in_features, hidden_dim, out_classes, dropout, edge_pred_features): 329 | super().__init__() 330 | self.out = out_classes 331 | self.W1 = nn.Linear(in_features*2 + edge_pred_features, hidden_dim) 332 | self.norm = nn.LayerNorm(hidden_dim) 333 | self.W2 = nn.Linear(hidden_dim, out_classes) 334 | self.drop = nn.Dropout(dropout) 335 | 336 | def apply_edges(self, edges): 337 | h_u = edges.src['h'] 338 | h_v = edges.dst['h'] 339 | cls_u = F.softmax(edges.src['cls'], dim=1) 340 | cls_v = F.softmax(edges.dst['cls'], dim=1) 341 | polar = edges.data['feat'] 342 | 343 | x = F.relu(self.norm(self.W1(torch.cat((h_u, cls_u, polar, h_v, cls_v), dim=1)))) 344 | score = self.drop(self.W2(x)) 345 | 346 | return {'score': score} 347 | 348 | def forward(self, graph, h, cls): 349 | # h contains the node representations computed from the GNN defined 350 | # in the node classification section (Section 5.1). 351 | with graph.local_scope(): 352 | graph.ndata['h'] = h 353 | graph.ndata['cls'] = cls 354 | graph.apply_edges(self.apply_edges) 355 | return graph.edata['score'] 356 | -------------------------------------------------------------------------------- /src/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet -------------------------------------------------------------------------------- /src/models/unet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segmentation_models_pytorch.base import modules as md 6 | 7 | 8 | class DecoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | skip_channels, 13 | out_channels, 14 | use_batchnorm=True, 15 | attention_type=None, 16 | ): 17 | super().__init__() 18 | self.conv1 = md.Conv2dReLU( 19 | in_channels + skip_channels, 20 | out_channels, 21 | kernel_size=3, 22 | padding=1, 23 | use_batchnorm=use_batchnorm, 24 | ) 25 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 26 | self.conv2 = md.Conv2dReLU( 27 | out_channels, 28 | out_channels, 29 | kernel_size=3, 30 | padding=1, 31 | use_batchnorm=use_batchnorm, 32 | ) 33 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 34 | 35 | def forward(self, x, skip=None): 36 | if skip is not None: 37 | x = F.interpolate(x, size=skip.shape[2:], mode="nearest") 38 | x = torch.cat([x, skip], dim=1) 39 | x = self.attention1(x) 40 | else: 41 | x = F.interpolate(x, scale_factor=2, mode="nearest") 42 | x = self.conv1(x) 43 | x = self.conv2(x) 44 | x = self.attention2(x) 45 | return x 46 | 47 | 48 | class CenterBlock(nn.Sequential): 49 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 50 | conv1 = md.Conv2dReLU( 51 | in_channels, 52 | out_channels, 53 | kernel_size=3, 54 | padding=1, 55 | use_batchnorm=use_batchnorm, 56 | ) 57 | conv2 = md.Conv2dReLU( 58 | out_channels, 59 | out_channels, 60 | kernel_size=3, 61 | padding=1, 62 | use_batchnorm=use_batchnorm, 63 | ) 64 | super().__init__(conv1, conv2) 65 | 66 | 67 | class UnetDecoder(nn.Module): 68 | def __init__( 69 | self, 70 | encoder_channels, 71 | decoder_channels, 72 | n_blocks=5, 73 | use_batchnorm=True, 74 | attention_type=None, 75 | center=False, 76 | ): 77 | super().__init__() 78 | 79 | if n_blocks != len(decoder_channels): 80 | raise ValueError( 81 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 82 | n_blocks, len(decoder_channels) 83 | ) 84 | ) 85 | 86 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 87 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 88 | 89 | # computing blocks input and output channels 90 | head_channels = encoder_channels[0] 91 | in_channels = [head_channels] + list(decoder_channels[:-1]) 92 | skip_channels = list(encoder_channels[1:]) + [0] 93 | out_channels = decoder_channels 94 | 95 | if center: 96 | self.center = CenterBlock( 97 | head_channels, head_channels, use_batchnorm=use_batchnorm 98 | ) 99 | else: 100 | self.center = nn.Identity() 101 | 102 | # combine decoder keyword arguments 103 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 104 | blocks = [ 105 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 106 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 107 | ] 108 | self.blocks = nn.ModuleList(blocks) 109 | 110 | def forward(self, *features): 111 | 112 | features = features[1:] # remove first skip with same spatial resolution 113 | features = features[::-1] # reverse channels to start from head of encoder 114 | 115 | head = features[0] 116 | skips = features[1:] 117 | 118 | x = self.center(head) 119 | for i, decoder_block in enumerate(self.blocks): 120 | skip = skips[i] if i < len(skips) else None 121 | x = decoder_block(x, skip) 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /src/models/unet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | from .decoder import UnetDecoder 3 | from segmentation_models_pytorch.encoders import get_encoder 4 | from segmentation_models_pytorch.base import SegmentationModel 5 | from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead 6 | 7 | 8 | class Unet(SegmentationModel): 9 | """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 10 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 11 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* 12 | for fusing decoder blocks with skip connections. 13 | 14 | Args: 15 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 16 | to extract features of different spatial resolution 17 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 18 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 19 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 20 | Default is 5 21 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 22 | other pretrained weights (see table with available weights for each encoder_name) 23 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 24 | Length of the list should be the same as **encoder_depth** 25 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 26 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 27 | Available options are **True, False, "inplace"** 28 | decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse**. 29 | SCSE paper - https://arxiv.org/abs/1808.08127 30 | in_channels: A number of input channels for the model, default is 3 (RGB images) 31 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 32 | activation: An activation function to apply after the final convolution layer. 33 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 34 | Default is **None** 35 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 36 | on top of encoder if **aux_params** is not **None** (default). Supported params: 37 | - classes (int): A number of classes 38 | - pooling (str): One of "max", "avg". Default is "avg" 39 | - dropout (float): Dropout factor in [0, 1) 40 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 41 | 42 | Returns: 43 | ``torch.nn.Module``: Unet 44 | 45 | .. _Unet: 46 | https://arxiv.org/abs/1505.04597 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | encoder_name: str = "resnet34", 53 | encoder_depth: int = 5, 54 | encoder_weights: Optional[str] = "imagenet", 55 | decoder_use_batchnorm: bool = True, 56 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 57 | decoder_attention_type: Optional[str] = None, 58 | in_channels: int = 3, 59 | classes: int = 1, 60 | activation: Optional[Union[str, callable]] = None, 61 | aux_params: Optional[dict] = None, 62 | ): 63 | super().__init__() 64 | 65 | self.encoder = get_encoder( 66 | encoder_name, 67 | in_channels=in_channels, 68 | depth=encoder_depth, 69 | weights=encoder_weights, 70 | ) 71 | 72 | self.decoder = UnetDecoder( 73 | encoder_channels=self.encoder.out_channels, 74 | decoder_channels=decoder_channels, 75 | n_blocks=encoder_depth, 76 | use_batchnorm=decoder_use_batchnorm, 77 | center=True if encoder_name.startswith("vgg") else False, 78 | attention_type=decoder_attention_type, 79 | ) 80 | 81 | self.segmentation_head = SegmentationHead( 82 | in_channels=decoder_channels[-1], 83 | out_channels=classes, 84 | activation=activation, 85 | kernel_size=3, 86 | ) 87 | 88 | if aux_params is not None: 89 | self.classification_head = ClassificationHead( 90 | in_channels=self.encoder.out_channels[-1], **aux_params 91 | ) 92 | else: 93 | self.classification_head = None 94 | 95 | self.name = "u-{}".format(encoder_name) 96 | self.initialize() 97 | -------------------------------------------------------------------------------- /src/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from dotenv import dotenv_values 3 | import os 4 | 5 | # ROOT 6 | HERE = Path(os.path.dirname(os.path.abspath(__file__))) 7 | config = dotenv_values(HERE / "root.env") 8 | ROOT = Path(config['ROOT']) 9 | 10 | # PROJECT TREE 11 | DATA = ROOT / 'DATA' 12 | CONFIGS = ROOT / 'configs' 13 | CFGM = CONFIGS / 'models' 14 | OUTPUTS = ROOT / 'outputs' 15 | RUNS = OUTPUTS / 'runs' 16 | RESULTS = OUTPUTS / 'results' 17 | IMGS = OUTPUTS / 'images' 18 | TRAIN_SAMPLES = OUTPUTS / 'train_samples' 19 | TEST_SAMPLES = OUTPUTS / 'test_samples' 20 | TRAINING = ROOT / 'src' / 'training' 21 | MODELS = ROOT / 'src' / 'models' 22 | CHECKPOINTS = MODELS / 'checkpoints' 23 | INFERENCE = ROOT / 'inference' 24 | 25 | # FUNSD 26 | FUNSD_TRAIN = DATA / 'FUNSD' / 'training_data' 27 | FUNSD_TEST = DATA / 'FUNSD' / 'testing_data' 28 | 29 | # PAU 30 | PAU_TRAIN = DATA / 'PAU' / 'train' 31 | PAU_TEST = DATA / 'PAU' / 'test' 32 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/funsd.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sklearn.model_selection import KFold, ShuffleSplit 3 | import torch 4 | from torch.nn import functional as F 5 | from random import shuffle, seed 6 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 7 | import dgl 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torchvision import transforms 10 | import time 11 | from statistics import mean 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from src.data.dataloader import Document2Graph 16 | from src.paths import * 17 | from src.models.graphs import SetModel 18 | from src.utils import get_config 19 | from src.training.utils import * 20 | from src.data.graph_builder import GraphBuilder 21 | 22 | def e2e(args): 23 | 24 | # configs 25 | start_training = time.time() 26 | cfg_train = get_config('train') 27 | seed(cfg_train.seed) 28 | device = get_device(args.gpu) 29 | sm = SetModel(name=args.model, device=device) 30 | 31 | if not args.test: 32 | ################* STEP 0: LOAD DATA ################ 33 | data = Document2Graph(name='FUNSD TRAIN', src_path=FUNSD_TRAIN, device = device, output_dir=TRAIN_SAMPLES) 34 | data.get_info() 35 | 36 | ss = KFold(n_splits=10, shuffle=True, random_state=cfg_train.seed) 37 | cv_indices = ss.split(data.graphs) 38 | 39 | models = [] 40 | train_index, val_index = next(ss.split(data.graphs)) 41 | 42 | for cvs in cv_indices: 43 | 44 | train_index, val_index = cvs 45 | 46 | # TRAIN 47 | train_graphs = [data.graphs[i] for i in train_index] 48 | tg = dgl.batch(train_graphs) 49 | tg = tg.int().to(device) 50 | 51 | val_graphs = [data.graphs[i] for i in val_index] 52 | vg = dgl.batch(val_graphs) 53 | vg = vg.int().to(device) 54 | 55 | ################* STEP 1: CREATE MODEL ################ 56 | model = sm.get_model(data.node_num_classes, data.edge_num_classes, data.get_chunks()) 57 | optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay)) 58 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=400, min_lr=1e-3, verbose=True, factor=0.01) 59 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 60 | e = datetime.now() 61 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}' 62 | models.append(train_name+'.pt') 63 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=2000) 64 | # writer = SummaryWriter(log_dir=RUNS) 65 | # convert_imgs = transforms.ToTensor() 66 | 67 | ################* STEP 2: TRAINING ################ 68 | print("\n### TRAINING ###") 69 | print(f"-> Training samples: {tg.batch_size}") 70 | print(f"-> Validation samples: {vg.batch_size}\n") 71 | 72 | # im_step = 0 73 | for epoch in range(cfg_train.epochs): 74 | 75 | #* TRAINING 76 | model.train() 77 | 78 | n_scores, e_scores = model(tg, tg.ndata['feat'].to(device)) 79 | n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device)) 80 | e_loss = compute_crossentropy_loss(e_scores.to(device), tg.edata['label'].to(device)) 81 | tot_loss = n_loss + e_loss 82 | macro, micro = get_f1(n_scores, tg.ndata['label'].to(device)) 83 | auc = compute_auc_mc(e_scores.to(device), tg.edata['label'].to(device)) 84 | 85 | 86 | optimizer.zero_grad() 87 | tot_loss.backward() 88 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 89 | optimizer.step() 90 | 91 | #* VALIDATION 92 | model.eval() 93 | with torch.no_grad(): 94 | val_n_scores, val_e_scores = model(vg, vg.ndata['feat'].to(device)) 95 | val_n_loss = compute_crossentropy_loss(val_n_scores.to(device), vg.ndata['label'].to(device)) 96 | val_e_loss = compute_crossentropy_loss(val_e_scores.to(device), vg.edata['label'].to(device)) 97 | val_tot_loss = val_n_loss + val_e_loss 98 | val_macro, _ = get_f1(val_n_scores, vg.ndata['label'].to(device)) 99 | val_auc = compute_auc_mc(val_e_scores.to(device), vg.edata['label'].to(device)) 100 | 101 | # scheduler.step(val_auc) 102 | # scheduler.step() 103 | 104 | #* PRINTING IMAGEs AND RESULTS 105 | 106 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainF1-MACRO {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValF1-MACRO {:.4f} | ValAUC-PR {:.4f} |" 107 | .format(epoch, tot_loss.item(), macro, auc, val_tot_loss.item(), val_macro, val_auc)) 108 | 109 | if cfg_train.stopper_metric == 'loss': 110 | step_value = val_tot_loss.item() 111 | elif cfg_train.stopper_metric == 'acc': 112 | step_value = val_auc 113 | 114 | # if val_auc > best_val_auc: 115 | # best_val_auc = val_auc 116 | # best_model = train_name 117 | 118 | ss = stopper.step(step_value) 119 | 120 | # if ss == 'improved': 121 | # im_step = epoch 122 | # train_imgs = [] 123 | # for r in rand_tid: 124 | # start, end = 0, 0 125 | # for tid in train_index: 126 | # start = end 127 | # end += data.graphs[tid].num_edges() 128 | # if tid == r: break 129 | 130 | # _, targets = torch.max(F.log_softmax(e_scores[start:end], dim=1), dim=1) 131 | # kvp_ids = targets.nonzero().flatten().tolist() 132 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700]) 133 | # # data.print_graph(num=r, name=f'train_labels_{r}') 134 | 135 | # val_imgs = [] 136 | # for r in rand_vid: 137 | # v_start, v_end = 0, 0 138 | # for vid in val_index: 139 | # v_start = v_end 140 | # v_end += data.graphs[vid].num_edges() 141 | # if vid == r: break 142 | 143 | # _, val_targets = torch.max(F.log_softmax(val_e_scores[v_start:v_end], dim=1), dim=1) 144 | # val_kvp_ids = val_targets.nonzero().flatten().tolist() 145 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700]) 146 | # data.print_graph(num=r, name=f'val_labels_{r}') 147 | 148 | if ss == 'stop': 149 | break 150 | 151 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch) 152 | # writer.add_scalars('LOSS', {'train': tot_loss.item(), 'val': val_tot_loss.item()}, epoch) 153 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch) 154 | 155 | # train_grid = torchvision.utils.make_grid(train_imgs) 156 | # writer.add_image('train_images', train_grid, im_step) 157 | # val_grid = torchvision.utils.make_grid(val_imgs) 158 | # writer.add_image('val_images', val_grid, im_step) 159 | 160 | else: 161 | ################* SKIP TRAINING ################ 162 | print("\n### SKIP TRAINING ###") 163 | print(f"-> loading {args.weights}") 164 | models = args.weights 165 | 166 | ################* STEP 3: TESTING ################ 167 | print("\n### TESTING ###") 168 | 169 | #? test 170 | test_data = Document2Graph(name='FUNSD TEST', src_path=FUNSD_TEST, device = device, output_dir=TEST_SAMPLES) 171 | test_data.get_info() 172 | 173 | model = sm.get_model(test_data.node_num_classes, test_data.edge_num_classes, test_data.get_chunks()) 174 | best_model = '' 175 | nodes_micro = [] 176 | edges_f1 = [] 177 | test_graph = dgl.batch(test_data.graphs).to(device) 178 | 179 | for m in models: 180 | model.load_state_dict(torch.load(CHECKPOINTS / m)) 181 | model.eval() 182 | with torch.no_grad(): 183 | 184 | n, e = model(test_graph, test_graph.ndata['feat'].to(device)) 185 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device)) 186 | _, preds = torch.max(F.softmax(e, dim=1), dim=1) 187 | 188 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label']) 189 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True) 190 | edges_f1.append(classes_f1[1]) 191 | 192 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device)) 193 | nodes_micro.append(micro) 194 | if classes_f1[1] >= max(edges_f1): 195 | best_model = m 196 | 197 | test_graph.edata['preds'] = preds 198 | 199 | ################* STEP 4: RESULTS ################ 200 | print("\n### RESULTS {} ###".format(m)) 201 | print("F1 Edges: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1])) 202 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro)) 203 | 204 | print(f"\n -> Loading best model {best_model}") 205 | model.load_state_dict(torch.load(CHECKPOINTS / best_model)) 206 | model.eval() 207 | with torch.no_grad(): 208 | 209 | n, e = model(test_graph, test_graph.ndata['feat'].to(device)) 210 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device)) 211 | 212 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1) 213 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1) 214 | test_graph.edata['preds'] = epreds 215 | test_graph.ndata['preds'] = npreds 216 | test_graph.ndata['net'] = n 217 | 218 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label']) 219 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True) 220 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device)) 221 | 222 | # ################* STEP 4: RESULTS ################ 223 | print("\n### BEST RESULTS ###") 224 | print("AUC {:.4f}".format(auc)) 225 | print("Accuracy {:.4f}".format(accuracy)) 226 | print("F1 Edges: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1])) 227 | print("F1 Edges: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1])) 228 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro)) 229 | 230 | print("\n### AVG RESULTS ###") 231 | print("Semantic Entity Labeling: MEAN ", mean(nodes_micro), " STD: ", np.std(nodes_micro)) 232 | print("Entity Linking: MEAN ", mean(edges_f1),"STD", np.std(edges_f1)) 233 | 234 | if not args.test: 235 | feat_n, feat_e = get_features(args) 236 | #? if skipping training, no need to save anything 237 | model = get_config(CFGM / args.model) 238 | results = {'MODEL': { 239 | 'name': sm.get_name(), 240 | 'weights': best_model, 241 | 'net-params': sm.get_total_params(), 242 | 'num-layers': model.num_layers, 243 | 'projector-output': model.out_chunks, 244 | 'dropout': model.dropout, 245 | 'lastFC': model.hidden_dim 246 | }, 247 | 'FEATURES': { 248 | 'nodes': feat_n, 249 | 'edges': feat_e 250 | }, 251 | 'PARAMS': { 252 | 'start-lr': cfg_train.lr, 253 | 'weight-decay': cfg_train.weight_decay, 254 | 'seed': cfg_train.seed 255 | }, 256 | 'RESULTS': { 257 | 'val-loss': stopper.best_score, 258 | 'f1-scores': f1, 259 | 'f1-classes': classes_f1, 260 | 'nodes-f1': [macro, micro], 261 | 'std-pairs': np.std(edges_f1), 262 | 'mean-pairs': mean(edges_f1) 263 | }} 264 | save_test_results(train_name, results) 265 | 266 | print("END TRAINING:", time.time() - start_training) 267 | return {'LINKS [MAX, MEAN, STD]': [classes_f1[1], mean(edges_f1), np.std(edges_f1)], 'NODES [MAX, MEAN, STD]': [micro, mean(nodes_micro), np.std(nodes_micro)]} 268 | 269 | def entity_linking(args): 270 | 271 | # configs 272 | start_training = time.time() 273 | cfg_train = get_config('train') 274 | seed(cfg_train.seed) 275 | device = get_device(args.gpu) 276 | sm = SetModel(name=args.model, device=device) 277 | 278 | if not args.test: 279 | ################* STEP 0: LOAD DATA ################ 280 | data = Document2Graph(name='FUNSD TRAIN', src_path=FUNSD_TRAIN, device = device, output_dir=TRAIN_SAMPLES) 281 | data.get_info() 282 | 283 | ss = KFold(n_splits=10, shuffle=True, random_state=cfg_train.seed) 284 | cv_indices = ss.split(data.graphs) 285 | 286 | models = [] 287 | train_index, val_index = next(ss.split(data.graphs)) 288 | 289 | for cvs in cv_indices: 290 | 291 | train_index, val_index = cvs 292 | 293 | # TRAIN 294 | train_graphs = [data.graphs[i] for i in train_index] 295 | tg = dgl.batch(train_graphs) 296 | tg = tg.int().to(device) 297 | 298 | val_graphs = [data.graphs[i] for i in val_index] 299 | vg = dgl.batch(val_graphs) 300 | vg = vg.int().to(device) 301 | 302 | ################* STEP 1: CREATE MODEL ################ 303 | model = sm.get_model(None, 2, data.get_chunks()) 304 | optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay)) 305 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=100, min_lr=1e-3, verbose=True, factor=0.01) 306 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 307 | e = datetime.now() 308 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}' 309 | models.append(train_name+'.pt') 310 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=1000) 311 | # writer = SummaryWriter(log_dir=RUNS) 312 | # convert_imgs = transforms.ToTensor() 313 | 314 | ################* STEP 2: TRAINING ################ 315 | print("\n### TRAINING ###") 316 | print(f"-> Training samples: {tg.batch_size}") 317 | print(f"-> Validation samples: {vg.batch_size}\n") 318 | 319 | # im_step = 0 320 | for epoch in range(cfg_train.epochs): 321 | 322 | #* TRAINING 323 | model.train() 324 | 325 | scores = model(tg, tg.ndata['feat'].to(device)) 326 | loss = compute_crossentropy_loss(scores.to(device), tg.edata['label'].to(device)) 327 | auc = compute_auc_mc(scores.to(device), tg.edata['label'].to(device)) 328 | 329 | optimizer.zero_grad() 330 | loss.backward() 331 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 332 | optimizer.step() 333 | 334 | #* VALIDATION 335 | model.eval() 336 | with torch.no_grad(): 337 | val_scores = model(vg, vg.ndata['feat'].to(device)) 338 | val_loss = compute_crossentropy_loss(val_scores.to(device), vg.edata['label'].to(device)) 339 | val_auc = compute_auc_mc(val_scores.to(device), vg.edata['label'].to(device)) 340 | 341 | # scheduler.step(val_auc) 342 | # scheduler.step() 343 | 344 | #* PRINTING IMAGEs AND RESULTS 345 | 346 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValAUC-PR {:.4f} |" 347 | .format(epoch, loss.item(), auc, val_loss.item(), val_auc)) 348 | 349 | if cfg_train.stopper_metric == 'loss': 350 | step_value = val_loss.item() 351 | elif cfg_train.stopper_metric == 'acc': 352 | step_value = val_auc 353 | 354 | ss = stopper.step(step_value) 355 | 356 | # if ss == 'improved': 357 | # im_step = epoch 358 | # train_imgs = [] 359 | # for r in rand_tid: 360 | # start, end = 0, 0 361 | # for tid in train_index: 362 | # start = end 363 | # end += data.graphs[tid].num_edges() 364 | # if tid == r: break 365 | 366 | # _, targets = torch.max(F.log_softmax(scores[start:end], dim=1), dim=1) 367 | # kvp_ids = targets.nonzero().flatten().tolist() 368 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700]) 369 | # # data.print_graph(num=r, name=f'train_labels_{r}') 370 | 371 | # val_imgs = [] 372 | # for r in rand_vid: 373 | # v_start, v_end = 0, 0 374 | # for vid in val_index: 375 | # v_start = v_end 376 | # v_end += data.graphs[vid].num_edges() 377 | # if vid == r: break 378 | 379 | # _, val_targets = torch.max(F.log_softmax(val_scores[v_start:v_end], dim=1), dim=1) 380 | # val_kvp_ids = val_targets.nonzero().flatten().tolist() 381 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700]) 382 | # data.print_graph(num=r, name=f'val_labels_{r}') 383 | 384 | if ss == 'stop': 385 | break 386 | 387 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch) 388 | # writer.add_scalars('LOSS', {'train': loss.item(), 'val': val_loss.item()}, epoch) 389 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch) 390 | 391 | # train_grid = torchvision.utils.make_grid(train_imgs) 392 | # writer.add_image('train_images', train_grid, im_step) 393 | # val_grid = torchvision.utils.make_grid(val_imgs) 394 | # writer.add_image('val_images', val_grid, im_step) 395 | 396 | # print("LOADING: ", train_name+'.pt') 397 | # model.load_state_dict(torch.load(WEIGHTS / (train_name+'.pt'))) 398 | 399 | else: 400 | ################* SKIP TRAINING ################ 401 | print("\n### SKIP TRAINING ###") 402 | print(f"-> loading {args.weights}") 403 | models = args.weights 404 | 405 | 406 | ################* STEP 3: TESTING ################ 407 | print("\n### TESTING ###") 408 | 409 | #? test 410 | test_data = Document2Graph(name='FUNSD TEST', src_path=FUNSD_TEST, device = device, output_dir=TEST_SAMPLES) 411 | test_data.get_info() 412 | model = sm.get_model(None, 2, test_data.get_chunks()) 413 | best_model = '' 414 | pair_scores = [] 415 | test_graph = dgl.batch(test_data.graphs).to(device) 416 | 417 | 418 | for m in models: 419 | model.load_state_dict(torch.load(CHECKPOINTS / m)) 420 | model.eval() 421 | with torch.no_grad(): 422 | 423 | scores = model(test_graph, test_graph.ndata['feat'].to(device)) 424 | auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device)) 425 | 426 | _, preds = torch.max(F.softmax(scores, dim=1), dim=1) 427 | 428 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label']) 429 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True) 430 | 431 | pair_scores.append(classes_f1[1]) 432 | if classes_f1[1] >= max(pair_scores): 433 | best_model = m 434 | 435 | ################* STEP 4: RESULTS ################ 436 | print(f"\n### RESULTS {m} ###") 437 | print("F1 Score: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1])) 438 | print("F1 Classes: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1])) 439 | 440 | print(f"\nLoading best model {best_model}") 441 | model.load_state_dict(torch.load(CHECKPOINTS / best_model)) 442 | model.eval() 443 | with torch.no_grad(): 444 | 445 | scores = model(test_graph, test_graph.ndata['feat'].to(device)) 446 | auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device)) 447 | 448 | _, preds = torch.max(F.softmax(scores, dim=1), dim=1) 449 | test_graph.edata['preds'] = preds 450 | 451 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label']) 452 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True) 453 | 454 | if not args.test: 455 | feat_n, feat_e = get_features(args) 456 | #? if skipping training, no need to save anything 457 | model = get_config(CFGM / args.model) 458 | results = {'MODEL': { 459 | 'name': sm.get_name(), 460 | 'weights': best_model, 461 | 'net-params': sm.get_total_params(), 462 | 'num-layers': model.num_layers, 463 | 'projector-output': model.out_chunks, 464 | 'dropout': model.dropout, 465 | 'lastFC': model.hidden_dim 466 | }, 467 | 'FEATURES': { 468 | 'nodes': feat_n, 469 | 'edges': feat_e 470 | }, 471 | 'PARAMS': { 472 | 'start-lr': cfg_train.lr, 473 | 'weight-decay': cfg_train.weight_decay, 474 | 'seed': cfg_train.seed 475 | }, 476 | 'RESULTS': { 477 | 'val-loss': stopper.best_score, 478 | 'f1-scores': f1, 479 | 'f1-classes': classes_f1, 480 | 'AUC-PR': auc, 481 | 'ACCURACY': accuracy, 482 | 'std-pairs': np.std(pair_scores), 483 | 'mean-pairs': mean(pair_scores) 484 | }} 485 | save_test_results(train_name, results) 486 | print("END TRAINING:", time.time() - start_training) 487 | 488 | return {'best_model': best_model, 'Pairs-F1': {'max': max(pair_scores), 'mean': mean(pair_scores), 'std': np.std(pair_scores)}} 489 | 490 | def train_funsd(args): 491 | 492 | if args.model == 'e2e': 493 | e2e(args) 494 | elif args.model == 'edge': 495 | entity_linking(args) 496 | else: 497 | raise Exception("Model selected does not exists. Choose 'e2e' or 'edge'.") 498 | return 499 | 500 | -------------------------------------------------------------------------------- /src/training/pau.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sklearn.model_selection import KFold, ShuffleSplit 3 | import torch 4 | from torch.nn import functional as F 5 | from random import seed 6 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 7 | import dgl 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torchvision import transforms 10 | import time 11 | from statistics import mean 12 | import numpy as np 13 | import xml.etree.ElementTree as ET 14 | from PIL import Image 15 | 16 | from src.data.dataloader import Document2Graph 17 | from src.data.preprocessing import match_pred_w_gt 18 | from src.paths import * 19 | from src.models.graphs import SetModel 20 | from src.utils import get_config 21 | from src.training.utils import * 22 | 23 | def e2e(args): 24 | 25 | # configs 26 | start_training = time.time() 27 | cfg_train = get_config('train') 28 | seed(cfg_train.seed) 29 | device = get_device(args.gpu) 30 | sm = SetModel(name=args.model, device=device) 31 | 32 | if not args.test: 33 | ################* STEP 0: LOAD DATA ################ 34 | data = Document2Graph(name='PAU TRAIN', src_path=PAU_TRAIN, device = device, output_dir=TRAIN_SAMPLES) 35 | data.get_info() 36 | 37 | ss = KFold(n_splits=7, shuffle=True, random_state=cfg_train.seed) 38 | cv_indices = ss.split(data.graphs) 39 | 40 | models = [] 41 | train_index, val_index = next(ss.split(data.graphs)) 42 | 43 | for cvs in cv_indices: 44 | 45 | train_index, val_index = cvs 46 | 47 | # TRAIN 48 | train_graphs = [data.graphs[i] for i in train_index] 49 | tg = dgl.batch(train_graphs) 50 | tg = tg.int().to(device) 51 | 52 | val_graphs = [data.graphs[i] for i in val_index] 53 | vg = dgl.batch(val_graphs) 54 | vg = vg.int().to(device) 55 | 56 | ################* STEP 1: CREATE MODEL ################ 57 | model = sm.get_model(data.node_num_classes, data.edge_num_classes, data.get_chunks()) 58 | optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay)) 59 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=400, min_lr=1e-3, verbose=True, factor=0.01) 60 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 61 | e = datetime.now() 62 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}' 63 | models.append(train_name+'.pt') 64 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=2000) 65 | # writer = SummaryWriter(log_dir=RUNS) 66 | # convert_imgs = transforms.ToTensor() 67 | 68 | ################* STEP 2: TRAINING ################ 69 | print("\n### TRAINING ###") 70 | print(f"-> Training samples: {tg.batch_size}") 71 | print(f"-> Validation samples: {vg.batch_size}\n") 72 | 73 | # im_step = 0 74 | for epoch in range(cfg_train.epochs): 75 | 76 | #* TRAINING 77 | model.train() 78 | 79 | n_scores, e_scores = model(tg, tg.ndata['feat'].to(device)) 80 | n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device)) 81 | e_loss = compute_crossentropy_loss(e_scores.to(device), tg.edata['label'].to(device)) 82 | tot_loss = n_loss + e_loss 83 | 84 | macro, micro = get_f1(n_scores, tg.ndata['label'].to(device)) 85 | auc = compute_auc_mc(e_scores.to(device), tg.edata['label'].to(device)) 86 | 87 | optimizer.zero_grad() 88 | tot_loss.backward() 89 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 90 | optimizer.step() 91 | 92 | #* VALIDATION 93 | model.eval() 94 | with torch.no_grad(): 95 | val_n_scores, val_e_scores = model(vg, vg.ndata['feat'].to(device)) 96 | val_n_loss = compute_crossentropy_loss(val_n_scores.to(device), vg.ndata['label'].to(device)) 97 | val_e_loss = compute_crossentropy_loss(val_e_scores.to(device), vg.edata['label'].to(device)) 98 | val_tot_loss = val_n_loss + val_e_loss 99 | val_macro, val_micro = get_f1(val_n_scores, vg.ndata['label'].to(device)) 100 | val_auc = compute_auc_mc(val_e_scores.to(device), vg.edata['label'].to(device)) 101 | 102 | # scheduler.step(val_auc) 103 | # scheduler.step() 104 | 105 | #* PRINTING IMAGEs AND RESULTS 106 | 107 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainF1-MACRO {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValF1-MACRO {:.4f} | ValAUC-PR {:.4f} |" 108 | .format(epoch, tot_loss.item(), macro, auc, val_tot_loss.item(), val_macro, val_auc)) 109 | 110 | if cfg_train.stopper_metric == 'loss': 111 | step_value = val_tot_loss.item() 112 | elif cfg_train.stopper_metric == 'acc': 113 | step_value = val_micro 114 | 115 | #if val_auc > best_val_auc: 116 | # best_val_auc = val_auc 117 | # best_model = train_name 118 | 119 | ss = stopper.step(step_value) 120 | 121 | # if ss == 'improved': 122 | # im_step = epoch 123 | # train_imgs = [] 124 | # for r in rand_tid: 125 | # start, end = 0, 0 126 | # for tid in train_index: 127 | # start = end 128 | # end += data.graphs[tid].num_edges() 129 | # if tid == r: break 130 | 131 | # _, targets = torch.max(F.log_softmax(e_scores[start:end], dim=1), dim=1) 132 | # kvp_ids = targets.nonzero().flatten().tolist() 133 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700]) 134 | # # data.print_graph(num=r, name=f'train_labels_{r}') 135 | 136 | # val_imgs = [] 137 | # for r in rand_vid: 138 | # v_start, v_end = 0, 0 139 | # for vid in val_index: 140 | # v_start = v_end 141 | # v_end += data.graphs[vid].num_edges() 142 | # if vid == r: break 143 | 144 | # _, val_targets = torch.max(F.log_softmax(val_e_scores[v_start:v_end], dim=1), dim=1) 145 | # val_kvp_ids = val_targets.nonzero().flatten().tolist() 146 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700]) 147 | # data.print_graph(num=r, name=f'val_labels_{r}') 148 | 149 | if ss == 'stop': 150 | break 151 | 152 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch) 153 | # writer.add_scalars('LOSS', {'train': tot_loss.item(), 'val': val_tot_loss.item()}, epoch) 154 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch) 155 | 156 | # train_grid = torchvision.utils.make_grid(train_imgs) 157 | # writer.add_image('train_images', train_grid, im_step) 158 | # val_grid = torchvision.utils.make_grid(val_imgs) 159 | # writer.add_image('val_images', val_grid, im_step) 160 | 161 | # print("LOADING: ", best_model+'.pt') 162 | # model.load_state_dict(torch.load(WEIGHTS / (best_model+'.pt'))) 163 | 164 | else: 165 | ################* SKIP TRAINING ################ 166 | print("\n### SKIP TRAINING ###") 167 | print(f"-> loading {args.weights}") 168 | models = args.weights 169 | 170 | ################* STEP 3: TESTING ################ 171 | print("\n### TESTING ###") 172 | 173 | #? test 174 | test_data = Document2Graph(name='PAU TEST', src_path=PAU_TEST, device = device, output_dir=TEST_SAMPLES) 175 | test_data.get_info() 176 | model = sm.get_model(test_data.node_num_classes, test_data.edge_num_classes, test_data.get_chunks()) 177 | best_model = '' 178 | nodes_micro = [] 179 | edges_f1 = [] 180 | test_graph = dgl.batch(test_data.graphs).to(device) 181 | 182 | all_precisions = [] 183 | all_recalls = [] 184 | all_f1 = [] 185 | 186 | for i, m in enumerate(models): 187 | model.load_state_dict(torch.load(CHECKPOINTS / m)) 188 | model.eval() 189 | with torch.no_grad(): 190 | 191 | n, e = model(test_graph, test_graph.ndata['feat'].to(device)) 192 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device)) 193 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1) 194 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1) 195 | 196 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label']) 197 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True) 198 | edges_f1.append(classes_f1[1]) 199 | 200 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device)) 201 | nodes_micro.append(micro) 202 | if micro >= max(nodes_micro): 203 | best_model = i 204 | 205 | test_graph.edata['preds'] = epreds 206 | test_graph.ndata['preds'] = npreds 207 | t_f1 = 0 208 | t_precision = 0 209 | t_recall = 0 210 | no_table = 0 211 | tables = 0 212 | 213 | for g, graph in enumerate(dgl.unbatch(test_graph)): 214 | etargets = graph.edata['preds'] 215 | ntargets = graph.ndata['preds'] 216 | kvp_ids = etargets.nonzero().flatten().tolist() 217 | 218 | table_g = dgl.edge_subgraph(graph, torch.tensor(kvp_ids, dtype=torch.int32).to(device)) 219 | table_nodes = table_g.ndata['geom'] 220 | try: 221 | table_topleft, _ = torch.min(table_nodes, 0) 222 | table_bottomright, _ = torch.max(table_nodes, 0) 223 | table = torch.cat([table_topleft[:2], table_bottomright[2:]], 0) 224 | except: 225 | table = None 226 | 227 | img_path = test_data.paths[g] 228 | w, h = Image.open(img_path).size 229 | gt_path = img_path.split(".")[0] 230 | 231 | root = ET.parse(gt_path + '_gt.xml').getroot() 232 | regions = [] 233 | for parent in root: 234 | if parent.tag.split("}")[1] == 'Page': 235 | for child in parent: 236 | # print(file_gt) 237 | region_label = child[0].attrib['value'] 238 | if region_label != 'positions': continue 239 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]), 240 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]), 241 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]), 242 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])] 243 | regions.append([region_label, region_bbox]) 244 | 245 | table_regions = [region[1] for region in regions if region[0]=='positions'] 246 | if table is None and len(table_regions) !=0: 247 | t_f1 += 0 248 | t_precision += 0 249 | t_recall += 0 250 | tables += len(table_regions) 251 | elif table is None and len(table_regions) == 0: 252 | no_table -= 1 253 | continue 254 | elif table is not None and len(table_regions) ==0: 255 | t_f1 += 0 256 | t_precision += 0 257 | t_recall += 0 258 | no_table -= 1 259 | else: 260 | table = [[t[0]*w, t[1]*h, t[2]*w, t[3]*h] for t in [table.flatten().tolist()]][0] 261 | # d = match_pred_w_gt(torch.tensor(boxs_preds[idx]), torch.tensor(gt)) 262 | d = match_pred_w_gt(torch.tensor(table).view(-1, 4), torch.tensor(table_regions).view(-1, 4), []) 263 | bbox_true_positive = len(d["pred2gt"]) 264 | p = bbox_true_positive / (bbox_true_positive + len(d["false_positive"])) 265 | r = bbox_true_positive / (bbox_true_positive + len(d["false_negative"])) 266 | try: 267 | t_f1 += (2 * p * r) / (p + r) 268 | except: 269 | t_f1 += 0 270 | t_precision += p 271 | t_recall += r 272 | tables += len(table_regions) 273 | 274 | test_data.print_graph(num=g, node_labels = None, labels_ids=None, name=f'test_{g}', bidirect=False, regions=regions, preds=table) 275 | 276 | # test_data.print_graph(num=g, name=f'test_labels_{g}') 277 | t_recall = t_recall / (tables + no_table) 278 | t_precision = t_precision / (tables + no_table) 279 | t_f1 = (2 * t_precision * t_recall) / (t_precision + t_recall) 280 | all_precisions.append(t_precision) 281 | all_recalls.append(t_recall) 282 | all_f1.append(t_f1) 283 | 284 | ################* STEP 4: RESULTS ################ 285 | print("\n### RESULTS {} ###".format(m)) 286 | print("AUC {:.4f}".format(auc)) 287 | print("Accuracy {:.4f}".format(accuracy)) 288 | print("F1 Edges: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1])) 289 | print("F1 Edges: None {:.4f} - Table {:.4f}".format(classes_f1[0], classes_f1[1])) 290 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro)) 291 | 292 | print("\nTABLE DETECTION") 293 | print("PRECISION [MAX, MEAN, STD]:", max(all_precisions), mean(all_precisions), np.std(all_precisions)) 294 | print("RECALLS [MAX, MEAN, STD]:", max(all_recalls), mean(all_recalls), np.std(all_recalls)) 295 | print("F1s [MAX, MEAN, STD]:", max(all_f1), mean(all_f1), np.std(all_f1)) 296 | 297 | ################* BEST MODEL STATISTICS ################ 298 | print(f"\nLoading best model {models[best_model]}") 299 | print("F1 Edges: Table {:.4f}".format(edges_f1[best_model])) 300 | print("F1 Nodes: Micro {:.4f}".format(nodes_micro[best_model])) 301 | 302 | if not args.test: 303 | feat_n, feat_e = get_features(args) 304 | #? if skipping training, no need to save anything 305 | model = get_config(CFGM / args.model) 306 | results = {'MODEL': { 307 | 'name': sm.get_name(), 308 | 'weights': best_model, 309 | 'net-params': sm.get_total_params(), 310 | 'projector-output': model.out_chunks, 311 | 'dropout': model.dropout, 312 | 'lastFC': model.hidden_dim 313 | }, 314 | 'FEATURES': { 315 | 'nodes': feat_n, 316 | 'edges': feat_e 317 | }, 318 | 'PARAMS': { 319 | 'start-lr': cfg_train.lr, 320 | 'weight-decay': cfg_train.weight_decay, 321 | 'seed': cfg_train.seed 322 | }, 323 | 'RESULTS': { 324 | 'val-loss': stopper.best_score, 325 | 'f1-scores': f1, 326 | 'f1-classes': classes_f1, 327 | 'nodes-f1': [macro, micro], 328 | 'std-pairs': np.std(nodes_micro), 329 | 'mean-pairs': mean(nodes_micro), 330 | 'table-detection-precision': [max(all_precisions), mean(all_precisions), np.std(all_precisions)], 331 | 'table-detection-recall': [max(all_recalls), mean(all_recalls), np.std(all_recalls)], 332 | 'table-detection-f1': [max(all_f1), mean(all_f1), np.std(all_f1)] 333 | }} 334 | save_test_results(train_name, results) 335 | print("END TRAINING:", time.time() - start_training) 336 | 337 | return {'best_model': best_model, 'Nodes-F1': {'max': max(nodes_micro), 'mean': mean(nodes_micro), 'std': np.std(nodes_micro)}, 338 | 'Edges-F1': {'max': max(edges_f1), 'mean': mean(edges_f1), 'std': np.std(edges_f1)}, 339 | 'Table Detection': {'precision': [max(all_precisions), mean(all_precisions), np.std(all_precisions)], 'recall': [max(all_recalls), mean(all_recalls), np.std(all_recalls)], 'f1': [max(all_f1), mean(all_f1), np.std(all_f1)]}} 340 | 341 | def train_pau(args): 342 | 343 | if args.model == 'e2e': 344 | e2e(args) 345 | else: 346 | raise Exception("Model selected does not exists. Choose 'e2e'.") 347 | return 348 | 349 | """ 350 | ### OLD EVALUATIONS ### 351 | model.load_state_dict(torch.load(CHECKPOINTS / best_model)) 352 | model.eval() 353 | with torch.no_grad(): 354 | 355 | n, e= model(test_graph, test_graph.ndata['feat'].to(device)) 356 | # auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device)) 357 | 358 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1) 359 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1) 360 | 361 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label']) 362 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True) 363 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device)) 364 | 365 | test_graph.edata['preds'] = epreds 366 | test_graph.ndata['preds'] = npreds 367 | t_f1 = 0 368 | t_precision = 0 369 | t_recall = 0 370 | no_table = 0 371 | tables = 0 372 | 373 | for g, graph in enumerate(dgl.unbatch(test_graph)): 374 | etargets = graph.edata['preds'] 375 | ntargets = graph.ndata['preds'] 376 | kvp_ids = etargets.nonzero().flatten().tolist() 377 | 378 | table_g = dgl.edge_subgraph(graph, torch.tensor(kvp_ids, dtype=torch.int32).to(device)) 379 | table_nodes = table_g.ndata['geom'] 380 | try: 381 | table_topleft, _ = torch.min(table_nodes, 0) 382 | table_bottomright, _ = torch.max(table_nodes, 0) 383 | table = torch.cat([table_topleft[:2], table_bottomright[2:]], 0) 384 | except: 385 | table = None 386 | 387 | img_path = test_data.paths[g] 388 | w, h = Image.open(img_path).size 389 | gt_path = img_path.split(".")[0] 390 | 391 | root = ET.parse(gt_path + '_gt.xml').getroot() 392 | regions = [] 393 | for parent in root: 394 | if parent.tag.split("}")[1] == 'Page': 395 | for child in parent: 396 | # print(file_gt) 397 | region_label = child[0].attrib['value'] 398 | if region_label != 'positions': continue 399 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]), 400 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]), 401 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]), 402 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])] 403 | regions.append([region_label, region_bbox]) 404 | 405 | table_regions = [region[1] for region in regions if region[0]=='positions'] 406 | if table is None and len(table_regions) !=0: 407 | t_f1 += 0 408 | t_precision += 0 409 | t_recall += 0 410 | tables += len(table_regions) 411 | elif table is None and len(table_regions) == 0: 412 | no_table -= 1 413 | continue 414 | elif table is not None and len(table_regions) ==0: 415 | t_f1 += 0 416 | t_precision += 0 417 | t_recall += 0 418 | no_table -= 1 419 | else: 420 | table = [[t[0]*w, t[1]*h, t[2]*w, t[3]*h] for t in [table.flatten().tolist()]][0] 421 | # d = match_pred_w_gt(torch.tensor(boxs_preds[idx]), torch.tensor(gt)) 422 | d = match_pred_w_gt(torch.tensor(table).view(-1, 4), torch.tensor(table_regions).view(-1, 4), []) 423 | bbox_true_positive = len(d["pred2gt"]) 424 | p = bbox_true_positive / (bbox_true_positive + len(d["false_positive"])) 425 | r = bbox_true_positive / (bbox_true_positive + len(d["false_negative"])) 426 | try: 427 | t_f1 += (2 * p * r) / (p + r) 428 | except: 429 | t_f1 += 0 430 | t_precision += p 431 | t_recall += r 432 | tables += len(table_regions) 433 | 434 | test_data.print_graph(num=g, node_labels = None, labels_ids=None, name=f'test_{g}', bidirect=False, regions=regions, preds=table) 435 | 436 | t_recall = t_recall / (tables + no_table) 437 | t_precision = t_precision / (tables + no_table) 438 | t_f1 = (2 * t_precision * t_recall) / (t_precision + t_recall) 439 | print(t_precision, t_recall, t_f1) 440 | """ -------------------------------------------------------------------------------- /src/training/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from cProfile import label 3 | import os 4 | from pickletools import optimize 5 | from typing import Tuple 6 | from itsdangerous import json 7 | import pandas as pd 8 | import sklearn 9 | from sklearn.metrics import average_precision_score, confusion_matrix, f1_score, precision_recall_fscore_support, roc_auc_score, roc_curve 10 | from sklearn.utils import class_weight 11 | import torch 12 | import torch.nn 13 | import torch.nn.functional as F 14 | import dgl 15 | from datetime import datetime 16 | import shutil 17 | import yaml 18 | import numpy as np 19 | 20 | from src.paths import CHECKPOINTS, CONFIGS, OUTPUTS, RESULTS 21 | 22 | 23 | class EarlyStopping: 24 | """Early stop for training purposes, looking at validation loss. 25 | """ 26 | def __init__(self, model, name = '', metric = 'loss', patience=50): 27 | """Constructor. 28 | 29 | Args: 30 | model (DGLModel): graph or batch of graphs 31 | name (str): name for weights. 32 | metric (str): set the stopper, following loss ['loss'] or accuracy ['acc'] on validation 33 | patience (int, optional): if validation do not improve after 'patience' iters, it stops training. Defaults to 50. 34 | """ 35 | self.patience = patience 36 | self.counter = 0 37 | self.best_score = None 38 | self.early_stop = 'improved' 39 | self.model = model 40 | self.metric = metric 41 | e = datetime.now() 42 | if name == '': self.name = f'{e.strftime("%Y%m%d-%H%M")}' 43 | else: self.name = name 44 | 45 | def step(self, score : float) -> str: 46 | """ It does a step of the stopper. If metric does not encrease after a while, it stops the training. 47 | 48 | Args: 49 | score (float) : metric / value to keep track of. 50 | 51 | Returns 52 | A status used by traning to do things. 53 | """ 54 | 55 | if self.best_score is None: 56 | self.best_score = score 57 | self.save_checkpoint() 58 | return 'improved' 59 | 60 | if self.metric == 'loss': 61 | if score > self.best_score: 62 | self.counter += 1 63 | print(f' !- Stop Counter {self.counter} / {self.patience}') 64 | self.early_stop = 'not-improved' 65 | if self.counter >= self.patience: 66 | self.early_stop = 'stop' 67 | else: 68 | print(f' !- Validation LOSS decreased from {self.best_score} -> {score}') 69 | self.best_score = score 70 | self.save_checkpoint() 71 | self.counter = 0 72 | self.early_stop = 'improved' 73 | 74 | elif self.metric == 'acc': 75 | if score <= self.best_score: 76 | self.counter += 1 77 | print(f' !- Stop Counter {self.counter} / {self.patience}') 78 | self.early_stop = 'not-improved' 79 | if self.counter >= self.patience: 80 | self.early_stop = 'stop' 81 | else: 82 | print(f' !- Validation ACCURACY increased from {self.best_score} -> {score}') 83 | self.best_score = score 84 | self.save_checkpoint() 85 | self.counter = 0 86 | self.early_stop = 'improved' 87 | 88 | else: 89 | raise Exception('EarlyStopping Error: metric provided not valid. Select between loss or acc') 90 | 91 | return self.early_stop 92 | 93 | def save_checkpoint(self) -> None: 94 | '''Saves model when validation acc increase.''' 95 | torch.save(self.model.state_dict(), CHECKPOINTS / f'{self.name}.pt') 96 | 97 | def save_best_results(best_params : dict, rm_logs : bool = False) -> None: 98 | """Save best results for cross validation. 99 | 100 | Args: 101 | best_params (dict): best parameters among k-fold cross validation. 102 | rm_logs (bool, optional): Remove tmp weights in output folder if True. Defaults to False. 103 | """ 104 | models = OUTPUTS / 'tmp' 105 | output = CHECKPOINTS / best_params['model'] 106 | shutil.copyfile(models / best_params['model'], output) 107 | 108 | new_configs = CONFIGS / (best_params['model'].split(".")[0] + '.yaml') 109 | shutil.copyfile(CONFIGS / 'base.yaml', new_configs) 110 | 111 | with open(new_configs) as f: 112 | config = yaml.safe_load(f) 113 | 114 | config['MODEL']['num_layers'] = best_params['num_layers'] 115 | config['TRAIN']['batch_size'] = best_params['batch_size'] 116 | config['INFO'] = {'split': best_params['split'], 'val_loss': best_params['val_loss'], 'total_params': best_params['total_params']} 117 | 118 | with open(new_configs, 'w') as f: 119 | yaml.dump(config, f) 120 | 121 | if rm_logs and os.path.isdir(models): 122 | shutil.rmtree(models) 123 | 124 | return 125 | 126 | def save_test_results(filename : str, infos : dict) -> None: 127 | """Save test results. 128 | 129 | Args: 130 | filename (str): name of the file to save results of experiments 131 | infos (dict): what to save in the json file about training 132 | """ 133 | results = RESULTS / (filename + '.json') 134 | 135 | with open(results, 'w') as f: 136 | json.dump(infos, f) 137 | return 138 | 139 | def get_f1(logits : torch.Tensor, labels : torch.Tensor, per_class = False) -> tuple: 140 | """Returns Macro and Micro F1-score for given logits / labels. 141 | 142 | Args: 143 | logits (torch.Tensor): model prediction logits 144 | labels (torch.Tensor): target labels 145 | 146 | Returns: 147 | tuple: macro-f1 and micro-f1 148 | """ 149 | _, indices = torch.max(logits, dim=1) 150 | indices = indices.cpu().detach().numpy() 151 | labels = labels.cpu().detach().numpy() 152 | if not per_class: 153 | return f1_score(labels, indices, average='macro'), f1_score(labels, indices, average='micro') 154 | else: 155 | return precision_recall_fscore_support(labels, indices, average=None)[2].tolist() 156 | 157 | def get_binary_accuracy_and_f1(classes, labels : torch.Tensor, per_class = False) -> Tuple[float, list]: 158 | 159 | correct = torch.sum(classes.flatten() == labels) 160 | accuracy = correct.item() * 1.0 / len(labels) 161 | classes = classes.detach().cpu().numpy() 162 | labels = labels.cpu().numpy() 163 | 164 | if not per_class: 165 | f1 = f1_score(labels, classes, average='macro'), f1_score(labels, classes, average='micro') 166 | else: 167 | f1 = precision_recall_fscore_support(labels, classes, average=None)[2].tolist() 168 | 169 | return accuracy, f1 170 | 171 | def accuracy(logits : torch.Tensor, labels : torch.Tensor) -> float: 172 | """Accuracy of the model. 173 | 174 | Args: 175 | logits (torch.Tensor): model prediction logits 176 | labels (torch.Tensor): target labels 177 | 178 | Returns: 179 | float: accuracy 180 | """ 181 | _, indices = torch.max(logits, dim=1) 182 | correct = torch.sum(indices == labels) 183 | return correct.item() * 1.0 / len(labels) 184 | 185 | def get_device(value : int) -> str: 186 | """Either to use cpu or gpu (and which one). 187 | """ 188 | if value < 0: 189 | return 'cpu' 190 | else: 191 | return 'cuda:{}'.format(value) 192 | 193 | def get_features(args : ArgumentParser) -> Tuple[str, str]: 194 | """ Return description of the features used in the experiment 195 | 196 | Args: 197 | args (ArgumentParser) : your ArgumentParser 198 | """ 199 | feat_n = '' 200 | feat_e = 'false' 201 | 202 | if args.add_geom: 203 | feat_n += 'geom-' 204 | if args.add_embs: 205 | feat_n += 'text-' 206 | if args.add_visual: 207 | feat_n += 'visual-' 208 | if args.add_hist: 209 | feat_n += 'histogram-' 210 | if args.add_eweights: 211 | feat_e = 'true' 212 | 213 | return feat_n, feat_e 214 | 215 | def compute_crossentropy_loss(scores : torch.Tensor, labels : torch.Tensor): 216 | w = class_weight.compute_class_weight(class_weight='balanced', classes= np.unique(labels.cpu().numpy()), y=labels.cpu().numpy()) 217 | return torch.nn.CrossEntropyLoss(weight=torch.tensor(w, dtype=torch.float32).to('cuda:0'))(scores, labels) 218 | 219 | def compute_auc_mc(scores, labels): 220 | scores = scores.detach().cpu().numpy() 221 | labels = F.one_hot(labels).cpu().numpy() 222 | # return roc_auc_score(labels, scores) 223 | return average_precision_score(labels, scores) 224 | 225 | def find_optimal_cutoff(target, predicted): 226 | """ Find the optimal probability cutoff point for a classification model related to event rate 227 | Parameters 228 | ---------- 229 | target : Matrix with dependent or target data, where rows are observations 230 | 231 | predicted : Matrix with predicted data, where rows are observations 232 | 233 | Returns 234 | ------- 235 | list type, with optimal cutoff value 236 | 237 | """ 238 | fpr, tpr, threshold = roc_curve(target, predicted) 239 | i = np.arange(len(tpr)) 240 | roc = pd.DataFrame({'tf' : pd.Series(tpr-(1-fpr), index=i), 'threshold' : pd.Series(threshold, index=i)}) 241 | roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]] 242 | 243 | return list(roc_t['threshold']) 244 | 245 | def generalized_f1_score(y_true, y_pred, match): 246 | # y_true = (y_nodes, y_link) 247 | # y_pred = (y_nodes, y_link) 248 | 249 | # # nodes 250 | # micro_f1_nodes, macro_f1_nodes = 0, 0 251 | 252 | nodes_confusion_mtx = confusion_matrix(y_true=y_true[0][list(match["pred2gt"].keys())], y_pred=y_pred[0][list(match["gt2pred"].values())], 253 | labels=[0, 1, 2, 3], normalize=None) 254 | print(nodes_confusion_mtx) 255 | ntp = [nodes_confusion_mtx[i, i] for i in range(nodes_confusion_mtx.shape[0])] 256 | nfp = [(nodes_confusion_mtx[:i, i].sum() + nodes_confusion_mtx[i+1:, i].sum()) for i in range(nodes_confusion_mtx.shape[0])] 257 | nfn = [(nodes_confusion_mtx[i, :i].sum() + nodes_confusion_mtx[i, i+1:].sum()) for i in range(nodes_confusion_mtx.shape[0])] 258 | 259 | macro_f1_nodes = np.mean([tp / (tp + (0.5 * (fp + len(match["false_positive"]) + fn + len(match["false_negative"])))) for (tp, fp, fn) in zip(ntp, nfp, nfn)]) 260 | micro_f1_nodes = np.sum(ntp) / (np.sum(ntp) + (0.5 * (np.sum(nfp) + len(match["false_positive"]) + np.sum(nfn) + len(match["false_negative"])))) 261 | 262 | # links 263 | # micro_f1_links, macro_f1_links = 0, 0 264 | 265 | # links2keep = [idx for idx, link in enumerate(y_true[1]) if (link[0] in match["pred2gt"].values()) and (link[1] in match["pred2gt"].values())] 266 | links_confusion_mtx = confusion_matrix(y_true=y_true[1], y_pred=y_pred[1], labels=[0, 1], normalize=None) 267 | 268 | ltp = [links_confusion_mtx[i, i] for i in range(links_confusion_mtx.shape[0])] 269 | lfp = [(links_confusion_mtx[:i, i].sum() + links_confusion_mtx[i+1:, i].sum()) for i in range(links_confusion_mtx.shape[0])] 270 | lfn = [(links_confusion_mtx[i, :i].sum() + links_confusion_mtx[i, i+1:].sum()) for i in range(links_confusion_mtx.shape[0])] 271 | 272 | # n = len(links2keep) + len(match["false_positive"]) - 1 273 | macro_f1_links = np.mean([tp / (tp + (0.5 * (fp + fn + match["n_link_fn"]))) for (tp, fp, fn) in zip(ltp, lfp, lfn)]) 274 | # if match['n_link_fn'] == 0: f1_pairs = None 275 | f1_pairs = [tp / (tp + (0.5 * (fp + fn + match["n_link_fn"])) + 1e-6) for (tp, fp, fn) in zip(ltp, lfp, lfn)][1] 276 | micro_f1_links = np.sum(ltp) / (np.sum(ltp) + (0.5 * (np.sum(lfp) + np.sum(lfn) + match["n_link_fn"])) + 1e-6) 277 | 278 | return ntp, nfp, nfn, ltp, lfp, lfn, {"nodes": {"micro_f1": micro_f1_nodes, "macro_f1": macro_f1_nodes}, "links": {"micro_f1": micro_f1_links, "macro_f1": macro_f1_links}, "pairs_f1": f1_pairs} 279 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | from attrdict import AttrDict 4 | import yaml 5 | 6 | from src.paths import * 7 | 8 | def create_folder(folder_path : str) -> None: 9 | """create a folder if not exists 10 | 11 | Args: 12 | folder_path (str): path 13 | """ 14 | if not os.path.exists(folder_path): 15 | os.mkdir(folder_path) 16 | 17 | return 18 | 19 | def get_config(name : str) -> AttrDict: 20 | """get yaml config file 21 | 22 | Args: 23 | name (str): yaml file name without extension 24 | 25 | Returns: 26 | AttrDict: config 27 | """ 28 | with open(CONFIGS / f'{name}.yaml') as fileobj: 29 | config = AttrDict(yaml.safe_load(fileobj)) 30 | return config 31 | 32 | def project_tree() -> None: 33 | """ Create the project tree folder 34 | """ 35 | create_folder(DATA) 36 | create_folder(OUTPUTS) 37 | create_folder(RUNS) 38 | create_folder(RESULTS) 39 | create_folder(TRAIN_SAMPLES) 40 | create_folder(TEST_SAMPLES) 41 | create_folder(CHECKPOINTS) 42 | return 43 | 44 | def set_preprocessing(args: ArgumentParser) -> None: 45 | """ Set preprocessings args 46 | 47 | Args: 48 | args (ArgumentParser): 49 | """ 50 | with open(CONFIGS / 'base.yaml') as fileobj: 51 | cfg_preprocessing = dict(yaml.safe_load(fileobj)) 52 | cfg_preprocessing['FEATURES']['add_geom'] = args.add_geom 53 | cfg_preprocessing['FEATURES']['add_embs'] = args.add_embs 54 | cfg_preprocessing['FEATURES']['add_hist'] = args.add_hist 55 | cfg_preprocessing['FEATURES']['add_visual'] = args.add_visual 56 | cfg_preprocessing['FEATURES']['add_eweights'] = args.add_eweights 57 | cfg_preprocessing['FEATURES']['num_polar_bins'] = args.num_polar_bins 58 | cfg_preprocessing['LOADER']['src_data'] = args.src_data 59 | cfg_preprocessing['GRAPHS']['data_type'] = args.data_type 60 | cfg_preprocessing['GRAPHS']['edge_type'] = args.edge_type 61 | cfg_preprocessing['GRAPHS']['node_granularity'] = args.node_granularity 62 | 63 | with open(CONFIGS / 'preprocessing.yaml', 'w') as f: 64 | yaml.dump(cfg_preprocessing, f) 65 | return 66 | --------------------------------------------------------------------------------