├── .gitignore
├── LICENCE
├── README.md
├── configs
├── base.yaml
├── models
│ ├── e2e.yaml
│ ├── edge.yaml
│ └── gcn.yaml
└── train.yaml
├── model.png
├── requirements.txt
├── setup.py
├── src
├── __init__.py
├── data
│ ├── __init__.py
│ ├── dataloader.py
│ ├── download.py
│ ├── feature_builder.py
│ ├── graph_builder.py
│ ├── preprocessing.py
│ └── utils.py
├── inference.py
├── main.py
├── models
│ ├── __init__.py
│ ├── graphs.py
│ └── unet
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ └── model.py
├── paths.py
├── training
│ ├── __init__.py
│ ├── funsd.py
│ ├── pau.py
│ └── utils.py
└── utils.py
└── tutorial
└── kie.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # system
2 | *.egg-info/
3 | *__pycache__/
4 | *.py[cod]
5 | *.eggs
6 | build
7 | .vscode
8 | cc.*
9 | *.DS_Store
10 |
11 | #user-defined
12 | DATA
13 | _old
14 | *.env
15 | outputs
16 | FUDGE
17 | configs/preprocessing.yaml
18 | esempi
19 | examples
20 | src/models/yolov5
21 | src/models/checkpoints
22 | prova*
23 | inference
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 |
2 | The MIT License (MIT)
3 | Copyright (c) 2022
4 | Andrea Gemelli
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 |
8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
9 |
10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
11 |
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
`Doc2Graph`
2 |
3 | 
4 |
5 | [](https://paperswithcode.com/sota/entity-linking-on-funsd?p=doc2graph-a-task-agnostic-document) [](https://paperswithcode.com/sota/semantic-entity-labeling-on-funsd?p=doc2graph-a-task-agnostic-document)
6 |
7 |  
8 |
9 | This library is the implementation of the paper [Doc2Graph: a Task Agnostic Document Understanding Framework based on Graph Neural Networks](https://arxiv.org/abs/2208.11168), accepted at [TiE @ ECCV 2022](https://sites.google.com/view/tie-eccv2022/accepted-papers?authuser=0).
10 |
11 | The model and pipeline aims at being task-agnostic on the domain of Document Understanding. It is an ongoing project, these are the steps already achieved and the ones we would like to implement in the future:
12 |
13 | - [x] Build a model based on GNNs to spot key-value relationships on forms
14 | - [x] Publish the preliminary results and the code
15 | - [x] Extend the framework to other document-related tasks
16 | - [x] Business documents Layout Analysis
17 | - [x] Table Detection
18 | - [ ] Let the user train Doc2Graph over private / other datasets using our dataloader
19 | - [ ] Transform Doc2Graph into a PyPI package
20 |
21 | Index:
22 | - [`Doc2Graph`](#doc2graph)
23 | - [News!](#news)
24 | - [Environment Setup](#environment-setup)
25 | - [Training](#training)
26 | - [Testing](#testing)
27 | - [Cite this project](#cite-this-project)
28 |
29 | ---
30 | ## News!
31 | - 🔥 Added **inference** method: you can now use Doc2Graph directly on your documents simply passing a path to them!
This call will output an image with the connected entities and a json / dictionary with all the useful information you need! 🤗
32 | ```
33 | python src/main.py -addG -addT -addE -addV --gpu 0 --weights e2e-funsd-best.pt --inference --docs /path/to/document
34 | ```
35 |
36 | - 🔥 Added **tutorial** folder: get to know how to use Doc2Graph from the tutorial notebooks!
37 |
38 | ## Environment Setup
39 | Setup the initial conda environment
40 |
41 | ```
42 | conda create -n doc2graph python=3.9 ipython cudatoolkit=11.3 -c anaconda &&
43 | conda activate doc2graph &&
44 | cd doc2graph
45 | ```
46 |
47 | Then, install [setuptools-git-versioning](https://pypi.org/project/setuptools-git-versioning/) and doc2graph package itself. The following has been tested only on linux: for different OS installations refer directly to [PyTorch](https://pytorch.org/get-started/previous-versions/) and [DGL](https://www.dgl.ai/pages/start.html) original documentation.
48 |
49 | ```
50 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url &&
51 | https://download.pytorch.org/whl/cu113 &&
52 | pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html &&
53 | pip install setuptools-git-versioning && pip install -e . &&
54 | pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.3.0/en_core_web_lg-3.3.0.tar.gz
55 | ```
56 |
57 | Finally, create the project folder structure and download data:
58 |
59 | ```
60 | python src/main.py --init
61 | ```
62 | The script will download and setup:
63 | - FUNSD and the 'adjusted_annotations' for FUNSD[^1] are given by the work of[^3].
64 | - The yolo detection bbox described in the paper (If you want to use YOLOv5-small to detect entities, script in `notebooks/YOLO.ipynb`, refer to [their github](https://github.com/ultralytics/yolov5) for the installation. Clone the repository into `src/models/yolov5`).
65 | - The Pau Riba's[^2] dataset with our train / test split.
66 |
67 | [^1]: G. Jaume et al., FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents, ICDARW 2019
68 | [^2]: P. Riba et al, Table Detection in Invoice Documents by Graph Neural Networks, ICDAR 2019
69 | [^3]: Hieu M. Vu et al., REVISING FUNSD DATASET FOR KEY-VALUE DETECTION IN DOCUMENT IMAGES, arXiv preprint 2020
70 |
71 | **Checkpoints**
72 | You can download our model checkpoints [here](https://drive.google.com/file/d/15jKWYLTcb8VwE7XY_jcRvZTAmqy_ElJ_/view?usp=sharing). Place them into `src/models/checkpoints`.
73 |
74 | ---
75 | ## Training
76 | 1. To train our **Doc2Graph** model (using CPU) use:
77 | ```
78 | python src/main.py [SETTINGS]
79 | ```
80 | 2. Instead, to test a trained **Doc2Graph** model (using GPU) [weights can be one or more file]:
81 | ```
82 | python src/main.py [SETTINGS] --gpu 0 --test --weights *.pt
83 | ```
84 | The project can be customized either changing directly `configs/base.yaml` file or providing these flags when calling `src/main.py`.
85 |
86 | **Features**
87 | - --add-geom: bool (to add positional features to graph nodes)
88 | - --add-embs: bool (to add textual features to graph nodes)
89 | - --add-hist: bool (to add visual features to graph nodes)
90 | - --add-visual: bool (to add visual features to graph nodes)
91 | - --add-eweights: bool (to add polar relative coordinates between nodes to graph edges)
92 |
93 | **Data**
94 | - --src-data: string [FUNSD, PAU or CUSTOM] (CUSTOM still under dev)
95 | - --src-type: string [img, pdf] (if src_data is CUSTOM, still under dev)
96 |
97 | **Graphs**
98 | - --edge-type: string [fully, knn] (to change the kind of connectivity)
99 | - --node-granularity: string [gt, yolo, ocr] (choose the granularity of nodes to be used, gt (if given), ocr (words) or yolo (entities))
100 | - --num-polar-bins: int [Default 8] (number of bins into which discretize the space for edge polar features. It must be a power of 2)
101 |
102 | **Inference (only for KiE)**
103 | - --inference: bool (run inference on given document/s path/s)
104 | - --docs: list (list your absolute path to your document)
105 |
106 | Change directly `configs/train.yaml` for training settings or pass these flags to `src/main.py`. To create your own model (changing hyperparams) copy `configs/models/*.yaml`.
107 |
108 | **Training/Testing**
109 | - --model: string [e2e, edge, gcn] (which model to use, which yaml file to load)
110 | - --gpu: int [Default -1] (which GPU to use. Set -1 to use CPU(
111 | - --test: true / false (skip training if true)
112 | - --weights: strin(s) (provide weight file(s) relative path(s), if testing)
113 |
114 | ## Testing
115 |
116 | You can use our pretrained models over the test sets of FUNSD[^1] and Pau Riba's[^2] datasets.
117 |
118 | 1. On FUNSD we were able to perform both Semantic Entity Labeling and Entity Linking:
119 |
120 | **E2E-FUNSD-GT**:
121 | ```
122 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-funsd-best.pt
123 | ```
124 |
125 | **E2E-FUNSD-YOLO**:
126 | ```
127 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-funsd-best.pt --node-granularity yolo
128 | ```
129 |
130 | 2. on Pau Riba's dataset, we were able to perform both Layout Analysis and Table Detection
131 |
132 | **E2E-PAU**:
133 | ```
134 | python src/main.py -addG -addT -addE -addV --gpu 0 --test --weights e2e-pau-best.pt --src-data PAU --edge-type knn
135 | ```
136 |
137 | ---
138 | ## Cite this project
139 | If you want to use our code in your project(s), please cite us:
140 | ```
141 | @InProceedings{10.1007/978-3-031-25069-9_22,
142 | author="Gemelli, Andrea
143 | and Biswas, Sanket
144 | and Civitelli, Enrico
145 | and Llad{\'o}s, Josep
146 | and Marinai, Simone",
147 | editor="Karlinsky, Leonid
148 | and Michaeli, Tomer
149 | and Nishino, Ko",
150 | title="Doc2Graph: A Task Agnostic Document Understanding Framework Based on Graph Neural Networks",
151 | booktitle="Computer Vision -- ECCV 2022 Workshops",
152 | year="2023",
153 | publisher="Springer Nature Switzerland",
154 | address="Cham",
155 | pages="329--344",
156 | abstract="Geometric Deep Learning has recently attracted significant interest in a wide range of machine learning fields, including document analysis. The application of Graph Neural Networks (GNNs) has become crucial in various document-related tasks since they can unravel important structural patterns, fundamental in key information extraction processes. Previous works in the literature propose task-driven models and do not take into account the full power of graphs. We propose Doc2Graph, a task-agnostic document understanding framework based on a GNN model, to solve different tasks given different types of documents. We evaluated our approach on two challenging datasets for key information extraction in form understanding, invoice layout analysis and table detection. Our code is freely accessible on https://github.com/andreagemelli/doc2graph.",
157 | isbn="978-3-031-25069-9"
158 | }
159 | ```
160 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | # THIS YAML SETTING IS USED AS A BASE TO PRODUCE THE 'PREPROCESSING.YAML' ONE.
2 | # THIS IT IS MEANT TO BE USED AS A STARTING POINT AND DOCUMENTATION ONE (see all the different choises a user has).
3 |
4 | LOADER:
5 | src_data:
6 | - FUNSD
7 | # - PAU
8 | # - CUSTOM
9 | FEATURES:
10 | add_embs: true
11 | add_eweights: true
12 | add_fudge: false
13 | add_geom: true
14 | add_hist: false
15 | add_visual: true
16 | GRAPHS:
17 | edge_type:
18 | - fully
19 | # - knn
20 | data_type:
21 | - img
22 | # - pdf
23 | node_granularity:
24 | - gt
25 | # - ocr
26 | # - yolo
27 |
--------------------------------------------------------------------------------
/configs/models/e2e.yaml:
--------------------------------------------------------------------------------
1 | name: E2E
2 | dropout: 0.2
3 | hidden_dim: 300
4 | out_chunks: 300
5 | num_layers: 1
6 | doProject: True
--------------------------------------------------------------------------------
/configs/models/edge.yaml:
--------------------------------------------------------------------------------
1 | name: EDGE
2 | dropout: 0.2
3 | num_layers: 1
4 | hidden_dim: 300
5 | out_chunks: 300
6 | doProject: True
--------------------------------------------------------------------------------
/configs/models/gcn.yaml:
--------------------------------------------------------------------------------
1 | name: GCN
2 | dropout: 0.0
3 | hidden_dim: 64
4 | num_layers: 3
5 | attn: False
6 | out_chunks: 100
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 15
2 | epochs: 100000
3 | lr: 1e-3
4 | weight_decay: 1e-4
5 | val_size : 0.1
6 | optimizer: Adam # or AdamW, SGD - NOT YET IMPLEMENTED
7 | scheduler:
8 | - ReduceLROnPlateau # CosineAnnealingLR, None- NOT YET IMPLEMENTED
9 | stopper_metric: acc # loss or acc
10 | seed: 42
11 |
--------------------------------------------------------------------------------
/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/model.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | attrdict==2.0.1
2 | itsdangerous==2.0.1
3 | numpy>=1.22
4 | Pillow>=9.1.1
5 | PyYAML==6.0
6 | scikit_learn==1.0.2
7 | setuptools==65.5.1
8 | setuptools_git_versioning==1.9.2
9 | spacy==3.3.0
10 | pydantic==1.8.2
11 | python-dotenv==0.20.0
12 | gitpython>=3.1.30
13 | gitdb2==3.0.1
14 | matplotlib==3.5.2
15 | segmentation-models-pytorch==0.2.1
16 | fasttext==0.9.2
17 | networkx==2.8.4
18 | tensorboard==2.9.1
19 | opencv-python==4.6.0.66
20 | pytesseract==0.3.9
21 | wget==3.2
22 | easyocr==1.6.2
23 | pandas
24 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | from setuptools_git_versioning import version_from_git
3 | import os
4 |
5 | HERE = os.path.dirname(os.path.abspath(__file__))
6 |
7 | def parse_requirements(file_content):
8 | lines = file_content.splitlines()
9 | return [line.strip() for line in lines if line and not line.startswith("#")]
10 |
11 | with open(os.path.join(HERE, "requirements.txt")) as f:
12 | requirements = parse_requirements(f.read())
13 |
14 | with open("src/root.env", "w") as f:
15 | f.write(f"ROOT = '{HERE}'")
16 |
17 | setup(
18 | name='doc2graph',
19 | version=version_from_git(),
20 | packages=['doc2graph'],
21 | package_dir={'doc2graph': 'src'},
22 | description='Repo to transform Documents to Graphs, performing several tasks on them',
23 | author='Andrea Gemelli',
24 | license='MIT',
25 | keywords="document analysis graph",
26 | setuptools_git_versioning={
27 | "enabled": True,
28 | },
29 | python_requires=">=3.7",
30 | setup_requires=["setuptools_git_versioning"],
31 | install_requires=requirements,
32 | )
33 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from src.training import *
2 | from src.data import *
3 | from src.utils import *
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/dataloader.py:
--------------------------------------------------------------------------------
1 | from random import randint
2 | from typing import Tuple
3 | import torch
4 | import torch.utils.data as data
5 | import os
6 | import numpy as np
7 | import dgl
8 | from PIL import Image, ImageDraw
9 |
10 | from src.data.feature_builder import FeatureBuilder
11 | from src.data.graph_builder import GraphBuilder
12 | from src.utils import get_config
13 |
14 | class Document2Graph(data.Dataset):
15 | """This class convert documents (both images or pdfs) into graph structures.
16 | """
17 |
18 | def __init__(self, name : str, src_path : str, device = str, output_dir = str):
19 | """
20 | Args:
21 | name (str): should be one of the following: ['gt', 'img', 'pdf']
22 | src_path (str): path to folder containing documents
23 | device (str): device to use. can be 'cpu' or 'cuda:n'
24 | output_dir (str): where to save printed graphs examples
25 | """
26 |
27 | # initialize class
28 | if not os.path.isdir(src_path): raise Exception(f'src_path {src_path} does not exists\n -> please provide an existing path')
29 |
30 | self.name = name
31 | self.src_path = src_path
32 | self.cfg_preprocessing = get_config('preprocessing')
33 | self.src_data = self.cfg_preprocessing.LOADER.src_data
34 | self.GB = GraphBuilder()
35 | self.FB = FeatureBuilder(device)
36 | self.output_dir = output_dir
37 |
38 | # TODO: DO A DIFFERENT FILE
39 | self.COLORS = {'invoice_info': (150, 75, 0), 'receiver':(0,100,0), 'other':(128, 128, 128), 'supplier': (255, 0, 255), 'positions':(255,140,0), 'total':(0, 255, 255)}
40 |
41 | # get graphs
42 | self.graphs, self.node_labels, self.edge_labels, self.paths = self.__docs2graphs()
43 |
44 | # LABELS to numeric value
45 | # NODES
46 | if self.node_labels:
47 | self.node_unique_labels = np.unique(np.array([l for nl in self.node_labels for l in nl]))
48 | self.node_num_classes = len(self.node_unique_labels)
49 | self.node_num_features = self.graphs[0].ndata['feat'].shape[1]
50 |
51 | for idx, labels in enumerate(self.node_labels):
52 | self.graphs[idx].ndata['label'] = torch.tensor([np.where(target == self.node_unique_labels)[0][0] for target in labels], dtype=torch.int64)
53 |
54 | # EDGES
55 | if self.edge_labels:
56 | self.edge_unique_labels = np.unique(self.edge_labels[0])
57 | self.edge_num_classes = len(self.edge_unique_labels)
58 | try:
59 | # TODO to be changed
60 | self.edge_num_features = self.graphs[0].edata['feat'].shape[1]
61 | except:
62 | self.edge_num_features = 0
63 |
64 | for idx, labels in enumerate(self.edge_labels):
65 | self.graphs[idx].edata['label'] = torch.tensor([np.where(target == self.edge_unique_labels)[0][0] for target in labels], dtype=torch.int64)
66 |
67 | def __getitem__(self, index: int) -> dgl.DGLGraph:
68 | """ Returns item (graph), given index
69 |
70 | Args:
71 | index (int): index of the item to be taken.
72 | """
73 | return self.graphs[index]
74 |
75 | def __len__(self) -> int:
76 | """ Returns data length
77 | """
78 | return len(self.graphs)
79 |
80 | def __docs2graphs(self) -> Tuple[list, list, list, list]:
81 | """It uses GraphBuilder and FeaturesBuilder objects to get graphs (and lables, if any) from source data.
82 |
83 | Returns:
84 | tuple (lists): DGLGraphs, nodes and edges label names, paths per each file
85 | """
86 | graphs, node_labels, edge_labels, features = self.GB.get_graph(self.src_path, self.src_data)
87 | self.feature_chunks, self.num_mods = self.FB.add_features(graphs, features)
88 | return graphs, node_labels, edge_labels, features['paths']
89 |
90 | def label2class(self, label : str, node=True) -> int:
91 | """ Transform a label (str) into its class number.
92 |
93 | Args:
94 | label (str): node or edge string label
95 | node (bool): either if taking node (true) ord edge (false) class number
96 |
97 | Returns:
98 | int: class number
99 | """
100 | if node:
101 | return self.node_unique_labels[label]
102 | else:
103 | return self.edge_unique_labels[label]
104 |
105 | def get_info(self, num_graph=0) -> None:
106 | """ Print information regarding the data uploaded
107 |
108 | Args:
109 | num_graph (int): give one index to print one example graph information
110 | """
111 | print(f"\n{self.name.upper()} dataset:\n-> graphs: {len(self.graphs)}\n-> node labels: {self.node_unique_labels}\n-> edge labels: {self.edge_unique_labels}\n-> node features: {self.node_num_features}")
112 | self.GB.get_info()
113 | self.FB.get_info()
114 | print(f"-> graph example: {self.graphs[num_graph]}")
115 | return
116 |
117 | def balance(self, cls = 'none', indices = None) -> None:
118 | """ Calls balance_edges() of GraphBuilder.
119 |
120 | Args:
121 | cls (str):
122 | """
123 |
124 | cls = int(np.where(cls == self.edge_unique_labels)[0][0])
125 | if indices is None:
126 | for i, g in enumerate(self.graphs):
127 | self.graphs[i] = self.GB.balance_edges(g, self.edge_num_classes, cls = cls)
128 | else:
129 | for id in indices:
130 | self.graphs[id] = self.GB.balance_edges(self.graphs[id], self.edge_num_classes, cls = cls)
131 |
132 | return
133 |
134 | def get_chunks(self) -> list:
135 | """ get feature_chunks, meaning the length of different modalities (features) contributions inside nodes.
136 |
137 | Returns:
138 | feature_chunks (list) : list of feature chunks
139 | """
140 | if len(self.feature_chunks) != self.num_mods: self.feature_chunks.pop(0)
141 | return self.feature_chunks
142 |
143 | def print_graph(self, num=None, node_labels=None, labels_ids=None, name='doc_graph', bidirect=True, regions=[], preds=None) -> Image:
144 | """ Print a given graph over its image document.
145 |
146 | Args:
147 | num (int): which graph / document to print
148 | node_labels (list): list of node labels
149 | labels_ids (list): list of labels ids to print
150 | name (str): name to give to output file
151 | bidirect (bool): either to print the graph bidirected or not
152 | regions (list): debug purposes for layout anaylis, if any available it prints it
153 | preds (list): if any, it prints model predictions
154 |
155 | Returns:
156 | graph_img (Image) : the drawn graph over the document
157 | """
158 | if num is None: num = randint(0, self.__len__()-1)
159 | graph = self.graphs[num]
160 | graph_path = self.paths[num]
161 | graph_img = Image.open(graph_path).convert('RGB')
162 | if labels_ids is None: labels_ids = graph.edata['label'].nonzero().flatten().tolist()
163 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
164 | graph_draw = ImageDraw.Draw(graph_img)
165 | w, h = graph_img.size
166 | boxs = graph.ndata['geom'][:, :4].tolist()
167 | boxs = [[box[0]*w, box[1]*h, box[2]*w, box[3]*h] for box in boxs]
168 |
169 | if node_labels is not None:
170 | for b, box in enumerate(boxs):
171 | label = self.node_unique_labels[node_labels[b]]
172 | graph_draw.rectangle(box, outline=self.COLORS[label], width=2)
173 | else:
174 | for box in boxs:
175 | graph_draw.rectangle(box, outline='blue', width=2)
176 |
177 | for region in regions:
178 | color = self.COLORS[region[0]]
179 | graph_draw.rectangle(region[1], outline=color, width=4)
180 |
181 | if preds is not None:
182 | graph_draw.rectangle(preds, outline='green', width=4)
183 |
184 | u,v = graph.edges()
185 | for id in labels_ids:
186 | sc = center(boxs[u[id]])
187 | ec = center(boxs[v[id]])
188 | graph_draw.line((sc,ec), fill='violet', width=3)
189 | if bidirect:
190 | graph_draw.ellipse([(sc[0]-4,sc[1]-4), (sc[0]+4,sc[1]+4)], fill = 'green', outline='black')
191 | graph_draw.ellipse([(ec[0]-4,ec[1]-4), (ec[0]+4,ec[1]+4)], fill = 'red', outline='black')
192 |
193 | graph_img.save(self.output_dir / f'{name}.png')
194 | return graph_img
195 |
--------------------------------------------------------------------------------
/src/data/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import requests
4 | import zipfile
5 | import wget
6 | from src.utils import create_folder
7 |
8 | from src.paths import CHECKPOINTS, DATA
9 |
10 | def download_url(url, save_path, chunk_size=128):
11 | r = requests.get(url, stream=True)
12 | with open(save_path, 'wb') as fd:
13 | for chunk in r.iter_content(chunk_size=chunk_size):
14 | fd.write(chunk)
15 |
16 | def funsd():
17 | print("Downloading FUNSD")
18 |
19 | dlz = DATA / 'funsd.zip'
20 | download_url("https://guillaumejaume.github.io/FUNSD/dataset.zip", dlz)
21 | with zipfile.ZipFile(dlz, 'r') as zip_ref:
22 | zip_ref.extractall(DATA)
23 | os.remove(dlz)
24 | os.rename(DATA / 'dataset', DATA / 'FUNSD')
25 |
26 | # adjusted annotations
27 | aa_train = os.path.join(DATA / 'adjusted_annotations.zip')
28 | wget.download(url="https://docs.google.com/uc?export=download&id=1cQE2dnLGh93u3xMUeGQRXM81tF2l0Zut", out=aa_train)
29 | with zipfile.ZipFile(aa_train, 'r') as zip_ref:
30 | zip_ref.extractall(DATA / 'FUNSD/training_data')
31 | os.remove(aa_train)
32 |
33 | aa_test = os.path.join(DATA / 'adjusted_annotations.zip')
34 | wget.download(url="https://docs.google.com/uc?export=download&id=18LXbRhjnkdsAvWBFhUdr7_44bfaBo0-v", out=aa_test)
35 | with zipfile.ZipFile(aa_test, 'r') as zip_ref:
36 | zip_ref.extractall(DATA / 'FUNSD/testing_data')
37 | os.remove(aa_test)
38 |
39 | # yolo_bbox
40 | yolo_train = os.path.join(DATA / 'yolo_bbox.zip')
41 |
42 |
43 | wget.download(url="https://docs.google.com/uc?export=download&id=1UzL5tYtBWDXk_nXj4KtoDMyBt7S3j-aS", out=yolo_train)
44 | with zipfile.ZipFile(yolo_train, 'r') as zip_ref:
45 | zip_ref.extractall(DATA / 'FUNSD/training_data')
46 | os.remove(yolo_train)
47 |
48 | yolo_test = os.path.join(DATA / 'yolo_bbox.zip')
49 | wget.download(url="https://docs.google.com/uc?export=download&id=1fWwbhfvINYFQmoPHpwlH8olyTyJPIe_-", out=yolo_test)
50 | with zipfile.ZipFile(yolo_test, 'r') as zip_ref:
51 | zip_ref.extractall(DATA / 'FUNSD/testing_data')
52 | os.remove(yolo_test)
53 |
54 | return
55 |
56 | def pau():
57 |
58 | print("Downloading PAU")
59 |
60 | PAU = DATA / 'PAU'
61 | spl = os.path.join(DATA, 'pau_split.zip')
62 | wget.download(url="https://docs.google.com/uc?export=download&id=1NKlME13tRIDraSid7r3v_huTqFdHgo-G", out=spl)
63 | with zipfile.ZipFile(spl, 'r') as zip_ref:
64 | zip_ref.extractall(PAU)
65 |
66 | dlz = DATA / 'pau.zip'
67 | download_url("https://zenodo.org/record/3257319/files/dataset.zip", dlz)
68 | with zipfile.ZipFile(dlz, 'r') as zip_ref:
69 | zip_ref.extractall(PAU)
70 |
71 | create_folder(PAU / 'train')
72 | create_folder(PAU / 'test')
73 | create_folder(PAU / 'outliers')
74 |
75 | for split in ['train.txt', 'test.txt', 'outliers.txt']:
76 | with open(PAU / split, mode='r') as f:
77 | folder = split.split(".")[0]
78 | lines = f.readlines()
79 | for l in lines:
80 | l = l.rstrip('\n')
81 | src = PAU / l
82 | dst = PAU / '{}/{}'.format(folder, l)
83 | shutil.move(src, dst)
84 |
85 | os.remove(spl)
86 | os.remove(dlz)
87 | return
88 |
89 | def get_data():
90 | funsd()
91 | pau()
92 | return
93 |
94 |
--------------------------------------------------------------------------------
/src/data/feature_builder.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Tuple
3 | import spacy
4 | import torch
5 | import torchvision
6 | from tqdm import tqdm
7 | from PIL import Image, ImageDraw
8 | import torchvision.transforms.functional as tvF
9 |
10 | from src.paths import CHECKPOINTS
11 | from src.models.unet import Unet
12 | from src.data.utils import to_bin
13 | from src.data.utils import polar, get_histogram
14 | from src.utils import get_config
15 |
16 | class FeatureBuilder():
17 |
18 | def __init__(self, d : int = 'cpu'):
19 | """FeatureBuilder constructor
20 |
21 | Args:
22 | d (int): device number, if any (cpu or cuda:n)
23 | """
24 | self.cfg_preprocessing = get_config('preprocessing')
25 | self.device = d
26 | self.add_geom = self.cfg_preprocessing.FEATURES.add_geom
27 | self.add_embs = self.cfg_preprocessing.FEATURES.add_embs
28 | self.add_hist = self.cfg_preprocessing.FEATURES.add_hist
29 | self.add_visual = self.cfg_preprocessing.FEATURES.add_visual
30 | self.add_eweights = self.cfg_preprocessing.FEATURES.add_eweights
31 | self.add_fudge = self.cfg_preprocessing.FEATURES.add_fudge
32 | self.num_polar_bins = self.cfg_preprocessing.FEATURES.num_polar_bins
33 |
34 | if self.add_embs:
35 | self.text_embedder = spacy.load('en_core_web_lg')
36 |
37 | if self.add_visual:
38 | self.visual_embedder = Unet(encoder_name="mobilenet_v2", encoder_weights=None, in_channels=1, classes=4)
39 | self.visual_embedder.load_state_dict(torch.load(CHECKPOINTS / 'backbone_unet.pth')['weights'])
40 | self.visual_embedder = self.visual_embedder.encoder
41 | self.visual_embedder.to(d)
42 |
43 | self.sg = lambda rect, s : [rect[0]/s[0], rect[1]/s[1], rect[2]/s[0], rect[3]/s[1]] # scaling by img width and height
44 |
45 | def add_features(self, graphs : list, features : list) -> Tuple[list, int]:
46 | """ Add features to provided graphs
47 |
48 | Args:
49 | graphs (list) : list of DGLGraphs
50 | features (list) : list of features "sources", like text, positions and images
51 |
52 | Returns:
53 | chunks list and its lenght
54 | """
55 |
56 | for id, g in enumerate(tqdm(graphs, desc='adding features')):
57 |
58 | # positional features
59 | size = Image.open(features['paths'][id]).size
60 | feats = [[] for _ in range(len(features['boxs'][id]))]
61 | geom = [self.sg(box, size) for box in features['boxs'][id]]
62 | chunks = []
63 |
64 | # 'geometrical' features
65 | if self.add_geom:
66 |
67 | # TODO add 2d encoding like "LayoutLM*"
68 | [feats[idx].extend(self.sg(box, size)) for idx, box in enumerate(features['boxs'][id])]
69 | chunks.append(4)
70 |
71 | # HISTOGRAM OF TEXT
72 | if self.add_hist:
73 |
74 | [feats[idx].extend(hist) for idx, hist in enumerate(get_histogram(features['texts'][id]))]
75 | chunks.append(4)
76 |
77 | # textual features
78 | if self.add_embs:
79 |
80 | # LANGUAGE MODEL (SPACY)
81 | [feats[idx].extend(self.text_embedder(features['texts'][id][idx]).vector) for idx, _ in enumerate(feats)]
82 | chunks.append(len(self.text_embedder(features['texts'][id][0]).vector))
83 |
84 | # visual features
85 | # https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi
86 | if self.add_visual:
87 | img = Image.open(features['paths'][id])
88 | visual_emb = self.visual_embedder(tvF.to_tensor(img).unsqueeze_(0).to(self.device)) # output [batch, channels, dim1, dim2]
89 | bboxs = [torch.Tensor(b) for b in features['boxs'][id]]
90 | bboxs = [torch.stack(bboxs, dim=0).to(self.device)]
91 | h = [torchvision.ops.roi_align(input=ve, boxes=bboxs, spatial_scale=1/ min(size[1] / ve.shape[2] , size[0] / ve.shape[3]), output_size=1) for ve in visual_emb[1:]]
92 | h = torch.cat(h, dim=1)
93 |
94 | # VISUAL FEATURES (RESNET-IMAGENET)
95 | [feats[idx].extend(torch.flatten(h[idx]).tolist()) for idx, _ in enumerate(feats)]
96 | chunks.append(len(torch.flatten(h[0]).tolist()))
97 |
98 | if self.add_eweights:
99 | u, v = g.edges()
100 | srcs, dsts = u.tolist(), v.tolist()
101 | distances = []
102 | angles = []
103 |
104 | # TODO CHOOSE WHICH DISTANCE NORMALIZATION TO APPLY
105 | #! with fully connected simply normalized with max distance between distances
106 | # m = sqrt((size[0]*size[0] + size[1]*size[1]))
107 | # parable = lambda x : (-x+1)**4
108 |
109 | for pair in zip(srcs, dsts):
110 | dist, angle = polar(features['boxs'][id][pair[0]], features['boxs'][id][pair[1]])
111 | distances.append(dist)
112 | angles.append(angle)
113 |
114 | m = max(distances)
115 | polar_coordinates = to_bin(distances, angles, self.num_polar_bins)
116 | g.edata['feat'] = polar_coordinates
117 |
118 | else:
119 | distances = ([0.0 for _ in range(g.number_of_edges())])
120 | m = 1
121 |
122 | g.ndata['geom'] = torch.tensor(geom, dtype=torch.float32)
123 | g.ndata['feat'] = torch.tensor(feats, dtype=torch.float32)
124 |
125 | distances = torch.tensor([(1-d/m) for d in distances], dtype=torch.float32)
126 | tresh_dist = torch.where(distances > 0.9, torch.full_like(distances, 0.1), torch.zeros_like(distances))
127 | g.edata['weights'] = tresh_dist
128 |
129 | norm = []
130 | num_nodes = len(features['boxs'][id]) - 1
131 | for n in range(num_nodes + 1):
132 | neigs = torch.count_nonzero(tresh_dist[n*num_nodes:(n+1)*num_nodes]).tolist()
133 | try: norm.append([1. / neigs])
134 | except: norm.append([1.])
135 | g.ndata['norm'] = torch.tensor(norm, dtype=torch.float32)
136 |
137 | #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET
138 | if False:
139 | if id == rand_id and self.add_eweights:
140 | print("\n\n### EXAMPLE ###")
141 |
142 | img_path = features['paths'][id]
143 | img = Image.open(img_path).convert('RGB')
144 | draw = ImageDraw.Draw(img)
145 |
146 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
147 | select = [random.randint(0, len(srcs)) for _ in range(10)]
148 | for p, pair in enumerate(zip(srcs, dsts)):
149 | if p in select:
150 | sc = center(features['boxs'][id][pair[0]])
151 | ec = center(features['boxs'][id][pair[1]])
152 | draw.line((sc, ec), fill='grey', width=3)
153 | middle_point = ((sc[0] + ec[0])/2,(sc[1] + ec[1])/2)
154 | draw.text(middle_point, str(angles[p]), fill='black')
155 | draw.rectangle(features['boxs'][id][pair[0]], fill='red')
156 | draw.rectangle(features['boxs'][id][pair[1]], fill='blue')
157 |
158 | img.save(f'esempi/FUNSD/edges.png')
159 |
160 | return chunks, len(chunks)
161 |
162 | def get_info(self):
163 | print(f"-> textual feats: {self.add_embs}\n-> visual feats: {self.add_visual}\n-> edge feats: {self.add_eweights}")
164 |
165 |
166 |
--------------------------------------------------------------------------------
/src/data/graph_builder.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from PIL import Image, ImageDraw
4 | from typing import Tuple
5 | import torch
6 | import dgl
7 | import random
8 | import numpy as np
9 | from tqdm import tqdm
10 | import xml.etree.ElementTree as ET
11 | import easyocr
12 |
13 | from src.data.preprocessing import load_predictions
14 | from src.data.utils import polar
15 | from src.paths import DATA, FUNSD_TEST
16 | from src.utils import get_config
17 |
18 |
19 | class GraphBuilder():
20 |
21 | def __init__(self):
22 | self.cfg_preprocessing = get_config('preprocessing')
23 | self.edge_type = self.cfg_preprocessing.GRAPHS.edge_type
24 | self.data_type = self.cfg_preprocessing.GRAPHS.data_type
25 | self.node_granularity = self.cfg_preprocessing.GRAPHS.node_granularity
26 | random.seed = 42
27 | return
28 |
29 | def get_graph(self, src_path : str, src_data : str) -> Tuple[list, list, list, list]:
30 | """ Given the source, it returns a graph
31 |
32 | Args:
33 | src_path (str) : path to source data
34 | src_data (str) : either FUNSD, PAU or CUSTOM
35 |
36 | Returns:
37 | tuple (lists) : graphs, nodes and edge labels, features
38 | """
39 |
40 | if src_data == 'FUNSD':
41 | return self.__fromFUNSD(src_path)
42 | elif src_data == 'PAU':
43 | return self.__fromPAU(src_path)
44 | elif src_data == 'CUSTOM':
45 | if self.data_type == 'img':
46 | return self.__fromIMG(src_path)
47 | elif self.data_type == 'pdf':
48 | return self.__fromPDF()
49 | else:
50 | raise Exception('GraphBuilder exception: data type invalid. Choose from ["img", "pdf"]')
51 | else:
52 | raise Exception('GraphBuilder exception: source data invalid. Choose from ["FUNSD", "PAU", "CUSTOM"]')
53 |
54 | def balance_edges(self, g : dgl.DGLGraph, cls=None ) -> dgl.DGLGraph:
55 | """ if cls (class) is not None, but an integer instead, balance that class to be equal to the sum of the other classes
56 |
57 | Args:
58 | g (DGLGraph) : a DGL graph
59 | cls (int) : class number, if any
60 |
61 | Returns:
62 | g (DGLGraph) : the new balanced graph
63 | """
64 |
65 | edge_targets = g.edata['label']
66 | u, v = g.all_edges(form='uv')
67 | edges_list = list()
68 | for e in zip(u.tolist(), v.tolist()):
69 | edges_list.append([e[0], e[1]])
70 |
71 | if type(cls) is int:
72 | to_remove = (edge_targets == cls)
73 | indices_to_remove = to_remove.nonzero().flatten().tolist()
74 |
75 | for _ in range(int((edge_targets != cls).sum()/2)):
76 | indeces_to_save = [random.choice(indices_to_remove)]
77 | edge = edges_list[indeces_to_save[0]]
78 |
79 | for index in sorted(indeces_to_save, reverse=True):
80 | del indices_to_remove[indices_to_remove.index(index)]
81 |
82 | indices_to_remove = torch.flatten(torch.tensor(indices_to_remove, dtype=torch.int32))
83 | g = dgl.remove_edges(g, indices_to_remove)
84 | return g
85 |
86 | else:
87 | raise Exception("Select a class to balance (an integer ranging from 0 to num_edge_classes).")
88 |
89 | def get_info(self):
90 | """ returns graph information
91 | """
92 | print(f"-> edge type: {self.edge_type}")
93 |
94 | def fully_connected(self, ids : list) -> Tuple[list, list]:
95 | """ create fully connected graph
96 |
97 | Args:
98 | ids (list) : list of node indices
99 |
100 | Returns:
101 | u, v (lists) : lists of indices
102 | """
103 | u, v = list(), list()
104 | for id in ids:
105 | u.extend([id for i in range(len(ids)) if i != id])
106 | v.extend([i for i in range(len(ids)) if i != id])
107 | return u, v
108 |
109 | def knn_connection(self, size : tuple, bboxs : list, k = 10) -> Tuple[list, list]:
110 | """ Given a list of bounding boxes, find for each of them their k nearest ones.
111 |
112 | Args:
113 | size (tuple) : width and height of the image
114 | bboxs (list) : list of bounding box coordinates
115 | k (int) : k of the knn algorithm
116 |
117 | Returns:
118 | u, v (lists) : lists of indices
119 | """
120 |
121 | edges = []
122 | width, height = size[0], size[1]
123 |
124 | # creating projections
125 | vertical_projections = [[] for i in range(width)]
126 | horizontal_projections = [[] for i in range(height)]
127 | for node_index, bbox in enumerate(bboxs):
128 | for hp in range(bbox[0], bbox[2]):
129 | if hp >= width: hp = width - 1
130 | vertical_projections[hp].append(node_index)
131 | for vp in range(bbox[1], bbox[3]):
132 | if vp >= height: vp = height - 1
133 | horizontal_projections[vp].append(node_index)
134 |
135 | def bound(a, ori=''):
136 | if a < 0 : return 0
137 | elif ori == 'h' and a > height: return height
138 | elif ori == 'w' and a > width: return width
139 | else: return a
140 |
141 | for node_index, node_bbox in enumerate(bboxs):
142 | neighbors = [] # collect list of neighbors
143 | window_multiplier = 2 # how much to look around bbox
144 | wider = (node_bbox[2] - node_bbox[0]) > (node_bbox[3] - node_bbox[1]) # if bbox wider than taller
145 |
146 | ### finding neighbors ###
147 | while(len(neighbors) < k and window_multiplier < 100): # keep enlarging the window until at least k bboxs are found or window too big
148 | vertical_bboxs = []
149 | horizontal_bboxs = []
150 | neighbors = []
151 |
152 | if wider:
153 | h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier/4)
154 | v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier)
155 | else:
156 | h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier)
157 | v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier/4)
158 |
159 | window = [bound(node_bbox[0] - h_offset),
160 | bound(node_bbox[1] - v_offset),
161 | bound(node_bbox[2] + h_offset, 'w'),
162 | bound(node_bbox[3] + v_offset, 'h')]
163 |
164 | [vertical_bboxs.extend(d) for d in vertical_projections[window[0]:window[2]]]
165 | [horizontal_bboxs.extend(d) for d in horizontal_projections[window[1]:window[3]]]
166 |
167 | for v in set(vertical_bboxs):
168 | for h in set(horizontal_bboxs):
169 | if v == h: neighbors.append(v)
170 |
171 | window_multiplier += 1 # enlarge the window
172 |
173 | ### finding k nearest neighbors ###
174 | neighbors = list(set(neighbors))
175 | if node_index in neighbors:
176 | neighbors.remove(node_index)
177 | neighbors_distances = [polar(node_bbox, bboxs[n])[0] for n in neighbors]
178 | for sd_num, sd_idx in enumerate(np.argsort(neighbors_distances)):
179 | if sd_num < k:
180 | if [node_index, neighbors[sd_idx]] not in edges and [neighbors[sd_idx], node_index] not in edges:
181 | edges.append([neighbors[sd_idx], node_index])
182 | edges.append([node_index, neighbors[sd_idx]])
183 | else: break
184 |
185 | return [e[0] for e in edges], [e[1] for e in edges]
186 |
187 | def __fromIMG(self, paths : list):
188 |
189 | graphs, node_labels, edge_labels = list(), list(), list()
190 | features = {'paths': paths, 'texts': [], 'boxs': []}
191 |
192 | for path in paths:
193 | reader = easyocr.Reader(['en']) #! TODO: in the future, handle multilanguage!
194 | result = reader.readtext(path, paragraph=True)
195 | img = Image.open(path).convert('RGB')
196 | draw = ImageDraw.Draw(img)
197 | boxs, texts = list(), list()
198 |
199 | for r in result:
200 | box = [int(r[0][0][0]), int(r[0][0][1]), int(r[0][2][0]), int(r[0][2][1])]
201 | draw.rectangle(box, outline='red', width=3)
202 | boxs.append(box)
203 | texts.append(r[1])
204 |
205 | features['boxs'].append(boxs)
206 | features['texts'].append(texts)
207 | img.save('prova.png')
208 |
209 | if self.edge_type == 'fully':
210 | u, v = self.fully_connected(range(len(boxs)))
211 | elif self.edge_type == 'knn':
212 | u,v = self.knn_connection(Image.open(path).size, boxs)
213 | else:
214 | raise Exception('Other edge types still under development.')
215 |
216 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32)
217 | graphs.append(g)
218 |
219 | return graphs, node_labels, edge_labels, features
220 |
221 | def __fromPDF():
222 | #TODO: dev from PDF import of graphs
223 | return
224 |
225 | def __fromPAU(self, src: str) -> Tuple[list, list, list, list]:
226 | """ build graphs from Pau Riba's dataset
227 |
228 | Args:
229 | src (str) : path to where data is stored
230 |
231 | Returns:
232 | tuple (lists) : graphs, nodes and edge labels, features
233 | """
234 |
235 | graphs, node_labels, edge_labels = list(), list(), list()
236 | features = {'paths': [], 'texts': [], 'boxs': []}
237 |
238 | for image in tqdm(os.listdir(src), desc='Creating graphs'):
239 | if not image.endswith('tif'): continue
240 |
241 | img_name = image.split('.')[0]
242 | file_gt = img_name + '_gt.xml'
243 | file_ocr = img_name + '_ocr.xml'
244 |
245 | if not os.path.isfile(os.path.join(src, file_gt)) or not os.path.isfile(os.path.join(src, file_ocr)): continue
246 | features['paths'].append(os.path.join(src, image))
247 |
248 | # DOCUMENT REGIONS
249 | root = ET.parse(os.path.join(src, file_gt)).getroot()
250 | regions = []
251 | for parent in root:
252 | if parent.tag.split("}")[1] == 'Page':
253 | for child in parent:
254 | region_label = child[0].attrib['value']
255 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
256 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
257 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
258 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
259 | regions.append([region_label, region_bbox])
260 |
261 | # DOCUMENT TOKENS
262 | root = ET.parse(os.path.join(src, file_ocr)).getroot()
263 | tokens_bbox = []
264 | tokens_text = []
265 | nl = []
266 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
267 | for parent in root:
268 | if parent.tag.split("}")[1] == 'Page':
269 | for child in parent:
270 | if child.tag.split("}")[1] == 'TextRegion':
271 | for elem in child:
272 | if elem.tag.split("}")[1] == 'TextLine':
273 | for word in elem:
274 | if word.tag.split("}")[1] == 'Word':
275 | word_bbox = [int(word[0].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
276 | int(word[0].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
277 | int(word[0].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
278 | int(word[0].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
279 | word_text = word[1][0].text
280 | c = center(word_bbox)
281 | for reg in regions:
282 | r = reg[1]
283 | if r[0] < c[0] < r[2] and r[1] < c[1] < r[3]:
284 | word_label = reg[0]
285 | break
286 | tokens_bbox.append(word_bbox)
287 | tokens_text.append(word_text)
288 | nl.append(word_label)
289 |
290 | features['boxs'].append(tokens_bbox)
291 | features['texts'].append(tokens_text)
292 | node_labels.append(nl)
293 |
294 | # getting edges
295 | if self.edge_type == 'fully':
296 | u, v = self.fully_connected(range(len(tokens_bbox)))
297 | elif self.edge_type == 'knn':
298 | u,v = self.knn_connection(Image.open(os.path.join(src, image)).size, tokens_bbox)
299 | else:
300 | raise Exception('Other edge types still under development.')
301 |
302 | el = list()
303 | for e in zip(u, v):
304 | if (nl[e[0]] == nl[e[1]]) and (nl[e[0]] == 'positions' or nl[e[0]] == 'total'):
305 | el.append('table')
306 | else: el.append('none')
307 | edge_labels.append(el)
308 |
309 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(tokens_bbox), idtype=torch.int32)
310 | graphs.append(g)
311 |
312 | return graphs, node_labels, edge_labels, features
313 |
314 | def __fromFUNSD(self, src : str) -> Tuple[list, list, list, list]:
315 | """Parsing FUNSD annotation files
316 |
317 | Args:
318 | src (str) : path to where data is stored
319 |
320 | Returns:
321 | tuple (lists) : graphs, nodes and edge labels, features
322 | """
323 |
324 | graphs, node_labels, edge_labels = list(), list(), list()
325 | features = {'paths': [], 'texts': [], 'boxs': []}
326 | # justOne = random.choice(os.listdir(os.path.join(src, 'adjusted_annotations'))).split(".")[0]
327 |
328 | if self.node_granularity == 'gt':
329 | for file in tqdm(os.listdir(os.path.join(src, 'adjusted_annotations')), desc='Creating graphs - GT'):
330 |
331 | img_name = f'{file.split(".")[0]}.png'
332 | img_path = os.path.join(src, 'images', img_name)
333 | features['paths'].append(img_path)
334 |
335 | with open(os.path.join(src, 'adjusted_annotations', file), 'r') as f:
336 | form = json.load(f)['form']
337 |
338 | # getting infos
339 | boxs, texts, ids, nl = list(), list(), list(), list()
340 | pair_labels = list()
341 |
342 | for elem in form:
343 | boxs.append(elem['box'])
344 | texts.append(elem['text'])
345 | nl.append(elem['label'])
346 | ids.append(elem['id'])
347 | [pair_labels.append(pair) for pair in elem['linking']]
348 |
349 | for p, pair in enumerate(pair_labels):
350 | pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])]
351 |
352 | node_labels.append(nl)
353 | features['texts'].append(texts)
354 | features['boxs'].append(boxs)
355 |
356 | # getting edges
357 | if self.edge_type == 'fully':
358 | u, v = self.fully_connected(range(len(boxs)))
359 | elif self.edge_type == 'knn':
360 | u,v = self.knn_connection(Image.open(img_path).size, boxs)
361 | else:
362 | raise Exception('GraphBuilder exception: Other edge types still under development.')
363 |
364 | el = list()
365 | for e in zip(u, v):
366 | edge = [e[0], e[1]]
367 | if edge in pair_labels: el.append('pair')
368 | else: el.append('none')
369 | edge_labels.append(el)
370 |
371 | # creating graph
372 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32)
373 | graphs.append(g)
374 |
375 | #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET
376 | if False:
377 | if justOne == file.split(".")[0]:
378 | print("\n\n### EXAMPLE ###")
379 | print("Savin example:", img_name)
380 |
381 | edge_unique_labels = np.unique(el)
382 | g.edata['label'] = torch.tensor([np.where(target == edge_unique_labels)[0][0] for target in el])
383 |
384 | g = self.balance_edges(g, 3, int(np.where('none' == edge_unique_labels)[0][0]))
385 |
386 | img_removed = Image.open(img_path).convert('RGB')
387 | draw_removed = ImageDraw.Draw(img_removed)
388 |
389 | for b, box in enumerate(boxs):
390 | if nl[b] == 'header':
391 | color = 'yellow'
392 | elif nl[b] == 'question':
393 | color = 'blue'
394 | elif nl[b] == 'answer':
395 | color = 'green'
396 | else:
397 | color = 'gray'
398 | draw_removed.rectangle(box, outline=color, width=3)
399 |
400 | u, v = g.all_edges()
401 | labels = g.edata['label'].tolist()
402 | u, v = u.tolist(), v.tolist()
403 |
404 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
405 |
406 | num_pair = 0
407 | num_none = 0
408 |
409 | for p, pair in enumerate(zip(u,v)):
410 | sc = center(boxs[pair[0]])
411 | ec = center(boxs[pair[1]])
412 | if labels[p] == int(np.where('pair' == edge_unique_labels)[0][0]):
413 | num_pair += 1
414 | color = 'violet'
415 | draw_removed.ellipse([(sc[0]-4,sc[1]-4), (sc[0]+4,sc[1]+4)], fill = 'green', outline='black')
416 | draw_removed.ellipse([(ec[0]-4,ec[1]-4), (ec[0]+4,ec[1]+4)], fill = 'red', outline='black')
417 | else:
418 | num_none += 1
419 | color='gray'
420 | draw_removed.line((sc,ec), fill=color, width=3)
421 |
422 | print("Balanced Links: None {} | Key-Value {}".format(num_none, num_pair))
423 | img_removed.save(f'esempi/FUNSD/{img_name}_removed_graph.png')
424 |
425 | elif self.node_granularity == 'yolo':
426 | path_preds = os.path.join(src, 'yolo_bbox')
427 | path_images = os.path.join(src, 'images')
428 | path_gts = os.path.join(src, 'adjusted_annotations')
429 | all_paths, all_preds, all_links, all_labels, all_texts = load_predictions(path_preds, path_gts, path_images)
430 | for f, img_path in enumerate(tqdm(all_paths, desc='Creating graphs - YOLO')):
431 |
432 | features['paths'].append(img_path)
433 | features['boxs'].append(all_preds[f])
434 | features['texts'].append(all_texts[f])
435 | node_labels.append(all_labels[f])
436 | pair_labels = all_links[f]
437 |
438 | # getting edges
439 | if self.edge_type == 'fully':
440 | u, v = self.fully_connected(range(len(features['boxs'][f])))
441 | elif self.edge_type == 'knn':
442 | u,v = self.knn_connection(Image.open(img_path).size, features['boxs'][f])
443 | else:
444 | raise Exception('GraphBuilder exception: Other edge types still under development.')
445 |
446 | el = list()
447 | for e in zip(u, v):
448 | edge = [e[0], e[1]]
449 | if edge in pair_labels: el.append('pair')
450 | else: el.append('none')
451 | edge_labels.append(el)
452 |
453 | # creating graph
454 | g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(features['boxs'][f]), idtype=torch.int32)
455 | graphs.append(g)
456 | else:
457 | #TODO develop OCR too
458 | raise Exception('GraphBuilder Exception: only \'gt\' or \'yolo\' available for FUNSD.')
459 |
460 |
461 | return graphs, node_labels, edge_labels, features
--------------------------------------------------------------------------------
/src/data/preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | from scipy.optimize import linprog
5 | import os
6 | from PIL import ImageDraw, Image
7 | import json
8 | import pytesseract
9 | from pytesseract import Output
10 |
11 | from src.paths import DATA, FUNSD_TEST
12 |
13 |
14 | def scale_back(r, w, h): return [int(r[0]*w),
15 | int(r[1]*h), int(r[2]*w), int(r[3]*h)]
16 |
17 |
18 | def center(r): return ((r[0] + r[2]) / 2, (r[1] + r[3]) / 2)
19 |
20 |
21 | def isIn(c, r):
22 | if c[0] < r[0] or c[0] > r[2]:
23 | return False
24 | elif c[1] < r[1] or c[1] > r[3]:
25 | return False
26 | else:
27 | return True
28 |
29 |
30 | def match_pred_w_gt(bbox_preds: torch.Tensor, bbox_gts: torch.Tensor, links_pair: list):
31 | bbox_iou = torchvision.ops.box_iou(boxes1=bbox_preds, boxes2=bbox_gts)
32 | bbox_iou = bbox_iou.numpy()
33 |
34 | A_ub = np.zeros(shape=(
35 | bbox_iou.shape[0] + bbox_iou.shape[1], bbox_iou.shape[0] * bbox_iou.shape[1]))
36 | for r in range(bbox_iou.shape[0]):
37 | st = r * bbox_iou.shape[1]
38 | A_ub[r, st:st + bbox_iou.shape[1]] = 1
39 | for j in range(bbox_iou.shape[1]):
40 | r = j + bbox_iou.shape[0]
41 | A_ub[r, j::bbox_iou.shape[1]] = 1
42 | b_ub = np.ones(shape=A_ub.shape[0])
43 |
44 | assignaments_score = linprog(
45 | c=-bbox_iou.reshape(-1), A_ub=A_ub, b_ub=b_ub, bounds=(0, 1), method="highs-ds")
46 | #assignaments_score = linprog(c=-bbox_iou.reshape(-1), bounds=(0, 1), method="highs-ds")
47 | # print(assignaments_score)
48 | if not assignaments_score.success:
49 | print("Optimization FAILED")
50 | assignaments_score = assignaments_score.x.reshape(bbox_iou.shape)
51 | assignaments_ids = assignaments_score.argmax(axis=1)
52 |
53 | # matched
54 | opt_assignaments = {}
55 | for idx in range(assignaments_score.shape[0]):
56 | if (bbox_iou[idx, assignaments_ids[idx]] > 0.5) and (assignaments_score[idx, assignaments_ids[idx]] > 0.9):
57 | opt_assignaments[idx] = assignaments_ids[idx]
58 | # unmatched predictions
59 | false_positive = [idx for idx in range(
60 | bbox_preds.shape[0]) if idx not in opt_assignaments]
61 | # unmatched gts
62 | false_negative = [idx for idx in range(
63 | bbox_gts.shape[0]) if idx not in opt_assignaments.values()]
64 |
65 | gt2pred = {v: k for k, v in opt_assignaments.items()}
66 | link_false_neg = []
67 | for link in links_pair:
68 | if link[0] in false_negative or link[1] in false_negative:
69 | link_false_neg.append(link)
70 |
71 | if len(links_pair) != 0:
72 | rate = len(link_false_neg) / len(links_pair)
73 | else:
74 | rate = 0
75 | return {"pred2gt": opt_assignaments, "gt2pred": gt2pred, "false_positive": false_positive, "false_negative": false_negative, "n_link_fn": int(len(link_false_neg) / 2), "link_loss": rate, "entity_loss": len(false_positive) / (len(false_positive) + len(opt_assignaments.keys()))}
76 |
77 |
78 | def get_objects(path, mode):
79 | # TODO given a document, apply OCR or Yolo to detect either words or entities.
80 | return
81 |
82 |
83 | def load_predictions(path_preds, path_gts, path_images, debug=False):
84 | # TODO read txt file and pass bounding box to the other function.
85 |
86 | boxs_preds = []
87 | boxs_gts = []
88 | links_gts = []
89 | labels_gts = []
90 | texts_ocr = []
91 | all_paths = []
92 |
93 | for img in os.listdir(path_images):
94 | all_paths.append(os.path.join(path_images, img))
95 | w, h = Image.open(os.path.join(path_images, img)).size
96 | texts = pytesseract.image_to_data(Image.open(
97 | os.path.join(path_images, img)), output_type=Output.DICT)
98 | tp = []
99 | n_elements = len(texts['level'])
100 | for t in range(n_elements):
101 | if int(texts['conf'][t]) > 50 and texts['text'][t] != ' ':
102 | b = [texts['left'][t], texts['top'][t], texts['left'][t] +
103 | texts['width'][t], texts['top'][t] + texts['height'][t]]
104 | tp.append([b, texts['text'][t]])
105 | texts_ocr.append(tp)
106 | preds_name = img.split(".")[0] + '.txt'
107 | with open(os.path.join(path_preds, preds_name), 'r') as preds:
108 | lines = preds.readlines()
109 | boxs = list()
110 | for line in lines:
111 | scaled = scale_back([float(c)
112 | for c in line[:-1].split(" ")[1:]], w, h)
113 | sw, sh = scaled[2] / 2, scaled[3] / 2
114 | boxs.append([scaled[0] - sw, scaled[1] - sh,
115 | scaled[0] + sw, scaled[1] + sh])
116 | boxs_preds.append(boxs)
117 |
118 | gts_name = img.split(".")[0] + '.json'
119 | with open(os.path.join(path_gts, gts_name), 'r') as f:
120 | form = json.load(f)['form']
121 | boxs = list()
122 | pair_labels = []
123 | ids = []
124 | labels = []
125 | for elem in form:
126 | boxs.append([float(e) for e in elem['box']])
127 | ids.append(elem['id'])
128 | labels.append(elem['label'])
129 | [pair_labels.append(pair) for pair in elem['linking']]
130 |
131 | for p, pair in enumerate(pair_labels):
132 | pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])]
133 |
134 | boxs_gts.append(boxs)
135 | links_gts.append(pair_labels)
136 | labels_gts.append(labels)
137 |
138 | all_links = []
139 | all_preds = []
140 | all_labels = []
141 | all_texts = []
142 | dropped_links = 0
143 | dropped_entity = 0
144 |
145 | for p in range(len(boxs_preds)):
146 | d = match_pred_w_gt(torch.tensor(
147 | boxs_preds[p]), torch.tensor(boxs_gts[p]), links_gts[p])
148 | dropped_links += d['link_loss']
149 | dropped_entity += d['entity_loss']
150 | links = list()
151 |
152 | for link in links_gts[p]:
153 | if link[0] in d['false_negative'] or link[1] in d['false_negative']:
154 | continue
155 | else:
156 | links.append([d['gt2pred'][link[0]], d['gt2pred'][link[1]]])
157 | all_links.append(links)
158 |
159 | preds = []
160 | labels = []
161 | texts = []
162 | for b, box in enumerate(boxs_preds[p]):
163 | if b in d['false_positive']:
164 | preds.append(box)
165 | labels.append('other')
166 | else:
167 | gt_id = d['pred2gt'][b]
168 | preds.append(box)
169 | labels.append(labels_gts[p][gt_id])
170 |
171 | text = ''
172 | for tocr in texts_ocr[p]:
173 | if isIn(center(tocr[0]), box):
174 | text += tocr[1] + ' '
175 |
176 | texts.append(text)
177 |
178 | all_preds.append(preds)
179 | all_labels.append(labels)
180 | all_texts.append(texts)
181 | print(dropped_links / len(boxs_preds), dropped_entity / len(boxs_preds))
182 |
183 | if debug:
184 | # random.seed(35)
185 | # rand_idx = random.randint(0, len(os.listdir(path_images)))
186 | print(all_texts[0])
187 | rand_idx = 0
188 | img = Image.open(os.path.join(path_images, os.listdir(
189 | path_images)[rand_idx])).convert('RGB')
190 | draw = ImageDraw.Draw(img)
191 |
192 | rand_boxs_preds = boxs_preds[rand_idx]
193 | rand_boxs_gts = boxs_gts[rand_idx]
194 |
195 | for box in rand_boxs_gts:
196 | draw.rectangle(box, outline='blue', width=3)
197 | for box in rand_boxs_preds:
198 | draw.rectangle(box, outline='red', width=3)
199 |
200 | d = match_pred_w_gt(torch.tensor(rand_boxs_preds),
201 | torch.tensor(rand_boxs_gts), links_gts[rand_idx])
202 | print(d)
203 | for idx in d['pred2gt'].keys():
204 | draw.rectangle(rand_boxs_preds[idx], outline='green', width=3)
205 |
206 | link_true_pos = list()
207 | link_false_neg = list()
208 | for link in links_gts[rand_idx]:
209 | if link[0] in d['false_negative'] or link[1] in d['false_negative']:
210 | link_false_neg.append(link)
211 | start = rand_boxs_gts[link[0]]
212 | end = rand_boxs_gts[link[1]]
213 | draw.line((center(start), center(end)), fill='red', width=3)
214 | else:
215 | link_true_pos.append(link)
216 | start = rand_boxs_preds[d['gt2pred'][link[0]]]
217 | end = rand_boxs_preds[d['gt2pred'][link[1]]]
218 | draw.line((center(start), center(end)), fill='green', width=3)
219 |
220 | precision = 0
221 | recall = 0
222 | for idx, gt in enumerate(boxs_gts):
223 | d = match_pred_w_gt(torch.tensor(
224 | boxs_preds[idx]), torch.tensor(gt), links_gts[rand_idx])
225 | bbox_true_positive = len(d["pred2gt"])
226 | p = bbox_true_positive / \
227 | (bbox_true_positive + len(d["false_positive"]))
228 | r = bbox_true_positive / \
229 | (bbox_true_positive + len(d["false_negative"]))
230 | # f1 += (2 * p * r) / (p + r)
231 | precision += p
232 | recall += r
233 |
234 | precision = precision / len(boxs_gts)
235 | recall = recall / len(boxs_gts)
236 | f1 = (2 * precision * recall) / (precision + recall)
237 | # print(f1, precision, recall)
238 |
239 | img.save('prova.png')
240 |
241 | return all_paths, all_preds, all_links, all_labels, all_texts
242 |
243 |
244 | def save_results():
245 | # TODO output json of matching and check with visualization of images.
246 | return
247 |
248 |
249 | if __name__ == "__main__":
250 | path_preds = DATA / 'FUNSD' / 'test_bbox'
251 | path_images = FUNSD_TEST / 'images'
252 | path_gts = FUNSD_TEST / 'adjusted_annotations'
253 | load_predictions(path_preds, path_gts, path_images, debug=True)
254 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from typing import Tuple
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import math
7 |
8 | def polar(rect_src : list, rect_dst : list) -> Tuple[int, int]:
9 | """Compute distance and angle from src to dst bounding boxes (poolar coordinates considering the src as the center)
10 | Args:
11 | rect_src (list) : source rectangle coordinates
12 | rect_dst (list) : destination rectangle coordinates
13 |
14 | Returns:
15 | tuple (ints): distance and angle
16 | """
17 |
18 | # check relative position
19 | left = (rect_dst[2] - rect_src[0]) <= 0
20 | bottom = (rect_src[3] - rect_dst[1]) <= 0
21 | right = (rect_src[2] - rect_dst[0]) <= 0
22 | top = (rect_dst[3] - rect_src[1]) <= 0
23 |
24 | vp_intersect = (rect_src[0] <= rect_dst[2] and rect_dst[0] <= rect_src[2]) # True if two rects "see" each other vertically, above or under
25 | hp_intersect = (rect_src[1] <= rect_dst[3] and rect_dst[1] <= rect_src[3]) # True if two rects "see" each other horizontally, right or left
26 | rect_intersect = vp_intersect and hp_intersect
27 |
28 | center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
29 |
30 | # evaluate reciprocal position
31 | sc = center(rect_src)
32 | ec = center(rect_dst)
33 | new_ec = (ec[0] - sc[0], ec[1] - sc[1])
34 | angle = int(math.degrees(math.atan2(new_ec[1], new_ec[0])) % 360)
35 |
36 | if rect_intersect:
37 | return 0, angle
38 | elif top and left:
39 | a, b = (rect_dst[2] - rect_src[0]), (rect_dst[3] - rect_src[1])
40 | return int(sqrt(a**2 + b**2)), angle
41 | elif left and bottom:
42 | a, b = (rect_dst[2] - rect_src[0]), (rect_dst[1] - rect_src[3])
43 | return int(sqrt(a**2 + b**2)), angle
44 | elif bottom and right:
45 | a, b = (rect_dst[0] - rect_src[2]), (rect_dst[1] - rect_src[3])
46 | return int(sqrt(a**2 + b**2)), angle
47 | elif right and top:
48 | a, b = (rect_dst[0] - rect_src[2]), (rect_dst[3] - rect_src[1])
49 | return int(sqrt(a**2 + b**2)), angle
50 | elif left:
51 | return (rect_src[0] - rect_dst[2]), angle
52 | elif right:
53 | return (rect_dst[0] - rect_src[2]), angle
54 | elif bottom:
55 | return (rect_dst[1] - rect_src[3]), angle
56 | elif top:
57 | return (rect_src[1] - rect_dst[3]), angle
58 |
59 | def transform_image(img_path : str, scale_image=1.0) -> torch.Tensor:
60 | """ Transform image to torch.Tensor
61 |
62 | Args:
63 | img_path (str) : where the image is stored
64 | scale_image (float) : how much scale the image
65 | """
66 |
67 | np_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
68 | width = int(np_img.shape[1] * scale_image)
69 | height = int(np_img.shape[0] * scale_image)
70 | new_size = (width, height)
71 | np_img = cv2.resize(np_img,new_size)
72 | img = cv2.cvtColor(np_img, cv2.COLOR_BGR2GRAY)
73 | img = img[None,None,:,:]
74 | img = img.astype(np.float32)
75 | img = torch.from_numpy(img)
76 | img = 1.0 - img / 128.0
77 |
78 | return img
79 |
80 | def get_histogram(contents : list) -> list:
81 | """Create histogram of content given a text.
82 |
83 | Args;
84 | contents (list)
85 |
86 | Returns:
87 | list of [x, y, z] - 3-dimension list with float values summing up to 1 where:
88 | - x is the % of literals inside the text
89 | - y is the % of numbers inside the text
90 | - z is the % of other symbols i.e. @, #, .., inside the text
91 | """
92 |
93 | c_histograms = list()
94 |
95 | for token in contents:
96 | num_symbols = 0 # all
97 | num_literals = 0 # A, B etc.
98 | num_figures = 0 # 1, 2, etc.
99 | num_others = 0 # !, @, etc.
100 |
101 | histogram = [0.0000, 0.0000, 0.0000, 0.0000]
102 |
103 | for symbol in token.replace(" ", ""):
104 | if symbol.isalpha():
105 | num_literals += 1
106 | elif symbol.isdigit():
107 | num_figures += 1
108 | else:
109 | num_others += 1
110 | num_symbols += 1
111 |
112 | if num_symbols != 0:
113 | histogram[0] = num_literals / num_symbols
114 | histogram[1] = num_figures / num_symbols
115 | histogram[2] = num_others / num_symbols
116 |
117 | # keep sum 1 after truncate
118 | if sum(histogram) != 1.0:
119 | diff = 1.0 - sum(histogram)
120 | m = max(histogram) + diff
121 | histogram[histogram.index(max(histogram))] = m
122 |
123 | # if symbols not recognized at all or empty, sum everything at 1 in the last
124 | if histogram[0:3] == [0.0,0.0,0.0]:
125 | histogram[3] = 1.0
126 |
127 | c_histograms.append(histogram)
128 |
129 | return c_histograms
130 |
131 | def to_bin(dist :int, angle : int, b=8) -> torch.Tensor:
132 | """ Discretize the space into equal "bins": return a distance and angle into a number between 0 and 1.
133 |
134 | Args:
135 | dist (int): distance in terms of pixel, given by "polar()" util function
136 | angle (int): angle between 0 and 360, given by "polar()" util function
137 | b (int): number of bins, MUST be power of 2
138 |
139 | Returns:
140 | torch.Tensor: new distance and angle (binary encoded)
141 |
142 | """
143 | def isPowerOfTwo(x):
144 | return (x and (not(x & (x - 1))) )
145 |
146 | # dist
147 | assert isPowerOfTwo(b)
148 | m = max(dist) / b
149 | new_dist = []
150 | for d in dist:
151 | bin = int(d / m)
152 | if bin >= b: bin = b - 1
153 | bin = [int(x) for x in list('{0:0b}'.format(bin))]
154 | while len(bin) < sqrt(b): bin.insert(0, 0)
155 | new_dist.append(bin)
156 |
157 | # angle
158 | amplitude = 360 / b
159 | new_angle = []
160 | for a in angle:
161 | bin = (a - amplitude / 2)
162 | bin = int(bin / amplitude)
163 | bin = [int(x) for x in list('{0:0b}'.format(bin))]
164 | while len(bin) < sqrt(b): bin.insert(0, 0)
165 | new_angle.append(bin)
166 |
167 | return torch.cat([torch.tensor(new_dist, dtype=torch.float32), torch.tensor(new_angle, dtype=torch.float32)], dim=1)
168 |
169 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | import os
4 | from PIL import Image, ImageDraw
5 | import json
6 |
7 | from src.data.feature_builder import FeatureBuilder
8 | from src.data.graph_builder import GraphBuilder
9 | from src.data.preprocessing import center
10 | from src.models.graphs import SetModel
11 | from src.paths import CHECKPOINTS, INFERENCE
12 | from src.training.utils import get_device
13 | from src.utils import create_folder
14 |
15 | pretrain = {
16 | 'funsd': {'node_num_classes': 4, 'edge_num_classes': 2},
17 | 'pau': {'node_num_classes': 5, 'edge_num_classes': 2}
18 | }
19 |
20 | def inference(weights, paths, device=-1):
21 | # create a graph per each file
22 | print("Doc2Graph Inference:")
23 | device = get_device(device)
24 |
25 | print("-> Creating graphs ...")
26 | gb = GraphBuilder()
27 | graphs, _, _, features = gb.get_graph(paths, 'CUSTOM')
28 |
29 | # add embedded visual, text, layout etc. features to the graphs
30 | print("-> Creating features ...")
31 | fb = FeatureBuilder(d=device)
32 | chunks, _ = fb.add_features(graphs, features)
33 |
34 | # create the model
35 | print("-> Creating model ...")
36 | model = weights[0].split("-")[0]
37 | pre = weights[0].split("-")[1]
38 | sm = SetModel(name=model, device=device)
39 | info = pretrain[pre]
40 | model = sm.get_model(info['node_num_classes'], info['edge_num_classes'], chunks, False)
41 | model.load_state_dict(torch.load(CHECKPOINTS / weights[0]))
42 | model.eval()
43 |
44 | # predict on graphs
45 | print("Predicting:")
46 | with torch.no_grad():
47 | for num, graph in enumerate(graphs):
48 | _, name = os.path.split(paths[num])
49 | name = name.split(".")[0]
50 | print(f" -> {name}")
51 |
52 | n, e = model(graph.to(device), graph.ndata['feat'].to(device))
53 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1)
54 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1)
55 |
56 | # save results
57 | links = (epreds == 1).nonzero(as_tuple=True)[0].tolist()
58 | u, v = graph.edges()
59 | entities = features['boxs'][num]
60 | contents = features['texts'][num]
61 |
62 | graph_img = Image.open(paths[num]).convert('RGB')
63 | graph_draw = ImageDraw.Draw(graph_img)
64 |
65 | result = []
66 | for i, idx in enumerate(links):
67 | pair = {'key': {'text': contents[u[idx]], 'box': entities[u[idx]]},
68 | 'value': {'text': contents[v[idx]], 'box': entities[v[idx]]}}
69 | result.append(pair)
70 |
71 | key_center = center(entities[u[idx]])
72 | value_center = center(entities[v[idx]])
73 | graph_draw.line((key_center, value_center), fill='violet', width=3)
74 | graph_draw.ellipse([(key_center[0]-4,key_center[1]-4), (key_center[0]+4,key_center[1]+4)], fill = 'green', outline='black')
75 | graph_draw.ellipse([(value_center[0]-4,value_center[1]-4), (value_center[0]+4,value_center[1]+4)], fill = 'red', outline='black')
76 |
77 | graph_img.save(INFERENCE / f'{name}.png')
78 | with open(INFERENCE / f'{name}.json', "w") as outfile:
79 | json.dump(result, outfile)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from src.data.download import get_data
4 | from src.inference import inference
5 | from src.training.funsd import train_funsd
6 | from src.utils import create_folder, project_tree, set_preprocessing
7 | from src.training.pau import train_pau
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description='Training')
11 |
12 | # init
13 | parser.add_argument('--init', action="store_true",
14 | help="download data and prepare folders")
15 |
16 | # features
17 | parser.add_argument('--add-geom', '-addG', action="store_true",
18 | help="add geometrical features to nodes")
19 | parser.add_argument('--add-embs', '-addT', action="store_true",
20 | help="add textual embeddings to nodes")
21 | parser.add_argument('--add-hist', '-addH', action="store_true",
22 | help="add histogram of contents to nodes")
23 | parser.add_argument('--add-visual', '-addV', action="store_true",
24 | help="add visual features to nodes")
25 | parser.add_argument('--add-eweights', '-addE', action="store_true",
26 | help="add edge features to graphs")
27 | # data
28 | parser.add_argument("--src-data", type=str, default='FUNSD',
29 | help="which data source to use. It can be FUNSD, PAU or CUSTOM")
30 | parser.add_argument("--data-type", type=str, default='img',
31 | help="if src-data is CUSTOM, define the data source type: img or pdf.")
32 | # graphs
33 | parser.add_argument("--edge-type", type=str, default='fully',
34 | help="choose the kind of connectivity in the graph. It can be: fully or knn.")
35 | parser.add_argument("--node-granularity", type=str, default='gt',
36 | help="choose the granularity of nodes to be used. It can be: gt (if given), ocr (words) or yolo (entities).")
37 | parser.add_argument("--num-polar-bins", type=int, default=8,
38 | help="number of bins into which discretize the space for edge polar features. It must be a power of 2: Default 8.")
39 |
40 | # training
41 | parser.add_argument("--model", type=str, default='e2e',
42 | help="which model to use, which yaml file to load: e2e, edge or gcn")
43 | parser.add_argument("--gpu", type=int, default=-1,
44 | help="which GPU to use. Set -1 to use CPU.")
45 | parser.add_argument('--test', action="store_true",
46 | help="skip training")
47 | parser.add_argument('--weights', '-w', nargs='+', type=str, default=None,
48 | help="provide a weights file relative path if testing")
49 |
50 | # inference
51 | parser.add_argument('--inference', action="store_true",
52 | help="use the model to predict on new, unseen examples")
53 | parser.add_argument('--docs', nargs='+', type=str, default=None,
54 | help="provide documents to do inference on them")
55 |
56 | args = parser.parse_args()
57 | print(args)
58 |
59 | if args.init:
60 | project_tree()
61 | get_data()
62 | print("Initialization completed!")
63 |
64 | else:
65 | set_preprocessing(args)
66 | if args.inference:
67 | create_folder('inference')
68 | inference(args.weights, args.docs, args.gpu)
69 | elif args.src_data == 'FUNSD':
70 | if args.test and args.weights == None:
71 | raise Exception("Main exception: Provide a weights file relative path! Or train a model first.")
72 | train_funsd(args)
73 | elif args.src_data == 'PAU':
74 | train_pau(args)
75 | elif args.src_data == 'CUSTOM':
76 | #TODO develop custom data preprocessing
77 | raise Exception('Main exception: "CUSTOM" source data still under development')
78 | else:
79 | raise Exception('Main exception: source data invalid. Choose from ["FUNSD", "PAU", "CUSTOM"]')
80 |
81 | return
82 |
83 | if __name__ == '__main__':
84 | main()
85 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/graphs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import dgl.function as fn
4 | import math
5 | import torch.nn.functional as F
6 |
7 | from src.paths import CFGM
8 | from src.utils import get_config
9 |
10 | class SetModel():
11 | def __init__(self, name='e2e', device = 'cpu'):
12 | """ Create a SetModel object, that handles dinamically different version of Doc2Graph Model. Default "end-to-end" (e2e)
13 |
14 | Args:
15 | name (str) : Which model to train / test. Default: e2e [gcn, edge].
16 |
17 | Returns:
18 | SetModel object.
19 | """
20 |
21 | self.cfg_model = get_config(CFGM / name)
22 | self.name = self.cfg_model.name
23 | self.total_params = 0
24 | self.device = device
25 |
26 | def get_name(self) -> str:
27 | """ Returns model name.
28 | """
29 | return self.name
30 |
31 | def get_total_params(self) -> int:
32 | """ Returns number of model parameteres.
33 | """
34 | return self.total_params
35 |
36 | def get_model(self, nodes : int, edges : int, chunks : list, verbatim : bool = True) -> nn.Module:
37 | """Return the DGL model defined in the setting file
38 |
39 | Args:
40 | nodes (int) : number of nodes target class
41 | edges (int) : number of edges target class
42 | chunks (list) : list of indeces of chunks
43 |
44 | Returns:
45 | A PyTorch nn.Module, your DGL model.
46 | """
47 | print("\n### MODEL ###")
48 | print(f"-> Using {self.name}")
49 |
50 | if self.name == 'GCN':
51 | m = NodeClassifier(chunks, self.cfg_model.out_chunks, nodes, self.cfg_model.num_layers, F.relu, False, self.device)
52 |
53 | elif self.name == 'EDGE':
54 | m = EdgeClassifier(edges, self.cfg_model.num_layers, self.cfg_model.dropout, chunks, self.cfg_model.out_chunks, self.cfg_model.hidden_dim, self.device, self.cfg_model.doProject)
55 |
56 | elif self.name == 'E2E':
57 | edge_pred_features = int((math.log2(get_config('preprocessing').FEATURES.num_polar_bins) + nodes)*2)
58 | m = E2E(nodes, edges, self.cfg_model.num_layers, self.cfg_model.dropout, chunks, self.cfg_model.out_chunks, self.cfg_model.hidden_dim, self.device, edge_pred_features, self.cfg_model.doProject)
59 |
60 | else:
61 | raise Exception(f"Error! Model {self.name} do not exists.")
62 |
63 | m.to(self.device)
64 | self.total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
65 | print(f"-> Total params: {self.total_params}")
66 | print("-> Device: " + str(next(m.parameters()).is_cuda) + "\n")
67 | if verbatim: print(m)
68 |
69 | return m
70 |
71 | ################
72 | ##### GCNS #####
73 |
74 | class NodeClassifier(nn.Module):
75 | def __init__(self,
76 | in_chunks,
77 | out_chunks,
78 | n_classes,
79 | n_layers,
80 | activation,
81 | dropout=0,
82 | use_pp=False,
83 | device='cuda:0'):
84 | super(NodeClassifier, self).__init__()
85 |
86 | self.projector = InputProjector(in_chunks, out_chunks, device)
87 | self.layers = nn.ModuleList()
88 | # self.dropout = nn.Dropout(dropout)
89 | self.n_layers = n_layers
90 |
91 | n_hidden = self.projector.get_out_lenght()
92 |
93 | # mp layers
94 | for i in range(0, n_layers - 1):
95 | self.layers.append(GcnSAGELayer(n_hidden, n_hidden, activation=activation,
96 | dropout=dropout, use_pp=False, use_lynorm=True))
97 |
98 | self.layers.append(GcnSAGELayer(n_hidden, n_classes, activation=None,
99 | dropout=False, use_pp=False, use_lynorm=False))
100 |
101 | def forward(self, g, h):
102 |
103 | h = self.projector(h)
104 |
105 | for l in range(self.n_layers):
106 | h = self.layers[l](g, h)
107 |
108 | return h
109 |
110 | ################
111 | ##### EDGE #####
112 |
113 | class EdgeClassifier(nn.Module):
114 |
115 | def __init__(self, edge_classes, m_layers, dropout, in_chunks, out_chunks, hidden_dim, device, doProject=True):
116 | super().__init__()
117 |
118 | # Project inputs into higher space
119 | self.projector = InputProjector(in_chunks, out_chunks, device, doProject)
120 |
121 | # Perform message passing
122 | m_hidden = self.projector.get_out_lenght()
123 | self.message_passing = nn.ModuleList()
124 | self.m_layers = m_layers
125 | for l in range(m_layers):
126 | self.message_passing.append(GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.))
127 |
128 | # Define edge predictori layer
129 | self.edge_pred = MLPPredictor(m_hidden, hidden_dim, edge_classes, dropout)
130 |
131 | def forward(self, g, h):
132 |
133 | h = self.projector(h)
134 |
135 | for l in range(self.m_layers):
136 | h = self.message_passing[l](g, h)
137 |
138 | e = self.edge_pred(g, h)
139 |
140 | return e
141 |
142 | ################
143 | ###### E2E #####
144 |
145 | class E2E(nn.Module):
146 | def __init__(self, node_classes,
147 | edge_classes,
148 | m_layers,
149 | dropout,
150 | in_chunks,
151 | out_chunks,
152 | hidden_dim,
153 | device,
154 | edge_pred_features,
155 | doProject=True):
156 |
157 | super().__init__()
158 |
159 | # Project inputs into higher space
160 | self.projector = InputProjector(in_chunks, out_chunks, device, doProject)
161 |
162 | # Perform message passing
163 | m_hidden = self.projector.get_out_lenght()
164 | self.message_passing = nn.ModuleList()
165 | # self.m_layers = m_layers
166 | # for l in range(m_layers):
167 | # self.message_passing.append(GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.))
168 | self.message_passing = GcnSAGELayer(m_hidden, m_hidden, F.relu, 0.)
169 |
170 | # Define edge predictor layer
171 | self.edge_pred = MLPPredictor_E2E(m_hidden, hidden_dim, edge_classes, dropout, edge_pred_features)
172 |
173 | # Define node predictor layer
174 | node_pred = []
175 | node_pred.append(nn.Linear(m_hidden, node_classes))
176 | node_pred.append(nn.LayerNorm(node_classes))
177 | self.node_pred = nn.Sequential(*node_pred)
178 |
179 | def forward(self, g, h):
180 |
181 | h = self.projector(h)
182 | # for l in range(self.m_layers):
183 | # h = self.message_passing[l](g, h)
184 | h = self.message_passing(g,h)
185 | n = self.node_pred(h)
186 | e = self.edge_pred(g, h, n)
187 |
188 | return n, e
189 |
190 | ################
191 | ##### LYRS #####
192 |
193 | class GcnSAGELayer(nn.Module):
194 | def __init__(self,
195 | in_feats,
196 | out_feats,
197 | activation,
198 | dropout,
199 | bias=True,
200 | use_pp=False,
201 | use_lynorm=True):
202 | super(GcnSAGELayer, self).__init__()
203 | self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)
204 | self.activation = activation
205 | self.use_pp = use_pp
206 | if dropout:
207 | self.dropout = nn.Dropout(p=dropout)
208 | else:
209 | self.dropout = 0.
210 | if use_lynorm:
211 | self.lynorm = nn.LayerNorm(out_feats, elementwise_affine=True)
212 | else:
213 | self.lynorm = lambda x: x
214 | self.reset_parameters()
215 |
216 | def reset_parameters(self):
217 | stdv = 1. / math.sqrt(self.linear.weight.size(1))
218 | self.linear.weight.data.uniform_(-stdv, stdv)
219 | if self.linear.bias is not None:
220 | self.linear.bias.data.uniform_(-stdv, stdv)
221 |
222 | def forward(self, g, h):
223 | g = g.local_var()
224 |
225 | if not self.use_pp:
226 | # norm = self.get_norm(g)
227 | norm = g.ndata['norm']
228 | g.ndata['h'] = h
229 | g.update_all(fn.u_mul_e('h', 'weights', 'm'),
230 | fn.sum(msg='m', out='h'))
231 | ah = g.ndata.pop('h')
232 | h = self.concat(h, ah, norm)
233 |
234 | if self.dropout:
235 | h = self.dropout(h)
236 |
237 | h = self.linear(h)
238 | h = self.lynorm(h)
239 | if self.activation:
240 | h = self.activation(h)
241 | return h
242 |
243 | def concat(self, h, ah, norm):
244 | ah = ah * norm
245 | h = torch.cat((h, ah), dim=1)
246 | return h
247 |
248 | def get_norm(self, g):
249 | norm = 1. / g.in_degrees().float().unsqueeze(1)
250 | norm[torch.isinf(norm)] = 0
251 | norm = norm.to(self.linear.weight.device)
252 | return norm
253 |
254 | class InputProjector(nn.Module):
255 | def __init__(self, in_chunks : list, out_chunks : int, device, doIt = True) -> None:
256 | super().__init__()
257 |
258 | if not doIt:
259 | self.output_length = sum(in_chunks)
260 | self.doIt = doIt
261 | return
262 |
263 | self.output_length = len(in_chunks)*out_chunks
264 | self.doIt = doIt
265 | self.chunks = in_chunks
266 | modules = []
267 | self.device = device
268 |
269 | for chunk in in_chunks:
270 | chunk_module = []
271 | chunk_module.append(nn.Linear(chunk, out_chunks))
272 | chunk_module.append(nn.LayerNorm(out_chunks))
273 | chunk_module.append(nn.ReLU())
274 | modules.append(nn.Sequential(*chunk_module))
275 |
276 | self.modalities = nn.Sequential(*modules)
277 | self.chunks.insert(0, 0)
278 |
279 | def get_out_lenght(self):
280 | return self.output_length
281 |
282 | def forward(self, h):
283 |
284 | if not self.doIt:
285 | return h
286 |
287 | mid = []
288 |
289 | for name, module in self.modalities.named_children():
290 | num = int(name)
291 | if num + 1 == len(self.chunks): break
292 | start = self.chunks[num] + sum(self.chunks[:num])
293 | end = start + self.chunks[num+1]
294 | input = h[:, start:end].to(self.device)
295 | mid.append(module(input))
296 |
297 | return torch.cat(mid, dim=1)
298 |
299 | class MLPPredictor(nn.Module):
300 | def __init__(self, in_features, hidden_dim, out_classes, dropout):
301 | super().__init__()
302 | self.out = out_classes
303 | self.W1 = nn.Linear(in_features*2, hidden_dim)
304 | self.norm = nn.LayerNorm(hidden_dim)
305 | self.W2 = nn.Linear(hidden_dim + 6, out_classes)
306 | self.drop = nn.Dropout(dropout)
307 |
308 | def apply_edges(self, edges):
309 | h_u = edges.src['h']
310 | h_v = edges.dst['h']
311 | polar = edges.data['feat']
312 |
313 | x = F.relu(self.norm(self.W1(torch.cat((h_u, h_v), dim=1))))
314 | x = torch.cat((x, polar), dim=1)
315 | score = self.drop(self.W2(x))
316 |
317 | return {'score': score}
318 |
319 | def forward(self, graph, h):
320 | # h contains the node representations computed from the GNN defined
321 | # in the node classification section (Section 5.1).
322 | with graph.local_scope():
323 | graph.ndata['h'] = h
324 | graph.apply_edges(self.apply_edges)
325 | return graph.edata['score']
326 |
327 | class MLPPredictor_E2E(nn.Module):
328 | def __init__(self, in_features, hidden_dim, out_classes, dropout, edge_pred_features):
329 | super().__init__()
330 | self.out = out_classes
331 | self.W1 = nn.Linear(in_features*2 + edge_pred_features, hidden_dim)
332 | self.norm = nn.LayerNorm(hidden_dim)
333 | self.W2 = nn.Linear(hidden_dim, out_classes)
334 | self.drop = nn.Dropout(dropout)
335 |
336 | def apply_edges(self, edges):
337 | h_u = edges.src['h']
338 | h_v = edges.dst['h']
339 | cls_u = F.softmax(edges.src['cls'], dim=1)
340 | cls_v = F.softmax(edges.dst['cls'], dim=1)
341 | polar = edges.data['feat']
342 |
343 | x = F.relu(self.norm(self.W1(torch.cat((h_u, cls_u, polar, h_v, cls_v), dim=1))))
344 | score = self.drop(self.W2(x))
345 |
346 | return {'score': score}
347 |
348 | def forward(self, graph, h, cls):
349 | # h contains the node representations computed from the GNN defined
350 | # in the node classification section (Section 5.1).
351 | with graph.local_scope():
352 | graph.ndata['h'] = h
353 | graph.ndata['cls'] = cls
354 | graph.apply_edges(self.apply_edges)
355 | return graph.edata['score']
356 |
--------------------------------------------------------------------------------
/src/models/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Unet
--------------------------------------------------------------------------------
/src/models/unet/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from segmentation_models_pytorch.base import modules as md
6 |
7 |
8 | class DecoderBlock(nn.Module):
9 | def __init__(
10 | self,
11 | in_channels,
12 | skip_channels,
13 | out_channels,
14 | use_batchnorm=True,
15 | attention_type=None,
16 | ):
17 | super().__init__()
18 | self.conv1 = md.Conv2dReLU(
19 | in_channels + skip_channels,
20 | out_channels,
21 | kernel_size=3,
22 | padding=1,
23 | use_batchnorm=use_batchnorm,
24 | )
25 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
26 | self.conv2 = md.Conv2dReLU(
27 | out_channels,
28 | out_channels,
29 | kernel_size=3,
30 | padding=1,
31 | use_batchnorm=use_batchnorm,
32 | )
33 | self.attention2 = md.Attention(attention_type, in_channels=out_channels)
34 |
35 | def forward(self, x, skip=None):
36 | if skip is not None:
37 | x = F.interpolate(x, size=skip.shape[2:], mode="nearest")
38 | x = torch.cat([x, skip], dim=1)
39 | x = self.attention1(x)
40 | else:
41 | x = F.interpolate(x, scale_factor=2, mode="nearest")
42 | x = self.conv1(x)
43 | x = self.conv2(x)
44 | x = self.attention2(x)
45 | return x
46 |
47 |
48 | class CenterBlock(nn.Sequential):
49 | def __init__(self, in_channels, out_channels, use_batchnorm=True):
50 | conv1 = md.Conv2dReLU(
51 | in_channels,
52 | out_channels,
53 | kernel_size=3,
54 | padding=1,
55 | use_batchnorm=use_batchnorm,
56 | )
57 | conv2 = md.Conv2dReLU(
58 | out_channels,
59 | out_channels,
60 | kernel_size=3,
61 | padding=1,
62 | use_batchnorm=use_batchnorm,
63 | )
64 | super().__init__(conv1, conv2)
65 |
66 |
67 | class UnetDecoder(nn.Module):
68 | def __init__(
69 | self,
70 | encoder_channels,
71 | decoder_channels,
72 | n_blocks=5,
73 | use_batchnorm=True,
74 | attention_type=None,
75 | center=False,
76 | ):
77 | super().__init__()
78 |
79 | if n_blocks != len(decoder_channels):
80 | raise ValueError(
81 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
82 | n_blocks, len(decoder_channels)
83 | )
84 | )
85 |
86 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
87 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
88 |
89 | # computing blocks input and output channels
90 | head_channels = encoder_channels[0]
91 | in_channels = [head_channels] + list(decoder_channels[:-1])
92 | skip_channels = list(encoder_channels[1:]) + [0]
93 | out_channels = decoder_channels
94 |
95 | if center:
96 | self.center = CenterBlock(
97 | head_channels, head_channels, use_batchnorm=use_batchnorm
98 | )
99 | else:
100 | self.center = nn.Identity()
101 |
102 | # combine decoder keyword arguments
103 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
104 | blocks = [
105 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
106 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
107 | ]
108 | self.blocks = nn.ModuleList(blocks)
109 |
110 | def forward(self, *features):
111 |
112 | features = features[1:] # remove first skip with same spatial resolution
113 | features = features[::-1] # reverse channels to start from head of encoder
114 |
115 | head = features[0]
116 | skips = features[1:]
117 |
118 | x = self.center(head)
119 | for i, decoder_block in enumerate(self.blocks):
120 | skip = skips[i] if i < len(skips) else None
121 | x = decoder_block(x, skip)
122 |
123 | return x
124 |
--------------------------------------------------------------------------------
/src/models/unet/model.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List
2 | from .decoder import UnetDecoder
3 | from segmentation_models_pytorch.encoders import get_encoder
4 | from segmentation_models_pytorch.base import SegmentationModel
5 | from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead
6 |
7 |
8 | class Unet(SegmentationModel):
9 | """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
10 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
11 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
12 | for fusing decoder blocks with skip connections.
13 |
14 | Args:
15 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
16 | to extract features of different spatial resolution
17 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
18 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
19 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
20 | Default is 5
21 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
22 | other pretrained weights (see table with available weights for each encoder_name)
23 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
24 | Length of the list should be the same as **encoder_depth**
25 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
26 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
27 | Available options are **True, False, "inplace"**
28 | decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse**.
29 | SCSE paper - https://arxiv.org/abs/1808.08127
30 | in_channels: A number of input channels for the model, default is 3 (RGB images)
31 | classes: A number of classes for output mask (or you can think as a number of channels of output mask)
32 | activation: An activation function to apply after the final convolution layer.
33 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
34 | Default is **None**
35 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
36 | on top of encoder if **aux_params** is not **None** (default). Supported params:
37 | - classes (int): A number of classes
38 | - pooling (str): One of "max", "avg". Default is "avg"
39 | - dropout (float): Dropout factor in [0, 1)
40 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits)
41 |
42 | Returns:
43 | ``torch.nn.Module``: Unet
44 |
45 | .. _Unet:
46 | https://arxiv.org/abs/1505.04597
47 |
48 | """
49 |
50 | def __init__(
51 | self,
52 | encoder_name: str = "resnet34",
53 | encoder_depth: int = 5,
54 | encoder_weights: Optional[str] = "imagenet",
55 | decoder_use_batchnorm: bool = True,
56 | decoder_channels: List[int] = (256, 128, 64, 32, 16),
57 | decoder_attention_type: Optional[str] = None,
58 | in_channels: int = 3,
59 | classes: int = 1,
60 | activation: Optional[Union[str, callable]] = None,
61 | aux_params: Optional[dict] = None,
62 | ):
63 | super().__init__()
64 |
65 | self.encoder = get_encoder(
66 | encoder_name,
67 | in_channels=in_channels,
68 | depth=encoder_depth,
69 | weights=encoder_weights,
70 | )
71 |
72 | self.decoder = UnetDecoder(
73 | encoder_channels=self.encoder.out_channels,
74 | decoder_channels=decoder_channels,
75 | n_blocks=encoder_depth,
76 | use_batchnorm=decoder_use_batchnorm,
77 | center=True if encoder_name.startswith("vgg") else False,
78 | attention_type=decoder_attention_type,
79 | )
80 |
81 | self.segmentation_head = SegmentationHead(
82 | in_channels=decoder_channels[-1],
83 | out_channels=classes,
84 | activation=activation,
85 | kernel_size=3,
86 | )
87 |
88 | if aux_params is not None:
89 | self.classification_head = ClassificationHead(
90 | in_channels=self.encoder.out_channels[-1], **aux_params
91 | )
92 | else:
93 | self.classification_head = None
94 |
95 | self.name = "u-{}".format(encoder_name)
96 | self.initialize()
97 |
--------------------------------------------------------------------------------
/src/paths.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from dotenv import dotenv_values
3 | import os
4 |
5 | # ROOT
6 | HERE = Path(os.path.dirname(os.path.abspath(__file__)))
7 | config = dotenv_values(HERE / "root.env")
8 | ROOT = Path(config['ROOT'])
9 |
10 | # PROJECT TREE
11 | DATA = ROOT / 'DATA'
12 | CONFIGS = ROOT / 'configs'
13 | CFGM = CONFIGS / 'models'
14 | OUTPUTS = ROOT / 'outputs'
15 | RUNS = OUTPUTS / 'runs'
16 | RESULTS = OUTPUTS / 'results'
17 | IMGS = OUTPUTS / 'images'
18 | TRAIN_SAMPLES = OUTPUTS / 'train_samples'
19 | TEST_SAMPLES = OUTPUTS / 'test_samples'
20 | TRAINING = ROOT / 'src' / 'training'
21 | MODELS = ROOT / 'src' / 'models'
22 | CHECKPOINTS = MODELS / 'checkpoints'
23 | INFERENCE = ROOT / 'inference'
24 |
25 | # FUNSD
26 | FUNSD_TRAIN = DATA / 'FUNSD' / 'training_data'
27 | FUNSD_TEST = DATA / 'FUNSD' / 'testing_data'
28 |
29 | # PAU
30 | PAU_TRAIN = DATA / 'PAU' / 'train'
31 | PAU_TEST = DATA / 'PAU' / 'test'
32 |
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreagemelli/doc2graph/99ac9e694b450b54b998d8a43063f0da942989b2/src/training/__init__.py
--------------------------------------------------------------------------------
/src/training/funsd.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from sklearn.model_selection import KFold, ShuffleSplit
3 | import torch
4 | from torch.nn import functional as F
5 | from random import shuffle, seed
6 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
7 | import dgl
8 | from torch.utils.tensorboard import SummaryWriter
9 | from torchvision import transforms
10 | import time
11 | from statistics import mean
12 | import numpy as np
13 | from PIL import Image
14 |
15 | from src.data.dataloader import Document2Graph
16 | from src.paths import *
17 | from src.models.graphs import SetModel
18 | from src.utils import get_config
19 | from src.training.utils import *
20 | from src.data.graph_builder import GraphBuilder
21 |
22 | def e2e(args):
23 |
24 | # configs
25 | start_training = time.time()
26 | cfg_train = get_config('train')
27 | seed(cfg_train.seed)
28 | device = get_device(args.gpu)
29 | sm = SetModel(name=args.model, device=device)
30 |
31 | if not args.test:
32 | ################* STEP 0: LOAD DATA ################
33 | data = Document2Graph(name='FUNSD TRAIN', src_path=FUNSD_TRAIN, device = device, output_dir=TRAIN_SAMPLES)
34 | data.get_info()
35 |
36 | ss = KFold(n_splits=10, shuffle=True, random_state=cfg_train.seed)
37 | cv_indices = ss.split(data.graphs)
38 |
39 | models = []
40 | train_index, val_index = next(ss.split(data.graphs))
41 |
42 | for cvs in cv_indices:
43 |
44 | train_index, val_index = cvs
45 |
46 | # TRAIN
47 | train_graphs = [data.graphs[i] for i in train_index]
48 | tg = dgl.batch(train_graphs)
49 | tg = tg.int().to(device)
50 |
51 | val_graphs = [data.graphs[i] for i in val_index]
52 | vg = dgl.batch(val_graphs)
53 | vg = vg.int().to(device)
54 |
55 | ################* STEP 1: CREATE MODEL ################
56 | model = sm.get_model(data.node_num_classes, data.edge_num_classes, data.get_chunks())
57 | optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay))
58 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=400, min_lr=1e-3, verbose=True, factor=0.01)
59 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
60 | e = datetime.now()
61 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}'
62 | models.append(train_name+'.pt')
63 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=2000)
64 | # writer = SummaryWriter(log_dir=RUNS)
65 | # convert_imgs = transforms.ToTensor()
66 |
67 | ################* STEP 2: TRAINING ################
68 | print("\n### TRAINING ###")
69 | print(f"-> Training samples: {tg.batch_size}")
70 | print(f"-> Validation samples: {vg.batch_size}\n")
71 |
72 | # im_step = 0
73 | for epoch in range(cfg_train.epochs):
74 |
75 | #* TRAINING
76 | model.train()
77 |
78 | n_scores, e_scores = model(tg, tg.ndata['feat'].to(device))
79 | n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device))
80 | e_loss = compute_crossentropy_loss(e_scores.to(device), tg.edata['label'].to(device))
81 | tot_loss = n_loss + e_loss
82 | macro, micro = get_f1(n_scores, tg.ndata['label'].to(device))
83 | auc = compute_auc_mc(e_scores.to(device), tg.edata['label'].to(device))
84 |
85 |
86 | optimizer.zero_grad()
87 | tot_loss.backward()
88 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
89 | optimizer.step()
90 |
91 | #* VALIDATION
92 | model.eval()
93 | with torch.no_grad():
94 | val_n_scores, val_e_scores = model(vg, vg.ndata['feat'].to(device))
95 | val_n_loss = compute_crossentropy_loss(val_n_scores.to(device), vg.ndata['label'].to(device))
96 | val_e_loss = compute_crossentropy_loss(val_e_scores.to(device), vg.edata['label'].to(device))
97 | val_tot_loss = val_n_loss + val_e_loss
98 | val_macro, _ = get_f1(val_n_scores, vg.ndata['label'].to(device))
99 | val_auc = compute_auc_mc(val_e_scores.to(device), vg.edata['label'].to(device))
100 |
101 | # scheduler.step(val_auc)
102 | # scheduler.step()
103 |
104 | #* PRINTING IMAGEs AND RESULTS
105 |
106 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainF1-MACRO {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValF1-MACRO {:.4f} | ValAUC-PR {:.4f} |"
107 | .format(epoch, tot_loss.item(), macro, auc, val_tot_loss.item(), val_macro, val_auc))
108 |
109 | if cfg_train.stopper_metric == 'loss':
110 | step_value = val_tot_loss.item()
111 | elif cfg_train.stopper_metric == 'acc':
112 | step_value = val_auc
113 |
114 | # if val_auc > best_val_auc:
115 | # best_val_auc = val_auc
116 | # best_model = train_name
117 |
118 | ss = stopper.step(step_value)
119 |
120 | # if ss == 'improved':
121 | # im_step = epoch
122 | # train_imgs = []
123 | # for r in rand_tid:
124 | # start, end = 0, 0
125 | # for tid in train_index:
126 | # start = end
127 | # end += data.graphs[tid].num_edges()
128 | # if tid == r: break
129 |
130 | # _, targets = torch.max(F.log_softmax(e_scores[start:end], dim=1), dim=1)
131 | # kvp_ids = targets.nonzero().flatten().tolist()
132 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700])
133 | # # data.print_graph(num=r, name=f'train_labels_{r}')
134 |
135 | # val_imgs = []
136 | # for r in rand_vid:
137 | # v_start, v_end = 0, 0
138 | # for vid in val_index:
139 | # v_start = v_end
140 | # v_end += data.graphs[vid].num_edges()
141 | # if vid == r: break
142 |
143 | # _, val_targets = torch.max(F.log_softmax(val_e_scores[v_start:v_end], dim=1), dim=1)
144 | # val_kvp_ids = val_targets.nonzero().flatten().tolist()
145 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700])
146 | # data.print_graph(num=r, name=f'val_labels_{r}')
147 |
148 | if ss == 'stop':
149 | break
150 |
151 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch)
152 | # writer.add_scalars('LOSS', {'train': tot_loss.item(), 'val': val_tot_loss.item()}, epoch)
153 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
154 |
155 | # train_grid = torchvision.utils.make_grid(train_imgs)
156 | # writer.add_image('train_images', train_grid, im_step)
157 | # val_grid = torchvision.utils.make_grid(val_imgs)
158 | # writer.add_image('val_images', val_grid, im_step)
159 |
160 | else:
161 | ################* SKIP TRAINING ################
162 | print("\n### SKIP TRAINING ###")
163 | print(f"-> loading {args.weights}")
164 | models = args.weights
165 |
166 | ################* STEP 3: TESTING ################
167 | print("\n### TESTING ###")
168 |
169 | #? test
170 | test_data = Document2Graph(name='FUNSD TEST', src_path=FUNSD_TEST, device = device, output_dir=TEST_SAMPLES)
171 | test_data.get_info()
172 |
173 | model = sm.get_model(test_data.node_num_classes, test_data.edge_num_classes, test_data.get_chunks())
174 | best_model = ''
175 | nodes_micro = []
176 | edges_f1 = []
177 | test_graph = dgl.batch(test_data.graphs).to(device)
178 |
179 | for m in models:
180 | model.load_state_dict(torch.load(CHECKPOINTS / m))
181 | model.eval()
182 | with torch.no_grad():
183 |
184 | n, e = model(test_graph, test_graph.ndata['feat'].to(device))
185 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device))
186 | _, preds = torch.max(F.softmax(e, dim=1), dim=1)
187 |
188 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'])
189 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True)
190 | edges_f1.append(classes_f1[1])
191 |
192 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device))
193 | nodes_micro.append(micro)
194 | if classes_f1[1] >= max(edges_f1):
195 | best_model = m
196 |
197 | test_graph.edata['preds'] = preds
198 |
199 | ################* STEP 4: RESULTS ################
200 | print("\n### RESULTS {} ###".format(m))
201 | print("F1 Edges: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1]))
202 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro))
203 |
204 | print(f"\n -> Loading best model {best_model}")
205 | model.load_state_dict(torch.load(CHECKPOINTS / best_model))
206 | model.eval()
207 | with torch.no_grad():
208 |
209 | n, e = model(test_graph, test_graph.ndata['feat'].to(device))
210 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device))
211 |
212 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1)
213 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1)
214 | test_graph.edata['preds'] = epreds
215 | test_graph.ndata['preds'] = npreds
216 | test_graph.ndata['net'] = n
217 |
218 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'])
219 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True)
220 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device))
221 |
222 | # ################* STEP 4: RESULTS ################
223 | print("\n### BEST RESULTS ###")
224 | print("AUC {:.4f}".format(auc))
225 | print("Accuracy {:.4f}".format(accuracy))
226 | print("F1 Edges: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1]))
227 | print("F1 Edges: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1]))
228 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro))
229 |
230 | print("\n### AVG RESULTS ###")
231 | print("Semantic Entity Labeling: MEAN ", mean(nodes_micro), " STD: ", np.std(nodes_micro))
232 | print("Entity Linking: MEAN ", mean(edges_f1),"STD", np.std(edges_f1))
233 |
234 | if not args.test:
235 | feat_n, feat_e = get_features(args)
236 | #? if skipping training, no need to save anything
237 | model = get_config(CFGM / args.model)
238 | results = {'MODEL': {
239 | 'name': sm.get_name(),
240 | 'weights': best_model,
241 | 'net-params': sm.get_total_params(),
242 | 'num-layers': model.num_layers,
243 | 'projector-output': model.out_chunks,
244 | 'dropout': model.dropout,
245 | 'lastFC': model.hidden_dim
246 | },
247 | 'FEATURES': {
248 | 'nodes': feat_n,
249 | 'edges': feat_e
250 | },
251 | 'PARAMS': {
252 | 'start-lr': cfg_train.lr,
253 | 'weight-decay': cfg_train.weight_decay,
254 | 'seed': cfg_train.seed
255 | },
256 | 'RESULTS': {
257 | 'val-loss': stopper.best_score,
258 | 'f1-scores': f1,
259 | 'f1-classes': classes_f1,
260 | 'nodes-f1': [macro, micro],
261 | 'std-pairs': np.std(edges_f1),
262 | 'mean-pairs': mean(edges_f1)
263 | }}
264 | save_test_results(train_name, results)
265 |
266 | print("END TRAINING:", time.time() - start_training)
267 | return {'LINKS [MAX, MEAN, STD]': [classes_f1[1], mean(edges_f1), np.std(edges_f1)], 'NODES [MAX, MEAN, STD]': [micro, mean(nodes_micro), np.std(nodes_micro)]}
268 |
269 | def entity_linking(args):
270 |
271 | # configs
272 | start_training = time.time()
273 | cfg_train = get_config('train')
274 | seed(cfg_train.seed)
275 | device = get_device(args.gpu)
276 | sm = SetModel(name=args.model, device=device)
277 |
278 | if not args.test:
279 | ################* STEP 0: LOAD DATA ################
280 | data = Document2Graph(name='FUNSD TRAIN', src_path=FUNSD_TRAIN, device = device, output_dir=TRAIN_SAMPLES)
281 | data.get_info()
282 |
283 | ss = KFold(n_splits=10, shuffle=True, random_state=cfg_train.seed)
284 | cv_indices = ss.split(data.graphs)
285 |
286 | models = []
287 | train_index, val_index = next(ss.split(data.graphs))
288 |
289 | for cvs in cv_indices:
290 |
291 | train_index, val_index = cvs
292 |
293 | # TRAIN
294 | train_graphs = [data.graphs[i] for i in train_index]
295 | tg = dgl.batch(train_graphs)
296 | tg = tg.int().to(device)
297 |
298 | val_graphs = [data.graphs[i] for i in val_index]
299 | vg = dgl.batch(val_graphs)
300 | vg = vg.int().to(device)
301 |
302 | ################* STEP 1: CREATE MODEL ################
303 | model = sm.get_model(None, 2, data.get_chunks())
304 | optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay))
305 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=100, min_lr=1e-3, verbose=True, factor=0.01)
306 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
307 | e = datetime.now()
308 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}'
309 | models.append(train_name+'.pt')
310 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=1000)
311 | # writer = SummaryWriter(log_dir=RUNS)
312 | # convert_imgs = transforms.ToTensor()
313 |
314 | ################* STEP 2: TRAINING ################
315 | print("\n### TRAINING ###")
316 | print(f"-> Training samples: {tg.batch_size}")
317 | print(f"-> Validation samples: {vg.batch_size}\n")
318 |
319 | # im_step = 0
320 | for epoch in range(cfg_train.epochs):
321 |
322 | #* TRAINING
323 | model.train()
324 |
325 | scores = model(tg, tg.ndata['feat'].to(device))
326 | loss = compute_crossentropy_loss(scores.to(device), tg.edata['label'].to(device))
327 | auc = compute_auc_mc(scores.to(device), tg.edata['label'].to(device))
328 |
329 | optimizer.zero_grad()
330 | loss.backward()
331 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
332 | optimizer.step()
333 |
334 | #* VALIDATION
335 | model.eval()
336 | with torch.no_grad():
337 | val_scores = model(vg, vg.ndata['feat'].to(device))
338 | val_loss = compute_crossentropy_loss(val_scores.to(device), vg.edata['label'].to(device))
339 | val_auc = compute_auc_mc(val_scores.to(device), vg.edata['label'].to(device))
340 |
341 | # scheduler.step(val_auc)
342 | # scheduler.step()
343 |
344 | #* PRINTING IMAGEs AND RESULTS
345 |
346 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValAUC-PR {:.4f} |"
347 | .format(epoch, loss.item(), auc, val_loss.item(), val_auc))
348 |
349 | if cfg_train.stopper_metric == 'loss':
350 | step_value = val_loss.item()
351 | elif cfg_train.stopper_metric == 'acc':
352 | step_value = val_auc
353 |
354 | ss = stopper.step(step_value)
355 |
356 | # if ss == 'improved':
357 | # im_step = epoch
358 | # train_imgs = []
359 | # for r in rand_tid:
360 | # start, end = 0, 0
361 | # for tid in train_index:
362 | # start = end
363 | # end += data.graphs[tid].num_edges()
364 | # if tid == r: break
365 |
366 | # _, targets = torch.max(F.log_softmax(scores[start:end], dim=1), dim=1)
367 | # kvp_ids = targets.nonzero().flatten().tolist()
368 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700])
369 | # # data.print_graph(num=r, name=f'train_labels_{r}')
370 |
371 | # val_imgs = []
372 | # for r in rand_vid:
373 | # v_start, v_end = 0, 0
374 | # for vid in val_index:
375 | # v_start = v_end
376 | # v_end += data.graphs[vid].num_edges()
377 | # if vid == r: break
378 |
379 | # _, val_targets = torch.max(F.log_softmax(val_scores[v_start:v_end], dim=1), dim=1)
380 | # val_kvp_ids = val_targets.nonzero().flatten().tolist()
381 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700])
382 | # data.print_graph(num=r, name=f'val_labels_{r}')
383 |
384 | if ss == 'stop':
385 | break
386 |
387 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch)
388 | # writer.add_scalars('LOSS', {'train': loss.item(), 'val': val_loss.item()}, epoch)
389 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
390 |
391 | # train_grid = torchvision.utils.make_grid(train_imgs)
392 | # writer.add_image('train_images', train_grid, im_step)
393 | # val_grid = torchvision.utils.make_grid(val_imgs)
394 | # writer.add_image('val_images', val_grid, im_step)
395 |
396 | # print("LOADING: ", train_name+'.pt')
397 | # model.load_state_dict(torch.load(WEIGHTS / (train_name+'.pt')))
398 |
399 | else:
400 | ################* SKIP TRAINING ################
401 | print("\n### SKIP TRAINING ###")
402 | print(f"-> loading {args.weights}")
403 | models = args.weights
404 |
405 |
406 | ################* STEP 3: TESTING ################
407 | print("\n### TESTING ###")
408 |
409 | #? test
410 | test_data = Document2Graph(name='FUNSD TEST', src_path=FUNSD_TEST, device = device, output_dir=TEST_SAMPLES)
411 | test_data.get_info()
412 | model = sm.get_model(None, 2, test_data.get_chunks())
413 | best_model = ''
414 | pair_scores = []
415 | test_graph = dgl.batch(test_data.graphs).to(device)
416 |
417 |
418 | for m in models:
419 | model.load_state_dict(torch.load(CHECKPOINTS / m))
420 | model.eval()
421 | with torch.no_grad():
422 |
423 | scores = model(test_graph, test_graph.ndata['feat'].to(device))
424 | auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device))
425 |
426 | _, preds = torch.max(F.softmax(scores, dim=1), dim=1)
427 |
428 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'])
429 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True)
430 |
431 | pair_scores.append(classes_f1[1])
432 | if classes_f1[1] >= max(pair_scores):
433 | best_model = m
434 |
435 | ################* STEP 4: RESULTS ################
436 | print(f"\n### RESULTS {m} ###")
437 | print("F1 Score: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1]))
438 | print("F1 Classes: None {:.4f} - Pairs {:.4f}".format(classes_f1[0], classes_f1[1]))
439 |
440 | print(f"\nLoading best model {best_model}")
441 | model.load_state_dict(torch.load(CHECKPOINTS / best_model))
442 | model.eval()
443 | with torch.no_grad():
444 |
445 | scores = model(test_graph, test_graph.ndata['feat'].to(device))
446 | auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device))
447 |
448 | _, preds = torch.max(F.softmax(scores, dim=1), dim=1)
449 | test_graph.edata['preds'] = preds
450 |
451 | accuracy, f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'])
452 | _, classes_f1 = get_binary_accuracy_and_f1(preds, test_graph.edata['label'], per_class=True)
453 |
454 | if not args.test:
455 | feat_n, feat_e = get_features(args)
456 | #? if skipping training, no need to save anything
457 | model = get_config(CFGM / args.model)
458 | results = {'MODEL': {
459 | 'name': sm.get_name(),
460 | 'weights': best_model,
461 | 'net-params': sm.get_total_params(),
462 | 'num-layers': model.num_layers,
463 | 'projector-output': model.out_chunks,
464 | 'dropout': model.dropout,
465 | 'lastFC': model.hidden_dim
466 | },
467 | 'FEATURES': {
468 | 'nodes': feat_n,
469 | 'edges': feat_e
470 | },
471 | 'PARAMS': {
472 | 'start-lr': cfg_train.lr,
473 | 'weight-decay': cfg_train.weight_decay,
474 | 'seed': cfg_train.seed
475 | },
476 | 'RESULTS': {
477 | 'val-loss': stopper.best_score,
478 | 'f1-scores': f1,
479 | 'f1-classes': classes_f1,
480 | 'AUC-PR': auc,
481 | 'ACCURACY': accuracy,
482 | 'std-pairs': np.std(pair_scores),
483 | 'mean-pairs': mean(pair_scores)
484 | }}
485 | save_test_results(train_name, results)
486 | print("END TRAINING:", time.time() - start_training)
487 |
488 | return {'best_model': best_model, 'Pairs-F1': {'max': max(pair_scores), 'mean': mean(pair_scores), 'std': np.std(pair_scores)}}
489 |
490 | def train_funsd(args):
491 |
492 | if args.model == 'e2e':
493 | e2e(args)
494 | elif args.model == 'edge':
495 | entity_linking(args)
496 | else:
497 | raise Exception("Model selected does not exists. Choose 'e2e' or 'edge'.")
498 | return
499 |
500 |
--------------------------------------------------------------------------------
/src/training/pau.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from sklearn.model_selection import KFold, ShuffleSplit
3 | import torch
4 | from torch.nn import functional as F
5 | from random import seed
6 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
7 | import dgl
8 | from torch.utils.tensorboard import SummaryWriter
9 | from torchvision import transforms
10 | import time
11 | from statistics import mean
12 | import numpy as np
13 | import xml.etree.ElementTree as ET
14 | from PIL import Image
15 |
16 | from src.data.dataloader import Document2Graph
17 | from src.data.preprocessing import match_pred_w_gt
18 | from src.paths import *
19 | from src.models.graphs import SetModel
20 | from src.utils import get_config
21 | from src.training.utils import *
22 |
23 | def e2e(args):
24 |
25 | # configs
26 | start_training = time.time()
27 | cfg_train = get_config('train')
28 | seed(cfg_train.seed)
29 | device = get_device(args.gpu)
30 | sm = SetModel(name=args.model, device=device)
31 |
32 | if not args.test:
33 | ################* STEP 0: LOAD DATA ################
34 | data = Document2Graph(name='PAU TRAIN', src_path=PAU_TRAIN, device = device, output_dir=TRAIN_SAMPLES)
35 | data.get_info()
36 |
37 | ss = KFold(n_splits=7, shuffle=True, random_state=cfg_train.seed)
38 | cv_indices = ss.split(data.graphs)
39 |
40 | models = []
41 | train_index, val_index = next(ss.split(data.graphs))
42 |
43 | for cvs in cv_indices:
44 |
45 | train_index, val_index = cvs
46 |
47 | # TRAIN
48 | train_graphs = [data.graphs[i] for i in train_index]
49 | tg = dgl.batch(train_graphs)
50 | tg = tg.int().to(device)
51 |
52 | val_graphs = [data.graphs[i] for i in val_index]
53 | vg = dgl.batch(val_graphs)
54 | vg = vg.int().to(device)
55 |
56 | ################* STEP 1: CREATE MODEL ################
57 | model = sm.get_model(data.node_num_classes, data.edge_num_classes, data.get_chunks())
58 | optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg_train.lr), weight_decay=float(cfg_train.weight_decay))
59 | # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=400, min_lr=1e-3, verbose=True, factor=0.01)
60 | # scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
61 | e = datetime.now()
62 | train_name = args.model + f'-{e.strftime("%Y%m%d-%H%M")}'
63 | models.append(train_name+'.pt')
64 | stopper = EarlyStopping(model, name=train_name, metric=cfg_train.stopper_metric, patience=2000)
65 | # writer = SummaryWriter(log_dir=RUNS)
66 | # convert_imgs = transforms.ToTensor()
67 |
68 | ################* STEP 2: TRAINING ################
69 | print("\n### TRAINING ###")
70 | print(f"-> Training samples: {tg.batch_size}")
71 | print(f"-> Validation samples: {vg.batch_size}\n")
72 |
73 | # im_step = 0
74 | for epoch in range(cfg_train.epochs):
75 |
76 | #* TRAINING
77 | model.train()
78 |
79 | n_scores, e_scores = model(tg, tg.ndata['feat'].to(device))
80 | n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device))
81 | e_loss = compute_crossentropy_loss(e_scores.to(device), tg.edata['label'].to(device))
82 | tot_loss = n_loss + e_loss
83 |
84 | macro, micro = get_f1(n_scores, tg.ndata['label'].to(device))
85 | auc = compute_auc_mc(e_scores.to(device), tg.edata['label'].to(device))
86 |
87 | optimizer.zero_grad()
88 | tot_loss.backward()
89 | n = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
90 | optimizer.step()
91 |
92 | #* VALIDATION
93 | model.eval()
94 | with torch.no_grad():
95 | val_n_scores, val_e_scores = model(vg, vg.ndata['feat'].to(device))
96 | val_n_loss = compute_crossentropy_loss(val_n_scores.to(device), vg.ndata['label'].to(device))
97 | val_e_loss = compute_crossentropy_loss(val_e_scores.to(device), vg.edata['label'].to(device))
98 | val_tot_loss = val_n_loss + val_e_loss
99 | val_macro, val_micro = get_f1(val_n_scores, vg.ndata['label'].to(device))
100 | val_auc = compute_auc_mc(val_e_scores.to(device), vg.edata['label'].to(device))
101 |
102 | # scheduler.step(val_auc)
103 | # scheduler.step()
104 |
105 | #* PRINTING IMAGEs AND RESULTS
106 |
107 | print("Epoch {:05d} | TrainLoss {:.4f} | TrainF1-MACRO {:.4f} | TrainAUC-PR {:.4f} | ValLoss {:.4f} | ValF1-MACRO {:.4f} | ValAUC-PR {:.4f} |"
108 | .format(epoch, tot_loss.item(), macro, auc, val_tot_loss.item(), val_macro, val_auc))
109 |
110 | if cfg_train.stopper_metric == 'loss':
111 | step_value = val_tot_loss.item()
112 | elif cfg_train.stopper_metric == 'acc':
113 | step_value = val_micro
114 |
115 | #if val_auc > best_val_auc:
116 | # best_val_auc = val_auc
117 | # best_model = train_name
118 |
119 | ss = stopper.step(step_value)
120 |
121 | # if ss == 'improved':
122 | # im_step = epoch
123 | # train_imgs = []
124 | # for r in rand_tid:
125 | # start, end = 0, 0
126 | # for tid in train_index:
127 | # start = end
128 | # end += data.graphs[tid].num_edges()
129 | # if tid == r: break
130 |
131 | # _, targets = torch.max(F.log_softmax(e_scores[start:end], dim=1), dim=1)
132 | # kvp_ids = targets.nonzero().flatten().tolist()
133 | # train_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=kvp_ids, name=f'train_{r}'))[:, :, :700])
134 | # # data.print_graph(num=r, name=f'train_labels_{r}')
135 |
136 | # val_imgs = []
137 | # for r in rand_vid:
138 | # v_start, v_end = 0, 0
139 | # for vid in val_index:
140 | # v_start = v_end
141 | # v_end += data.graphs[vid].num_edges()
142 | # if vid == r: break
143 |
144 | # _, val_targets = torch.max(F.log_softmax(val_e_scores[v_start:v_end], dim=1), dim=1)
145 | # val_kvp_ids = val_targets.nonzero().flatten().tolist()
146 | # val_imgs.append(convert_imgs(data.print_graph(num=r, labels_ids=val_kvp_ids, name=f'val_{r}'))[:, :, :700])
147 | # data.print_graph(num=r, name=f'val_labels_{r}')
148 |
149 | if ss == 'stop':
150 | break
151 |
152 | # writer.add_scalars('AUC-PR', {'train': auc, 'val': val_auc}, epoch)
153 | # writer.add_scalars('LOSS', {'train': tot_loss.item(), 'val': val_tot_loss.item()}, epoch)
154 | # writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
155 |
156 | # train_grid = torchvision.utils.make_grid(train_imgs)
157 | # writer.add_image('train_images', train_grid, im_step)
158 | # val_grid = torchvision.utils.make_grid(val_imgs)
159 | # writer.add_image('val_images', val_grid, im_step)
160 |
161 | # print("LOADING: ", best_model+'.pt')
162 | # model.load_state_dict(torch.load(WEIGHTS / (best_model+'.pt')))
163 |
164 | else:
165 | ################* SKIP TRAINING ################
166 | print("\n### SKIP TRAINING ###")
167 | print(f"-> loading {args.weights}")
168 | models = args.weights
169 |
170 | ################* STEP 3: TESTING ################
171 | print("\n### TESTING ###")
172 |
173 | #? test
174 | test_data = Document2Graph(name='PAU TEST', src_path=PAU_TEST, device = device, output_dir=TEST_SAMPLES)
175 | test_data.get_info()
176 | model = sm.get_model(test_data.node_num_classes, test_data.edge_num_classes, test_data.get_chunks())
177 | best_model = ''
178 | nodes_micro = []
179 | edges_f1 = []
180 | test_graph = dgl.batch(test_data.graphs).to(device)
181 |
182 | all_precisions = []
183 | all_recalls = []
184 | all_f1 = []
185 |
186 | for i, m in enumerate(models):
187 | model.load_state_dict(torch.load(CHECKPOINTS / m))
188 | model.eval()
189 | with torch.no_grad():
190 |
191 | n, e = model(test_graph, test_graph.ndata['feat'].to(device))
192 | auc = compute_auc_mc(e.to(device), test_graph.edata['label'].to(device))
193 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1)
194 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1)
195 |
196 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'])
197 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True)
198 | edges_f1.append(classes_f1[1])
199 |
200 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device))
201 | nodes_micro.append(micro)
202 | if micro >= max(nodes_micro):
203 | best_model = i
204 |
205 | test_graph.edata['preds'] = epreds
206 | test_graph.ndata['preds'] = npreds
207 | t_f1 = 0
208 | t_precision = 0
209 | t_recall = 0
210 | no_table = 0
211 | tables = 0
212 |
213 | for g, graph in enumerate(dgl.unbatch(test_graph)):
214 | etargets = graph.edata['preds']
215 | ntargets = graph.ndata['preds']
216 | kvp_ids = etargets.nonzero().flatten().tolist()
217 |
218 | table_g = dgl.edge_subgraph(graph, torch.tensor(kvp_ids, dtype=torch.int32).to(device))
219 | table_nodes = table_g.ndata['geom']
220 | try:
221 | table_topleft, _ = torch.min(table_nodes, 0)
222 | table_bottomright, _ = torch.max(table_nodes, 0)
223 | table = torch.cat([table_topleft[:2], table_bottomright[2:]], 0)
224 | except:
225 | table = None
226 |
227 | img_path = test_data.paths[g]
228 | w, h = Image.open(img_path).size
229 | gt_path = img_path.split(".")[0]
230 |
231 | root = ET.parse(gt_path + '_gt.xml').getroot()
232 | regions = []
233 | for parent in root:
234 | if parent.tag.split("}")[1] == 'Page':
235 | for child in parent:
236 | # print(file_gt)
237 | region_label = child[0].attrib['value']
238 | if region_label != 'positions': continue
239 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
240 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
241 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
242 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
243 | regions.append([region_label, region_bbox])
244 |
245 | table_regions = [region[1] for region in regions if region[0]=='positions']
246 | if table is None and len(table_regions) !=0:
247 | t_f1 += 0
248 | t_precision += 0
249 | t_recall += 0
250 | tables += len(table_regions)
251 | elif table is None and len(table_regions) == 0:
252 | no_table -= 1
253 | continue
254 | elif table is not None and len(table_regions) ==0:
255 | t_f1 += 0
256 | t_precision += 0
257 | t_recall += 0
258 | no_table -= 1
259 | else:
260 | table = [[t[0]*w, t[1]*h, t[2]*w, t[3]*h] for t in [table.flatten().tolist()]][0]
261 | # d = match_pred_w_gt(torch.tensor(boxs_preds[idx]), torch.tensor(gt))
262 | d = match_pred_w_gt(torch.tensor(table).view(-1, 4), torch.tensor(table_regions).view(-1, 4), [])
263 | bbox_true_positive = len(d["pred2gt"])
264 | p = bbox_true_positive / (bbox_true_positive + len(d["false_positive"]))
265 | r = bbox_true_positive / (bbox_true_positive + len(d["false_negative"]))
266 | try:
267 | t_f1 += (2 * p * r) / (p + r)
268 | except:
269 | t_f1 += 0
270 | t_precision += p
271 | t_recall += r
272 | tables += len(table_regions)
273 |
274 | test_data.print_graph(num=g, node_labels = None, labels_ids=None, name=f'test_{g}', bidirect=False, regions=regions, preds=table)
275 |
276 | # test_data.print_graph(num=g, name=f'test_labels_{g}')
277 | t_recall = t_recall / (tables + no_table)
278 | t_precision = t_precision / (tables + no_table)
279 | t_f1 = (2 * t_precision * t_recall) / (t_precision + t_recall)
280 | all_precisions.append(t_precision)
281 | all_recalls.append(t_recall)
282 | all_f1.append(t_f1)
283 |
284 | ################* STEP 4: RESULTS ################
285 | print("\n### RESULTS {} ###".format(m))
286 | print("AUC {:.4f}".format(auc))
287 | print("Accuracy {:.4f}".format(accuracy))
288 | print("F1 Edges: Macro {:.4f} - Micro {:.4f}".format(f1[0], f1[1]))
289 | print("F1 Edges: None {:.4f} - Table {:.4f}".format(classes_f1[0], classes_f1[1]))
290 | print("F1 Nodes: Macro {:.4f} - Micro {:.4f}".format(macro, micro))
291 |
292 | print("\nTABLE DETECTION")
293 | print("PRECISION [MAX, MEAN, STD]:", max(all_precisions), mean(all_precisions), np.std(all_precisions))
294 | print("RECALLS [MAX, MEAN, STD]:", max(all_recalls), mean(all_recalls), np.std(all_recalls))
295 | print("F1s [MAX, MEAN, STD]:", max(all_f1), mean(all_f1), np.std(all_f1))
296 |
297 | ################* BEST MODEL STATISTICS ################
298 | print(f"\nLoading best model {models[best_model]}")
299 | print("F1 Edges: Table {:.4f}".format(edges_f1[best_model]))
300 | print("F1 Nodes: Micro {:.4f}".format(nodes_micro[best_model]))
301 |
302 | if not args.test:
303 | feat_n, feat_e = get_features(args)
304 | #? if skipping training, no need to save anything
305 | model = get_config(CFGM / args.model)
306 | results = {'MODEL': {
307 | 'name': sm.get_name(),
308 | 'weights': best_model,
309 | 'net-params': sm.get_total_params(),
310 | 'projector-output': model.out_chunks,
311 | 'dropout': model.dropout,
312 | 'lastFC': model.hidden_dim
313 | },
314 | 'FEATURES': {
315 | 'nodes': feat_n,
316 | 'edges': feat_e
317 | },
318 | 'PARAMS': {
319 | 'start-lr': cfg_train.lr,
320 | 'weight-decay': cfg_train.weight_decay,
321 | 'seed': cfg_train.seed
322 | },
323 | 'RESULTS': {
324 | 'val-loss': stopper.best_score,
325 | 'f1-scores': f1,
326 | 'f1-classes': classes_f1,
327 | 'nodes-f1': [macro, micro],
328 | 'std-pairs': np.std(nodes_micro),
329 | 'mean-pairs': mean(nodes_micro),
330 | 'table-detection-precision': [max(all_precisions), mean(all_precisions), np.std(all_precisions)],
331 | 'table-detection-recall': [max(all_recalls), mean(all_recalls), np.std(all_recalls)],
332 | 'table-detection-f1': [max(all_f1), mean(all_f1), np.std(all_f1)]
333 | }}
334 | save_test_results(train_name, results)
335 | print("END TRAINING:", time.time() - start_training)
336 |
337 | return {'best_model': best_model, 'Nodes-F1': {'max': max(nodes_micro), 'mean': mean(nodes_micro), 'std': np.std(nodes_micro)},
338 | 'Edges-F1': {'max': max(edges_f1), 'mean': mean(edges_f1), 'std': np.std(edges_f1)},
339 | 'Table Detection': {'precision': [max(all_precisions), mean(all_precisions), np.std(all_precisions)], 'recall': [max(all_recalls), mean(all_recalls), np.std(all_recalls)], 'f1': [max(all_f1), mean(all_f1), np.std(all_f1)]}}
340 |
341 | def train_pau(args):
342 |
343 | if args.model == 'e2e':
344 | e2e(args)
345 | else:
346 | raise Exception("Model selected does not exists. Choose 'e2e'.")
347 | return
348 |
349 | """
350 | ### OLD EVALUATIONS ###
351 | model.load_state_dict(torch.load(CHECKPOINTS / best_model))
352 | model.eval()
353 | with torch.no_grad():
354 |
355 | n, e= model(test_graph, test_graph.ndata['feat'].to(device))
356 | # auc = compute_auc_mc(scores.to(device), test_graph.edata['label'].to(device))
357 |
358 | _, epreds = torch.max(F.softmax(e, dim=1), dim=1)
359 | _, npreds = torch.max(F.softmax(n, dim=1), dim=1)
360 |
361 | accuracy, f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'])
362 | _, classes_f1 = get_binary_accuracy_and_f1(epreds, test_graph.edata['label'], per_class=True)
363 | macro, micro = get_f1(n, test_graph.ndata['label'].to(device))
364 |
365 | test_graph.edata['preds'] = epreds
366 | test_graph.ndata['preds'] = npreds
367 | t_f1 = 0
368 | t_precision = 0
369 | t_recall = 0
370 | no_table = 0
371 | tables = 0
372 |
373 | for g, graph in enumerate(dgl.unbatch(test_graph)):
374 | etargets = graph.edata['preds']
375 | ntargets = graph.ndata['preds']
376 | kvp_ids = etargets.nonzero().flatten().tolist()
377 |
378 | table_g = dgl.edge_subgraph(graph, torch.tensor(kvp_ids, dtype=torch.int32).to(device))
379 | table_nodes = table_g.ndata['geom']
380 | try:
381 | table_topleft, _ = torch.min(table_nodes, 0)
382 | table_bottomright, _ = torch.max(table_nodes, 0)
383 | table = torch.cat([table_topleft[:2], table_bottomright[2:]], 0)
384 | except:
385 | table = None
386 |
387 | img_path = test_data.paths[g]
388 | w, h = Image.open(img_path).size
389 | gt_path = img_path.split(".")[0]
390 |
391 | root = ET.parse(gt_path + '_gt.xml').getroot()
392 | regions = []
393 | for parent in root:
394 | if parent.tag.split("}")[1] == 'Page':
395 | for child in parent:
396 | # print(file_gt)
397 | region_label = child[0].attrib['value']
398 | if region_label != 'positions': continue
399 | region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
400 | int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
401 | int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
402 | int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
403 | regions.append([region_label, region_bbox])
404 |
405 | table_regions = [region[1] for region in regions if region[0]=='positions']
406 | if table is None and len(table_regions) !=0:
407 | t_f1 += 0
408 | t_precision += 0
409 | t_recall += 0
410 | tables += len(table_regions)
411 | elif table is None and len(table_regions) == 0:
412 | no_table -= 1
413 | continue
414 | elif table is not None and len(table_regions) ==0:
415 | t_f1 += 0
416 | t_precision += 0
417 | t_recall += 0
418 | no_table -= 1
419 | else:
420 | table = [[t[0]*w, t[1]*h, t[2]*w, t[3]*h] for t in [table.flatten().tolist()]][0]
421 | # d = match_pred_w_gt(torch.tensor(boxs_preds[idx]), torch.tensor(gt))
422 | d = match_pred_w_gt(torch.tensor(table).view(-1, 4), torch.tensor(table_regions).view(-1, 4), [])
423 | bbox_true_positive = len(d["pred2gt"])
424 | p = bbox_true_positive / (bbox_true_positive + len(d["false_positive"]))
425 | r = bbox_true_positive / (bbox_true_positive + len(d["false_negative"]))
426 | try:
427 | t_f1 += (2 * p * r) / (p + r)
428 | except:
429 | t_f1 += 0
430 | t_precision += p
431 | t_recall += r
432 | tables += len(table_regions)
433 |
434 | test_data.print_graph(num=g, node_labels = None, labels_ids=None, name=f'test_{g}', bidirect=False, regions=regions, preds=table)
435 |
436 | t_recall = t_recall / (tables + no_table)
437 | t_precision = t_precision / (tables + no_table)
438 | t_f1 = (2 * t_precision * t_recall) / (t_precision + t_recall)
439 | print(t_precision, t_recall, t_f1)
440 | """
--------------------------------------------------------------------------------
/src/training/utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from cProfile import label
3 | import os
4 | from pickletools import optimize
5 | from typing import Tuple
6 | from itsdangerous import json
7 | import pandas as pd
8 | import sklearn
9 | from sklearn.metrics import average_precision_score, confusion_matrix, f1_score, precision_recall_fscore_support, roc_auc_score, roc_curve
10 | from sklearn.utils import class_weight
11 | import torch
12 | import torch.nn
13 | import torch.nn.functional as F
14 | import dgl
15 | from datetime import datetime
16 | import shutil
17 | import yaml
18 | import numpy as np
19 |
20 | from src.paths import CHECKPOINTS, CONFIGS, OUTPUTS, RESULTS
21 |
22 |
23 | class EarlyStopping:
24 | """Early stop for training purposes, looking at validation loss.
25 | """
26 | def __init__(self, model, name = '', metric = 'loss', patience=50):
27 | """Constructor.
28 |
29 | Args:
30 | model (DGLModel): graph or batch of graphs
31 | name (str): name for weights.
32 | metric (str): set the stopper, following loss ['loss'] or accuracy ['acc'] on validation
33 | patience (int, optional): if validation do not improve after 'patience' iters, it stops training. Defaults to 50.
34 | """
35 | self.patience = patience
36 | self.counter = 0
37 | self.best_score = None
38 | self.early_stop = 'improved'
39 | self.model = model
40 | self.metric = metric
41 | e = datetime.now()
42 | if name == '': self.name = f'{e.strftime("%Y%m%d-%H%M")}'
43 | else: self.name = name
44 |
45 | def step(self, score : float) -> str:
46 | """ It does a step of the stopper. If metric does not encrease after a while, it stops the training.
47 |
48 | Args:
49 | score (float) : metric / value to keep track of.
50 |
51 | Returns
52 | A status used by traning to do things.
53 | """
54 |
55 | if self.best_score is None:
56 | self.best_score = score
57 | self.save_checkpoint()
58 | return 'improved'
59 |
60 | if self.metric == 'loss':
61 | if score > self.best_score:
62 | self.counter += 1
63 | print(f' !- Stop Counter {self.counter} / {self.patience}')
64 | self.early_stop = 'not-improved'
65 | if self.counter >= self.patience:
66 | self.early_stop = 'stop'
67 | else:
68 | print(f' !- Validation LOSS decreased from {self.best_score} -> {score}')
69 | self.best_score = score
70 | self.save_checkpoint()
71 | self.counter = 0
72 | self.early_stop = 'improved'
73 |
74 | elif self.metric == 'acc':
75 | if score <= self.best_score:
76 | self.counter += 1
77 | print(f' !- Stop Counter {self.counter} / {self.patience}')
78 | self.early_stop = 'not-improved'
79 | if self.counter >= self.patience:
80 | self.early_stop = 'stop'
81 | else:
82 | print(f' !- Validation ACCURACY increased from {self.best_score} -> {score}')
83 | self.best_score = score
84 | self.save_checkpoint()
85 | self.counter = 0
86 | self.early_stop = 'improved'
87 |
88 | else:
89 | raise Exception('EarlyStopping Error: metric provided not valid. Select between loss or acc')
90 |
91 | return self.early_stop
92 |
93 | def save_checkpoint(self) -> None:
94 | '''Saves model when validation acc increase.'''
95 | torch.save(self.model.state_dict(), CHECKPOINTS / f'{self.name}.pt')
96 |
97 | def save_best_results(best_params : dict, rm_logs : bool = False) -> None:
98 | """Save best results for cross validation.
99 |
100 | Args:
101 | best_params (dict): best parameters among k-fold cross validation.
102 | rm_logs (bool, optional): Remove tmp weights in output folder if True. Defaults to False.
103 | """
104 | models = OUTPUTS / 'tmp'
105 | output = CHECKPOINTS / best_params['model']
106 | shutil.copyfile(models / best_params['model'], output)
107 |
108 | new_configs = CONFIGS / (best_params['model'].split(".")[0] + '.yaml')
109 | shutil.copyfile(CONFIGS / 'base.yaml', new_configs)
110 |
111 | with open(new_configs) as f:
112 | config = yaml.safe_load(f)
113 |
114 | config['MODEL']['num_layers'] = best_params['num_layers']
115 | config['TRAIN']['batch_size'] = best_params['batch_size']
116 | config['INFO'] = {'split': best_params['split'], 'val_loss': best_params['val_loss'], 'total_params': best_params['total_params']}
117 |
118 | with open(new_configs, 'w') as f:
119 | yaml.dump(config, f)
120 |
121 | if rm_logs and os.path.isdir(models):
122 | shutil.rmtree(models)
123 |
124 | return
125 |
126 | def save_test_results(filename : str, infos : dict) -> None:
127 | """Save test results.
128 |
129 | Args:
130 | filename (str): name of the file to save results of experiments
131 | infos (dict): what to save in the json file about training
132 | """
133 | results = RESULTS / (filename + '.json')
134 |
135 | with open(results, 'w') as f:
136 | json.dump(infos, f)
137 | return
138 |
139 | def get_f1(logits : torch.Tensor, labels : torch.Tensor, per_class = False) -> tuple:
140 | """Returns Macro and Micro F1-score for given logits / labels.
141 |
142 | Args:
143 | logits (torch.Tensor): model prediction logits
144 | labels (torch.Tensor): target labels
145 |
146 | Returns:
147 | tuple: macro-f1 and micro-f1
148 | """
149 | _, indices = torch.max(logits, dim=1)
150 | indices = indices.cpu().detach().numpy()
151 | labels = labels.cpu().detach().numpy()
152 | if not per_class:
153 | return f1_score(labels, indices, average='macro'), f1_score(labels, indices, average='micro')
154 | else:
155 | return precision_recall_fscore_support(labels, indices, average=None)[2].tolist()
156 |
157 | def get_binary_accuracy_and_f1(classes, labels : torch.Tensor, per_class = False) -> Tuple[float, list]:
158 |
159 | correct = torch.sum(classes.flatten() == labels)
160 | accuracy = correct.item() * 1.0 / len(labels)
161 | classes = classes.detach().cpu().numpy()
162 | labels = labels.cpu().numpy()
163 |
164 | if not per_class:
165 | f1 = f1_score(labels, classes, average='macro'), f1_score(labels, classes, average='micro')
166 | else:
167 | f1 = precision_recall_fscore_support(labels, classes, average=None)[2].tolist()
168 |
169 | return accuracy, f1
170 |
171 | def accuracy(logits : torch.Tensor, labels : torch.Tensor) -> float:
172 | """Accuracy of the model.
173 |
174 | Args:
175 | logits (torch.Tensor): model prediction logits
176 | labels (torch.Tensor): target labels
177 |
178 | Returns:
179 | float: accuracy
180 | """
181 | _, indices = torch.max(logits, dim=1)
182 | correct = torch.sum(indices == labels)
183 | return correct.item() * 1.0 / len(labels)
184 |
185 | def get_device(value : int) -> str:
186 | """Either to use cpu or gpu (and which one).
187 | """
188 | if value < 0:
189 | return 'cpu'
190 | else:
191 | return 'cuda:{}'.format(value)
192 |
193 | def get_features(args : ArgumentParser) -> Tuple[str, str]:
194 | """ Return description of the features used in the experiment
195 |
196 | Args:
197 | args (ArgumentParser) : your ArgumentParser
198 | """
199 | feat_n = ''
200 | feat_e = 'false'
201 |
202 | if args.add_geom:
203 | feat_n += 'geom-'
204 | if args.add_embs:
205 | feat_n += 'text-'
206 | if args.add_visual:
207 | feat_n += 'visual-'
208 | if args.add_hist:
209 | feat_n += 'histogram-'
210 | if args.add_eweights:
211 | feat_e = 'true'
212 |
213 | return feat_n, feat_e
214 |
215 | def compute_crossentropy_loss(scores : torch.Tensor, labels : torch.Tensor):
216 | w = class_weight.compute_class_weight(class_weight='balanced', classes= np.unique(labels.cpu().numpy()), y=labels.cpu().numpy())
217 | return torch.nn.CrossEntropyLoss(weight=torch.tensor(w, dtype=torch.float32).to('cuda:0'))(scores, labels)
218 |
219 | def compute_auc_mc(scores, labels):
220 | scores = scores.detach().cpu().numpy()
221 | labels = F.one_hot(labels).cpu().numpy()
222 | # return roc_auc_score(labels, scores)
223 | return average_precision_score(labels, scores)
224 |
225 | def find_optimal_cutoff(target, predicted):
226 | """ Find the optimal probability cutoff point for a classification model related to event rate
227 | Parameters
228 | ----------
229 | target : Matrix with dependent or target data, where rows are observations
230 |
231 | predicted : Matrix with predicted data, where rows are observations
232 |
233 | Returns
234 | -------
235 | list type, with optimal cutoff value
236 |
237 | """
238 | fpr, tpr, threshold = roc_curve(target, predicted)
239 | i = np.arange(len(tpr))
240 | roc = pd.DataFrame({'tf' : pd.Series(tpr-(1-fpr), index=i), 'threshold' : pd.Series(threshold, index=i)})
241 | roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]]
242 |
243 | return list(roc_t['threshold'])
244 |
245 | def generalized_f1_score(y_true, y_pred, match):
246 | # y_true = (y_nodes, y_link)
247 | # y_pred = (y_nodes, y_link)
248 |
249 | # # nodes
250 | # micro_f1_nodes, macro_f1_nodes = 0, 0
251 |
252 | nodes_confusion_mtx = confusion_matrix(y_true=y_true[0][list(match["pred2gt"].keys())], y_pred=y_pred[0][list(match["gt2pred"].values())],
253 | labels=[0, 1, 2, 3], normalize=None)
254 | print(nodes_confusion_mtx)
255 | ntp = [nodes_confusion_mtx[i, i] for i in range(nodes_confusion_mtx.shape[0])]
256 | nfp = [(nodes_confusion_mtx[:i, i].sum() + nodes_confusion_mtx[i+1:, i].sum()) for i in range(nodes_confusion_mtx.shape[0])]
257 | nfn = [(nodes_confusion_mtx[i, :i].sum() + nodes_confusion_mtx[i, i+1:].sum()) for i in range(nodes_confusion_mtx.shape[0])]
258 |
259 | macro_f1_nodes = np.mean([tp / (tp + (0.5 * (fp + len(match["false_positive"]) + fn + len(match["false_negative"])))) for (tp, fp, fn) in zip(ntp, nfp, nfn)])
260 | micro_f1_nodes = np.sum(ntp) / (np.sum(ntp) + (0.5 * (np.sum(nfp) + len(match["false_positive"]) + np.sum(nfn) + len(match["false_negative"]))))
261 |
262 | # links
263 | # micro_f1_links, macro_f1_links = 0, 0
264 |
265 | # links2keep = [idx for idx, link in enumerate(y_true[1]) if (link[0] in match["pred2gt"].values()) and (link[1] in match["pred2gt"].values())]
266 | links_confusion_mtx = confusion_matrix(y_true=y_true[1], y_pred=y_pred[1], labels=[0, 1], normalize=None)
267 |
268 | ltp = [links_confusion_mtx[i, i] for i in range(links_confusion_mtx.shape[0])]
269 | lfp = [(links_confusion_mtx[:i, i].sum() + links_confusion_mtx[i+1:, i].sum()) for i in range(links_confusion_mtx.shape[0])]
270 | lfn = [(links_confusion_mtx[i, :i].sum() + links_confusion_mtx[i, i+1:].sum()) for i in range(links_confusion_mtx.shape[0])]
271 |
272 | # n = len(links2keep) + len(match["false_positive"]) - 1
273 | macro_f1_links = np.mean([tp / (tp + (0.5 * (fp + fn + match["n_link_fn"]))) for (tp, fp, fn) in zip(ltp, lfp, lfn)])
274 | # if match['n_link_fn'] == 0: f1_pairs = None
275 | f1_pairs = [tp / (tp + (0.5 * (fp + fn + match["n_link_fn"])) + 1e-6) for (tp, fp, fn) in zip(ltp, lfp, lfn)][1]
276 | micro_f1_links = np.sum(ltp) / (np.sum(ltp) + (0.5 * (np.sum(lfp) + np.sum(lfn) + match["n_link_fn"])) + 1e-6)
277 |
278 | return ntp, nfp, nfn, ltp, lfp, lfn, {"nodes": {"micro_f1": micro_f1_nodes, "macro_f1": macro_f1_nodes}, "links": {"micro_f1": micro_f1_links, "macro_f1": macro_f1_links}, "pairs_f1": f1_pairs}
279 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | from attrdict import AttrDict
4 | import yaml
5 |
6 | from src.paths import *
7 |
8 | def create_folder(folder_path : str) -> None:
9 | """create a folder if not exists
10 |
11 | Args:
12 | folder_path (str): path
13 | """
14 | if not os.path.exists(folder_path):
15 | os.mkdir(folder_path)
16 |
17 | return
18 |
19 | def get_config(name : str) -> AttrDict:
20 | """get yaml config file
21 |
22 | Args:
23 | name (str): yaml file name without extension
24 |
25 | Returns:
26 | AttrDict: config
27 | """
28 | with open(CONFIGS / f'{name}.yaml') as fileobj:
29 | config = AttrDict(yaml.safe_load(fileobj))
30 | return config
31 |
32 | def project_tree() -> None:
33 | """ Create the project tree folder
34 | """
35 | create_folder(DATA)
36 | create_folder(OUTPUTS)
37 | create_folder(RUNS)
38 | create_folder(RESULTS)
39 | create_folder(TRAIN_SAMPLES)
40 | create_folder(TEST_SAMPLES)
41 | create_folder(CHECKPOINTS)
42 | return
43 |
44 | def set_preprocessing(args: ArgumentParser) -> None:
45 | """ Set preprocessings args
46 |
47 | Args:
48 | args (ArgumentParser):
49 | """
50 | with open(CONFIGS / 'base.yaml') as fileobj:
51 | cfg_preprocessing = dict(yaml.safe_load(fileobj))
52 | cfg_preprocessing['FEATURES']['add_geom'] = args.add_geom
53 | cfg_preprocessing['FEATURES']['add_embs'] = args.add_embs
54 | cfg_preprocessing['FEATURES']['add_hist'] = args.add_hist
55 | cfg_preprocessing['FEATURES']['add_visual'] = args.add_visual
56 | cfg_preprocessing['FEATURES']['add_eweights'] = args.add_eweights
57 | cfg_preprocessing['FEATURES']['num_polar_bins'] = args.num_polar_bins
58 | cfg_preprocessing['LOADER']['src_data'] = args.src_data
59 | cfg_preprocessing['GRAPHS']['data_type'] = args.data_type
60 | cfg_preprocessing['GRAPHS']['edge_type'] = args.edge_type
61 | cfg_preprocessing['GRAPHS']['node_granularity'] = args.node_granularity
62 |
63 | with open(CONFIGS / 'preprocessing.yaml', 'w') as f:
64 | yaml.dump(cfg_preprocessing, f)
65 | return
66 |
--------------------------------------------------------------------------------