├── .gitignore ├── LICENSE ├── README.md ├── examples └── eval.py ├── scitsr ├── __init__.py ├── data │ ├── __init__.py │ ├── loader.py │ ├── rel_gen.py │ └── utils.py ├── eval.py ├── graph.py ├── model.py ├── relation.py ├── table.py └── train.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-present Zewen Chi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SciTSR 2 | 3 | ## Introduction 4 | 5 | SciTSR is a large-scale table structure recognition dataset, which contains 15,000 tables in PDF format and their corresponding structure labels obtained from LaTeX source files. 6 | 7 | **Download link** is [here](https://drive.google.com/file/d/1qXaJblBg9sbPN0xknWsYls1aGGtlp4ZN/view?usp=sharing). 8 | 9 | There are 15,000 examples in total, and we split 12,000 for training and 3,000 for test. We also provide the test set that only contains complicated tables, called SciTSR-COMP. The indices of SciTSR-COMP is stored in `SciTSR-COMP.list`. 10 | 11 | The statistics of SciTSR dataset is following: 12 | 13 | | | Train | Test | 14 | | --------------------------- | -----: | ----: | 15 | | \# Tables | 12,000 | 3,000 | 16 | | \# Complicated tables | 2,885 | 716 | 17 | 18 | ## Format and Example 19 | 20 | The directory tree structure is as follow: 21 | 22 | ``` 23 | SciTSR 24 | ├── SciTSR-COMP.list 25 | ├── test 26 | │   ├── chunk 27 | │   ├── img 28 | │   ├── pdf 29 | │   └── structure 30 | └── train 31 | ├── chunk 32 | ├── img 33 | ├── pdf 34 | ├── rel 35 | └── structure 36 | ``` 37 | 38 | The input PDF files are stored in `pdf`, and the structure labels are stored in the `structure` directory. 39 | 40 | For convenience, we provide the input in image format stored in `img`, which are converted from PDFs by `pdfcairo`. 41 | 42 | We also provide the extracted chunks stored in `chunk`, which are pre-processed by [Tabby](https://github.com/cellsrg/tabbypdf/). 43 | 44 | For training data, we provide the our constructed relation labels for our GraphTSR model, which are generated by matching chunks and the texts of structure labels. 45 | 46 | **Note that our pre-processed chunk and relation data may contain noise. The original input files are in PDF.** 47 | 48 | ### Text Chunks 49 | 50 | File: chunk/[ID].chunk 51 | 52 | The `pos` array contains the `x1`, `x2`, `y1` and `y2` coordinates (in PDF) of the chunk. 53 | 54 | ```json 55 | {"chunks": [ 56 | { 57 | "pos": [ 58 | 147.96600341796875, 59 | 205.49998474121094, 60 | 475.7929992675781, 61 | 480.4206237792969 62 | ], 63 | "text": "Probability" 64 | }, 65 | { 66 | "pos": [ 67 | 217.45510864257812, 68 | 290.6802673339844, 69 | 475.7929992675781, 70 | 480.4206237792969 71 | ], 72 | "text": "Generated Text" 73 | }, 74 | ... 75 | ]} 76 | ``` 77 | 78 | ### Relations 79 | 80 | File rel/[ID].rel 81 | 82 | A line of `CHUNK_ID_1 CHUNK_ID_2 RELATION_ID:NUM_BLANK` represents the relation between CHUNK_ID_1-th chunk and CHUNK_ID_2-th chunk is RELATION_ID, and there are NUM_BLANK blank cells between them. 83 | For RELATION_ID, 1 and 2 represents horizontal and vertical, respectively. 84 | 85 | ``` 86 | 0 1 1:0 87 | 1 2 1:0 88 | 0 9 2:0 89 | ... 90 | ``` 91 | 92 | ### Structure Labels 93 | 94 | File: structure/[ID].json 95 | 96 | A table is stored as a list of cells. For each cell, we provide its original tex code, content (split by space) and position in the table (start/end row/column number, started from 0). 97 | 98 | ```json 99 | {"cells": [ 100 | { 101 | "id": 21, 102 | "tex": "959", 103 | "content": [ 104 | "959" 105 | ], 106 | "start_row": 5, 107 | "end_row": 5, 108 | "start_col": 1, 109 | "end_col": 1 110 | }, 111 | { 112 | "id": 1, 113 | "tex": "Training set", 114 | "content": [ 115 | "Training", 116 | "set" 117 | ], 118 | "start_row": 0, 119 | "end_row": 0, 120 | "start_col": 1, 121 | "end_col": 1 122 | }, 123 | ... 124 | ]} 125 | ``` 126 | 127 | ## Implementation Details 128 | 129 | ### Features 130 | 131 | The codes for vertex and edge features are at `./scitsr/graph.py`. 132 | 133 | You can get vertex features by `Vertex(vid, chunk, tab_h, tab_w).features` and edge features by `Edge(vertex1, vertex2).features`. 134 | 135 | `tab_h` and `tab_w` denotes the height (y-axis) and width (x-axis) of the table. 136 | 137 | See `./scitsr/graph.py` for more details. 138 | 139 | ### Evaluation 140 | 141 | In the evaluation procedure, a table should be converted to a list of horizontally/vertically adjacent relations. Then we make a comparison between ground truth relations and output relations. 142 | 143 | We release the evaluation scripts for comparing horizontally and vertically adjacent relations. In the following example (`./examples/eval.py`), we show how to use the scripts to calculate precision/recall/F1 for an output table. 144 | 145 | 146 | 147 | ```python 148 | with open(json_path) as fp: json_obj = json.load(fp) 149 | # convert the structure labels (a table in json format) to a list of relations 150 | ground_truth_relations = json2Relations(json_obj, splitted_content=True) 151 | # your_relations should be a List of Relation. 152 | # Here we directly use the ground truth relations in the example. 153 | your_relations = ground_truth_relations 154 | precision, recall = eval_relations( 155 | gt=[ground_truth_relations], res=[your_relations], cmp_blank=True) 156 | ``` 157 | 158 | Note: Your output tables should be represented as `List[Relation]`. You can also store a table as a `Table` object and then convert it to `List[Relation]` by using `scitsr.eval.Table2Relations`. 159 | 160 | ## Citation 161 | 162 | Please cite the paper if you found the resources useful. 163 | 164 | ``` 165 | @article{chi2019complicated, 166 | title={Complicated Table Structure Recognition}, 167 | author={Chi, Zewen and Huang, Heyan and Xu, Heng-Da and Yu, Houjin and Yin, Wanxuan and Mao, Xian-Ling}, 168 | journal={arXiv preprint arXiv:1908.04729}, 169 | year={2019} 170 | } 171 | ``` 172 | -------------------------------------------------------------------------------- /examples/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from scitsr.eval import json2Relations, eval_relations 4 | 5 | 6 | def example(): 7 | json_path = "/home/czwin32768/res/SciTSR/SciTSR/test/structure/1010.1982v1.2.json" 8 | with open(json_path) as fp: json_obj = json.load(fp) 9 | ground_truth_relations = json2Relations(json_obj, splitted_content=True) 10 | # your_relations should be a List of Relation. 11 | # Here we directly use the ground truth relations as the results. 12 | your_relations = ground_truth_relations 13 | precision, recall = eval_relations( 14 | gt=[ground_truth_relations], res=[your_relations], cmp_blank=True) 15 | f1 = 2.0 * precision * recall / (precision + recall) if precision + recall > 0 else 0 16 | print("P: %.2f, R: %.2f, F1: %.2f" % (precision, recall, f1)) 17 | 18 | 19 | if __name__ == "__main__": 20 | example() -------------------------------------------------------------------------------- /scitsr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Academic-Hammer/SciTSR/79954b5143295162ceaf7e9d9af918a29fe12f55/scitsr/__init__.py -------------------------------------------------------------------------------- /scitsr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Academic-Hammer/SciTSR/79954b5143295162ceaf7e9d9af918a29fe12f55/scitsr/data/__init__.py -------------------------------------------------------------------------------- /scitsr/data/loader.py: -------------------------------------------------------------------------------- 1 | """Load Real Data for Graph Attention Model 2 | Author: Heng-Da Xu 3 | Date Created: March 22, 2019 4 | Modified by: Heng-Da Xu 5 | Date Modified: March 22, 2019 6 | """ 7 | import json 8 | import os 9 | from typing import List 10 | import random 11 | 12 | from tqdm import tqdm 13 | from pprint import pprint 14 | import torch 15 | from torch.utils.data import Dataset 16 | 17 | from scitsr.data.utils import preprocessing, construct_knn_edges 18 | from scitsr.graph import Edge, Vertex 19 | from scitsr.table import Chunk 20 | from scitsr.eval import Relation 21 | 22 | 23 | class Data: 24 | 25 | def __init__(self, chunks, relations, cells=None, 26 | path=None, nodes=None, edges=None, 27 | adj=None, incidence=None, labels=None): 28 | self.chunks = chunks 29 | self.relations = relations 30 | self.cells = cells 31 | self.path = path 32 | self.nodes = nodes 33 | self.edges = edges 34 | self.adj = adj 35 | self.incidence = incidence 36 | self.labels = labels 37 | 38 | 39 | class TableDataset(Dataset): 40 | 41 | def __init__(self, dataset_dir, with_cells, trim=None, 42 | node_norm=None, edge_norm=None, exts=None): 43 | if exts is None: exts = ['chunk12', 'rel12'] 44 | raw_dataset = self.load_dataset( 45 | dataset_dir, with_cells, trim, exts=exts) 46 | raw_dataset = preprocessing(raw_dataset) 47 | 48 | dataset = [] 49 | for data in tqdm(raw_dataset, desc='TableDataset'): 50 | if len(data.chunks) <= 2 or len(data.relations) <= 2: 51 | continue 52 | data.nodes, data.edges, data.adj, data.incidence, data.labels = \ 53 | self.transform(data.chunks, data.relations) 54 | dataset.append(data) 55 | self.node_norm, self.edge_norm = self.feature_normalizaion( 56 | dataset, node_norm, edge_norm) 57 | self.dataset = dataset 58 | 59 | self.n_node_features = self.dataset[0].nodes.size(1) 60 | self.n_edge_features = self.dataset[0].edges.size(1) 61 | self.output_size = self.dataset[0].labels.max().item() + 1 62 | 63 | def shuffle(self): 64 | random.shuffle(self.dataset) 65 | 66 | def feature_normalizaion(self, dataset, node_param=None, edge_param=None): 67 | 68 | def _get_mean_std(features): 69 | mean = features.mean(dim=0, keepdim=True) 70 | std = features.std(dim=0, keepdim=True) 71 | return mean, std 72 | 73 | def _norm(features, mean, std, eps=1e-6): 74 | return (features - mean) / (std + 1e-6) 75 | 76 | # normalize edge features 77 | if edge_param is None: 78 | all_edge_features = torch.cat([data.edges for data in dataset]) 79 | edge_mean, edge_std = _get_mean_std(all_edge_features) 80 | else: edge_mean, edge_std = edge_param 81 | for data in dataset: 82 | data.edges = _norm(data.edges, edge_mean, edge_std) 83 | 84 | # normalize node features 85 | if node_param is None: 86 | all_node_features = torch.cat([data.nodes for data in dataset]) 87 | node_mean, node_std = _get_mean_std(all_node_features) 88 | else: node_mean, node_std = node_param 89 | for data in dataset: 90 | data.nodes = _norm(data.nodes, node_mean, node_std) 91 | 92 | return (node_mean, node_std), (edge_mean, edge_std) 93 | 94 | 95 | def transform(self, chunks, relations): 96 | vertexes = self.get_vertexes(chunks) 97 | nodes = self.get_vertex_features(vertexes) 98 | adj, incidence = self.get_adjcancy(relations, len(chunks)) 99 | edges = self.get_edges(relations, vertexes) 100 | labels = self.get_labels(relations) 101 | nodes, edges, adj, incidence, labels = \ 102 | self.to_tensors(nodes, edges, adj, incidence, labels) 103 | #nodes, edges = self.normlize(nodes), self.normlize(edges) 104 | return nodes, edges, adj, incidence, labels 105 | 106 | def load_dataset(self, dataset_dir, with_cells, trim=None, debug=False, exts=None): 107 | dataset, cells = [], [] 108 | if exts is None: exts = ['chunk', 'rel'] 109 | if with_cells: 110 | exts.append('json') 111 | sub_paths = self.get_sub_paths(dataset_dir, exts, trim=trim) 112 | for i, paths in enumerate(sub_paths): 113 | if debug and i > 50: 114 | break 115 | chunk_path = paths[0] 116 | relation_path = paths[1] 117 | 118 | chunks = self.load_chunks(chunk_path) 119 | # TODO handle big tables 120 | #if len(chunks) > 100 or len(chunks) == 0: continue 121 | relations = self.load_relations(relation_path) 122 | #new_chunks, new_rels = self.clean_chunk_rel(chunks, relations) 123 | #chunks, relations = new_chunks, new_rels 124 | 125 | if with_cells: 126 | cell_path = paths[2] 127 | with open(cell_path) as f: 128 | cell_json = json.load(f) 129 | else: 130 | cell_json = None 131 | 132 | dataset.append(Data( 133 | chunks=chunks, 134 | relations=relations, 135 | cells=cell_json, 136 | path=chunk_path, 137 | )) 138 | return dataset 139 | 140 | def clean_chunk_rel(self, chunks, relations): 141 | """Remove null chunks""" 142 | new_chunks = [] 143 | oldid2newid = [-1 for i in range(len(chunks))] 144 | for i, c in enumerate(chunks): 145 | if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "": 146 | continue 147 | oldid2newid[i] = len(new_chunks) 148 | new_chunks.append(c) 149 | new_rels = [] 150 | for i, j, t in relations: 151 | ni = oldid2newid[i] 152 | nj = oldid2newid[j] 153 | if ni != -1 and nj != -1: new_rels.append((ni, nj, t)) 154 | return new_chunks, new_rels 155 | 156 | def load_chunks(self, chunk_path): 157 | with open(chunk_path, 'r') as f: 158 | chunks = json.load(f)['chunks'] 159 | # NOTE remove the chunk with 0 len 160 | ret = [] 161 | for chunk in chunks: 162 | if chunk["pos"][1] < chunk["pos"][0]: 163 | chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0] 164 | print("Warning load illegal chunk.") 165 | c = Chunk.load_from_dict(chunk) 166 | #if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "": 167 | # continue 168 | ret.append(c) 169 | return ret 170 | 171 | def load_relations(self, relation_path): 172 | with open(relation_path, 'r') as f: 173 | lines = f.readlines() 174 | relations = [] 175 | for line in lines: 176 | i, j, t = line.split('\t') 177 | i, j, t = int(i), int(j), int(t.split(':')[0]) 178 | relations.append((i, j, t)) 179 | return relations 180 | 181 | def get_sub_paths(self, root_dir: str, sub_names: List[str], trim=None): 182 | # Check the existence of directories 183 | assert os.path.isdir(root_dir) 184 | # TODO: sub_dirs redundancy 185 | sub_dirs = [] 186 | for sub_name in sub_names: 187 | sub_dir = os.path.join(root_dir, sub_name) 188 | assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir 189 | sub_dirs.append(sub_dir) 190 | 191 | paths = [] 192 | d = os.listdir(sub_dirs[0]) 193 | d = d[:trim] if trim else d 194 | for file_name in d: 195 | sub_paths = [os.path.join(sub_dirs[0], file_name)] 196 | name = os.path.splitext(file_name)[0] 197 | for ext in sub_names[1:]: 198 | sub_path = os.path.join(root_dir, ext, name + '.' + ext) 199 | assert os.path.exists(sub_path) 200 | sub_paths.append(sub_path) 201 | paths.append(sub_paths) 202 | 203 | return paths 204 | 205 | def get_vertexes(self, chunks): 206 | coords_x, coords_y = [], [] 207 | for chunk in chunks: 208 | coords_x.append(chunk.x1) 209 | coords_x.append(chunk.x2) 210 | coords_y.append(chunk.y1) 211 | coords_y.append(chunk.y2) 212 | table_width = max(coords_x) - min(coords_x) 213 | table_height = max(coords_y) - min(coords_y) 214 | 215 | vertexes = [] 216 | for index, chunk in enumerate(chunks): 217 | vertex = Vertex(index, chunk, table_width, table_height) 218 | vertexes.append(vertex) 219 | return vertexes 220 | 221 | def get_vertex_features(self, vertexes): 222 | vertex_features = [] 223 | for vertex in vertexes: 224 | features = [v for v in vertex.get_features().values()] 225 | vertex_features.append(features) 226 | return vertex_features 227 | 228 | def get_adjcancy(self, relations, n_vertexes): 229 | n_edges = len(relations) 230 | adj = [[0] * n_vertexes for _ in range(n_vertexes)] 231 | incidence = [[0] * n_edges for _ in range(n_vertexes)] 232 | for idx, (i, j, _) in enumerate(relations): 233 | adj[i][j] = adj[j][i] = 1 234 | incidence[i][idx] = incidence[j][idx] = 1 235 | return adj, incidence 236 | 237 | def get_edges(self, relations, vertexes): 238 | edge_features = [] 239 | for i, j, _ in relations: 240 | edge = Edge(vertexes[i], vertexes[j]) 241 | features = [v for v in edge.get_features().values()] 242 | edge_features.append(features) 243 | return edge_features 244 | 245 | def get_labels(self, relations): 246 | labels = [label for id_a, id_b, label in relations] 247 | return labels 248 | 249 | def to_tensors(self, nodes, edges, adj, incidence, labels): 250 | nodes = torch.tensor(nodes, dtype=torch.float) 251 | edges = torch.tensor(edges, dtype=torch.float) 252 | adj = torch.tensor(adj, dtype=torch.long) 253 | incidence = torch.tensor(incidence, dtype=torch.long) 254 | labels = torch.tensor(labels, dtype=torch.long) 255 | return nodes, edges, adj, incidence, labels 256 | 257 | #TODO normalize over dataset? 258 | def normlize(self, features): 259 | mean = features.mean(dim=0, keepdim=True) 260 | std = features.std(dim=0, keepdim=True) 261 | features = (features - mean) / (std + 1e-6) 262 | return features 263 | 264 | def __len__(self): 265 | return len(self.dataset) 266 | 267 | def __getitem__(self, idx): 268 | return self.dataset[idx] 269 | 270 | 271 | class TableInferDataset(TableDataset): 272 | 273 | def __init__(self, dataset_dir, trim=None, 274 | node_norm=None, edge_norm=None, exts=None): 275 | if exts is None: exts = ['chunk12', 'rel12'] 276 | raw_dataset = self.load_dataset( 277 | dataset_dir, True, trim, exts=exts) 278 | raw_dataset = preprocessing(raw_dataset) 279 | 280 | dataset = [] 281 | for data in tqdm(raw_dataset, desc='TableInferDataset'): 282 | if len(data.chunks) <= 2 or len(data.relations) <= 2: 283 | continue 284 | data.nodes, data.edges, data.adj, data.incidence, data.relations = \ 285 | self.transform(data.chunks) 286 | dataset.append(data) 287 | self.node_norm, self.edge_norm = self.feature_normalizaion( 288 | dataset, node_norm, edge_norm) 289 | self.dataset = dataset 290 | 291 | self.n_node_features = self.dataset[0].nodes.size(1) 292 | self.n_edge_features = self.dataset[0].edges.size(1) 293 | self.output_size = 3 294 | 295 | def transform(self, chunks): 296 | vertexes = self.get_vertexes(chunks) 297 | nodes = self.get_vertex_features(vertexes) 298 | relations = construct_knn_edges(chunks) 299 | if len(relations) <= 2: 300 | return None 301 | adj, incidence = self.get_adjcancy(relations, len(chunks)) 302 | edges = self.get_edges(relations, vertexes) 303 | nodes, edges, adj, incidence = \ 304 | self.to_tensors(nodes, edges, adj, incidence) 305 | #nodes, edges = self.normlize(nodes), self.normlize(edges) 306 | return nodes, edges, adj, incidence, relations 307 | 308 | def to_tensors(self, nodes, edges, adj, incidence): 309 | nodes = torch.tensor(nodes, dtype=torch.float) 310 | edges = torch.tensor(edges, dtype=torch.float) 311 | adj = torch.tensor(adj, dtype=torch.long) 312 | incidence = torch.tensor(incidence, dtype=torch.long) 313 | return nodes, edges, adj, incidence -------------------------------------------------------------------------------- /scitsr/data/rel_gen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import json 9 | 10 | from collections import defaultdict 11 | from tqdm import tqdm 12 | from scitsr.table import Chunk 13 | from scitsr.eval import Table2Relations, normalize 14 | from scitsr.relation import Relation 15 | from scitsr.data import utils 16 | 17 | 18 | def dump_iters_as_tsv(filename, iterables, spliter="\t"): 19 | """ 20 | Dump iters as tsv. 21 | item1\titem2\t... (from iterable1) 22 | item1\titem2\t... (from iterable2) 23 | """ 24 | with open(filename, "w") as f: 25 | for iterable in iterables: 26 | iterable = [str(i) for i in iterable] 27 | f.write(spliter.join(iterable) + "\n") 28 | 29 | 30 | def match(src:dict, trg:dict, src_chunks, trg_chunks, fid): 31 | """Match chunks to latex cells w.r.t. the contents.""" 32 | sid2tid = {} 33 | print("--------%s---------------------------" % fid) 34 | for stxt, sids in src.items(): 35 | if stxt in trg: 36 | tids = trg[stxt] 37 | if len(sids) == 1 and len(tids) == 1: sid2tid[sids[0]] = tids[0] 38 | elif len(sids) == len(tids): 39 | schunks = [(sid, src_chunks[sid]) for sid in sids] 40 | tchunks = [(tid, trg_chunks[tid]) for tid in tids] 41 | sorted_sc = sorted(schunks, key=lambda x: (-x[1].y1, x[1].x1)) 42 | sorted_tc = sorted(tchunks, key=lambda x: (x[1].x1, x[1].y1)) 43 | for (sid, _), (tid, _) in zip(sorted_sc, sorted_tc): 44 | sid2tid[sid] = tid 45 | else: 46 | print("[W] length of sids and tids doesn't match") 47 | else: 48 | print("[W] no match for text %s" % stxt) 49 | print("-----------------------------------------------------------") 50 | return sid2tid 51 | 52 | 53 | def chunks2rel(ds_dir, rel_dir, chunk_ds="chunk", cell_ds="json"): 54 | if os.path.exists(rel_dir): 55 | print("%s exists." % rel_dir) 56 | return 57 | else: os.mkdir(rel_dir) 58 | 59 | skipped = 1 60 | 61 | for fid, (ch_json, cell_json) in tqdm(utils.ds_iter(ds_dir, [chunk_ds, cell_ds])): 62 | 63 | try: 64 | chunks = [Chunk.load_from_dict(cd) for cd in ch_json["chunks"]] 65 | table = utils.json2Table(cell_json, fid, splitted_content=True) 66 | except Exception as e: 67 | print(e) 68 | skipped += 1 69 | continue 70 | 71 | relations = Table2Relations(table) 72 | trg_chunks = table.cells 73 | rel_dict = {(r.from_id, r.to_id):r for r in relations} 74 | src_txt2id, trg_txt2id = defaultdict(list), defaultdict(list) 75 | for i, c in enumerate(chunks): src_txt2id[normalize(c.text)].append(i) 76 | for i, c in enumerate(trg_chunks): trg_txt2id[normalize(c.text)].append(i) 77 | 78 | sid2tid = match(src_txt2id, trg_txt2id, chunks, trg_chunks, fid) 79 | if sid2tid is None: continue 80 | tuples = [] 81 | for i in range(len(chunks)): 82 | if i in sid2tid: ti = sid2tid[i] 83 | else: continue 84 | for j in range(i + 1, len(chunks), 1): 85 | if j in sid2tid: tj = sid2tid[j] 86 | else: continue 87 | if (ti, tj) in rel_dict: tuples.append((i, j, rel_dict[(ti, tj)])) 88 | 89 | dump_iters_as_tsv(os.path.join(rel_dir, fid + ".rel"), tuples) 90 | 91 | 92 | if __name__ == "__main__": 93 | chunks2rel( 94 | ds_dir="/path/to/scitsr", 95 | rel_dir="/path/to/scitsr/rel", 96 | chunk_ds="chunk", 97 | cell_ds="json" 98 | ) 99 | -------------------------------------------------------------------------------- /scitsr/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi, Heng-Da Xu 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import json 9 | import random 10 | from typing import List 11 | 12 | from tqdm import tqdm 13 | 14 | from scitsr.table import Chunk, Table 15 | 16 | 17 | def ds_iter(ds_dir, ds_ls): 18 | ds_dir_ls = [os.path.join(ds_dir, ds) for ds in ds_ls] 19 | ds_dir_ext = [os.path.splitext(os.listdir(d)[0])[1] for d in ds_dir_ls] 20 | for fn in os.listdir(ds_dir_ls[0]): 21 | fid, _ = os.path.splitext(fn) 22 | fid_fn = [os.path.join( 23 | ds_dir_ls[i], fid + ds_dir_ext[i] 24 | ) for i,ds in enumerate(ds_dir_ls)] 25 | ret = [] 26 | try: 27 | for f in fid_fn: 28 | with open(f) as fp: 29 | ret.append(json.load(fp)) 30 | if len(ret) != len(ds_ls): 31 | print("[W] 1 instance skipped") 32 | else: 33 | yield fid, ret 34 | except: 35 | continue 36 | 37 | 38 | def json2Table(json_obj, tid="", splitted_content=False): 39 | """Construct a Table object from json object 40 | Args: 41 | json_obj: a json object 42 | Returns: 43 | a Table object 44 | """ 45 | jo = json_obj["cells"] 46 | row_n, col_n = 0, 0 47 | cells = [] 48 | for co in jo: 49 | content = co["content"] 50 | if content is None: continue 51 | if splitted_content: 52 | content = " ".join(content) 53 | else: 54 | content = content.strip() 55 | if content == "": continue 56 | start_row = co["start_row"] 57 | end_row = co["end_row"] 58 | start_col = co["start_col"] 59 | end_col = co["end_col"] 60 | row_n = max(row_n, end_row) 61 | col_n = max(col_n, end_col) 62 | cell = Chunk(content, (start_row, end_row, start_col, end_col)) 63 | cells.append(cell) 64 | return Table(row_n + 1, col_n + 1, cells, tid) 65 | 66 | def transform_coord(chunks): 67 | # Get table width and height 68 | coords_x, coords_y = [], [] 69 | for chunk in chunks: 70 | coords_x.append(chunk.x1) 71 | coords_x.append(chunk.x2) 72 | coords_y.append(chunk.y1) 73 | coords_y.append(chunk.y2) 74 | # table_width = max(coords_x) - min(coords_x) 75 | # table_height = max(coords_y) - min(coords_y) 76 | 77 | # Coordinate transformation for chunks 78 | table_min_x, table_max_y = min(coords_x), max(coords_y) 79 | chunks_new = [] 80 | for chunk in chunks: 81 | x1 = chunk.x1 - table_min_x 82 | x2 = chunk.x2 - table_min_x 83 | y1 = table_max_y - chunk.y2 84 | y2 = table_max_y - chunk.y1 85 | chunk_new = Chunk( 86 | text=chunk.text, 87 | pos=(x1, x2, y1, y2), 88 | ) 89 | chunks_new.append(chunk_new) 90 | 91 | # return table_width, table_height 92 | return chunks_new 93 | 94 | 95 | def _eul_dis(chunks, i, j): 96 | xi = (chunks[i].x1 + chunks[i].x2) / 2 97 | yi = (chunks[i].y1 + chunks[i].y2) / 2 98 | xj = (chunks[j].x1 + chunks[j].x2) / 2 99 | yj = (chunks[j].y1 + chunks[j].y2) / 2 100 | return (xj - xi)**2 + (yj-yi)**2 101 | 102 | 103 | def construct_knn_edges(chunks, k=20): 104 | relations = [] 105 | edges = set() 106 | for i in range(len(chunks)): 107 | _dis_ij = [] 108 | for j in range(len(chunks)): 109 | if j == i: continue 110 | _dis_ij.append((_eul_dis(chunks, i, j), j)) 111 | sorted_dis_ij = sorted(_dis_ij) 112 | for _, j in sorted_dis_ij[:k]: 113 | _i, _j = (i, j) if i < j else (j, i) 114 | if (_i, _j) not in edges: 115 | edges.add((_i, _j)) 116 | relations.append((_i, _j, 0)) 117 | return relations 118 | 119 | 120 | def add_knn_edges(chunks, relations, k=20, debug=False): 121 | """Add edges according to knn of vertexes. 122 | """ 123 | edges = set() 124 | rel_recall = {} 125 | for i, j, _ in relations: 126 | edges.add((i, j) if i < j else (j, i)) 127 | rel_recall[(i, j) if i < j else (j, i)] = False 128 | for i in range(len(chunks)): 129 | _dis_ij = [] 130 | for j in range(len(chunks)): 131 | if j == i: continue 132 | _dis_ij.append((_eul_dis(chunks, i, j), j)) 133 | sorted_dis_ij = sorted(_dis_ij) 134 | for _, j in sorted_dis_ij[:k]: 135 | _i, _j = (i, j) if i < j else (j, i) 136 | if (_i, _j) in rel_recall: rel_recall[(_i, _j)] = True 137 | if (_i, _j) not in edges: 138 | edges.add((_i, _j)) 139 | relations.append((_i, _j, 0)) 140 | cnt = 0 141 | for _, val in rel_recall.items(): 142 | if val: cnt += 1 143 | recall = 0 if len(rel_recall) == 0 else cnt / len(rel_recall) 144 | if debug: 145 | print("add knn edge. recall:%.3f" % recall) 146 | return relations, recall 147 | 148 | 149 | def add_null_edges(chunks, relations): 150 | n_chunks = len(chunks) 151 | 152 | # Convert relations to adjcancy matrix 153 | adj = [[0] * n_chunks for _ in range(n_chunks)] 154 | for i, j, _ in relations: 155 | adj[i][j] = adj[j][i] = 1 156 | 157 | # Add null edges 158 | for i in range(n_chunks): 159 | x = (chunks[i].x1 + chunks[i].x2) / 2 160 | y = (chunks[i].y1 + chunks[i].y2) / 2 161 | for j in range(i + 1, n_chunks): 162 | if adj[i][j] == 1: 163 | continue 164 | xx = (chunks[j].x1 + chunks[j].x2) / 2 165 | yy = (chunks[j].y1 + chunks[j].y2) / 2 166 | if (xx - x)**2 + (yy - y)**2 > 30**2: 167 | continue 168 | adj[i][j] = adj[j][i] = 1 169 | relations.append((i, j, 0)) 170 | 171 | return relations 172 | 173 | 174 | def add_full_edges(chunks, relations): 175 | n_chunks = len(chunks) 176 | 177 | # Convert relations to adjcancy matrix 178 | adj = [[0] * n_chunks for _ in range(n_chunks)] 179 | for i, j, _ in relations: 180 | adj[i][j] = adj[j][i] = 1 181 | 182 | # Add null edges 183 | for i in range(n_chunks): 184 | x = (chunks[i].x1 + chunks[i].x2) / 2 185 | y = (chunks[i].y1 + chunks[i].y2) / 2 186 | for j in range(i + 1, n_chunks): 187 | if adj[i][j] == 1: 188 | continue 189 | adj[i][j] = adj[j][i] = 1 190 | relations.append((i, j, 0)) 191 | 192 | return relations 193 | 194 | 195 | def preprocessing(dataset, debug=True): 196 | # random.seed(0) 197 | dataset_new = [] 198 | edge_recall_sum = 0 199 | cnt = 0 200 | if debug: recall_path = [] 201 | for data in tqdm(dataset, desc='preprocessing'): 202 | data.chunks = transform_coord(data.chunks) 203 | #data.relations = add_null_edges(data.chunks, data.relations) 204 | data.relations, recall = add_knn_edges(data.chunks, data.relations) 205 | edge_recall_sum += recall 206 | cnt += 1 207 | if debug: recall_path.append((recall, data.path)) 208 | # data.relations = add_full_edges(data.chunks, data.relations) 209 | # random.shuffle(relations) 210 | print("edge recall:%.3f" % (edge_recall_sum / cnt)) 211 | return dataset -------------------------------------------------------------------------------- /scitsr/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import List 9 | 10 | from scitsr.relation import Relation 11 | from scitsr.table import Table, Chunk 12 | 13 | 14 | DIR_HORIZ = 1 15 | DIR_VERT = 2 16 | DIR_SAME_CELL = 3 17 | 18 | 19 | def normalize(s:str, rule=0): 20 | if rule == 0: 21 | s = s.replace("\r", "") 22 | s = s.replace("\n", "") 23 | s = s.replace(" ", "") 24 | s = s.replace("\t", "") 25 | return s.upper() 26 | else: 27 | raise NotImplementedError 28 | 29 | 30 | def eval_relations(gt:List[List], res:List[List], cmp_blank=True): 31 | """Evaluate results 32 | 33 | Args: 34 | gt: a list of list of Relation 35 | res: a list of list of Relation 36 | """ 37 | 38 | #TODO to know how to calculate the total recall and prec 39 | 40 | assert len(gt) == len(res) 41 | tot_prec = 0 42 | tot_recall = 0 43 | total = 0 44 | # print("evaluating result...") 45 | 46 | # for _gt, _res in tqdm(zip(gt, res)): 47 | # for _gt, _res in tqdm(zip(gt, res), total=len(gt), desc='eval'): 48 | idx, t = 0, len(gt) 49 | for _gt, _res in zip(gt, res): 50 | idx += 1 51 | print('Eval %d/%d (%d%%)' % (idx, t, idx / t * 100), ' ' * 45, end='\r') 52 | corr = compare_rel(_gt, _res, cmp_blank) 53 | precision = corr / len(_res) if len(_res) != 0 else 0 54 | recall = corr / len(_gt) if len(_gt) != 0 else 0 55 | tot_prec += precision 56 | tot_recall += recall 57 | total += 1 58 | # print() 59 | 60 | precision = tot_prec / total 61 | recall = tot_recall / total 62 | # print("Test on %d instances. Precision: %.2f, Recall: %.2f" % ( 63 | # total, precision, recall)) 64 | return precision, recall 65 | 66 | def compare_rel(gt_rel:List[Relation], res_rel:List[Relation], cmp_blank=True): 67 | count = 0 68 | 69 | #print("compare_rel =======================") 70 | #for gt in gt_rel: 71 | # print("rel gt:", gt.from_text, gt.to_text, gt.direction) 72 | #for gt in res_rel: 73 | # print("rel res:", gt.from_text, gt.to_text, gt.direction) 74 | #print("\n\n\n\n\n") 75 | 76 | dup_res_rel = [r for r in res_rel] 77 | for gt in gt_rel: 78 | to_rm = None 79 | for i, res in enumerate(dup_res_rel): 80 | if gt.equal(res, cmp_blank): 81 | to_rm = i 82 | count += 1 83 | break 84 | if to_rm is not None: 85 | dup_res_rel = dup_res_rel[:i] + dup_res_rel[i + 1:] 86 | 87 | return count 88 | 89 | def Table2Relations(t:Table): 90 | """Convert a Table object to a List of Relation. 91 | """ 92 | ret = [] 93 | cl = t.coo2cell_id 94 | # remove duplicates with pair set 95 | used = set() 96 | 97 | # look right 98 | for r in range(t.row_n): 99 | for cFrom in range(t.col_n - 1): 100 | cTo = cFrom + 1 101 | loop = True 102 | while loop and cTo < t.col_n: 103 | fid, tid = cl[r][cFrom], cl[r][cTo] 104 | if fid != -1 and tid != -1 and fid != tid: 105 | if (fid, tid) not in used: 106 | ret.append(Relation( 107 | from_text=t.cells[fid].text, 108 | to_text=t.cells[tid].text, 109 | direction=DIR_HORIZ, 110 | from_id=fid, 111 | to_id=tid, 112 | no_blanks=cTo - cFrom - 1 113 | )) 114 | used.add((fid, tid)) 115 | loop = False 116 | else: 117 | if fid != -1 and tid != -1 and fid == tid: 118 | cFrom = cTo 119 | cTo += 1 120 | 121 | # look down 122 | for c in range(t.col_n): 123 | for rFrom in range(t.row_n - 1): 124 | rTo = rFrom + 1 125 | loop = True 126 | while loop and rTo < t.row_n: 127 | fid, tid = cl[rFrom][c], cl[rTo][c] 128 | if fid != -1 and tid != -1 and fid != tid: 129 | if (fid, tid) not in used: 130 | ret.append(Relation( 131 | from_text=t.cells[fid].text, 132 | to_text=t.cells[tid].text, 133 | direction=DIR_VERT, 134 | from_id=fid, 135 | to_id=tid, 136 | no_blanks=rTo - rFrom - 1 137 | )) 138 | used.add((fid, tid)) 139 | loop = False 140 | else: 141 | if fid != -1 and tid != -1 and fid == tid: 142 | rFrom = rTo 143 | rTo += 1 144 | 145 | return ret 146 | 147 | def json2Table(json_obj, tid="", splitted_content=False): 148 | """Construct a Table object from json object 149 | 150 | Args: 151 | json_obj: a json object 152 | Returns: 153 | a Table object 154 | """ 155 | jo = json_obj["cells"] 156 | row_n, col_n = 0, 0 157 | cells = [] 158 | for co in jo: 159 | content = co["content"] 160 | if content is None: continue 161 | if splitted_content: 162 | content = " ".join(content) 163 | else: 164 | content = content.strip() 165 | if content == "": continue 166 | start_row = co["start_row"] 167 | end_row = co["end_row"] 168 | start_col = co["start_col"] 169 | end_col = co["end_col"] 170 | row_n = max(row_n, end_row) 171 | col_n = max(col_n, end_col) 172 | cell = Chunk(content, (start_row, end_row, start_col, end_col)) 173 | cells.append(cell) 174 | return Table(row_n + 1, col_n + 1, cells, tid) 175 | 176 | def json2Relations(json_obj, splitted_content): 177 | return Table2Relations(json2Table(json_obj, "", splitted_content)) 178 | 179 | 180 | -------------------------------------------------------------------------------- /scitsr/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from typing import List 10 | from scitsr.table import Chunk 11 | 12 | 13 | class Vertex(object): 14 | 15 | def __init__(self, vid: int, chunk: Chunk, tab_h, tab_w): 16 | """ 17 | Args: 18 | vid: Vertex id 19 | chunk: the chunk to extract features 20 | tab_h: height of the table (y-axis) 21 | tab_w: width of the table (x-axis) 22 | """ 23 | self.vid = vid 24 | self.tab_h = tab_h 25 | self.tab_w = tab_w 26 | self.chunk = chunk 27 | self.features = self.get_features() 28 | 29 | def get_features(self): 30 | return { 31 | "x1": self.chunk.x1, 32 | "x2": self.chunk.x2, 33 | "y1": self.chunk.y1, 34 | "y2": self.chunk.y2, 35 | "x center": (self.chunk.x1 + self.chunk.x2) / 2, 36 | "y center": (self.chunk.y1 + self.chunk.y2) / 2, 37 | "relative x1": self.chunk.x1 / self.tab_w, 38 | "relative x2": self.chunk.x2 / self.tab_w, 39 | "relative y1": self.chunk.y1 / self.tab_h, 40 | "relative y2": self.chunk.y2 / self.tab_h, 41 | "relative x center": (self.chunk.x1 + self.chunk.x2) / 2 / self.tab_w, 42 | "relative y center": (self.chunk.y2 + self.chunk.y2) / 2 / self.tab_h, 43 | "height of chunk": self.chunk.y2 - self.chunk.y1, 44 | "width of chunk": self.chunk.x2 - self.chunk.x1 45 | } 46 | 47 | 48 | class Edge(object): 49 | 50 | def __init__(self, fr: Vertex, to: Vertex): 51 | self.fr = fr 52 | self.to = to 53 | self.features = self.get_features() 54 | 55 | def get_features(self): 56 | c1, c2 = self.fr.chunk, self.to.chunk 57 | tab_h = self.fr.tab_h 58 | tab_w = self.fr.tab_w 59 | 60 | # distance belong x/y 61 | x_dis, y_dis = 0, 0 62 | # coincide belong x/y 63 | x_cncd, y_cncd = 0, 0 64 | 65 | if c1.x2 <= c2.x1: 66 | x_dis = c2.x1 - c1.x2 67 | elif c2.x2 <= c1.x1: 68 | x_dis = c1.x1 - c2.x2 69 | elif c1.x2 <= c2.x2: 70 | if c1.x1 <= c2.x1: 71 | x_cncd = c1.x2 - c2.x1 72 | else: 73 | x_cncd = c1.x2 - c1.x1 74 | elif c1.x2 > c2.x2: 75 | if c1.x1 <= c2.x1: 76 | x_cncd = c2.x2 - c2.x1 77 | else: 78 | x_cncd = c2.x2 - c1.x1 79 | 80 | if c1.y2 <= c2.y1: 81 | y_dis = c2.y1 - c1.y2 82 | elif c2.y2 <= c1.y1: 83 | y_dis = c1.y1 - c2.y2 84 | elif c1.y2 <= c2.y2: 85 | if c1.y1 <= c2.y1: 86 | y_cncd = c1.y2 - c2.y1 87 | else: 88 | y_cncd = c1.y2 - c1.y1 89 | elif c1.y2 > c2.y2: 90 | if c1.y1 <= c2.y1: 91 | y_cncd = c2.y2 - c2.y1 92 | else: 93 | y_cncd = c2.y2 - c1.y1 94 | 95 | c_h = (c1.y2 - c1.y1 + c2.y2 - c2.y1) / 2 96 | c_w = (c1.x2 - c1.x1 + c2.x2 - c2.x1) / 2 + 1e-7 97 | 98 | c1x = (c1.x1 + c1.x2) / 2 99 | c2x = (c2.x1 + c2.x2) / 2 100 | c1y = (c1.y1 + c1.y2) / 2 101 | c2y = (c2.y1 + c2.y2) / 2 102 | c_x_dis = abs(c2x - c1x) 103 | c_y_dis = abs(c2y - c1y) 104 | 105 | return { 106 | "x distance": x_dis, 107 | "y distance": y_dis, 108 | "relative (table) x distance": x_dis / tab_w, 109 | "relative (table) y distance": y_dis / tab_h, 110 | "relative (chunk) x distance": x_dis / c_w, 111 | "relative (chunk) y distance": y_dis / c_h, 112 | "x coincide": x_cncd, 113 | "y coincide": y_cncd, 114 | "relative (table) x coincide": x_cncd / tab_w, 115 | "relative (table) y coincide": y_cncd / tab_h, 116 | "relative (chunk) x coincide": x_cncd / c_w, 117 | "relative (chunk) y coincide": y_cncd / c_h, 118 | "Euler distance": math.sqrt(x_dis**2 + y_dis**2), 119 | "relative (table) Euler distance": math.sqrt((x_dis / tab_w)**2 + (y_dis / tab_h)**2), 120 | "relative (chunk) Euler distance": math.sqrt((x_dis / c_w)**2 + (y_dis / c_h)**2), 121 | "center x distance": c_x_dis, 122 | "relative (table) center x distance": c_x_dis / tab_w, 123 | "relative (chunk) center x distance": c_x_dis / c_w, 124 | "center y distance": c_y_dis, 125 | "relative (table) center y distance": c_y_dis / tab_h, 126 | "relative (chunk) center y distance": c_y_dis / c_h, 127 | "center Euler distance": math.sqrt(c_x_dis**2 + c_y_dis**2), 128 | "relative (table) center Euler distance": math.sqrt((c_x_dis / tab_w)**2 + (c_y_dis / tab_h)**2), 129 | "relative (chunk) center Euler distance": math.sqrt((c_x_dis / c_w)**2 + (c_y_dis / c_h)**2), 130 | } 131 | 132 | 133 | class Graph(object): 134 | 135 | def __init__(self, E: List[Edge]=None, V: List[Vertex]=None, directed=False): 136 | self.E = E if E is not None else [] 137 | self.V = V if V is not None else [] 138 | self.directed = directed -------------------------------------------------------------------------------- /scitsr/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Heng-Da Xu 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import math 9 | import torch 10 | 11 | 12 | class Attention(torch.nn.Module): 13 | """Attention unit""" 14 | 15 | def __init__(self, size): 16 | super(Attention, self).__init__() 17 | self.size = size 18 | self.linear_q = torch.nn.Linear(size, size, bias=False) 19 | self.linear_k = torch.nn.Linear(size, size, bias=False) 20 | self.linear_v = torch.nn.Linear(size, size, bias=False) 21 | self.layer_norm_1 = torch.nn.LayerNorm(size) 22 | self.feed_forward = torch.nn.Sequential( 23 | torch.nn.Linear(size, size, bias=False), 24 | torch.nn.ReLU(), 25 | torch.nn.Linear(size, size, bias=False), 26 | ) 27 | self.layer_norm_2 = torch.nn.LayerNorm(size) 28 | self.d_rate = 0.4 29 | self.dropout = torch.nn.Dropout(self.d_rate) 30 | 31 | def masked_softmax(self, x, mask, dim): 32 | ex = torch.exp(x) 33 | masked_exp = ex * mask.float() 34 | masked_exp_sum = masked_exp.sum(dim=dim, keepdim=True) 35 | x = masked_exp / (masked_exp_sum + 1e-6) 36 | return x 37 | 38 | def forward(self, x, y, mask): 39 | """ 40 | Shapes: 41 | mask: [nodes/edges, edges/nodes] 42 | q: [nodes/edges, h] 43 | k: [edges/nodes, h] 44 | v: [edges/nodes, h] 45 | score: [nodes/edges, edges/nodes] 46 | x_atten: [nodes/edges, h] 47 | """ 48 | q = self.linear_q(x) 49 | k = self.linear_k(y) 50 | v = self.linear_v(y) 51 | score = torch.mm(q, k.t()) / math.sqrt(self.size) 52 | score = self.masked_softmax(score, mask, dim=1) 53 | x_atten = torch.mm(score, v) 54 | # dropout 55 | x_atten = self.dropout(x_atten) 56 | x = self.layer_norm_1(x + x_atten) 57 | x_linear = self.feed_forward(x) 58 | # dropout 59 | x_linear = self.dropout(x_linear) 60 | x = self.layer_norm_2(x + x_linear) 61 | return x 62 | 63 | 64 | class AttentionBlock(torch.nn.Module): 65 | """Attention Block""" 66 | 67 | def __init__(self, size): 68 | super(AttentionBlock, self).__init__() 69 | self.atten_e2v = Attention(size) 70 | self.atten_v2e = Attention(size) 71 | 72 | def forward(self, nodes, edges, adjacency, incidence): 73 | new_nodes = self.atten_e2v(nodes, edges, incidence) 74 | new_edges = self.atten_v2e(edges, nodes, incidence.t()) 75 | return new_nodes, new_edges 76 | 77 | 78 | class GraphAttention(torch.nn.Module): 79 | """Graph Attention Model""" 80 | 81 | def __init__(self, n_node_features, n_edge_features, \ 82 | hidden_size, output_size, n_blocks): 83 | super(GraphAttention, self).__init__() 84 | self.n_node_features = n_node_features 85 | self.n_edge_features = n_edge_features 86 | self.hidden_size = hidden_size 87 | self.n_blocks = n_blocks 88 | self.node_transform = torch.nn.Sequential( 89 | torch.nn.Linear(n_node_features, hidden_size), 90 | torch.nn.ReLU(), 91 | torch.nn.Linear(hidden_size, hidden_size), 92 | ) 93 | self.edge_transform = torch.nn.Sequential( 94 | torch.nn.Linear(n_edge_features, hidden_size), 95 | torch.nn.ReLU(), 96 | torch.nn.Linear(hidden_size, hidden_size), 97 | ) 98 | self.attention_blocks = torch.nn.ModuleList() 99 | for _ in range(n_blocks): 100 | self.attention_blocks.append(AttentionBlock(hidden_size)) 101 | self.output_linear = torch.nn.Linear(hidden_size, output_size) 102 | 103 | def forward(self, nodes, edges, adjacency, incidence): 104 | nodes = self.node_transform(nodes) 105 | edges = self.edge_transform(edges) 106 | for attention_block in self.attention_blocks: 107 | nodes, edges = attention_block(nodes, edges, adjacency, incidence) 108 | outputs = self.output_linear(edges) 109 | outputs = torch.nn.functional.softmax(outputs, dim=1) 110 | return outputs 111 | -------------------------------------------------------------------------------- /scitsr/relation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | def normalize(s:str, rule=0): 10 | 11 | if rule == 0: 12 | s = s.replace("\r", "") 13 | s = s.replace("\n", "") 14 | s = s.replace(" ", "") 15 | s = s.replace("\t", "") 16 | return s.upper() 17 | else: 18 | raise NotImplementedError 19 | 20 | 21 | class Relation(object): 22 | 23 | def __init__(self, from_text, to_text, direction, from_id=0, to_id=0, no_blanks=0): 24 | self.from_text = from_text 25 | self.to_text = to_text 26 | self.direction = direction 27 | self.no_blanks = no_blanks 28 | self.from_id = from_id 29 | self.to_id = to_id 30 | 31 | def __eq__(self, rl): 32 | this_ft = normalize(self.from_text) 33 | this_tt = normalize(self.to_text) 34 | rl_ft = normalize(rl.from_text) 35 | rl_tt = normalize(rl.to_text) 36 | if len(this_ft) == 0 or len(this_tt) == 0 or \ 37 | len(rl_ft) == 0 or len(rl_tt) == 0: 38 | print("Warning: Text comparison of 0-length strings after normalization", 39 | file=sys.stderr) 40 | 41 | return this_ft == rl_ft and this_tt == rl_tt and \ 42 | self.direction == rl.direction and self.no_blanks == rl.no_blanks 43 | 44 | def equal(self, rl, cmp_blank=True): 45 | this_ft = normalize(self.from_text) 46 | this_tt = normalize(self.to_text) 47 | rl_ft = normalize(rl.from_text) 48 | rl_tt = normalize(rl.to_text) 49 | if len(this_ft) == 0 or len(this_tt) == 0 or \ 50 | len(rl_ft) == 0 or len(rl_tt) == 0: 51 | print("Warning: Text comparison of 0-length strings after normalization", 52 | file=sys.stderr) 53 | 54 | return this_ft == rl_ft and this_tt == rl_tt and \ 55 | self.direction == rl.direction and \ 56 | (self.no_blanks == rl.no_blanks if cmp_blank else True) 57 | 58 | def __str__(self): 59 | return "%d:%d" % (self.direction, self.no_blanks) -------------------------------------------------------------------------------- /scitsr/table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Zewen Chi 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | 9 | from typing import Iterable, List, Tuple 10 | 11 | 12 | def load_chunks(chunk_path): 13 | with open(chunk_path, 'r') as f: 14 | chunks = json.load(f)['chunks'] 15 | # NOTE remove the chunk with 0 len 16 | ret = [] 17 | for chunk in chunks: 18 | if chunk["pos"][1] < chunk["pos"][0]: 19 | chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0] 20 | print("Warning load illegal chunk.") 21 | c = Chunk.load_from_dict(chunk) 22 | #if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "": 23 | # continue 24 | ret.append(c) 25 | return ret 26 | 27 | 28 | class Box(object): 29 | 30 | def __init__(self, pos): 31 | """pos: (x1, x2, y1, y2)""" 32 | self.set_pos(pos) 33 | 34 | def set_pos(self, pos): 35 | assert pos[0] <= pos[1] 36 | assert pos[2] <= pos[3] 37 | self.x1 = pos[0] 38 | self.x2 = pos[1] 39 | self.y1 = pos[2] 40 | self.y2 = pos[3] 41 | self.w = self.x2 - self.x1 42 | self.h = self.y2 - self.y1 43 | self.pos = pos 44 | 45 | def __lt__(self, other): 46 | return self.pos.__lt__(other.pos) 47 | 48 | def __contains__(self, other): 49 | if other.x1 >= self.x1 and other.x2 <= self.x2 and \ 50 | other.y1 >= self.y1 and other.y2 <= self.y2: 51 | return True 52 | return False 53 | 54 | def __str__(self): 55 | return 'Box(%d, %d, %d, %d)' % self.pos 56 | 57 | def __hash__(self): 58 | return self.pos.__hash__() 59 | 60 | 61 | class Chunk(Box): 62 | 63 | def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None): 64 | super(Chunk, self).__init__(pos) 65 | self.text = text 66 | self.size = size 67 | self.cell_id = cell_id 68 | 69 | def __str__(self): 70 | return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos) 71 | 72 | def __repr__(self): 73 | return self.__str__() 74 | 75 | def dump_as_json_obj(self): 76 | return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id} 77 | 78 | @classmethod 79 | def load_from_dict(cls, d): 80 | assert type(d) == dict 81 | assert type(d["text"]) == str 82 | assert len(d["pos"]) == 4 83 | cell_id = d["cell_id"] if "cell_id" in d else None 84 | return cls(d["text"].strip(), d["pos"], cell_id=cell_id) 85 | 86 | 87 | class Table(object): 88 | 89 | """ 90 | The output of table segmentation. 91 | With the Table object, we can get the set of cells 92 | and their corresponding text. 93 | """ 94 | def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""): 95 | # NOTE the Chunk object here represents the coordinate of 96 | # the cell in the table. 97 | # NOTE x in cell object represents the row id 98 | self.tid = tid 99 | self.row_n = row_n 100 | self.col_n = col_n 101 | self.coo2cell_id = [ 102 | [ -1 for _ in range(col_n) ] for _ in range(row_n) ] 103 | self.cells:List[Chunk] = [] 104 | for cell in cells: 105 | self.add_cell(cell) 106 | 107 | def reverse(self, is_col=True): 108 | cells = self.cells 109 | self.cells = [] 110 | cell:Chunk = None 111 | for cell in cells: 112 | if is_col: 113 | _c = Chunk(cell.text, ( 114 | self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2)) 115 | else: 116 | _c = Chunk(cell.text, ( 117 | cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2)) 118 | self.add_cell(_c) 119 | 120 | def add_cell(self, cell:Chunk): 121 | # TODO Check conflicts of cells 122 | assert cell.y2 < self.col_n 123 | assert cell.x2 < self.row_n 124 | 125 | for x in range(cell.x1, cell.x2 + 1, 1): 126 | for y in range(cell.y1, cell.y2 + 1, 1): 127 | self.coo2cell_id[x][y] = len(self.cells) 128 | self.cells.append(cell) 129 | 130 | def __getitem__(self, id_tuple): 131 | row_id, col_id = id_tuple 132 | assert row_id < self.row_n and col_id < self.col_n 133 | return self.cells[self.coo2cell_id[row_id][col_id]] -------------------------------------------------------------------------------- /scitsr/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Graph Attention Model Trainning 3 | Author: Heng-Da Xu 4 | Date Created: March 21, 2019 5 | Modified by: Heng-Da Xu , Zewen 6 | Date Modified: March 23, 2019 7 | """ 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from scitsr.data.loader import TableDataset, TableInferDataset, Data 12 | from scitsr.model import GraphAttention 13 | from scitsr.table import Chunk 14 | 15 | 16 | class Trainer: 17 | """Trainer""" 18 | 19 | def __init__(self, model, train_dataset, test_dataset, infer_dataset, 20 | criterion, optimizer, n_epochs, device, weight_clipping): 21 | self.model = model 22 | self.train_dataset = train_dataset 23 | self.test_dataset = test_dataset 24 | self.infer_dataset = infer_dataset 25 | self.criterion = criterion 26 | self.optimizer = optimizer 27 | self.n_epochs = n_epochs 28 | self.weight_clipping = weight_clipping 29 | self.device = device 30 | self.empty = 2000 31 | self.epoch_info_dict = { 32 | 'loss': None, 33 | 'acc': None, 34 | 't_acc': None, 35 | 'precision': None, 36 | 'recall': None, 37 | } 38 | 39 | def _reset_epcho_info(self): 40 | for k in self.epoch_info_dict: 41 | self.epoch_info_dict[k] = None 42 | 43 | def _print_epoch_info(self, epoch, desc, **keywords): 44 | self.epoch_info_dict.update(keywords) 45 | print('[Epoch %2d] %s' % (epoch, desc), end='') 46 | n_none = 0 47 | for k, v in self.epoch_info_dict.items(): 48 | if self.epoch_info_dict[k] is not None: 49 | print(' | %s: %.3f' % (k, v), end='') 50 | else: 51 | n_none += 1 52 | #print(end='\n' if n_none == 0 else '\r') 53 | print("") 54 | 55 | def train(self): 56 | 57 | print('Start training ...') 58 | for epoch in range(1, self.n_epochs + 1): 59 | self._reset_epcho_info() 60 | 61 | torch.cuda.empty_cache() 62 | loss = self.train_epoch(epoch, self.train_dataset) 63 | self._print_epoch_info(epoch, 'train', loss=loss) 64 | 65 | torch.cuda.empty_cache() 66 | #train_acc = self.test_epoch(epoch, self.train_dataset) 67 | #self._print_epoch_info(epoch, 'train', acc=train_acc) 68 | 69 | test_acc = self.test_epoch(epoch, self.test_dataset) 70 | self._print_epoch_info(epoch, 'test', t_acc=test_acc) 71 | 72 | print('Training finished.') 73 | return self.model 74 | 75 | def train_epoch(self, epoch, dataset, should_print=False): 76 | self.model.train() 77 | loss_list = [] 78 | for index, data in tqdm(enumerate(dataset)): 79 | torch.cuda.empty_cache() 80 | self._to_device(data) 81 | # if index % 10 == 0: 82 | percent = index / len(dataset) * 100 83 | if should_print: 84 | print('[Epoch %d] Train | Data %d (%d%%): loss: | path: %s' % \ 85 | (epoch, index, percent, data.path), ' ' * 20, end='\r') 86 | # try: 87 | outputs = self.model(data.nodes, data.edges, data.adj, data.incidence) 88 | # except Exception as e: 89 | # print(e, data.path) 90 | loss = self.criterion(outputs, data.labels) 91 | loss_list.append(loss.item()) 92 | 93 | if should_print: 94 | print('[Epoch %d] Train | Data %d (%d%%): loss: %.3f | path: %s' % \ 95 | (epoch, index, percent, loss.item(), data.path), ' ' * 20, end='\n') 96 | 97 | self.optimizer.zero_grad() 98 | loss.backward() 99 | if self.weight_clipping is not None: 100 | torch.nn.utils.clip_grad_norm_( 101 | self.model.parameters(), 102 | max_norm=self.weight_clipping 103 | ) 104 | self.optimizer.step() 105 | loss = sum(loss_list) / len(loss_list) 106 | return loss 107 | 108 | def test_epoch(self, epoch, dataset, should_print=False, use_mask=True): 109 | """ 110 | use_mask: mask the 0 label 111 | """ 112 | self.model.eval() 113 | acc_list = [] 114 | for index, data in tqdm(enumerate(dataset)): 115 | self._to_device(data) 116 | percent = index / len(dataset) * 100 117 | if should_print: 118 | print('[Epoch %d] Test | Data %d (%d%%): acc: | path: %s' % \ 119 | (epoch, index, percent, data.path), ' ' * 30, end='\r') 120 | outputs = self.model(data.nodes, data.edges, data.adj, data.incidence) 121 | _lab_len = len(data.labels) 122 | if use_mask: 123 | for i in data.labels: 124 | if i == 0: _lab_len -= 1 125 | _labels = torch.LongTensor( 126 | [(-1 if i == 0 else i) for i in data.labels]).to(self.device) 127 | else: _labels = data.labels 128 | acc = (outputs.max(dim=1)[1] == _labels).float().sum().item() / _lab_len 129 | acc_list.append(acc) 130 | # if index % 10 == 0: 131 | if should_print: 132 | print('[Epoch %d] Test | Data %d (%d%%): acc: %.3f | path: %s' % \ 133 | (epoch, index, percent, acc, data.path), ' ' * 30, end='\n') 134 | acc = sum(acc_list) / len(acc_list) 135 | return acc 136 | 137 | def _to_device(self, data): 138 | data.nodes = data.nodes.to(self.device) 139 | data.edges = data.edges.to(self.device) 140 | data.adj = data.adj.to(self.device) 141 | data.incidence = data.incidence.to(self.device) 142 | if data.labels is not None: 143 | data.labels = data.labels.to(self.device) 144 | 145 | 146 | def patch_chunks(dataset_folder): 147 | """ 148 | To patch the all chunk files of the train & test dataset that have the problem of duplicate last character 149 | of the last cell in all chunk files 150 | :param dataset_folder: train dataset path 151 | :return: 1 152 | """ 153 | import os 154 | import shutil 155 | from pathlib import Path 156 | 157 | shutil.move(os.path.join(dataset_folder, "chunk"), os.path.join(dataset_folder, "chunk-old")) 158 | dir_ = Path(os.path.join(dataset_folder, "chunk-old")) 159 | os.makedirs(os.path.join(dataset_folder, "chunk"), exist_ok=True) 160 | 161 | for chunk_path in dir_.iterdir(): 162 | # print(chunk_path) 163 | with open(str(chunk_path), encoding="utf-8") as f: 164 | chunks = json.load(f)['chunks'] 165 | chunks[-1]['text'] = chunks[-1]['text'][:-1] 166 | 167 | with open(str(chunk_path).replace("chunk-old", "chunk"), "w", encoding="utf-8") as ofile: 168 | json.dump({"chunks": chunks}, ofile) 169 | print("Input files patched, ready for the use") 170 | return 1 171 | 172 | 173 | if __name__ == '__main__': 174 | 175 | train_path = "/path/to/train_folder" 176 | test_path = "/path/to/test_folder/" 177 | patch_chunks(train_path) 178 | patch_chunks(test_path) 179 | 180 | train_dataset = TableDataset( 181 | train_path, with_cells=False, exts=["chunk", "rel"]) 182 | node_norm, edge_norm = train_dataset.node_norm, train_dataset.edge_norm 183 | infer_dataset = test_dataset = TableDataset( 184 | test_path, with_cells=True, node_norm=node_norm, 185 | edge_norm=edge_norm, exts=["chunk", "rel"]) 186 | #device = 'cuda:1' 187 | device = "cpu" 188 | 189 | # Hyper-parameters 190 | n_node_features = train_dataset.n_node_features 191 | n_edge_features = train_dataset.n_edge_features 192 | output_size = train_dataset.output_size 193 | 194 | #hidden_size = 64 195 | hidden_size = 4 196 | n_blocks = 3 197 | n_epochs = 15 198 | learning_rate = 0.0005 199 | weight_clipping = 1 200 | weight_decay = 1e-4 201 | random_seed = 0 202 | model_path = './gat-model.pt' 203 | 204 | # Random seed and device 205 | if random_seed is not None: 206 | torch.manual_seed(random_seed) 207 | if device.startswith('cuda'): 208 | torch.cuda.manual_seed(random_seed) 209 | device = torch.device(device) 210 | 211 | model = GraphAttention( 212 | n_node_features=n_node_features, 213 | n_edge_features=n_edge_features, 214 | hidden_size=hidden_size, 215 | output_size=output_size, 216 | n_blocks=n_blocks, 217 | ) 218 | model.to(device) 219 | criterion = torch.nn.CrossEntropyLoss( 220 | torch.FloatTensor([0.2, 1, 1]).to(device)) 221 | optimizer = torch.optim.Adam( 222 | model.parameters(), 223 | lr=learning_rate, 224 | weight_decay=weight_decay, 225 | ) 226 | trainer = Trainer( 227 | model=model, 228 | train_dataset=train_dataset, 229 | test_dataset=test_dataset, 230 | infer_dataset=infer_dataset, 231 | criterion=criterion, 232 | optimizer=optimizer, 233 | n_epochs=n_epochs, 234 | weight_clipping=weight_clipping, 235 | device=device, 236 | ) 237 | model = trainer.train() 238 | 239 | # Save model 240 | torch.save(model.state_dict(), model_path) 241 | print('The model has been saved to "%s".' % model_path) 242 | 243 | # Save node_norm, edge_norm 244 | node_mean_path = "./tmp_gat_node_mean" 245 | node_std_path = "./tmp_gat_node_std" 246 | edge_mean_path = "./tmp_gat_edge_mean" 247 | edge_std_path = "./tmp_gat_edge_std" 248 | node_mean, node_std = node_norm 249 | edge_mean, edge_std = edge_norm 250 | 251 | torch.save(node_mean, node_mean_path) 252 | torch.save(node_std, node_std_path) 253 | torch.save(edge_mean, edge_mean_path) 254 | torch.save(edge_std, edge_std_path) 255 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="scitsr", 5 | version="0.0.1", 6 | author="Zewen Chi", 7 | author_email="czw@bit.edu.cn", 8 | description="code for paper: complicated table structure recognition", 9 | url="https://github.com/Academic-Hammer/SciTSR", 10 | packages=setuptools.find_packages(), 11 | install_requires=[ 12 | #"ujson", 13 | #"tqdm", 14 | #"torch", 15 | #"torchtext" 16 | ], 17 | classifiers=( 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ) 22 | ) --------------------------------------------------------------------------------