├── .gitignore ├── README.md ├── beam_search.py ├── commons.py ├── configs └── linear_transformer.yaml ├── data_utils.py ├── inference.py ├── losses.py ├── models.py ├── processing_pubtabnet.py ├── resources └── ltiayn.png ├── score.py ├── srcs ├── PubTabNet │ ├── ICDAR_SLR_competition │ │ ├── example.png │ │ └── val_mini.zip │ ├── LICENSE.md │ ├── README.md │ ├── examples │ │ ├── PMC1626454_002_00.png │ │ ├── PMC2753619_002_00.png │ │ ├── PMC2759935_007_01.png │ │ ├── PMC2838834_005_00.png │ │ ├── PMC3519711_003_00.png │ │ ├── PMC3826085_003_00.png │ │ ├── PMC3907710_006_00.png │ │ ├── PMC4003957_018_00.png │ │ ├── PMC4172848_007_00.png │ │ ├── PMC4517499_004_00.png │ │ ├── PMC4682394_003_00.png │ │ ├── PMC4776821_005_00.png │ │ ├── PMC4840965_004_00.png │ │ ├── PMC5134617_013_00.png │ │ ├── PMC5198506_004_00.png │ │ ├── PMC5332562_005_00.png │ │ ├── PMC5402779_004_00.png │ │ ├── PMC5577841_001_00.png │ │ ├── PMC5679144_002_01.png │ │ ├── PMC5897438_004_00.png │ │ ├── PubTabNet_Examples.jsonl │ │ └── utils.py │ ├── exploring_PubTabNet_dataset.ipynb │ └── src │ │ ├── LICENSE │ │ ├── README.md │ │ ├── demo.ipynb │ │ ├── metric.py │ │ ├── parallel.py │ │ ├── requirements.txt │ │ ├── sample_gt.json │ │ └── sample_pred.json └── __init__.py ├── test.ipynb ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | .ipynb_checkpoints 5 | Untitled.ipynb 6 | *.npy 7 | *.pyc 8 | test 9 | train 10 | val 11 | checkpoint/model* 12 | data/*.txt 13 | *.pptx 14 | demo_data 15 | export 16 | postprocessing.png 17 | hard_data 18 | .vscode 19 | outputs 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear Transformer for Table Recognition 2 | 3 | ## Introduction 4 | 5 | This is the code repository for participation in [ICDAR2021 Competition on scientific literature parsing - Task B: Table recognition](https://icdar2021.org/competitions/competition-on-scientific-literature-parsing/) (Team Name: LTIAYN = Kaen Context). 6 | 7 | 8 | 9 |
10 | 11 |
12 | 13 | 14 | 15 | - Dataset: [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet) 16 | - Metric: [Tree-Edit-Distance-based Similarity(TEDS)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 17 | - Baseline: [Image-based table recognition: data, model, and evaluation](https://arxiv.org/abs/1911.10683) 18 | 19 | 20 | ## 0. Before Training 21 | 22 | 1. change the prefined data directory '/data/private/datasets/pubtabnet' to your own data directory in 'processing_pubtabnet.py', 'configs/linear_transformer.yaml' 23 | 2. `python processing_pubtabnet.py` 24 | 25 | 26 | ## 1. Training 27 | 28 | ``` bash 29 | python train.py model_dir=base 30 | ``` 31 | 32 | 33 | ## 2. After Training 34 | 35 | 1. inference 36 | 37 | ```bash 38 | python inference.py -m "./outputs/base/" -i "/data/private/datasets/pubtabnet/val/" -o "./results/val1" -nt 16 -ni 0 -na 20 39 | python inference.py -m "./outputs/base/" -i "/data/private/datasets/pubtabnet/val/" -o "./results/val1" -nt 16 -ni 1 -na 20 40 | ... 41 | python inference.py -m "./outputs/base/" -i "/data/private/datasets/pubtabnet/val/" -o "./results/val1" -nt 16 -ni 15 -na 20 42 | ``` 43 | 44 | 2. evalution 45 | 46 | ```bash 47 | python score.py 48 | ``` 49 | -------------------------------------------------------------------------------- /beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import tqdm 4 | 5 | # Assuming EOS_ID is 2 6 | EOS_ID = 2 7 | # Default value for INF 8 | INF = 1. * 1e7 9 | 10 | 11 | def _merge_beam_dim(tensor): 12 | """Reshapes first two dimensions in to single dimension. 13 | Args: 14 | tensor: Tensor to reshape of shape [A, B, ...] 15 | Returns: 16 | Reshaped tensor of shape [A*B, ...] 17 | """ 18 | shape = list(tensor.shape) 19 | shape[0] *= shape[1] # batch -> batch * beam_size 20 | shape.pop(1) # Remove beam dim 21 | return tensor.reshape(shape) 22 | 23 | 24 | def _unmerge_beam_dim(tensor, batch_size, beam_size): 25 | """Reshapes first dimension back to [batch_size, beam_size]. 26 | Args: 27 | tensor: Tensor to reshape of shape [batch_size*beam_size, ...] 28 | batch_size: Tensor, original batch size. 29 | beam_size: int, original beam size. 30 | Returns: 31 | Reshaped tensor of shape [batch_size, beam_size, ...] 32 | """ 33 | shape = list(tensor.shape) 34 | new_shape = [batch_size] + [beam_size] + shape[1:] 35 | return tensor.reshape(new_shape) 36 | 37 | 38 | def _expand_to_beam_size(tensor, beam_size): 39 | """Tiles a given tensor by beam_size. 40 | Args: 41 | tensor: tensor to tile [batch_size, ...] 42 | beam_size: How much to tile the tensor by. 43 | Returns: 44 | Tiled tensor [batch_size, beam_size, ...] 45 | """ 46 | tensor = tensor.unsqueeze(1) 47 | tile_dims = [1] * len(tensor.shape) 48 | tile_dims[1] = beam_size 49 | return tensor.repeat(tile_dims) 50 | 51 | 52 | def _gather_coordinates(tensor, coordinates): 53 | batch_size, *_ = tensor.shape 54 | beam_size = coordinates.size(0) // batch_size 55 | tensor_flat = _merge_beam_dim(tensor) 56 | tensor_gather = torch.index_select(tensor_flat, 0, coordinates) 57 | tensor = _unmerge_beam_dim(tensor_gather, batch_size, beam_size) 58 | return tensor 59 | 60 | 61 | def compute_batch_indices(batch_size, beam_size): 62 | """Computes the i'th coordinate that contains the batch index for gathers. 63 | Batch pos is a tensor like [[0,0,0,0],[1,1,1,1],..]. It says which 64 | batch the beam item is in. This will create the i of the i,j coordinate 65 | needed for the gather. 66 | Args: 67 | batch_size: Batch size 68 | beam_size: Size of the beam. 69 | Returns: 70 | batch_pos: [batch_size, beam_size] tensor of ids 71 | """ 72 | batch_pos = torch.arange(batch_size * beam_size) // beam_size 73 | batch_pos = batch_pos.reshape([batch_size, beam_size]) 74 | return batch_pos 75 | 76 | 77 | def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, 78 | beam_size, batch_size, 79 | states_to_gather=None): 80 | """Given sequences and scores, will gather the top k=beam size sequences. 81 | This function is used to grow alive, and finished. It takes sequences, 82 | scores, and flags, and returns the top k from sequences, scores_to_gather, 83 | and flags based on the values in scores. 84 | This method permits easy introspection using tfdbg. It adds three named ops 85 | that are prefixed by `prefix`: 86 | - _topk_seq: the tensor for topk_seq returned by this method. 87 | - _topk_flags: the tensor for topk_finished_flags returned by this method. 88 | - _topk_scores: the tensor for tokp_gathered_scores returned by this method. 89 | Args: 90 | sequences: Tensor of sequences that we need to gather from. 91 | [batch_size, beam_size, seq_length] 92 | scores: Tensor of scores for each sequence in sequences. 93 | [batch_size, beam_size]. We will use these to compute the topk. 94 | scores_to_gather: Tensor of scores for each sequence in sequences. 95 | [batch_size, beam_size]. We will return the gathered scores from here. 96 | Scores to gather is different from scores because for grow_alive, we will 97 | need to return log_probs, while for grow_finished, we will need to return 98 | the length penalized scores. 99 | flags: Tensor of bools for sequences that say whether a sequence has reached 100 | EOS or not 101 | beam_size: int 102 | prefix: string that will prefix unique names for the ops run. 103 | states_to_gather: dict (possibly nested) of decoding states. 104 | Returns: 105 | Tuple of 106 | (topk_seq [batch_size, beam_size, decode_length], 107 | topk_gathered_scores [batch_size, beam_size], 108 | topk_finished_flags[batch_size, beam_size]) 109 | """ 110 | # sort scores 111 | _, topk_indexes = torch.topk(scores, k=beam_size) 112 | # The next three steps are to create coordinates for the gather to pull 113 | # out the topk sequences from sequences based on scores. 114 | # batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. It says which 115 | # batch the beam item is in. This will create the i of the i,j coordinate 116 | # needed for the gather 117 | batch_pos = compute_batch_indices(batch_size, beam_size).to(device=scores.device) 118 | 119 | # top coordinates will give us the actual coordinates to do the gather. 120 | # top coordinates is a sequence of dimension batch * beam, each of which 121 | # contains the gathering coordinate. 122 | top_coordinates = (batch_pos * scores.size(1) + topk_indexes).view(-1) 123 | 124 | # Gather up the highest scoring sequences. For each operation added, give it 125 | # a concrete name to simplify observing these operations with tfdbg. Clients 126 | # can capture these tensors by watching these node names. 127 | topk_seq = _gather_coordinates(sequences, top_coordinates) 128 | topk_flags = _gather_coordinates(flags, top_coordinates) 129 | topk_gathered_scores = _gather_coordinates(scores_to_gather, top_coordinates) 130 | if states_to_gather: 131 | for state in states_to_gather: 132 | for k, v in state.items(): 133 | state[k] = _gather_coordinates(v, top_coordinates) 134 | topk_gathered_states = states_to_gather 135 | return topk_seq, topk_gathered_scores, topk_flags, topk_gathered_states 136 | 137 | 138 | def beam_search(symbols_to_logits_fn, 139 | initial_ids, 140 | beam_size, 141 | decode_length, 142 | vocab_size, 143 | alpha, 144 | states=None, 145 | eos_id=EOS_ID, 146 | stop_early=True): 147 | """Beam search with length penalties. 148 | Requires a function that can take the currently decoded symbols and return 149 | the logits for the next symbol. The implementation is inspired by 150 | https://arxiv.org/abs/1609.08144. 151 | When running, the beam search steps can be visualized by using tfdbg to watch 152 | the operations generating the output ids for each beam step. These operations 153 | have the pattern: 154 | (alive|finished)_topk_(seq,scores) 155 | Operations marked `alive` represent the new beam sequences that will be 156 | processed in the next step. Operations marked `finished` represent the 157 | completed beam sequences, which may be padded with 0s if no beams finished. 158 | Operations marked `seq` store the full beam sequence for the time step. 159 | Operations marked `scores` store the sequence's final log scores. 160 | The beam search steps will be processed sequentially in order, so when 161 | capturing observed from these operations, tensors, clients can make 162 | assumptions about which step is being recorded. 163 | WARNING: Assumes 2nd dimension of tensors in `states` and not invariant, this 164 | means that the shape of the 2nd dimension of these tensors will not be 165 | available (i.e. set to None) inside symbols_to_logits_fn. 166 | Args: 167 | symbols_to_logits_fn: Interface to the model, to provide logits. 168 | Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size] 169 | initial_ids: Ids to start off the decoding, this will be the first thing 170 | handed to symbols_to_logits_fn (after expanding to beam size) 171 | [batch_size] 172 | beam_size: Size of the beam. 173 | decode_length: Number of steps to decode for. 174 | vocab_size: Size of the vocab, must equal the size of the logits returned by 175 | symbols_to_logits_fn 176 | alpha: alpha for length penalty. 177 | states: dict (possibly nested) of decoding states. 178 | eos_id: ID for end of sentence. 179 | stop_early: a boolean - stop once best sequence is provably determined. 180 | Returns: 181 | Tuple of 182 | (decoded beams [batch_size, beam_size, decode_length] 183 | decoding probabilities [batch_size, beam_size]) 184 | """ 185 | batch_size = initial_ids.shape[0] 186 | 187 | # Assume initial_ids are prob 1.0 188 | initial_log_probs = torch.Tensor( 189 | [[0.] + [-float("inf")] * (beam_size - 1)] 190 | ).to(device=initial_ids.device) 191 | # Expand to beam_size (batch_size, beam_size) 192 | alive_log_probs = initial_log_probs.repeat([batch_size, 1]) 193 | 194 | # Expand each batch and state to beam_size 195 | alive_seq = _expand_to_beam_size(initial_ids, beam_size) 196 | alive_seq = alive_seq.unsqueeze(2) # (batch_size, beam_size, 1) 197 | if states: 198 | for state in states: 199 | for k, v in state.items(): 200 | state[k] = _expand_to_beam_size(v, beam_size) 201 | else: 202 | states = None 203 | 204 | # Finished will keep track of all the sequences that have finished so far 205 | # Finished log probs will be negative infinity in the beginning 206 | # finished_flags will keep track of booleans 207 | finished_seq = torch.zeros_like(alive_seq) 208 | # Setting the scores of the initial to negative infinity. 209 | finished_scores = torch.ones_like(alive_log_probs) * -INF 210 | finished_flags = torch.zeros_like(alive_log_probs, dtype=torch.bool) 211 | 212 | def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, 213 | curr_scores, curr_finished): 214 | """Given sequences and scores, will gather the top k=beam size sequences. 215 | Args: 216 | finished_seq: Current finished sequences. 217 | [batch_size, beam_size, current_decoded_length] 218 | finished_scores: scores for each of these sequences. 219 | [batch_size, beam_size] 220 | finished_flags: finished bools for each of these sequences. 221 | [batch_size, beam_size] 222 | curr_seq: current topk sequence that has been grown by one position. 223 | [batch_size, beam_size, current_decoded_length] 224 | curr_scores: scores for each of these sequences. [batch_size, beam_size] 225 | curr_finished: Finished flags for each of these sequences. 226 | [batch_size, beam_size] 227 | Returns: 228 | Tuple of 229 | (Topk sequences based on scores, 230 | log probs of these sequences, 231 | Finished flags of these sequences) 232 | """ 233 | # First append a column of 0'ids to finished to make the same length with 234 | # finished scores 235 | finished_seq = torch.cat( 236 | [finished_seq, 237 | torch.zeros([batch_size, beam_size, 1], 238 | dtype=finished_seq.dtype, device=finished_seq.device) 239 | ], 2) 240 | 241 | # Set the scores of the unfinished seq in curr_seq to large negative 242 | # values 243 | curr_scores = curr_scores + (1. - curr_finished.to(dtype=curr_scores.dtype)) * -INF 244 | # concatenating the sequences and scores along beam axis 245 | curr_finished_seq = torch.cat([finished_seq, curr_seq], 1) 246 | curr_finished_scores = torch.cat([finished_scores, curr_scores], 1) 247 | curr_finished_flags = torch.cat([finished_flags, curr_finished], 1) 248 | return compute_topk_scores_and_seq( 249 | curr_finished_seq, curr_finished_scores, curr_finished_scores, 250 | curr_finished_flags, beam_size, batch_size) 251 | 252 | def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): 253 | """Given sequences and scores, will gather the top k=beam size sequences. 254 | Args: 255 | curr_seq: current topk sequence that has been grown by one position. 256 | [batch_size, beam_size, i+1] 257 | curr_scores: scores for each of these sequences. [batch_size, beam_size] 258 | curr_log_probs: log probs for each of these sequences. 259 | [batch_size, beam_size] 260 | curr_finished: Finished flags for each of these sequences. 261 | [batch_size, beam_size] 262 | states: dict (possibly nested) of decoding states. 263 | Returns: 264 | Tuple of 265 | (Topk sequences based on scores, 266 | log probs of these sequences, 267 | Finished flags of these sequences) 268 | """ 269 | # Set the scores of the finished seq in curr_seq to large negative 270 | # values 271 | curr_scores = curr_scores + curr_finished.to(dtype=curr_scores.dtype) * -INF 272 | return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs, 273 | curr_finished, beam_size, batch_size, 274 | states) 275 | 276 | def grow_topk(i, alive_seq, alive_log_probs, states): 277 | r"""Inner beam search loop. 278 | This function takes the current alive sequences, and grows them to topk 279 | sequences where k = 2*beam. We use 2*beam because, we could have beam_size 280 | number of sequences that might hit and there will be no alive 281 | sequences to continue. With 2*beam_size, this will not happen. This relies 282 | on the assumption the vocab size is > beam size. If this is true, we'll 283 | have at least beam_size non extensions if we extract the next top 284 | 2*beam words. 285 | Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to 286 | https://arxiv.org/abs/1609.08144. 287 | Args: 288 | i: loop index 289 | alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1] 290 | alive_log_probs: probabilities of these sequences. [batch_size, beam_size] 291 | states: dict (possibly nested) of decoding states. 292 | Returns: 293 | Tuple of 294 | (Topk sequences extended by the next word, 295 | The log probs of these sequences, 296 | The scores with length penalty of these sequences, 297 | Flags indicating which of these sequences have finished decoding, 298 | dict of transformed decoding states) 299 | """ 300 | # Get the logits for all the possible next symbols 301 | flat_ids = alive_seq.reshape([batch_size * beam_size, -1]) 302 | 303 | # (batch_size * beam_size, decoded_length) 304 | if states: 305 | for state in states: 306 | for k, v in state.items(): 307 | state[k] = _merge_beam_dim(v) 308 | flat_states = states 309 | flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states) 310 | for state in flat_states: 311 | for k, v in state.items(): 312 | state[k] = _unmerge_beam_dim(v, batch_size, beam_size) 313 | states = flat_states 314 | else: 315 | flat_logits = symbols_to_logits_fn(flat_ids) 316 | 317 | logits = flat_logits.reshape([batch_size, beam_size, -1]) 318 | 319 | # Convert logits to normalized log probs 320 | candidate_log_probs = torch.log_softmax(logits, -1) 321 | 322 | # Multiply the probabilities by the current probabilities of the beam. 323 | # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) 324 | log_probs = candidate_log_probs + alive_log_probs.unsqueeze(2) 325 | 326 | length_penalty = math.pow(((5. + i + 1) / 6.), alpha) 327 | 328 | curr_scores = log_probs / length_penalty 329 | # Flatten out (beam_size, vocab_size) probs in to a list of possibilities 330 | flat_curr_scores = curr_scores.reshape([-1, beam_size * vocab_size]) 331 | 332 | topk_scores, topk_ids = torch.topk(flat_curr_scores, k=beam_size * 2) 333 | 334 | # Recovering the log probs because we will need to send them back 335 | topk_log_probs = topk_scores * length_penalty 336 | 337 | # Work out what beam the top probs are in. 338 | topk_beam_index = topk_ids // vocab_size 339 | topk_ids %= vocab_size # Unflatten the ids 340 | 341 | # The next three steps are to create coordinates for the gather to pull 342 | # out the correct sequences from id's that we need to grow. 343 | # We will also use the coordinates to gather the booleans of the beam items 344 | # that survived. 345 | batch_pos = compute_batch_indices(batch_size, beam_size * 2).to(device=alive_seq.device) 346 | 347 | # top coordinates will give us the actual coordinates to do the gather. 348 | # top coordinates is a sequence of dimension batch * beam, each of which 349 | # contains the gathering coordinate. 350 | top_coordinates = (batch_pos * beam_size + topk_beam_index).view(-1) 351 | 352 | # Gather up the most probable 2*beams both for the ids and finished_in_alive 353 | # bools 354 | topk_seq = _gather_coordinates(alive_seq, top_coordinates) 355 | if states: 356 | for state in states: 357 | for k, v in state.items(): 358 | state[k] = _gather_coordinates(v, top_coordinates) 359 | 360 | # Append the most probable alive 361 | topk_seq = torch.cat([topk_seq, topk_ids.unsqueeze(2)], 2) 362 | 363 | topk_finished = torch.eq(topk_ids, eos_id) 364 | 365 | return topk_seq, topk_log_probs, topk_scores, topk_finished, states 366 | 367 | def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, 368 | finished_flags, states): 369 | """Inner beam search loop. 370 | There are three groups of tensors, alive, finished, and topk. 371 | The alive group contains information about the current alive sequences 372 | The topk group contains information about alive + topk current decoded words 373 | the finished group contains information about finished sentences, that is, 374 | the ones that have decoded to . These are what we return. 375 | The general beam search algorithm is as follows: 376 | While we haven't terminated (pls look at termination condition) 377 | 1. Grow the current alive to get beam*2 topk sequences 378 | 2. Among the topk, keep the top beam_size ones that haven't reached EOS 379 | into alive 380 | 3. Among the topk, keep the top beam_size ones have reached EOS into 381 | finished 382 | Repeat 383 | To make things simple with using fixed size tensors, we will end 384 | up inserting unfinished sequences into finished in the beginning. To stop 385 | that we add -ve INF to the score of the unfinished sequence so that when a 386 | true finished sequence does appear, it will have a higher score than all the 387 | unfinished ones. 388 | Args: 389 | i: loop index 390 | alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1] 391 | alive_log_probs: probabilities of the beams. [batch_size, beam_size] 392 | finished_seq: Current finished sequences. 393 | [batch_size, beam_size, i+1] 394 | finished_scores: scores for each of these sequences. 395 | [batch_size, beam_size] 396 | finished_flags: finished bools for each of these sequences. 397 | [batch_size, beam_size] 398 | states: dict (possibly nested) of decoding states. 399 | Returns: 400 | Tuple of 401 | (Incremented loop index 402 | New alive sequences, 403 | Log probs of the alive sequences, 404 | New finished sequences, 405 | Scores of the new finished sequences, 406 | Flags indicating which sequence in finished as reached EOS, 407 | dict of final decoding states) 408 | """ 409 | 410 | # Each inner loop, we carry out three steps: 411 | # 1. Get the current topk items. 412 | # 2. Extract the ones that have finished and haven't finished 413 | # 3. Recompute the contents of finished based on scores. 414 | topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( 415 | i, alive_seq, alive_log_probs, states) 416 | alive_seq, alive_log_probs, _, states = grow_alive( 417 | topk_seq, topk_scores, topk_log_probs, topk_finished, states) 418 | finished_seq, finished_scores, finished_flags, _ = grow_finished( 419 | finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, 420 | topk_finished) 421 | 422 | return (alive_seq, alive_log_probs, finished_seq, finished_scores, 423 | finished_flags, states) 424 | 425 | def _is_finished(alive_log_probs, finished_scores, finished_in_finished): 426 | """Checking termination condition. 427 | We terminate when we decoded up to decode_length or the lowest scoring item 428 | in finished has a greater score that the highest prob item in alive divided 429 | by the max length penalty 430 | Args: 431 | alive_log_probs: probabilities of the beams. [batch_size, beam_size] 432 | finished_scores: scores for each of these sequences. 433 | [batch_size, beam_size] 434 | finished_in_finished: finished bools for each of these sequences. 435 | [batch_size, beam_size] 436 | Returns: 437 | Bool. 438 | """ 439 | max_length_penalty = math.pow(((5. + decode_length) / 6.), alpha) 440 | # The best possible score of the most likely alive sequence. 441 | upper_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty 442 | 443 | # Now to compute the lowest score of a finished sequence in finished 444 | # If the sequence isn't finished, we multiply it's score by 0. since 445 | # scores are all -ve, taking the min will give us the score of the lowest 446 | # finished item. 447 | lowest_score_of_finished_in_finished = torch.min( 448 | finished_scores * finished_in_finished.to(dtype=finished_scores.dtype), 1)[0] 449 | # If none of the sequences have finished, then the min will be 0 and 450 | # we have to replace it by -ve INF if it is. The score of any seq in alive 451 | # will be much higher than -ve INF and the termination condition will not 452 | # be met. 453 | lowest_score_of_finished_in_finished += ( 454 | (1. - torch.any( 455 | finished_in_finished, 456 | 1).to(dtype=lowest_score_of_finished_in_finished.dtype)) * -INF) 457 | 458 | bound_is_met = torch.all( 459 | lowest_score_of_finished_in_finished > upper_bound_alive_scores) 460 | 461 | return bound_is_met 462 | 463 | for i in tqdm.tqdm(range(decode_length)): 464 | (alive_seq, alive_log_probs, finished_seq, finished_scores, 465 | finished_flags, states) = inner_loop(i, alive_seq, alive_log_probs, 466 | finished_seq, finished_scores, finished_flags, states) 467 | if stop_early and _is_finished(alive_log_probs, finished_scores, finished_flags): 468 | break 469 | 470 | # Accounting for corner case: It's possible that no sequence in alive for a 471 | # particular batch item ever reached EOS. In that case, we should just copy 472 | # the contents of alive for that batch item. torch.any(finished_flags, 1) 473 | # if 0, means that no sequence for that batch index had reached EOS. We need 474 | # to do the same for the scores as well. 475 | finished_seq = torch.where( 476 | torch.any(finished_flags, 1).view(batch_size, *([1] * (finished_seq.dim() - 1))), 477 | finished_seq, 478 | alive_seq) 479 | finished_scores = torch.where( 480 | torch.any(finished_flags, 1).view(batch_size, *([1] * (finished_seq.dim() - 1))), 481 | finished_scores, 482 | alive_log_probs) 483 | return finished_seq, finished_scores -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def grad_norm(parameters, norm_type=2): 9 | if isinstance(parameters, torch.Tensor): 10 | parameters = [parameters] 11 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 12 | norm_type = float(norm_type) 13 | total_norm = 0 14 | for p in parameters: 15 | param_norm = p.grad.data.norm(norm_type) 16 | total_norm += param_norm.item() ** norm_type 17 | total_norm = total_norm ** (1. / norm_type) 18 | return total_norm 19 | -------------------------------------------------------------------------------- /configs/linear_transformer.yaml: -------------------------------------------------------------------------------- 1 | model_dir: .hydra 2 | hydra: 3 | run: 4 | dir: outputs/${model_dir} 5 | train: 6 | seed: 123 7 | fp16_run: false 8 | log_interval: 1000 9 | eval_interval: 1000 10 | num_tokens: 32768 # 16384 11 | epochs: 10000 12 | learning_rate: 1e-4 13 | betas: [0.9, 0.98] 14 | eps: 1e-9 15 | lamb: 0 16 | 17 | data: 18 | vocab_path: /data/private/datasets/pubtabnet/annotations/vocab.txt 19 | training_file_path: /data/private/datasets/pubtabnet/annotations/train.json 20 | validation_file_path: /data/private/datasets/pubtabnet/annotations/val.json 21 | patch_length: 8 22 | 23 | model: 24 | hidden_channels: 512 # 1024 25 | filter_channels: 2048 # 4096 26 | n_heads: 8 # 16 27 | n_layers: 12 28 | p_dropout: 0.1 # 0.3 29 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import json 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | 12 | class ImageTextLoader(torch.utils.data.Dataset): 13 | """ 14 | Load image, text pairs 15 | """ 16 | def __init__(self, file_path, cfg): 17 | self.cfg = cfg 18 | self.image_paths, self.texts, self.lengths, \ 19 | self.image_heights, self.image_widths, self.text_lengths = self._build(file_path) 20 | self.vocab = self._load_vocab() 21 | 22 | def _build(self, file_path): 23 | with open(file_path, "r") as f: 24 | data = json.load(f) 25 | 26 | image_paths = [] 27 | texts = [] 28 | lengths = [] 29 | image_heights = [] 30 | image_widths = [] 31 | text_lengths = [] 32 | for elm in data: 33 | image_paths.append(elm['image_path']) 34 | texts.append(elm['text']) 35 | w, h = [math.ceil(x / self.cfg.patch_length) for x in elm['image_size']] 36 | t = elm['num_tokens'] 37 | 38 | lengths.append(h * w + t) 39 | image_heights.append(h) 40 | image_widths.append(w) 41 | text_lengths.append(t) 42 | return image_paths, texts, lengths, image_heights, image_widths, text_lengths 43 | 44 | def _load_vocab(self): 45 | with open(self.cfg.vocab_path) as f: 46 | words = [x.replace('\n', '') for x in f.readlines()] 47 | vocab = {word: idx for idx, word in enumerate(words)} 48 | return vocab 49 | 50 | def get_items(self, index): 51 | patch_length = self.cfg.patch_length 52 | h, w = self.image_heights[index], self.image_widths[index] 53 | c = 3 54 | 55 | image = Image.open(self.image_paths[index]).convert('RGB') 56 | image = (np.asarray(image, dtype=np.float32) / 255) * 2 - 1 57 | image = torch.from_numpy(image) 58 | image = torch.nn.functional.pad(image, [ 59 | 0, 0, 60 | 0, (patch_length - (image.shape[1] % patch_length)) % patch_length, 61 | 0, (patch_length - (image.shape[0] % patch_length)) % patch_length 62 | ]) 63 | image = image.view([h, patch_length, w, patch_length, c]) 64 | image = image.permute(0, 2, 4, 1, 3) 65 | image = image.reshape(h * w, c * (patch_length ** 2)) 66 | 67 | text = torch.LongTensor([self.vocab[w] for w in self.texts[index]]) 68 | length = self.lengths[index] 69 | image_height = h 70 | image_width = w 71 | text_length = self.text_lengths[index] 72 | return (image, text, length, image_height, image_width, text_length) 73 | 74 | def __getitem__(self, index): 75 | return self.get_items(index) 76 | 77 | def __len__(self): 78 | return len(self.image_paths) 79 | 80 | 81 | class ImageTextCollate(): 82 | """ Zero-pads model inputs 83 | """ 84 | def __call__(self, batch): 85 | """Collate's training batch from image and text info 86 | Inputs: 87 | - batch: [img, txt, t_tot, h_img, w_img, t_txt] 88 | 89 | Outputs: 90 | - (img_padded, txt_padded, mask_img, mask_txt, pos_r, pos_c, pos_t) 91 | """ 92 | max_len = max(x[2] for x in batch) 93 | b = len(batch) 94 | c = batch[0][0].size(1) # image patch size 95 | 96 | img_padded = torch.FloatTensor(b, max_len, c) 97 | txt_padded = torch.LongTensor(b, max_len) 98 | mask_img = torch.FloatTensor(b, max_len, 1) 99 | mask_txt = torch.FloatTensor(b, max_len, 1) 100 | pos_r = torch.FloatTensor(b, max_len-1, 1) # for teacher forcing 101 | pos_c = torch.FloatTensor(b, max_len-1, 1) # for teacher forcing 102 | pos_t = torch.FloatTensor(b, max_len-1, 1) # for teacher forcing 103 | 104 | img_padded.zero_() 105 | txt_padded.zero_() 106 | mask_img.zero_() 107 | mask_txt.zero_() 108 | pos_r.zero_() 109 | pos_c.zero_() 110 | pos_t.zero_() 111 | for i in range(b): 112 | img, txt, t_tot, h_img, w_img, t_txt = batch[i] 113 | t_img = img.size(0) 114 | 115 | img_padded[i, :t_img] = img 116 | txt_padded[i, t_img:t_tot] = txt 117 | mask_img[i, :t_img] = 1 118 | mask_txt[i, t_img:t_tot] = 1 119 | pos_r[i, :t_img] = torch.arange(h_img).unsqueeze(-1).repeat(1, w_img).view(-1, 1) 120 | pos_c[i, :t_img] = torch.arange(w_img).repeat(h_img).view(-1, 1) 121 | pos_t[i, t_img:t_tot-1] = torch.arange(t_txt-1, dtype=torch.float).view(-1, 1) 122 | return img_padded, txt_padded, mask_img, mask_txt, pos_r, pos_c, pos_t 123 | 124 | 125 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 126 | """ 127 | Maintain similar total token sizes in a batch. 128 | 1) choose the minimum highly composite number among which is larger than given num_tokens. 129 | 2) automatically set bucket boundaries and batch_sizes s.t. boundary * batch_size = the highly composite number. 130 | 3) merge buckets that contain smaller number of elements than batch_sizes 131 | """ 132 | def __init__(self, dataset, num_tokens=2**16, num_replicas=None, rank=None, shuffle=True): 133 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 134 | highly_composite_numbers = [ 135 | 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, 136 | 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440, 137 | 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, 138 | 720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480, 139 | 7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400, 140 | 36756720, 43243200, 61261200, 73513440, 110270160 141 | ] 142 | 143 | self.lengths = dataset.lengths 144 | self.num_tokens = min([i for i in highly_composite_numbers if i >= num_tokens]) 145 | print("%s: num_tokens is changed from %d to %d." % (self.__class__.__name__, num_tokens, self.num_tokens)) 146 | 147 | self.buckets, self.num_samples_per_bucket, self.batch_sizes = self._create_buckets() 148 | self.total_size = sum(self.num_samples_per_bucket) 149 | self.num_samples = self.total_size // self.num_replicas 150 | self.num_batches = sum([self.num_samples_per_bucket[i] // (self.batch_sizes[i] * self.num_replicas) for i in range(len(self.batch_sizes))]) 151 | 152 | def _create_buckets(self): 153 | boundaries, batch_sizes = [], [] 154 | for i in range(1, self.num_tokens + 1): 155 | q, r = divmod(self.num_tokens, i) 156 | if r == 0: 157 | boundaries.append(i) 158 | if i != 1: 159 | batch_sizes.append(q) 160 | buckets = [[] for _ in range(len(boundaries) - 1)] 161 | for i in range(len(self.lengths)): 162 | length = self.lengths[i] 163 | idx_bucket = self._bisect(length, boundaries) 164 | if idx_bucket != -1: 165 | buckets[idx_bucket].append(i) 166 | for i in range(len(buckets) - 1, 0, -1): 167 | if len(buckets[i]) == 0: 168 | buckets.pop(i) 169 | batch_sizes.pop(i) 170 | 171 | buckets_new = [] 172 | batch_sizes_new = [] 173 | bucket = [] 174 | for i in range(len(buckets) - 1): 175 | bucket += buckets[i] 176 | if len(bucket) >= batch_sizes[i] * self.num_replicas: 177 | buckets_new.append(bucket) 178 | bucket = [] 179 | batch_sizes_new.append(batch_sizes[i]) 180 | buckets_new.append(bucket + buckets[-1]) 181 | batch_sizes_new.append(batch_sizes[-1]) 182 | buckets = buckets_new 183 | batch_sizes = batch_sizes_new 184 | 185 | num_samples_per_bucket = [] 186 | for i in range(len(buckets)): 187 | len_bucket = len(buckets[i]) 188 | total_batch_size = self.num_replicas * batch_sizes[i] 189 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 190 | num_samples_per_bucket.append(len_bucket + rem) 191 | return buckets, num_samples_per_bucket, batch_sizes 192 | 193 | def __iter__(self): 194 | # deterministically shuffle based on epoch 195 | g = torch.Generator() 196 | g.manual_seed(self.epoch) 197 | 198 | indices = [] 199 | if self.shuffle: 200 | for bucket in self.buckets: 201 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 202 | else: 203 | for bucket in self.buckets: 204 | indices.append(list(range(len(bucket)))) 205 | 206 | batches = [] 207 | for i in range(len(self.buckets)): 208 | batch_size = self.batch_sizes[i] 209 | bucket = self.buckets[i] 210 | len_bucket = len(bucket) 211 | ids_bucket = indices[i] 212 | num_samples_bucket = self.num_samples_per_bucket[i] 213 | 214 | # add extra samples to make it evenly divisible 215 | rem = num_samples_bucket - len_bucket 216 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 217 | 218 | # subsample 219 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 220 | 221 | # batching 222 | for j in range(len(ids_bucket) // batch_size): 223 | batch = [bucket[idx] for idx in ids_bucket[j*batch_size:(j+1)*batch_size]] 224 | batches.append(batch) 225 | 226 | if self.shuffle: 227 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 228 | batches = [batches[i] for i in batch_ids] 229 | self.batches = batches 230 | 231 | assert sum([len(x) for x in self.batches]) == self.num_samples 232 | assert len(self.batches) == self.num_batches 233 | return iter(self.batches) 234 | 235 | def _bisect(self, x, boundaries, lo=0, hi=None): 236 | if hi is None: 237 | hi = len(boundaries) - 1 238 | 239 | if hi > lo: 240 | mid = (hi + lo) // 2 241 | if boundaries[mid] < x and x <= boundaries[mid+1]: 242 | return mid 243 | elif x <= boundaries[mid]: 244 | return self._bisect(x, boundaries, lo, mid) 245 | else: 246 | return self._bisect(x, boundaries, mid + 1, hi) 247 | else: 248 | return -1 249 | 250 | def __len__(self): 251 | return self.num_batches 252 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import glob 4 | import json 5 | import yaml 6 | import argparse 7 | from pathlib import Path 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | from torch import nn, optim 12 | from torch.nn import functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | import utils 16 | import commons 17 | from models import TableRecognizer, get_positional_encoding 18 | import beam_search 19 | 20 | class HParams(dict): 21 | def __getattr__(self, name): 22 | value = self[name] 23 | if isinstance(value, dict): 24 | value = HParams(value) 25 | return value 26 | 27 | class ImageLoader(torch.utils.data.Dataset): 28 | """ 29 | Load image 30 | """ 31 | def __init__(self, dir_path, cfg, n_tot, n_idx): 32 | self.cfg = cfg 33 | self.n_tot = n_tot 34 | self.n_idx = n_idx 35 | self.image_paths, self.lengths, \ 36 | self.image_heights, self.image_widths = self._build(dir_path) 37 | self.vocab = self._load_vocab() 38 | 39 | def _build(self, dir_path): 40 | image_paths = glob.glob(os.path.join(dir_path, "*.png")) 41 | 42 | tuple_list = [] 43 | for image_path in image_paths: 44 | image = Image.open(image_path) 45 | w, h = [math.ceil(x / self.cfg.patch_length) for x in image.size] 46 | tuple_list.append((image_path, h*w, h, w)) 47 | tuple_list.sort(key=lambda x: x[1], reverse=True) 48 | 49 | image_paths = [] 50 | lengths = [] 51 | image_heights = [] 52 | image_widths = [] 53 | for image_path, length, h, w in tuple_list[self.n_idx::self.n_tot]: 54 | image_paths.append(image_path) 55 | lengths.append(length) 56 | image_heights.append(h) 57 | image_widths.append(w) 58 | return image_paths, lengths, image_heights, image_widths 59 | 60 | def _load_vocab(self): 61 | with open(self.cfg.vocab_path) as f: 62 | words = [x.replace('\n', '') for x in f.readlines()] 63 | vocab = {word: idx for idx, word in enumerate(words)} 64 | return vocab 65 | 66 | def get_items(self, index): 67 | patch_length = self.cfg.patch_length 68 | h, w = self.image_heights[index], self.image_widths[index] 69 | c = 3 70 | 71 | image = Image.open(self.image_paths[index]).convert('RGB') 72 | image = (np.asarray(image, dtype=np.float32) / 255) * 2 - 1 73 | image = torch.from_numpy(image) 74 | image = torch.nn.functional.pad(image, [ 75 | 0, 0, 76 | 0, (patch_length - (image.shape[1] % patch_length)) % patch_length, 77 | 0, (patch_length - (image.shape[0] % patch_length)) % patch_length 78 | ]) 79 | image = image.view([h, patch_length, w, patch_length, c]) 80 | image = image.permute(0, 2, 4, 1, 3) 81 | image = image.reshape(h * w, c * (patch_length ** 2)) 82 | 83 | length = self.lengths[index] 84 | image_height = h 85 | image_width = w 86 | return (image, length, image_height, image_width) 87 | 88 | def __getitem__(self, index): 89 | return self.get_items(index) 90 | 91 | def __len__(self): 92 | return len(self.image_paths) 93 | 94 | 95 | class ImageCollate(): 96 | """ Zero-pads model inputs 97 | """ 98 | def __call__(self, batch): 99 | """Collate's training batch from image and text info 100 | Inputs: 101 | - batch: [img, t_tot, h_img, w_img] 102 | 103 | Outputs: 104 | - (img_padded, mask_img, pos_r, pos_c) 105 | """ 106 | max_len = max(x[1] for x in batch) 107 | b = len(batch) 108 | c = batch[0][0].size(1) # image patch size 109 | 110 | img_padded = torch.FloatTensor(b, max_len, c) 111 | mask_img = torch.FloatTensor(b, max_len, 1) 112 | pos_r = torch.FloatTensor(b, max_len, 1) 113 | pos_c = torch.FloatTensor(b, max_len, 1) 114 | 115 | img_padded.zero_() 116 | mask_img.zero_() 117 | pos_r.zero_() 118 | pos_c.zero_() 119 | for i in range(b): 120 | img, t_tot, h_img, w_img = batch[i] 121 | 122 | img_padded[i, :t_tot] = img 123 | mask_img[i, :t_tot] = 1 124 | pos_r[i, :t_tot] = torch.arange(h_img).unsqueeze(-1).repeat(1, w_img).view(-1, 1) 125 | pos_c[i, :t_tot] = torch.arange(w_img).repeat(h_img).view(-1, 1) 126 | return img_padded, mask_img, pos_r, pos_c 127 | 128 | def load_checkpoints(dir_path, model, regex="model_*.pth", n=1): 129 | f_list = glob.glob(os.path.join(dir_path, regex)) 130 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 131 | 132 | for i, fname in enumerate(f_list): 133 | idx_last = i 134 | if fname.find("_919000.pth") != -1: 135 | break 136 | f_list = f_list[:idx_last+1] 137 | 138 | f_list = f_list[-n:] 139 | 140 | saved_state_dict = {} 141 | for i, checkpoint_path in enumerate(f_list): 142 | print(checkpoint_path) 143 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 144 | for k, v in checkpoint_dict['model'].items(): 145 | if i == 0: 146 | saved_state_dict[k] = v / len(f_list) 147 | else: 148 | saved_state_dict[k] += v / len(f_list) 149 | 150 | if hasattr(model, 'module'): 151 | state_dict = model.module.state_dict() 152 | else: 153 | state_dict = model.state_dict() 154 | new_state_dict= {} 155 | for k, v in state_dict.items(): 156 | new_state_dict[k] = saved_state_dict[k] 157 | if hasattr(model, 'module'): 158 | model.module.load_state_dict(new_state_dict) 159 | else: 160 | model.load_state_dict(new_state_dict) 161 | 162 | def inference(self, x_img, mask_img, pos_r, pos_c, idx_start=1, idx_end=2, max_decode_len=10000, beam_size=1, top_beams=1, alpha=1., n_toks=5000): 163 | from tqdm import tqdm 164 | with torch.no_grad(): 165 | b = x_img.size(0) 166 | nh = self.n_heads 167 | d = self.hidden_channels // self.n_heads 168 | dtype = x_img.dtype 169 | device = x_img.device 170 | 171 | x_emb_img = self.emb_img(x_img, mask_img, pos_r, pos_c) 172 | cache = [{ 173 | "kv": [], 174 | "k_cum": [] 175 | } for _ in range(self.n_layers) 176 | ] 177 | n_split = max(n_toks // x_img.size(1), 1) 178 | n_iter = math.ceil(b / n_split) 179 | for i in range(n_iter): 180 | print("%05d" % i, end='\r') 181 | x_emb_img_iter = x_emb_img[i*n_split:(i+1)*n_split] 182 | mask_img_iter = mask_img[i*n_split:(i+1)*n_split] 183 | b_iter = x_emb_img_iter.size(0) 184 | 185 | cache_each = [{ 186 | "kv": torch.zeros(b_iter, 1, nh, d, d).to(dtype=torch.float, device=device), 187 | "k_cum": torch.zeros(b_iter, 1, nh, d).to(dtype=torch.float, device=device) 188 | } for _ in range(self.n_layers) 189 | ] 190 | _ = self.enc(x_emb_img_iter, mask_img_iter, cache_each) 191 | for l in range(self.n_layers): 192 | cache[l]["kv"].append(cache_each[l]["kv"].clone()) 193 | cache[l]["k_cum"].append(cache_each[l]["k_cum"].clone()) 194 | for l in range(self.n_layers): 195 | cache[l]["kv"] = torch.cat(cache[l]["kv"], 0) 196 | cache[l]["k_cum"] = torch.cat(cache[l]["k_cum"], 0) 197 | 198 | pos_enc = get_positional_encoding( 199 | torch.arange(max_decode_len).view(1,-1,1).to(device=device), 200 | self.hidden_channels 201 | ) 202 | 203 | if beam_size == 1: 204 | finished = torch.BoolTensor(b,1).to(device=device).fill_(False) 205 | idx = torch.zeros(b,1).long().to(device=device) + idx_start 206 | ids = [] 207 | for i in tqdm(range(max_decode_len)): 208 | x_emb_txt = self.emb_txt.emb(idx) + pos_enc[:,i:i+1] 209 | x = self.enc(x_emb_txt, None, cache) 210 | logit_txt = self.proj_txt(x) 211 | idx = torch.argmax(logit_txt, -1) 212 | ids.append(idx) 213 | finished |= torch.eq(idx, idx_end) 214 | if torch.all(finished): 215 | break 216 | return ids 217 | else: 218 | def symbols_to_logits_fn(ids, i, cache): 219 | x_emb_txt = self.emb_txt.emb(ids[:,i:i+1]) + pos_enc[:,i:i+1] 220 | x = self.enc(x_emb_txt, None, cache) 221 | logit_txt = self.proj_txt(x) 222 | return logit_txt, cache 223 | initial_ids = torch.zeros(b).long().to(device=device) + idx_start 224 | decoded_ids, scores = beam_search.beam_search( 225 | symbols_to_logits_fn, 226 | initial_ids, 227 | beam_size, 228 | max_decode_len, 229 | self.n_vocab, 230 | alpha, 231 | states=cache, 232 | eos_id=idx_end, 233 | stop_early=(top_beams == 1)) 234 | 235 | if top_beams == 1: 236 | decoded_ids = decoded_ids[:, 0, 1:] 237 | scores = scores[:, 0] 238 | else: 239 | decoded_ids = decoded_ids[:, :top_beams, 1:] 240 | scores = scores[:, :top_beams] 241 | return decoded_ids, scores 242 | 243 | 244 | if __name__ == "__main__": 245 | # python inference.py -m "./outputs/base/" -i "/data/private/datasets/pubtabnet/val/" -o "./results/val1" -nt 16 -ni 0 -na 20 246 | # ... 247 | # python inference.py -m "./outputs/base/" -i "/data/private/datasets/pubtabnet/val/" -o "./results/val1" -nt 16 -ni 15 -na 20 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--model_dir", "-m", type=str, help="model directory") 250 | parser.add_argument("--image_dir", "-i", type=str, help="image directory") 251 | parser.add_argument("--out_dir", "-o", type=str, help="output directory") 252 | parser.add_argument("--n_tot", "-nt", type=int, help="total number of processes") 253 | parser.add_argument("--n_idx", "-ni", type=int, help="index of current process") 254 | parser.add_argument("--n_avg", "-na", type=int, default=20, help="number of checkpoints to be averaged") 255 | args = parser.parse_args() 256 | 257 | 258 | with open(os.path.join(args.model_dir, ".hydra/config.yaml"), "r") as f: 259 | hps = HParams(yaml.full_load(f)) 260 | 261 | dataset = ImageLoader(args.image_dir, hps.data, args.n_tot, args.n_idx) 262 | collate_fn = ImageCollate() 263 | loader = DataLoader(dataset, num_workers=8, shuffle=False, pin_memory=False, 264 | collate_fn=collate_fn, batch_size=2**6) 265 | vocab_inv = {v: k for k, v in dataset.vocab.items()} 266 | 267 | model = TableRecognizer( 268 | len(dataset.vocab), 269 | 3 * (hps.data.patch_length ** 2), 270 | **hps.model).cuda().eval() 271 | 272 | load_checkpoints(args.model_dir, model, "model_*.pth", args.n_avg) 273 | 274 | prefix = '' 275 | postfix = '
' 276 | html_strings = [] 277 | with torch.no_grad(): 278 | for i, elms in enumerate(loader): 279 | print(i) 280 | (img, mask_img, pos_r, pos_c) = elms 281 | img = img.cuda() 282 | mask_img = mask_img.cuda() 283 | pos_r = pos_r.cuda() 284 | pos_c = pos_c.cuda() 285 | 286 | ret, _ = inference(model, img, mask_img, pos_r, pos_c, beam_size=32, alpha=0.6, max_decode_len=min(10000, math.ceil(4.5 * img.shape[1]))) 287 | ret = ret.cpu().numpy() 288 | for j, r in enumerate(ret): 289 | try: 290 | eos_pos = list(r).index(2) 291 | r = r[:eos_pos] 292 | except: 293 | pass 294 | html_string = prefix + "".join([vocab_inv[x] for x in r]) + postfix 295 | html_strings.append(html_string) 296 | 297 | image_names = [x.split("/")[-1] for x in dataset.image_paths] 298 | 299 | Path(args.out_dir).mkdir(parents=True, exist_ok=True) 300 | with open(os.path.join(args.out_dir, "out_%d_of_%d.json" % (args.n_idx, args.n_tot)), 'w', encoding='utf-8') as f: 301 | json.dump({img: txt for img, txt in zip(image_names, html_strings)}, f) 302 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def loss_fn_img(src, tgt, mask): 6 | src = src.float() 7 | tgt = tgt.float() 8 | c = src.size(-1) 9 | 10 | loss = (src - tgt) ** 2 11 | return (loss * mask).sum() / (c * mask.sum()) 12 | 13 | 14 | def loss_fn_txt(src, tgt, mask): 15 | src = src.transpose(1,2).float() 16 | loss = F.cross_entropy(src, tgt, reduction='none') 17 | return (loss * mask).sum() / (mask.sum()) 18 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from fast_transformers.attention import CausalLinearAttention 7 | from fast_transformers.attention.causal_linear_attention import causal_linear 8 | from fast_transformers.masking import LengthMask, TriangularCausalMask 9 | from fast_transformers.feature_maps import elu_feature_map 10 | 11 | 12 | def get_positional_encoding(position, channels, min_timescale=1.0, max_timescale=1.0e4): 13 | num_timescales = channels // 2 14 | log_timescale_increment = ( 15 | math.log(float(max_timescale) / float(min_timescale)) / 16 | (num_timescales - 1)) 17 | inv_timescales = min_timescale * torch.exp( 18 | torch.arange(num_timescales, dtype=position.dtype, device=position.device) * -log_timescale_increment) 19 | scaled_time = position * inv_timescales.view(1, 1, -1) 20 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], -1) 21 | signal = F.pad(signal, [0, channels % 2]) 22 | return signal 23 | 24 | 25 | class ImageEmbedding(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.emb = nn.Linear(in_channels, out_channels) 31 | 32 | def forward(self, x, x_mask, pos_r, pos_c): 33 | half_channels = self.out_channels // 2 34 | x = self.emb(x) 35 | x_pos_r = get_positional_encoding(pos_r, half_channels) 36 | x_pos_c = get_positional_encoding(pos_c, self.out_channels - half_channels) 37 | x_pos = torch.cat([x_pos_r, x_pos_c], -1) 38 | x_emb = (x + x_pos) * x_mask 39 | return x_emb 40 | 41 | 42 | class TextEmbedding(nn.Module): 43 | def __init__(self, n_vocab, out_channels): 44 | super().__init__() 45 | self.n_vocab = n_vocab 46 | self.out_channels = out_channels 47 | self.emb = nn.Embedding(n_vocab, out_channels) 48 | 49 | def forward(self, x, x_mask, pos_t): 50 | x = self.emb(x) 51 | x_pos_t = get_positional_encoding(pos_t, self.out_channels) 52 | x_emb = (x + x_pos_t) * x_mask 53 | return x_emb 54 | 55 | 56 | class CausalLinearAttentionAMP(CausalLinearAttention): 57 | 58 | def forward(self, queries, keys, values, query_mask=None, 59 | key_mask=None, cache=None): 60 | self.feature_map.new_feature_map() 61 | Q = self.feature_map.forward_queries(queries) 62 | K = self.feature_map.forward_keys(keys) 63 | 64 | if cache is not None: 65 | with torch.cuda.amp.autocast(enabled=False): 66 | Q = Q.float() # [b, t, nh ,d] 67 | K = K.float() # [b, t, nh, d] 68 | if key_mask is not None: 69 | K = K * key_mask[:, :, :, None] 70 | 71 | values = values.float() # [b, t, nh, d] 72 | Q_p = Q.unsqueeze(-2) # [b, t, nh, 1, d] 73 | K_p = K.unsqueeze(-1) # [b, t, nh, d, 1] 74 | values_p = values.unsqueeze(-2) # [b, t, nh, 1, d] 75 | kv_cum = cache['kv'] + (K_p * values_p).cumsum(1) # [b, t, nh, d, d] 76 | K_cum = cache['k_cum'] + K.cumsum(1) # [b, t, nh, d] 77 | cache['kv'] = kv_cum[:,-1:] # [b, 1, nh, d, d] 78 | cache['k_cum'] = K_cum[:,-1:] # [b, 1, nh, d] 79 | 80 | V = (Q_p @ kv_cum).squeeze(-2) # [b, t, nh, d] 81 | Z = 1/(torch.sum(Q * K_cum, -1) + self.eps) # [b, t, nh, d], [b, t, nh, d] 82 | out = V * Z[:, :, :, None] # [b, t, nh, d], [b, t, nh, 1] 83 | out = out.to(dtype=queries.dtype) 84 | return out 85 | else: 86 | K = K * key_mask[:, :, :, None] 87 | Q, K = self._make_sizes_compatible(Q, K) 88 | 89 | with torch.cuda.amp.autocast(enabled=False): 90 | Q = Q.float() 91 | K = K.float() 92 | values = values.float() 93 | 94 | # Compute the normalizers 95 | Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) 96 | if getattr(self, "save_attn", False): 97 | self.attn_map = (Q.permute(0,2,1,3).contiguous() @ K.permute(0,2,3,1).contiguous()).tril()*Z.permute(0,2,1).unsqueeze(-1) 98 | 99 | # Compute the unnormalized result 100 | V = causal_linear( 101 | Q, 102 | K, 103 | values 104 | ) 105 | out = V * Z[:, :, :, None] 106 | out = out.to(dtype=queries.dtype) 107 | return out 108 | 109 | def set_save_attn(self, v): 110 | self.save_attn = v 111 | 112 | 113 | class AttentionLayer(nn.Module): 114 | """Implement the attention layer. Namely project the inputs to multi-head 115 | queries, keys and values, call the attention implementation and then 116 | reproject the output. 117 | 118 | It can be thought of as a decorator (see decorator design patter) of an 119 | attention layer. 120 | 121 | Arguments 122 | --------- 123 | attention: Specific inner attention implementation that just computes a 124 | weighted average of values given a similarity of queries and 125 | keys. 126 | d_model: The input feature dimensionality 127 | n_heads: The number of heads for the multi head attention 128 | d_keys: The dimensionality of the keys/queries 129 | (default: d_model/n_heads) 130 | d_values: The dimensionality of the values (default: d_model/n_heads) 131 | """ 132 | def __init__(self, attention, d_model, n_heads, d_keys=None, 133 | d_values=None): 134 | super().__init__() 135 | 136 | # Fill d_keys and d_values 137 | d_keys = d_keys or (d_model//n_heads) 138 | d_values = d_values or (d_model//n_heads) 139 | 140 | self.inner_attention = attention 141 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 143 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 144 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 145 | self.n_heads = n_heads 146 | 147 | def forward(self, queries, keys, values, query_mask=None, 148 | key_mask=None, cache=None): 149 | """Apply attention to the passed in queries/keys/values after 150 | projecting them to multiple heads. 151 | 152 | In the argument description we make use of the following sizes 153 | 154 | - N: the batch size 155 | - L: The maximum length of the queries 156 | - S: The maximum length of the keys (the actual length per sequence 157 | is given by the length mask) 158 | - D: The input feature dimensionality passed in the constructor as 159 | 'd_model' 160 | 161 | Arguments 162 | --------- 163 | queries: (N, L, D) The tensor containing the queries 164 | keys: (N, S, D) The tensor containing the keys 165 | values: (N, S, D) The tensor containing the values 166 | 167 | Returns 168 | ------- 169 | The new value for each query as a tensor of shape (N, L, D). 170 | """ 171 | # Extract the dimensions into local variables 172 | N, L, _ = queries.shape 173 | _, S, _ = keys.shape 174 | H = self.n_heads 175 | 176 | # Project the queries/keys/values 177 | queries = self.query_projection(queries).view(N, L, H, -1) 178 | keys = self.key_projection(keys).view(N, S, H, -1) 179 | values = self.value_projection(values).view(N, S, H, -1) 180 | 181 | # Compute the attention 182 | new_values = self.inner_attention( 183 | queries, 184 | keys, 185 | values, 186 | query_mask, 187 | key_mask, 188 | cache=cache 189 | ).view(N, L, -1) 190 | 191 | # Project the output and return 192 | return self.out_projection(new_values) 193 | 194 | 195 | class TransformerEncoderLayer(nn.Module): 196 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 197 | activation="relu"): 198 | super().__init__() 199 | d_ff = d_ff or 4*d_model 200 | self.attention = attention 201 | self.linear1 = nn.Linear(d_model, d_ff) 202 | self.linear2 = nn.Linear(d_ff, d_model) 203 | self.norm1 = nn.LayerNorm(d_model) 204 | self.norm2 = nn.LayerNorm(d_model) 205 | self.dropout = nn.Dropout(dropout) 206 | self.activation = F.relu if activation == "relu" else F.gelu 207 | 208 | def forward(self, x, x_mask=None, cache=None): 209 | """Apply the transformer encoder to the input x. 210 | 211 | Arguments 212 | --------- 213 | x: The input features of shape (N, L, E) where N is the batch size, 214 | L is the sequence length (padded) and E is d_model passed in the 215 | constructor. 216 | """ 217 | # Run self attention and add it to the input 218 | x = x + self.dropout(self.attention( 219 | x, x, x, 220 | query_mask=x_mask, 221 | key_mask=x_mask, 222 | cache=cache 223 | )) 224 | 225 | # Run the fully connected part of the layer 226 | y = x = self.norm1(x) 227 | y = self.dropout(self.activation(self.linear1(y))) 228 | y = self.dropout(self.linear2(y)) 229 | 230 | return self.norm2(x+y) 231 | 232 | 233 | class CausalLinearTransformerEncoder(nn.Module): 234 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 235 | p_dropout=0.1, activation="gelu", feature_map=elu_feature_map): 236 | 237 | super().__init__() 238 | self.layers = nn.ModuleList([ 239 | TransformerEncoderLayer( 240 | AttentionLayer( 241 | CausalLinearAttentionAMP(hidden_channels, feature_map), 242 | hidden_channels, 243 | n_heads), 244 | hidden_channels, 245 | filter_channels, 246 | p_dropout, 247 | activation 248 | ) 249 | for i in range(n_layers) 250 | ]) 251 | 252 | def forward(self, x, x_mask=None, cache=None): 253 | cache_l = None 254 | 255 | # Apply all the transformers 256 | for i, layer in enumerate(self.layers): 257 | if cache is not None: 258 | cache_l = cache[i] 259 | x = layer(x, x_mask=x_mask, cache=cache_l) 260 | 261 | return x 262 | 263 | 264 | class TableRecognizer(nn.Module): 265 | def __init__(self, n_vocab, img_channels, hidden_channels, filter_channels, n_heads, n_layers, p_dropout=.1): 266 | super().__init__() 267 | self.n_vocab = n_vocab 268 | self.img_channels = img_channels 269 | self.hidden_channels = hidden_channels 270 | self.filter_channels = filter_channels 271 | self.n_heads = n_heads 272 | self.n_layers = n_layers 273 | self.p_dropout = p_dropout 274 | 275 | self.emb_img = ImageEmbedding(img_channels, hidden_channels) 276 | self.emb_txt = TextEmbedding(n_vocab, hidden_channels) 277 | self.enc = CausalLinearTransformerEncoder( 278 | hidden_channels, 279 | filter_channels, 280 | n_heads, 281 | n_layers, 282 | p_dropout) 283 | 284 | self.proj_img = nn.Linear(hidden_channels, img_channels) 285 | self.proj_txt = nn.Linear(hidden_channels, n_vocab) 286 | 287 | def forward(self, x_img, x_txt, mask_img, mask_txt, pos_r, pos_c, pos_t): 288 | x_mask = mask_img + mask_txt 289 | 290 | x_emb_img = self.emb_img(x_img, mask_img, pos_r, pos_c) 291 | x_emb_txt = self.emb_txt(x_txt, mask_txt, pos_t) 292 | x_emb = x_emb_img + x_emb_txt 293 | 294 | x = self.enc(x_emb, x_mask) 295 | logit_img = self.proj_img(x) 296 | logit_txt = self.proj_txt(x) 297 | return logit_img, logit_txt 298 | 299 | def inference(self, x_img, mask_img, pos_r, pos_c, idx_start=1, idx_end=2, max_decode_len=10000): 300 | from tqdm import tqdm 301 | with torch.no_grad(): 302 | b = x_img.size(0) 303 | nh = self.n_heads 304 | d = self.hidden_channels // self.n_heads 305 | dtype = x_img.dtype 306 | device = x_img.device 307 | 308 | cache = [{ 309 | "kv": torch.zeros(b, 1, nh, d, d).to(dtype=torch.float, device=device), 310 | "k_cum": torch.zeros(b, 1, nh, d).to(dtype=torch.float, device=device) 311 | } for _ in range(self.n_layers) 312 | ] 313 | x_emb_img = self.emb_img(x_img, mask_img, pos_r, pos_c) 314 | _ = self.enc(x_emb_img, mask_img, cache) 315 | 316 | pos_enc = get_positional_encoding( 317 | torch.arange(max_decode_len).view(1,-1,1).to(device=device), 318 | self.hidden_channels 319 | ) 320 | finished = torch.BoolTensor(b,1).to(device=device).fill_(False) 321 | idx = torch.zeros(b,1).long().to(device=device) + idx_start 322 | ids = [] 323 | for i in tqdm(range(max_decode_len)): 324 | x_emb_txt = self.emb_txt.emb(idx) + pos_enc[:,i:i+1] 325 | x = self.enc(x_emb_txt, None, cache) 326 | logit_txt = self.proj_txt(x) 327 | idx = torch.argmax(logit_txt, -1) 328 | ids.append(idx) 329 | finished |= torch.eq(idx, idx_end) 330 | if torch.all(finished): 331 | break 332 | return ids 333 | -------------------------------------------------------------------------------- /processing_pubtabnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | 7 | def preprocess_json(dataset_dir, output_dir): 8 | json_path = os.path.join(dataset_dir, 'PubTabNet_2.0.0.jsonl') 9 | 10 | dicts = {split: [] for split in ['train', 'val', 'test']} 11 | with open(json_path, 'r', encoding='utf-8') as f: 12 | for line in tqdm(f, total=509892): 13 | data = json.loads(line) 14 | 15 | data_new = dict() 16 | image_path = os.path.join(dataset_dir, data['split'], data['filename']) 17 | image = Image.open(image_path) 18 | data_new['image_path'] = image_path 19 | data_new['image_size'] = image.size 20 | # start text 21 | nd = len([x for x in data['html']['structure']['tokens'] if x == '']) 22 | nc = len(data['html']['cells']) 23 | alert_msg = "The number of td (%d) is note equal to the number of cells (%d)." % (nd, nc) 24 | assert nd == nc, alert_msg 25 | 26 | data_new['text'] = [''] 27 | cnt_cell = 0 28 | for struct_tok in data['html']['structure']['tokens']: 29 | if struct_tok == '': 30 | cell = data['html']['cells'][cnt_cell] 31 | cnt_cell += 1 32 | data_new['text'] += cell['tokens'] 33 | data_new['text'].append(struct_tok) 34 | # end text 35 | data_new['num_tokens'] = len(data_new['text']) 36 | dicts[data['split']].append(data_new) 37 | 38 | for k, v in dicts.items(): 39 | output_path = os.path.join(output_dir, k + ".json") 40 | with open(output_path, 'w', encoding='utf-8') as out: 41 | json.dump(v, out) 42 | 43 | 44 | def generate_vocab(dataset_dir, output_dir): 45 | json_path = os.path.join(dataset_dir, "PubTabNet_2.0.0.jsonl") 46 | tokens = {key: set() for key in ['structure', 'cell']} 47 | with open(json_path, 'r', encoding='utf-8') as f: 48 | for line in tqdm(f, total=509892): 49 | data = json.loads(line) 50 | tokens['structure'].update(data['html']['structure']['tokens']) 51 | for cell in data['html']['cells']: 52 | tokens['cell'].update(cell['tokens']) 53 | 54 | print('\nsize of structure_tokens: ', len(tokens['structure'])) 55 | print('size of cell_tokens: ', len(tokens['cell'])) 56 | 57 | tokens['cell'] = tokens['cell'].difference(tokens['structure']) 58 | tokens_total = [] 59 | for key, value in tokens.items(): 60 | tokens_total.extend(sorted(list(value))) 61 | 62 | vocab_path = os.path.join(output_dir, 'vocab.txt') 63 | 64 | with open(vocab_path, 'w', encoding='utf-8') as out: 65 | out.write('\n\n\n\n') 66 | for token in tokens_total: 67 | out.write(token + '\n') 68 | 69 | 70 | if __name__ == '__main__': 71 | dataset_dir = '/data/private/datasets/pubtabnet' 72 | output_dir = '/data/private/datasets/pubtabnet/annotations' 73 | generate_vocab(dataset_dir, output_dir) 74 | preprocess_json(dataset_dir, output_dir) 75 | -------------------------------------------------------------------------------- /resources/ltiayn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/resources/ltiayn.png -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import glob 5 | import numpy as np 6 | from srcs.PubTabNet.src.metric import TEDS 7 | from srcs.PubTabNet.src.parallel import parallel_process 8 | 9 | 10 | f_list = glob.glob("./results/val1/*.json") 11 | pred = {} 12 | for f_name in f_list: 13 | with open(f_name, "r") as f: 14 | pred.update(json.load(f)) 15 | 16 | with open("/data/private/datasets/pubtabnet/annotations/val.json", "r") as f: 17 | data = json.load(f) 18 | true = [(x['image_path'].split("/")[-1], "".join(x['text'][1:-1])) for x in data if x['image_path'].split("/")[-1] in pred] 19 | 20 | true_sorted = sorted(true, key=lambda x: len(x[1])) 21 | 22 | teds = TEDS(n_jobs=48) 23 | 24 | html_strings_pred = [pred[x[0]] for x in true_sorted] 25 | prefix = '' 26 | postfix = '
' 27 | html_strings_tgt = [prefix + x[1] + postfix for x in true_sorted] 28 | 29 | inputs = [{"pred": pred, "true": true} for pred, true in zip(html_strings_pred, html_strings_tgt)] 30 | scores = parallel_process(inputs, teds.evaluate, use_kwargs=True, n_jobs=teds.n_jobs, front_num=1) 31 | print(np.mean(scores)) 32 | -------------------------------------------------------------------------------- /srcs/PubTabNet/ICDAR_SLR_competition/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/ICDAR_SLR_competition/example.png -------------------------------------------------------------------------------- /srcs/PubTabNet/ICDAR_SLR_competition/val_mini.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/ICDAR_SLR_competition/val_mini.zip -------------------------------------------------------------------------------- /srcs/PubTabNet/LICENSE.md: -------------------------------------------------------------------------------- 1 | The annotations in this dataset belong to IBM and are licensed under a [Community Data License Agreement – Permissive – Version 1.0 License](https://cdla.io/permissive-1-0/). 2 | 3 | ## Images 4 | IBM does not own the copyright of the images. Use of the images must abide by the [PMC Open Access Subset Terms of Use](https://www.ncbi.nlm.nih.gov/pmc/tools/openftlist/). 5 | -------------------------------------------------------------------------------- /srcs/PubTabNet/README.md: -------------------------------------------------------------------------------- 1 | # PubTabNet 2 | 3 | PubTabNet is a large dataset for image-based table recognition, containing 568k+ images of tabular data annotated with the corresponding HTML representation of the tables. The table images are extracted from the scientific publications included in the [PubMed Central Open Access Subset (commercial use collection)](https://www.ncbi.nlm.nih.gov/pmc/tools/openftlist/). Table regions are identified by matching the PDF format and the XML format of the articles in the PubMed Central Open Access Subset. More details are available in our paper ["Image-based table recognition: data, model, and evaluation"](https://arxiv.org/abs/1911.10683). 4 | 5 | ## Headlines 6 | 7 | `21/July/2020` - PubTabNet 2.0.0 is released, where the position (bounding box) of non-empty cells is added into the annotation. The annotation file is also changed from `json` format to `jsonl` format to reduce the requirement on large RAM. 8 | 9 | `20/Jul/2020` - PubTabNet is used in [ICDAR 2021 Competition on Scientific Literature Parsing](https://github.com/IBM/ICDAR2021-SLP) ([Task B on Table Recognition](https://aieval.draco.res.ibm.com/challenge/40/overview)) 10 | 11 | `03/July/2020` - `Image-based table recognition: data, model, and evaluation` is accepted by ECCV20. 12 | 13 | `01/July/2020` - Code of **T**ree-**Edit**-**D**istance-based **S**imilarity (TEDS) metric is [released](src). 14 | 15 | ## Updates in progress 16 | 17 | ### Encoder-dual-decoder model 18 | 19 | In our paper, we proposed a new encoder-dual-decoder architecture, which was trained on PubTabNet and can accurately reconstruct the HTML representation of complex tables solely relying on image input. Due to legal constraints, the source code of the model will not be released. 20 | 21 | ### Ground truth of test set 22 | 23 | The ground truth of test will not be release, as we want to keep it for a competition in the future. We will offer a service for people to submit and evaluate their results soon. 24 | 25 | ## Getting data 26 | 27 | Images and annotations can be downloaded [here](https://developer.ibm.com/exchanges/data/all/pubtabnet/). If you want to download the data from the command line, you can use curl or wget to download the data. 28 | 29 | ``` 30 | curl -o /PubTabNet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-pubtabnet/2.0.0/pubtabnet.tar.gz 31 | ``` 32 | 33 | ``` 34 | wget -O /PubTabNet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-pubtabnet/2.0.0/pubtabnet.tar.gz 35 | ``` 36 | 37 | ## Annotation structure 38 | 39 | The annotation is in the jsonl (jsonlines) format, where each line contains the annotations on a given sample in the following format: 40 | The structure of the annotation jsonl file is: 41 | 42 | ``` 43 | { 44 | 'filename': str, 45 | 'split': str, 46 | 'imgid': int, 47 | 'html': { 48 | 'structure': {'tokens': [str]}, 49 | 'cell': [ 50 | { 51 | 'tokens': [str], 52 | 'bbox': [x0, y0, x1, y1] # only non-empty cells have this attribute 53 | } 54 | ] 55 | } 56 | } 57 | ``` 58 | 59 | ## Cite us 60 | 61 | ``` 62 | @article{zhong2019image, 63 | title={Image-based table recognition: data, model, and evaluation}, 64 | author={Zhong, Xu and ShafieiBavani, Elaheh and Yepes, Antonio Jimeno}, 65 | journal={arXiv preprint arXiv:1911.10683}, 66 | year={2019} 67 | } 68 | ``` 69 | 70 | ## Examples 71 | 72 | A [Jupyter notebook](./exploring_PubTabNet_dataset.ipynb) is provided to inspect the annotations of 20 sample tables. 73 | 74 | 75 | ## Related links 76 | 77 | [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) is a large dataset of document images, of which the layout is annotated with both bounding boxes and polygonal segmentations. The source of the documents is [PubMed Central Open Access Subset (commercial use collection)](https://www.ncbi.nlm.nih.gov/pmc/tools/openftlist/). The annotations are automatically generated by matching the PDF format and the XML format of the articles in the PubMed Central Open Access Subset. More details are available in our paper ["PubLayNet: largest dataset ever for document layout analysis."](https://arxiv.org/abs/1908.07836), which was awarded the [best paper at ICDAR 2019](http://icdar2019.org/award/)! 78 | -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC1626454_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC1626454_002_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC2753619_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC2753619_002_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC2759935_007_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC2759935_007_01.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC2838834_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC2838834_005_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC3519711_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC3519711_003_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC3826085_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC3826085_003_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC3907710_006_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC3907710_006_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4003957_018_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4003957_018_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4172848_007_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4172848_007_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4517499_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4517499_004_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4682394_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4682394_003_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4776821_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4776821_005_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC4840965_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC4840965_004_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5134617_013_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5134617_013_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5198506_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5198506_004_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5332562_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5332562_005_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5402779_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5402779_004_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5577841_001_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5577841_001_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5679144_002_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5679144_002_01.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/PMC5897438_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/PubTabNet/examples/PMC5897438_004_00.png -------------------------------------------------------------------------------- /srcs/PubTabNet/examples/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from bs4 import BeautifulSoup as bs 3 | 4 | def format_html(img): 5 | ''' Formats HTML code from tokenized annotation of img 6 | ''' 7 | html_string = ''' 8 | 9 | 10 | 16 | 17 | 18 | 19 | %s 20 |
21 | 22 | ''' % ''.join(img['html']['structure']['tokens']) 23 | cell_nodes = list(re.finditer(r'(]*>)()', html_string)) 24 | assert len(cell_nodes) == len(img['html']['cells']), 'Number of cells defined in tags does not match the length of cells' 25 | cells = [''.join(c['tokens']) for c in img['html']['cells']] 26 | offset = 0 27 | for n, cell in zip(cell_nodes, cells): 28 | html_string = html_string[:n.end(1) + offset] + cell + html_string[n.start(2) + offset:] 29 | offset += len(cell) 30 | # prettify the html 31 | soup = bs(html_string) 32 | html_string = soup.prettify() 33 | return html_string 34 | 35 | 36 | if __name__ == '__main__': 37 | import json 38 | import sys 39 | f = sys.argv[1] 40 | with open(f, 'r') as fp: 41 | annotations = json.load(fp) 42 | for img in annotations['images']: 43 | html_string = format_html(img) 44 | print(html_string) 45 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/README.md: -------------------------------------------------------------------------------- 1 | # Tree-Edit-Distance-based Similarity (TEDS) 2 | 3 | Evaluation metric for table recognition. This metric measures both the structure similarity and the cell content similarity between the prediction and the ground truth. The score is normalized between 0 and 1, where 1 means perfect matching. 4 | 5 | ## How this metric works 6 | 7 | Please see Section V in [our paper](https://arxiv.org/abs/1911.10683) for the principle of this metric. 8 | 9 | ## How to use the code 10 | 11 | ### Installation 12 | 13 | `pip install -r requirements.txt` 14 | 15 | ### Run the code 16 | 17 | Please see [this demo](demo.ipynb). 18 | 19 | ## Cite us 20 | 21 | ``` 22 | @article{zhong2019image, 23 | title={Image-based table recognition: data, model, and evaluation}, 24 | author={Zhong, Xu and ShafieiBavani, Elaheh and Jimeno Yepes, Antonio}, 25 | journal={arXiv preprint arXiv:1911.10683}, 26 | year={2019} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Evaluate a single prediction agains ground truth" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 5, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# Sample HTML code\n", 17 | "pred = '
Name of algoriNotablefeatures
MACS [23]Uses both a control library and local statistics to minimize bias
SICER [15]Designed for detecting diffusely enriched regions; for example, histone modification
PeakSEQ [24]Corrects for reference genome mappability and local statistics
SISSRs [25]High resolution, precise identification of binding-site location
F-seq [26]Uses kernel density estimation
'\n", 18 | "true = '
Name of algorithmNotable features
MACS [23]Uses both a control library and local statistics to minimize bias
SICER [14]Designed for detecting diffusely enriched regions; for example, histone modification
PeakSeq [24]Corrects for reference genome mappability and local statistics
SISSRs [25]High resolution, precise identification of binding-site location
F-seq [26]Uses kernel density estimation
'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 10, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "TEDS score: 0.9781765018607124\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "from metric import TEDS\n", 36 | "# Initialize TEDS object\n", 37 | "teds = TEDS()\n", 38 | "# Evaluate\n", 39 | "score = teds.evaluate(pred, true)\n", 40 | "print('TEDS score:', score)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Batch evaluation with parallel threads" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 7, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import json\n", 57 | "import pprint\n", 58 | "from metric import TEDS" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 8, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Load sample ground truth and predictions\n", 68 | "with open('sample_pred.json') as fp:\n", 69 | " pred_json = json.load(fp)\n", 70 | "with open('sample_gt.json') as fp:\n", 71 | " true_json = json.load(fp)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 9, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stderr", 81 | "output_type": "stream", 82 | "text": [ 83 | "100%|██████████| 19.0/19.0 [00:10<00:00, 1.50s/it]\n", 84 | "19it [00:00, 112400.25it/s]" 85 | ] 86 | }, 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "{'PMC2094709_004_00.png': 1.0,\n", 92 | " 'PMC2871264_002_00.png': 1.0,\n", 93 | " 'PMC2915972_003_00.png': 0.9298260149130074,\n", 94 | " 'PMC3160368_005_00.png': 0.994615695248351,\n", 95 | " 'PMC3568059_003_00.png': 0.9609420535891124,\n", 96 | " 'PMC3707453_006_00.png': 0.8538903625110521,\n", 97 | " 'PMC3765162_003_01.png': 0.9867342100509474,\n", 98 | " 'PMC3872294_001_00.png': 0.9863636363636363,\n", 99 | " 'PMC4196076_004_00.png': 0.9958653089334908,\n", 100 | " 'PMC4219599_004_00.png': 0.6029978075326913,\n", 101 | " 'PMC4297392_007_00.png': 0.8070175438596492,\n", 102 | " 'PMC4311460_007_00.png': 0.6576923076923077,\n", 103 | " 'PMC4357206_002_00.png': 0.9295181638546892,\n", 104 | " 'PMC4445578_009_01.png': 0.6754965084868096,\n", 105 | " 'PMC4969833_016_01.png': 1.0,\n", 106 | " 'PMC5303243_003_00.png': 0.6494374120956399,\n", 107 | " 'PMC5451934_004_00.png': 0.9978213507625272,\n", 108 | " 'PMC5755158_010_01.png': 1.0,\n", 109 | " 'PMC5849724_006_00.png': 0.9653439200120101,\n", 110 | " 'PMC6022086_007_00.png': 1.0}\n" 111 | ] 112 | }, 113 | { 114 | "name": "stderr", 115 | "output_type": "stream", 116 | "text": [ 117 | "\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "# Initialize TEDS object, using 4 parallel threads\n", 123 | "teds = TEDS(n_jobs=4)\n", 124 | "# Evaluate\n", 125 | "scores = teds.batch_evaluate(pred_json, true_json)\n", 126 | "# Print results\n", 127 | "pp = pprint.PrettyPrinter()\n", 128 | "pp.pprint(scores)" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.6.8" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 IBM 2 | # Author: peter.zhong@au1.ibm.com 3 | # 4 | # This is free software; you can redistribute it and/or modify 5 | # it under the terms of the Apache 2.0 License. 6 | # 7 | # This software is distributed in the hope that it will be useful, 8 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | # Apache 2.0 License for more details. 11 | 12 | import distance 13 | from apted import APTED, Config 14 | from apted.helpers import Tree 15 | from lxml import etree, html 16 | from collections import deque 17 | from srcs.models.PubTabNet.src.parallel import parallel_process 18 | from tqdm import tqdm 19 | 20 | class TableTree(Tree): 21 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): 22 | self.tag = tag 23 | self.colspan = colspan 24 | self.rowspan = rowspan 25 | self.content = content 26 | self.children = list(children) 27 | 28 | def bracket(self): 29 | """Show tree using brackets notation""" 30 | if self.tag == 'td': 31 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ 32 | (self.tag, self.colspan, self.rowspan, self.content) 33 | else: 34 | result = '"tag": %s' % self.tag 35 | for child in self.children: 36 | result += child.bracket() 37 | return "{{{}}}".format(result) 38 | 39 | 40 | class CustomConfig(Config): 41 | @staticmethod 42 | def maximum(*sequences): 43 | """Get maximum possible value 44 | """ 45 | return max(map(len, sequences)) 46 | 47 | def normalized_distance(self, *sequences): 48 | """Get distance from 0 to 1 49 | """ 50 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) 51 | 52 | def rename(self, node1, node2): 53 | """Compares attributes of trees""" 54 | if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): 55 | return 1. 56 | if node1.tag == 'td': 57 | if node1.content or node2.content: 58 | return self.normalized_distance(node1.content, node2.content) 59 | return 0. 60 | 61 | 62 | class TEDS(object): 63 | ''' Tree Edit Distance basead Similarity 64 | ''' 65 | def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): 66 | assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1' 67 | self.structure_only = structure_only 68 | self.n_jobs = n_jobs 69 | self.ignore_nodes = ignore_nodes 70 | self.__tokens__ = [] 71 | 72 | def tokenize(self, node): 73 | ''' Tokenizes table cells 74 | ''' 75 | self.__tokens__.append('<%s>' % node.tag) 76 | if node.text is not None: 77 | self.__tokens__ += list(node.text) 78 | for n in node.getchildren(): 79 | self.tokenize(n) 80 | if node.tag != 'unk': 81 | self.__tokens__.append('' % node.tag) 82 | if node.tag != 'td' and node.tail is not None: 83 | self.__tokens__ += list(node.tail) 84 | 85 | def load_html_tree(self, node, parent=None): 86 | ''' Converts HTML tree to the format required by apted 87 | ''' 88 | global __tokens__ 89 | if node.tag == 'td': 90 | if self.structure_only: 91 | cell = [] 92 | else: 93 | self.__tokens__ = [] 94 | self.tokenize(node) 95 | cell = self.__tokens__[1:-1].copy() 96 | new_node = TableTree(node.tag, 97 | int(node.attrib.get('colspan', '1')), 98 | int(node.attrib.get('rowspan', '1')), 99 | cell, *deque()) 100 | else: 101 | new_node = TableTree(node.tag, None, None, None, *deque()) 102 | if parent is not None: 103 | parent.children.append(new_node) 104 | if node.tag != 'td': 105 | for n in node.getchildren(): 106 | self.load_html_tree(n, new_node) 107 | if parent is None: 108 | return new_node 109 | 110 | def evaluate(self, pred, true): 111 | ''' Computes TEDS score between the prediction and the ground truth of a 112 | given sample 113 | ''' 114 | if (not pred) or (not true): 115 | return 0.0 116 | parser = html.HTMLParser(remove_comments=True, encoding='utf-8') 117 | pred = html.fromstring(pred, parser=parser) 118 | true = html.fromstring(true, parser=parser) 119 | if pred.xpath('body/table') and true.xpath('body/table'): 120 | pred = pred.xpath('body/table')[0] 121 | true = true.xpath('body/table')[0] 122 | if self.ignore_nodes: 123 | etree.strip_tags(pred, *self.ignore_nodes) 124 | etree.strip_tags(true, *self.ignore_nodes) 125 | n_nodes_pred = len(pred.xpath(".//*")) 126 | n_nodes_true = len(true.xpath(".//*")) 127 | n_nodes = max(n_nodes_pred, n_nodes_true) 128 | tree_pred = self.load_html_tree(pred) 129 | tree_true = self.load_html_tree(true) 130 | distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance() 131 | return 1.0 - (float(distance) / n_nodes) 132 | else: 133 | return 0.0 134 | 135 | def batch_evaluate(self, pred_json, true_json): 136 | ''' Computes TEDS score between the prediction and the ground truth of 137 | a batch of samples 138 | @params pred_json: {'FILENAME': 'HTML CODE', ...} 139 | @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} 140 | @output: {'FILENAME': 'TEDS SCORE', ...} 141 | ''' 142 | samples = true_json.keys() 143 | if self.n_jobs == 1: 144 | scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] 145 | else: 146 | inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples] 147 | scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) 148 | scores = dict(zip(samples, scores)) 149 | return scores 150 | 151 | 152 | if __name__ == '__main__': 153 | import json 154 | import pprint 155 | with open('sample_pred.json') as fp: 156 | pred_json = json.load(fp) 157 | with open('sample_gt.json') as fp: 158 | true_json = json.load(fp) 159 | teds = TEDS(n_jobs=4) 160 | scores = teds.batch_evaluate(pred_json, true_json) 161 | pp = pprint.PrettyPrinter() 162 | pp.pprint(scores) 163 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/parallel.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from concurrent.futures import ProcessPoolExecutor, as_completed 3 | 4 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): 5 | """ 6 | A parallel version of the map function with a progress bar. 7 | 8 | Args: 9 | array (array-like): An array to iterate over. 10 | function (function): A python function to apply to the elements of array 11 | n_jobs (int, default=16): The number of cores to use 12 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 13 | keyword arguments to function 14 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 15 | Useful for catching bugs 16 | Returns: 17 | [function(array[0]), function(array[1]), ...] 18 | """ 19 | # We run the first few iterations serially to catch bugs 20 | if front_num > 0: 21 | front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] 22 | else: 23 | front = [] 24 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. 25 | if n_jobs == 1: 26 | return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] 27 | # Assemble the workers 28 | with ProcessPoolExecutor(max_workers=n_jobs) as pool: 29 | # Pass the elements of array into function 30 | if use_kwargs: 31 | futures = [pool.submit(function, **a) for a in array[front_num:]] 32 | else: 33 | futures = [pool.submit(function, a) for a in array[front_num:]] 34 | kwargs = { 35 | 'total': len(futures), 36 | 'unit': 'it', 37 | 'unit_scale': True, 38 | 'leave': True 39 | } 40 | # Print out the progress as tasks complete 41 | for f in tqdm(as_completed(futures), **kwargs): 42 | pass 43 | out = [] 44 | # Get the results from the futures. 45 | for i, future in tqdm(enumerate(futures)): 46 | try: 47 | out.append(future.result()) 48 | except Exception as e: 49 | out.append(e) 50 | return front + out 51 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/requirements.txt: -------------------------------------------------------------------------------- 1 | pprint 2 | apted 3 | distance 4 | lxml 5 | tqdm 6 | -------------------------------------------------------------------------------- /srcs/PubTabNet/src/sample_gt.json: -------------------------------------------------------------------------------- 1 | {"PMC5755158_010_01.png": {"html": "
WeaningWeek 15Off-test
Weaning\u2013\u2013\u2013
Week 15\u20130.17 \u00b1 0.080.16 \u00b1 0.03
Off-test\u20130.80 \u00b1 0.240.19 \u00b1 0.09
", "tag_len": 44, "cell_len_max": 11, "width": 238, "height": 59, "type": "simple"}, "PMC4445578_009_01.png": {"html": "
Reactive astroglioisChanges in astrocytes morphologyChanges in molecules expression
Upregulated moleculesUpregulated or downregulated molecules
Mild to moderate astrogliosis\u2022 Hypertrophy of cell body\u2022 Structural elements: GFAP, nestin, vimentin\u2022 Inflammatory cell regulators: cytokines, growth factors, glutathione
\u2022 Astrocytes processes are are numerous and thicker\u2022 Transcriptional regulators: STAT3, NF\u03baB, Rheb-m TOR, cAMP, Olig2, SOX9 [61\u201365].\u2022 Transporters and pumps: AQP4 and Na+/K+ transporters [61, 66\u201369]
\u2022 Glutamate transporter [70\u201373]
\u2022 The non-overlapping domains of individual astrocytes are preserved\u2022 Vascular regulators: PGE, NO [74, 75]
\u2022 Energy provision: lactate [76]
\u2022 Molecules implicated in synapse formation and
Severe astrogliosis and glial scar\u2022 Intense hypertrophy of cell body\u2022 Remodeling: thrombospondin and Complement C1q [77, 78]
\u2022 Significant extension of processes\u2022 Molecules implicated in oxidative stress and providing protection from oxidative stress: NO, NOS, SOD, Glutathione [67, 68, 79]
\u2022 Proliferation
\u2022 Overlapping of individual domains
\u2022 Substantial reorganization of tissue architecture [60]
", "tag_len": 116, "cell_len_max": 129, "width": 486, "height": 248, "type": "complex"}, "PMC2871264_002_00.png": {"html": "
Name of algorithmNotable features
MACS [23]Uses both a control library and local statistics to minimize bias
SICER [14]Designed for detecting diffusely enriched regions; for example, histone modification
PeakSeq [24]Corrects for reference genome mappability and local statistics
SISSRs [25]High resolution, precise identification of binding-site location
F-seq [26]Uses kernel density estimation
", "tag_len": 40, "cell_len_max": 84, "width": 238, "height": 124, "type": "simple"}, "PMC3872294_001_00.png": {"html": "
HC (N = 20)FASD (N = 15)
Age (years)16.3 (2.1)15.3 (2.1)
IQ108 (15)*80 (15)*
Male/female (%male)12/8 (60%)10/5 (67%)
FASD sub diagnosis\u20138 FAS, 7 ARND
", "tag_len": 44, "cell_len_max": 19, "width": 251, "height": 88, "type": "simple"}, "PMC2915972_003_00.png": {"html": "
No of patients
Gender:
Men24
Women26
Age (years):
30-392
40-498
50-5915
60-6916
70-796
\u2265 803
Tumor site:
Bladder4
Breast10
Colorectal4
Esophageal9
Gynecological7
Lung6
Prostate10
Length of interval between baseline and follow-up interview
(median)
< 50 days22
\u2265 50 days28
", "tag_len": 142, "cell_len_max": 59, "width": 238, "height": 287, "type": "complex"}, "PMC4196076_004_00.png": {"html": "
miRNAChange relative to controlsDirection of regulationChromosomemiRNAChange relative to controlsDirection of regulationChromosome
hsa-miR-11812.13Up19hsa-miR-8742.97Up5
hsa-miR-125a-5p5.04Up19hsa-miR-8902.83UpX
hsa-miR-21-3p2.82Up17hsa-miR-9392.59Up8
hsa-miR-29b-1-5p3.12Up7hsa-miR-1290\u22127.56Down1
hsa-miR-3663-3p2.19Up10hsa-miR-1915-3p\u22122.63Down10
hsa-miR-3127-5p2.01Up2hsa-miR-2861\u22123.31Down9
hsa-miR-3663-3p2.03Up10hsa-miR-3665\u22122.37Down13
hsa-miR-371a-5p3.14Up19hsa-miR-4257\u22123.62Down1
hsa-miR-43272.95Up21hsa-miR-452-5p\u22122.54DownX
hsa-miR-584-5p2.31Up5hsa-miR-513a-5p\u22123.15DownX
hsa-miR-6025.74Up9hsa-miR-572\u22125.80Down4
hsa-miR-629-3p2.71Up15hsa-miR-629-3p\u22123.03Down15
hsa-miR-642b-3p2.10Up19hsa-miR-765\u22127.18Down1
hsa-miR-6513.91UpXhsa-miR-875-5p\u22123.91Down8
hsa-miR-7622.84Up16hsa-miR-940\u22122.31Down16
", "tag_len": 292, "cell_len_max": 29, "width": 486, "height": 236, "type": "simple"}, "PMC3160368_005_00.png": {"html": "
Methods (n-mers used)Average Sensitivity of 5-fold cross validation (%)Average Specificity of 5-fold cross validation (%)
FDAFSA(hexamers)84*86*
PromMachine(tetramers)86+81+
", "tag_len": 28, "cell_len_max": 52, "width": 238, "height": 71, "type": "simple"}, "PMC3707453_006_00.png": {"html": "
TFC Layer Thickness [\u03bcm]Star Magnitude 1Star Magnitude 6Saturation Charge [e-]Capacitance Linearity [%]
Signal @ 0.1s integr. [e-]Noise @ 0.1s integr. [e-]S/N at 10 bit A/D [dB]Signal @ 0.1s integr. [e-]Noise @ 0.1s integr. [e-]S/N at 10 bit A/D [dB]
0.51212004984718823581050000099.2
1.01439604265016101991327223298.6
1.51552204185017131471919710998.1
1.81599504185017591301917201897.8
2.01624004195017841221915957597.6
2.21645504205018071151914925497.5
", "tag_len": 160, "cell_len_max": 30, "width": 446, "height": 184, "type": "complex"}, "PMC4311460_007_00.png": {"html": "
NumberPatients
CategoryTypeCHP%(N = 4,560)%
IInflammation6,98711.33,53777.6
IIInfection3,6295.92,45153.8
IIIInjury5,5569.03,40174.6
IVSpecific conditions32,01651.9n.c.
VNeoplasms3,5925.82,461#54
Maligne1,4441,219 (27%)
Other-benign2,1481,758 (39%)
VICongenital4900.8n.c.
VIIOtherwise9,38315.2n.c.
TotalALL-types61,653100
", "tag_len": 220, "cell_len_max": 19, "width": 486, "height": 170, "type": "complex"}, "PMC5451934_004_00.png": {"html": "
ConditionPre Well-BeingPost Well-BeingPre-Post Change
TP (handler & dog interaction)46.33 \u00b1 7.41 148.69 \u00b1 7.22+2.36
DO (dog only interaction)49.78 \u00b1 7.9151.56 \u00b1 6.99+1.78 **
HO (handler only interaction)47.37 \u00b1 7.5746.43 \u00b1 8.03\u22120.94 **
", "tag_len": 44, "cell_len_max": 30, "width": 389, "height": 56, "type": "simple"}, "PMC5849724_006_00.png": {"html": "
AnalytesGC\u2013HRMSGC\u2013MS/MSGC\u2013MS
LOQ, (ng/CFPa)Estimated LOQ, (ng/cig)LOQ, (ng/CFPa)Estimated LOQ, (ng/cig)LOQ, (ng/CFPa)Estimated LOQ, (ng/cig)
Naphthalene0.510.0261178.7158.94108.175.41
Benzo[c]phenanthrene0.040.002NDND66.803.34
Benzo[a]anthracene0.030.00238.571.9338.111.91
Chrysene0.040.00250.132.5149.612.48
Cyclopenta[c,d]pyrene0.020.00148.842.4460.043.00
5-Methylchrysene0.040.002NDND2.480.12
Benzo[b]fluoranthene0.040.00211.440.575.080.25
Benzo[k]fluoranthene0.050.00312.410.625.070.25
Benzo[j]aceanthrylene0.090.005NDNDNDND
Benzo[a]pyrene0.040.0025.010.253.030.15
Indeno[1,2,3-c,d]pyrene0.020.0015.460.271.540.08
Dibenzo[a,h]anthracene0.070.0040.830.041.480.07
Dibenzo[a,l]pyrene0.050.003NDNDNDND
Dibenzo[a,e]pyrene0.040.0020.800.040.280.01
Dibenzo[a,i]pyrene0.060.0031.330.07NDND
Dibenzo[a,h]pyrene0.070.0042.990.15NDND
", "tag_len": 292, "cell_len_max": 27, "width": 486, "height": 253, "type": "complex"}, "PMC6022086_007_00.png": {"html": "
MethodData TypeMean (m)RMSE (m)P90% (m)PGSD (%)
Improved FCMGaofen-35.775.8910.0794.37
Sentinel-16.305.8314.0380.00
Original FCMGaofen-36.977.6613.8790.70
Sentinel-18.534.8113.1490.00
", "tag_len": 74, "cell_len_max": 12, "width": 409, "height": 77, "type": "complex"}, "PMC4297392_007_00.png": {"html": "
Treatment phaseAdverse eventNo. of patients
T1Swelling1
Itching1
Fever4
Throat infection1
Chest Congestion2
Total9
T2Diarrhea1
Body Pain1
Total2
T3Diarrhea1
Total1
T4Nil-
", "tag_len": 98, "cell_len_max": 17, "width": 238, "height": 185, "type": "complex"}, "PMC2094709_004_00.png": {"html": "
WeekDuration (min)Intensity (% HRR)Intensity (RPE)
12050 \u2013 609 \u2013 11
22050 \u2013 609 \u2013 11
3 \u2013 52560 \u2013 7011
6 \u2013 83060 \u2013 7011
9 \u2013 113070 \u2013 8011 \u2013 13
12 \u2013 143570 \u2013 8011 \u2013 13
15 & 164075 \u2013 8513 \u2013 15
", "tag_len": 84, "cell_len_max": 19, "width": 503, "height": 107, "type": "simple"}, "PMC3568059_003_00.png": {"html": "
Participants during the period;
0 to 3 months3 to 6 months6 to 12 months
Characteristicsn=72n=71n=65
Age, years, median (range)73 (50\u201394)73 (47\u201392)73 (47\u201390)
Patients, n (%)
Female33 (46)27 (38)26 (40)
Male39 (54)44 (62)39 (60)
Stroke classification (TOAST), n (%)
Large vessel disease17 (24)18 (25)17 (26)
Small vessel disease21 (29)21 (30)17 (26)
Cardioembolic stroke15 (21)11 (15)11 (17)
Cryptogenic stroke13 (18)14 (20)12 (19)
Intracerebral haemorrhage6 (8)7 (10)8 (12)
Side of lesion, n (%)
Right side lesion35 (49)32 (45)28 (43)
Left side lesion37 (51)39 (55)37 (57)
Hypertension47 (65)44 (62)41 (63)
Diabetes mellitus17 (24)18 (25)17 (26)
Results from clinical scales 1\u20137 days after stroke onset
BBS median (range) (n)35 (0\u201356) (n=71)41 (0\u201356) (n=70)41 (0\u201356) (n=64)
M-MAS UAS-95 median (range)45 (12\u201355) (n=65)47 (12\u201355) (n=65)50 (16\u201355) (n=59)
", "tag_len": 208, "cell_len_max": 56, "width": 486, "height": 296, "type": "complex"}, "PMC4357206_002_00.png": {"html": "
N = 121
Demographics
Age (yr) - median (IQR)62 (56-73)
Female sex (%)46 (38)
White race (%)112 (93)
Comorbidities (%)
Hypertension64 (53)
Chronic lung disease37 (31)
Active malignancy34 (28)
Diabetes mellitus29 (24)
Chronic kidney disease7 (6)
Congestive heart failure4 (3)
Chronic liver disease2 (2)
Severity of illness
APACHE II score - median (IQR)*14 (10-16)
Charlson Comorbidity Index - median (IQR)\u20202 (1-4)
ICU type
Surgical102 (84)
SICU66 (54)
TICU36 (30)
Nonsurgical19 (16)
CCU11 (9)
MICU8 (7)
Status of procedure (for surgical patients) (%)
Elective41 (34)
Urgent57 (47)
Days in hospital prior to enrollment \u2013 median (IQR)1 (1-3)
", "tag_len": 166, "cell_len_max": 51, "width": 238, "height": 381, "type": "simple"}, "PMC4219599_004_00.png": {"html": "
ORP (n = 9)RALP (n = 24)Total (n = 33)
Anthropometric data
Age (yr)60 (7)63 (6)62 (6)
Height (m)1.76 (0.07)1.75 (0.05)1.75 (0.06)
Weight (kg)92 (12)83 (10)86 (11)
BMI (kg.m-2)29.6 (4.5)27.3 (3.0)27.9 (3.6)
Preoperative factors
PSA (ng/mL)5.8 (4.2)5.0 (2.1)5.2 (2.8)
Preoperative Gleason score
3 + 31 (11%)5 (21%)6 (18%)
3 + 45 (56%)16 (67%)21 (64%)
4 + 33 (33%)2 (9%)5 (15%)
4 + 40 (0%)1 (4%)1 (3%)
Clinical tumour stage
cT14 (44%)13 (54%)17 (52%)
cT25 (56%)11 (46%)16 (48%)
cT30 (0%)0 (0%)0 (0%)
cT40 (0%)0 (0%)0 (0%)
Prostate volume (cc)40.2 (13.4)41.2 (12.5)40.9 (12.6)
Intraoperative factors
Nerve sparing
None3 (33%)3 (13%)6 (18%)
One bundle2 (22%)2 (9%)4 (12%)
Two bundles4 (44%)19 (79%)23 (70%)
Pelvic lymph node dissection7 (78%)2 (9%)a9 (27%)
Bladder neck preservation0 (0%)23 (96%)a23 (70%)
Postoperative factors
Postoperative Gleason score
3 + 31 (11%)3 (13%)4 (12%)
3 + 46 (67%)16 (67%)22 (67%)
4 + 32 (22%)5 (21%)7 (21%)
4 + 40 (0%)0 (0%)0 (0%)
Pathological tumour stage
pT26 (67%)18 (75%)24 (73%)
pT33 (33%)6 (25%)9 (27%)
pT40 (0%)0 (0%)0 (0%)
Positive lymph nodes1/7 (14%)0/2 (0%)1/9 (11%)
Positive margins2 (22%)2 (9%)4 (12%)
Duration of postoperative hospital stay (d)2.9 (0.3)2.0 (0.2)a2.3 (0.5)
Duration of postoperative catheterization (d)10.2 (3.0)8.4 (1.6)8.9 (2.2)
Anastomic structure0 (0%)1 (4%)1 (3%)
", "tag_len": 414, "cell_len_max": 45, "width": 486, "height": 577, "type": "simple"}, "PMC3765162_003_01.png": {"html": "
Men (n = 359)Women (n = 412)
Metabolic syndromeMetabolic syndrome
Baseline characteristicsYes (n = 163)No (n = 196)P valueYes (n = 96)No (n = 316)P value
Age (years)*61.86 (\u00b10.83)60.32 (\u00b10.77)0.1764.96 (\u00b10.88)58.52 (\u00b10.55)<0.001
Sitting Systolic BP (mmHg)*141.34 (\u00b11.27)132.26 (\u00b11.15)<0.001151.82 (\u00b11.16)137.49 (\u00b10.96)<0.001
Sitting Diastolic BP (mmHg)*85.69 (\u00b10.77)80.79 (\u00b10.73)<0.00189.27 (\u00b10.92)82.67 (\u00b10.51)<0.001
Antihypertensive Therapy (%)50.9%28.4%<0.00160.4%29.4%<0.001
Total Cholesterol (mmol/L)*5.61 (\u00b10.08)5.70 (\u00b10.08)0.566.04 (\u00b10.1)5.99 (\u00b10.06)0.67
LDL cholesterol (mmol/L)*3.44 (\u00b10.06)3.49 (\u00b10.06)0.523.58 (\u00b1 0.06)3.54 (\u00b1 0.04)0.66
HDL cholesterol (mmol/L)*1.03 (\u00b10.63)1.27 (\u00b10.02)<0.0011.20 (\u00b1 0.02)1.48 (\u00b10.016)<0.001
Triglycerides (mmol/L)\u20202.10 (1.63; 2.64)1.32 (0.98; 1.57)<0.0012.15 (1.78; 2.83)1.24 (0.97; 1.56)<0.001
Diabetes mellitus (%)30.7%6.3%<0.00133.3%2.3%<0.001
BMI (kg/m2)*29.88 (\u00b10.35)26.06 (\u00b10.2)<0.00132.39 (\u00b10.47)26.95 (\u00b10.25)<0.001
ApoA1 (g/L)*1.29 (\u00b10.013)1.40 (\u00b10.017)<0.0011.44 (\u00b10.02)1.55 (\u00b10.001)<0.001
ApoB (g/L)*1.21 (\u00b10.02)1.19 (\u00b10.02)0.481.23 (\u00b10.02)1.18 (\u00b10.014)0.044
Homa index\u20202.25(1.15; 4.18)0.94(0.51; 1.8)<0.0012.51 (1.67; 3.86)1.14 (0.72; 1.7)<0.001
IMTccMean (mm)*0.79 (\u00b10.15)0.76 (\u00b10.12)0.0840.77 (\u00b10.16)0.69 (\u00b10.13)<0.001
Sum of total plaque area (mm2)\u202053 (25; 103)42 (10;72)0.00216 (1; 44)8 (1;32)0.01
Sum of plaque area carotids (mm2)\u202022 (1; 39)12 (1; 27.5)0.0118.75 (1;25.75)1 (1; 19)0.013
Sum of plaque area femoral (mm2)\u202033(10; 62)23(1; 49)0.0111(1; 17.75)1(1; 6)0.012
", "tag_len": 316, "cell_len_max": 42, "width": 486, "height": 282, "type": "complex"}, "PMC5303243_003_00.png": {"html": "
CharacteristicsTotal (N = 613)MSSA(N = 508)MRSA (N = 105)OR (95%CI)P-value
Age (years)(median, quartiles)72 (66;79)75 (67;81)72 (65;78)N/A0.0048
Gender:Female322 (100.0)214 (82.3)57 (17.7)1.4 (0.93\u20132.16)0.5909
Male291 (100.0)255 (83.5)48 (16.5)
Step aging n (%)0,0849
Young Old311 (100.0)267 (85.9)44 (14.1)1.5 (1.00\u20132.35)
Old Old272 (100.0)219 (80.5)53 (19.5)0.7 (0.49\u20131.13)
Longevity30 (100.0)22 (73.3)8 (26.7)0.6 (0.24\u20131.27)
Disease n (%)<0.0001
PNU47 (100.0)28 (59.6)19 (40.4)0.3 (0.14\u20130.49)
BSI37 (100.0)27 (73.0)10 (27.0)0.5 (0.25\u20131.14)
SSTI416 (100.0)350 (84.1)66 (15.9)1.3 (0.85\u20132.03)
EI62 (100.0)56 (90.3)6 (9.7)1.7 (0.72\u20134.06)
Others51 (100.0)47 (92.2)4 (7.8)2.6 (0.91\u20137.31)
Place of the treatment infections n (%)0.0033
INPATIENTS430 (100.0)352 (81.4)78 (18.1)0.8 (0.49\u20131.26)
LTCF16 (100.0)9 (56.3)7 (43.8)0.3 (0.09\u20130.69)
OUTPATIENTS167 (100.0)147 (88.0)20 (12.0)1.7 (1.03\u20132.92)
Infections treated in hospitals (INPATIENTS N = 430, n (%))
ICU19 (100.0)12 (63.2)7 (36.8)2.8 (1.06\u20137.34)0.014
non-ICU411 (100.0)340 (82.7)71 (17.3)
", "tag_len": 290, "cell_len_max": 63, "width": 486, "height": 316, "type": "complex"}, "PMC4969833_016_01.png": {"html": "
HorizontalNormalVerticalTotal Object
Horizontal383546 (83%)
Normal154762 (87%)
Vertical22111401163 (98%)
", "tag_len": 52, "cell_len_max": 14, "width": 264, "height": 58, "type": "simple"}} -------------------------------------------------------------------------------- /srcs/PubTabNet/src/sample_pred.json: -------------------------------------------------------------------------------- 1 | {"PMC2094709_004_00.png": "\n \n \n \n \n \n \n \n
WeekDuration (min)Intensity (% HRR)Intensity (RPE)
12050 \u2013 609 \u2013 11
22050 \u2013 609 \u2013 11
3 \u2013 52560 \u2013 7011
6 \u2013 83060 \u2013 7011
9 \u2013 113070 \u2013 8011 \u2013 13
12 \u2013 143570 \u2013 8011 \u2013 13
15 & 164075 \u2013 8513 \u2013 15
\n \n ", "PMC2871264_002_00.png": "\n \n \n \n \n \n \n \n
Name of algorithmNotable features
MACS [23]Uses both a control library and local statistics to minimize bias
SICER [14]Designed for detecting diffusely enriched regions; for example, histone modification
PeakSeq [24]Corrects for reference genome mappability and local statistics
SISSRs [25]High resolution, precise identification of binding-site location
F-seq [26]Uses kernel density estimation
\n \n ", "PMC2915972_003_00.png": "\n \n \n \n \n \n \n \n
No of patients
Gender:
Men24
Women26
Age (years):
30-392
40-498
50-5915
60-6916
70-796
\u2265 803
Tumor site:
Bladder4
Breast10
Colorectal4
Exophageal9
Gynecological7
Lung6
Prostate10
Length of interval between baseline and follow-up interview (median)
< 50 days22
\u2265 50 days28
\n \n ", "PMC3160368_005_00.png": "\n \n \n \n \n \n \n \n
Methods (n-mers used)Average Sensitivity of 5-fold cross validation (%)Average Specificity of 5-fold cross validation (%)
FDAFSA (hexamers)84*86*
PromMachine (tetramers)86+81+
\n \n ", "PMC3568059_003_00.png": "\n \n \n \n \n \n \n \n
Participants during the period;
0 to 3 months3 to 6 months6 to 12 months
Characteristicsn=72n=71n=65
Age, years, median (range)73 (50\u201394)73 (47\u201392)73 (47\u201390)
Patients, n (%)
Female33 (46)27 (38)26 (40)
Male39 (54)44 (62)39 (60)
Stroke classification (TOAST), n (%)
Large vessel disease17 (24)18 (25)17 (26)
Small vessel disease21 (29)21 (30)17 (26)
Cardioembolic stroke15 (21)11 (15)11 (17)
Cryptogenic stroke13 (18)14 (20)12 (19)
Intracerebral haemorrhage6 (8)7 (10)8 (12)
Side of feision, n (%)
Right side lesion35 (49)32 (45)28 (43)
Left side lesion37 (51)39 (53)37 (57)
Hypertension47 (65)44 (62)41 (63)
Diabetes mellitus17 (24)18 (25)17 (26)
Results from clinical scales 1\u20137 days after stroke onset
BBS median (range) (n)35 (0\u201356) (n=71)41 (0\u201356) (n=70)41 (0\u201356) (n=46)
M-MAS UAS-IS median (range)45 (12\u201355) (n=65)47 (12\u201355) (n=65)50 (16\u201355) (n=56)
\n \n ", "PMC3707453_006_00.png": "\n \n \n \n \n \n \n \n
Star Magnitude 1Star Magnitude 6Saturation Charge [%]Capacitanc e Linearity [%]
Noise (g)SN at 10Signal (g)Noise (g)SN at 10 No AD [d]
121200498471882358105000099.2
1439604265016101991327223298.6
1552204185017131471919710998.1
1599504185017591301917201897.8
1624004195017841221915957597.6
164550420501801151914925497.5
\n \n ", "PMC3765162_003_01.png": "\n \n \n \n \n \n \n \n
Men (n = 359)Women (n = 412)
Metabolic syndromeMetabolic syndrome
Baseline characteristicsYes (n = 163)No (n = 196)P-valueYes (n = 96)No (n = 316)P value
Age (years)*61.86 (\u00b10.83)60.32 (\u00b10.77)0.1764.96 (\u00b10.88)58.52 (\u00b10.55)<0.001
Sitting Systolic BP (mmHg)*141.34 (\u00b11.27)132.26 (\u00b11.15)<0.001151.82 (\u00b11.16)137.4( (\u00b10.96)<0.001
Stitting Diastolic BP (mmHg)*85.69 (\u00b10.77)80.79 (\u00b10.73)<0.00189.27 (\u00b10.92)82.67 (\u00b10.51)<0.001
Antitypertensive Therapy (%)50.9%28.4%<0.00160.4%29.4%<0.001
Total Cholesterol (mmol/L)*5.61 (\u00b10.08)5.70 (\u00b10.08)0.566.04 (\u00b10.1)5.99 (\u00b10.06)0.67
LDL cholesterol (mmol/L)*3.44 (\u00b10.06)3.49 (\u00b10.06)0.523.58 (\u00b1 0.06)3.54 (\u00b1 0.04)0.66
HDL cholesterol (mmol/L)*1.03 (\u00b10.63)1.27 (\u00b10.02)<0.0011.20 (\u00b1 0.02)1.48 (\u00b10.016)<0.001
Triglycerides (mmol/L)*2.10 (1.63; 2.64)1.32 (0.98; 1.57)<0.0012.15 (1.78; 2.83)1.24 (0.97; 1.56)<0.001
Diabetes mellitus (%)30.7%6.3%<0.00133.3%2.3%<0.001
BMI (kg/m2)*29.88 (\u00b10.35)26.06 (\u00b10.2)<0.00122.39 (\u00b10.47)26.95 (\u00b10.25)<0.001
ApoA1 Ig/L*1.29 (\u00b10.013)1.40 (\u00b10.017)<0.0011.44 (\u00b10.02)1.55 (\u00b10.001)<0.001
ApoB (g/L)*1.21 (\u00b10.02)1.19 (\u00b10.02)0.481.23 (\u00b10.02)1.18 (\u00b10.014)0.044
Homa index*2.25(1.15; 4.18)0.94(0.51; 1.8)<0.0012.51 (1.67, 3.86)1.14 (0.72; 1.7)<0.001
MITCoffean (mm)*0.79 (\u00b10.15)0.76 (\u00b10.12)0.0840.77 (\u00b10.16)0.69 (\u00b10.13)<0.001
Sum of total plaque area (mm2)*53 (25; 100)42 (10/27)0.00216 (1; 44)8 (1;32)0.01
Sum of plaque area carotids (mm2)*22 (1; 39)12 (1; 27.5)0.0118.75 (1.25.75)1 (1; 19)0.013
Sum of plaque area femoral (mm3)*33(10 6,0)23(1, 49)0.01110 (-17.75)1(1; 6)0.012
\n \n ", "PMC3872294_001_00.png": "\n \n \n \n \n \n \n \n
HC (N = 20)FASD (N = 15)
Age (years)16.3 (2.1)15.3 (2.1)
IQ108 (15)*80 (15)*
Male/female (%male)12/8 (60%)10/5 (67%)
FASD sub diagnosis\u20138 FAS, 7 ARND
\n \n ", "PMC4196076_004_00.png": "\n \n \n \n \n \n \n \n
miRNAChange relative to controlsDirection of regulationChromosomemiRNAChange relative to controlsDirection of regulationChromosome
hsa-miR-11812.13Up19hsa-miR-8742.97Up5
hsa-miR-125a-5p5.04Up19hsa-miR-8902.83UpX
hsa-miR-21-3p2.82Up17hsa-miR-9392.59Up8
hsa-miR-29b-1-pp3.12Up7hsa-miR-1290\u22127.56Down1
hsa-miR-3665-3p2.19Up10hsa-miR-191-3-p\u22122.63Down10
hsa-miR-1327-5p2.01Up2hsa-miR-2861\u22123.31Down9
hsa-miR-3665-3p2.03Up10hsa-miR-3665\u22122.37Down13
hsa-miR-371a-5p3.14Up19hsa-miR-4357\u22123.62Down1
hsa-miR-43272.95Up21hsa-miR-452-5p\u22122.54DownX
hsa-miR-584-5p2.31Up5hsa-miR-513a-5p\u22123.15DownX
hsa-miR-6025.74Up9hsa-miR-572\u22125.80Down4
hsa-miR-629-3p2.71Up15hsa-miR-629-3p\u22123.03Down15
hsa-miR-642b-3p2.10Up19hsa-miR-165\u22127.18Down1
hsa-miR-6513.91UpXhsa-miR-875-5p\u22123.91Down8
hsa-miR-7622.84Up16hsa-miR-940\u22122.31Down16
\n \n ", "PMC4219599_004_00.png": "\n \n \n \n \n \n \n \n
SBE (n = 24)MEA 7n = 24Evele N = 24
Ethnopositive data
Age (yrs)0.1 (0)0.1 (0)43.9 (8)
Male (%)0.3 (0.0)0.1 (0.0)8.1 (10%)
Married0.1 (0.9)0.9 (0%)8.9 (11)
Married29.6 (4.3)27.0 (0.0)27.9 (161)
Preventions Fathers
1 + 11.0 (1%)5 (21%)5.2 (2.8)
1 + 15 (5.9%)1 (1.9%)8 (18%)
4 + 15 (5.9%)11 (5%)21 (69%)
4 + 13 (33%)1 (4%)3 (19%)
41 + 10 (0%)1 (4%)1 (1%)
Others increase stage
CT14 (6.4%)11 (54%)11 (52%)
-715 (5%)0 (0%)0 (0%)
CT25 (5%)0 (0%)0 (0%)
Private wound with schools0 (0%)0 (0%)0 (0%)
Non-sensitive factors40.2 (11.4)41.2 (13.3)45.0 (12.0)
Non-sensitive factors
None1 (11%)1 (13%)6 (18%)
None2 (2.9%)2 (9%)4 (1.7%)
None2 (2.9%)0 (0%)4 (1.9%)
Total survivor0 (0%)0 (0%)0 (0%)
Primary experience8 (9%)23 (80%)*0.0 (0%)
Postoperative followsors
1 + 01 (11%)1 (13%)4 (12%)
1 + 06 (6.7%)15 (57%)21 (61%)
4 + 12 (2%)5 (37%)2 (2%)
4 + 18 (29%)8 (29%)0 (0%)
Pathological survour stage
PT38 (37%)16 (39%)24 (17%)
PT38 (37%)6 (3%)5 (17%)
PT30 (0%)6 (3%)4 (3%)
Positive17 (14%)6 (3%)1 (0.1%)
Positive nempl nodes17 (14%)0.9 (0%)1.0 (1%)
Positive reference in complete hospital stay (n)2.0 (0.4)2.0 (0.2)2.2 (0.3)
Position of pressoreation compression (%)10.5 (10)4.4 (14)8.9 (2.2)
Duration of pressoreation collectivation (%)10.5 (10)8.4 (14)8.9 (9.2)
\n \n ", "PMC4297392_007_00.png": "\n \n \n \n \n \n \n \n
Treatment phaseAdverse eventNo. of patients
T1Swelling1
Itching1
Fever4
Throat infection1
Chest Congestion2
Total9
T2Diarrhea1
Body Pain1
Total2
T3Diarrhea1
Total1
T4Nil-
\n \n ", "PMC4311460_007_00.png": "\n \n \n \n \n \n \n \n
Number PatientsPatients
CategoryType CHP%(N = 4,560)%
IInflammation 6,98711.33,53777.6
IIInfection 3,6295.92,45153.8
IIIInjury 5,5569.03,40174.6
IVSpecific conditions 32,01651.9n.c.
VNeoplasms 3,5925.82,461#54
Maligne 1,219 (27%)
O,ther-benign2,148 1,758 (39%)
VICongenital 4900.8n.c.
VIIOtherwise 9,38315.2n.c.
TotalALL-types 100
\n \n ", "PMC4357206_002_00.png": "\n \n \n \n \n \n \n \n
N = 121
Demographics
Age (yr) - median (IQR)62 (56-73)
Female sex (%)46 (38)
White race (%)112 (93)
Comorbidities (%)
Hypertension64 (53)
Chronic lung disease37 (31)
Active malignancy34 (28)
Diabetes mellitus29 (24)
Chronic kidney disease7 (6)
Congestive heart failure4 (3)
Chronic liver disease2 (2)
Severity of illness
APACHE II score - median (IQR)*14 (10-16)
Chanlson Comorbidity Index - median (IQR)\u20202 (1-4)
ICU type
Surgical102 (84)
SICU66 (54)
TICU36 (30)
Nonsurgical19 (16)
CCU11 (9)
MICU8 (7)
Status of procedure (for surgical patients) (%)
Elective41 (34)
Urgent57 (47)
Dops in hospital prior to enrollment \u2013 median (IQR)1 (1-3)
\n \n ", "PMC4445578_009_01.png": "\n \n \n \n \n \n \n \n
Reactive astrogliossChanges in astrocytes morphologyChanges in molecules expression
Upregulated moleculesUpregulated or downregulated molecules
Mild to moderate astroglosis\u2022 Hypertrophy of cell body\u2022 Structural elements GFAP, nestin, virenetin\u2022 Inflammatory cell regulators, cytokines, growth factors, glutathione
\u2022 Astrocytes processes are are numeroca and thicker\u2022 Transcriptional regulators STAT3, NFASI (Pechem 1076, cAnP6 Chiga, SOX9 [61-65].Trassopteres and purprs; AQP4 and No YK+ transporters [26,64-69]
\u2022 Glutamate transporter [76-73]
\u2022 The non-overlapping domains of individual astrocytes are preserved\u2022 Vascular regulators: PGE, NO [74,75]
\u2022 Energy provision: lactate [76]
\u2022 Molecules implicated in synapse formation and
\u2022 Remodeling thrombospondin and Complement C1q [77,78]
- Significant extension of processes\u2022 Molecules implicated in ovidative stress, and providing protection from oxidative stress: NO, NOS, SOX, Glutathione [67,68,79]
\u2022 Proliferation
\u2022 Overlapping of individual domains
\u2022 Substantial reorganization of tissue activitecute [50]
\n \n ", "PMC4969833_016_01.png": "\n \n \n \n \n \n \n \n
HorizontalNormalVerticalTotal Object
Horizontal383546 (83%)
Normal154762 (87%)
Vertical22111401163 (98%)
\n \n ", "PMC5303243_003_00.png": "\n \n \n \n \n \n \n \n
CharacteristicsTotal (N = 613)MSSA (N = 508)MRSA (N = 105)OR (95%CI)P-value
Age (years) (median, quartiles)72 (66,79)75 (6731)72 (67,78)N/A0.0048
Gender322 (100.0)214 (82.3)57 (17.7)1.4 (0.93\u20132.16)0.5909
Male291 (100.0)255 (83.5)48 (16.5)
Step aging n (%)0,0849
Young Old311 (100.0)267 (85.9)44 (14.1)1.5 (1.00\u20132.35)
O6: O&272 (100.0)219 (80.5)53 (19.5)0.7 (0.49\u20131.13)
Longevity30 (100.0)22 (73.3)8 (26.7)0.6 (0.24\u20131.27)
Disease n (%)<0.0001
PNU47 (100.0)28 (59.6)19 (40.4)0.3 (0.14\u20130.49)
BSI37 (100.0)27 (73.0)10 (27.0)0.5 (0.25\u20131.14)
SSTI416 (100.0)350 (84.1)66 (15.9)1.3 (0.85\u20132.03)
EI62 (100.0)56 (90.3)6 (9.7)1.7 (0.72\u20134.06)
Others51 (100.0)47 (92.2)4 (7.8)2.6 (0.91\u20137.31)
Place of the treatment infections n (%)0.0033
INPATBENTS430 (100.0)352 (81.4)78 (18.1)0.8 (0.49\u20131.26)
LTCF16 (100.0)9 (56.3)7 (43.8)0.3 (0.09\u20130.69)
OUTPATIENTS167 (100.0)147 (88.0)20 (12.0)1.7 (1.03\u20132.92)
Infections treated in hospitals (NPATIENTS N = 430, n (%))
ICU19 (100.0)12 (63.2)7 (36.8)2.8 (1.06\u20137.34)0.014
non-ICU411 (100.0)340 (82.7)71 (17.3)
\n \n ", "PMC5451934_004_00.png": "\n \n \n \n \n \n \n \n
ConditionPre Well-BeingPost Well-BeingPre-Post-Change
TP (handler & dog interaction)46.33 \u00b1 7.41 148.69 \u00b1 7.22+2.36
DO (dog only interaction)49.78 \u00b1 7.9151.56 \u00b1 6.99+1.78 **
HO (handler only interaction)47.37 \u00b1 7.5746.43 \u00b1 8.03\u22120.94 **
\n \n ", "PMC5755158_010_01.png": "\n \n \n \n \n \n \n \n
WeaningWeek 15Off-test
Weaning\u2013\u2013\u2013
Week 15\u20130.17 \u00b1 0.080.16 \u00b1 0.03
Off-test\u20130.80 \u00b1 0.240.19 \u00b1 0.09
\n \n ", "PMC5849724_006_00.png": "\n \n \n \n \n \n \n \n
AnalytesGC-HRMSGC-MS/MSGC-MS
LOQ (ng/CIPP)Estimated LOQ, (ng/cig)LOQ, (ng/CPP)Estimated LOQ, (ng/cig)LOQ (ng/CIPP)Estimated LOQ, (ng/cig)
Naphthalene0.510.0261178.7158.94108.175.41
Benzolylphenamthene0.040.002NDND66.803.34
Benzolylanthracene0.030.00238.571.9338.111.91
Chrysene0.040.00250.132.5149.612.48
Cyclopentid,culysyner0.020.00148.842.4460.043.00
S-Methylchrysene0.040.002NDND2.480.12
Benzo[p]Iluonarthene0.040.00211.440.575.080.25
Benzol[Illicuranthene0.050.00312.410.625.070.25
Benzo[[aceanthrylene]0.090.005NDNDNDND
Benzoliglyreene0.040.0025.010.253.030.15
Indeno(1,2,1-cultypnee0.020.0015.460.271.540.08
Dibenodju/lipinthe cere0.070.0040.830.041.480.07
Dibenzolip/lyprene0.050.003NDNDNDND
Dibenzolyadyprene0.040.0020.800.040.280.01
Dibenzolyuloyene0.060.0031.330.07NDND
Dibenzolya/hyperene0.070.0042.990.15NDND
\n \n ", "PMC6022086_007_00.png": "\n \n \n \n \n \n \n \n
MethodData TypeMean (m)RMSE (m)P90% (m)PGSD (%)
Improved FCMGaofen-35.775.8910.0794.37
Sentinel-16.305.8314.0380.00
Original FCMGaofen-36.977.6613.8790.70
Sentinel-18.534.8113.1490.00
\n \n "} -------------------------------------------------------------------------------- /srcs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaywalnut310/linear-transformer-for-table-recognition/ffe1a64f0a9cf3b798d995e8a2e8babf16ba2aca/srcs/__init__.py -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "changing-agenda", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from IPython.core.display import display, HTML\n", 13 | "from hydra.experimental import compose, initialize\n", 14 | "import numpy as np\n", 15 | "\n", 16 | "import os\n", 17 | "import math\n", 18 | "import glob\n", 19 | "import json\n", 20 | "from PIL import Image\n", 21 | "import yaml\n", 22 | "import torch\n", 23 | "from torch import nn, optim\n", 24 | "from torch.nn import functional as F\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "\n", 27 | "import utils\n", 28 | "import commons\n", 29 | "from models import TableRecognizer, get_positional_encoding\n", 30 | "import beam_search\n", 31 | "\n", 32 | "import tqdm\n", 33 | "from srcs.PubTabNet.src.metric import TEDS\n", 34 | "from srcs.PubTabNet.src.parallel import parallel_process" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "australian-canberra", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "class ImageLoader(torch.utils.data.Dataset):\n", 45 | " \"\"\"\n", 46 | " Load image\n", 47 | " \"\"\"\n", 48 | " def __init__(self, dir_path, cfg):\n", 49 | " self.cfg = cfg\n", 50 | " self.image_paths, self.lengths, \\\n", 51 | " self.image_heights, self.image_widths = self._build(dir_path)\n", 52 | " self.vocab = self._load_vocab()\n", 53 | "\n", 54 | " def _build(self, dir_path):\n", 55 | " image_paths = glob.glob(os.path.join(dir_path, \"*.png\"))\n", 56 | "\n", 57 | " tuple_list = []\n", 58 | " for image_path in image_paths:\n", 59 | " image = Image.open(image_path)\n", 60 | " w, h = [math.ceil(x / self.cfg.patch_length) for x in image.size]\n", 61 | " tuple_list.append((image_path, h*w, h, w))\n", 62 | " tuple_list.sort(key=lambda x: x[1], reverse=True)\n", 63 | " \n", 64 | " image_paths = []\n", 65 | " lengths = []\n", 66 | " image_heights = []\n", 67 | " image_widths = []\n", 68 | " for image_path, length, h, w in tuple_list:\n", 69 | " image_paths.append(image_path)\n", 70 | " lengths.append(length)\n", 71 | " image_heights.append(h)\n", 72 | " image_widths.append(w)\n", 73 | " return image_paths, lengths, image_heights, image_widths\n", 74 | "\n", 75 | " def _load_vocab(self):\n", 76 | " with open(self.cfg.vocab_path) as f:\n", 77 | " words = [x.replace('\\n', '') for x in f.readlines()]\n", 78 | " vocab = {word: idx for idx, word in enumerate(words)}\n", 79 | " return vocab\n", 80 | "\n", 81 | " def get_items(self, index):\n", 82 | " patch_length = self.cfg.patch_length\n", 83 | " h, w = self.image_heights[index], self.image_widths[index]\n", 84 | " c = 3\n", 85 | "\n", 86 | " image = Image.open(self.image_paths[index]).convert('RGB')\n", 87 | " image = (np.asarray(image, dtype=np.float32) / 255) * 2 - 1\n", 88 | " image = torch.from_numpy(image)\n", 89 | " image = torch.nn.functional.pad(image, [\n", 90 | " 0, 0,\n", 91 | " 0, (patch_length - (image.shape[1] % patch_length)) % patch_length,\n", 92 | " 0, (patch_length - (image.shape[0] % patch_length)) % patch_length\n", 93 | " ])\n", 94 | " image = image.view([h, patch_length, w, patch_length, c])\n", 95 | " image = image.permute(0, 2, 4, 1, 3)\n", 96 | " image = image.reshape(h * w, c * (patch_length ** 2))\n", 97 | "\n", 98 | " length = self.lengths[index]\n", 99 | " image_height = h\n", 100 | " image_width = w\n", 101 | " return (image, length, image_height, image_width)\n", 102 | "\n", 103 | " def __getitem__(self, index):\n", 104 | " return self.get_items(index)\n", 105 | "\n", 106 | " def __len__(self):\n", 107 | " return len(self.image_paths)\n", 108 | " \n", 109 | "\n", 110 | "class ImageCollate():\n", 111 | " \"\"\" Zero-pads model inputs\n", 112 | " \"\"\"\n", 113 | " def __call__(self, batch):\n", 114 | " \"\"\"Collate's training batch from image and text info\n", 115 | " Inputs:\n", 116 | " - batch: [img, t_tot, h_img, w_img]\n", 117 | "\n", 118 | " Outputs:\n", 119 | " - (img_padded, mask_img, pos_r, pos_c)\n", 120 | " \"\"\"\n", 121 | " max_len = max(x[1] for x in batch)\n", 122 | " b = len(batch)\n", 123 | " c = batch[0][0].size(1) # image patch size\n", 124 | "\n", 125 | " img_padded = torch.FloatTensor(b, max_len, c)\n", 126 | " mask_img = torch.FloatTensor(b, max_len, 1)\n", 127 | " pos_r = torch.FloatTensor(b, max_len, 1)\n", 128 | " pos_c = torch.FloatTensor(b, max_len, 1)\n", 129 | " \n", 130 | " img_padded.zero_()\n", 131 | " mask_img.zero_()\n", 132 | " pos_r.zero_()\n", 133 | " pos_c.zero_()\n", 134 | " for i in range(b):\n", 135 | " img, t_tot, h_img, w_img = batch[i]\n", 136 | "\n", 137 | " img_padded[i, :t_tot] = img\n", 138 | " mask_img[i, :t_tot] = 1\n", 139 | " pos_r[i, :t_tot] = torch.arange(h_img).unsqueeze(-1).repeat(1, w_img).view(-1, 1)\n", 140 | " pos_c[i, :t_tot] = torch.arange(w_img).repeat(h_img).view(-1, 1)\n", 141 | " return img_padded, mask_img, pos_r, pos_c" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "better-place", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "def inference(self, x_img, mask_img, pos_r, pos_c, idx_start=1, idx_end=2, max_decode_len=10000, beam_size=1, top_beams=1, alpha=1., n_toks=5000):\n", 152 | " from tqdm import tqdm\n", 153 | " with torch.no_grad():\n", 154 | " b = x_img.size(0)\n", 155 | " nh = self.n_heads\n", 156 | " d = self.hidden_channels // self.n_heads\n", 157 | " dtype = x_img.dtype\n", 158 | " device = x_img.device\n", 159 | "\n", 160 | " x_emb_img = self.emb_img(x_img, mask_img, pos_r, pos_c)\n", 161 | " cache = [{\n", 162 | " \"kv\": [],\n", 163 | " \"k_cum\": []\n", 164 | " } for _ in range(self.n_layers)\n", 165 | " ]\n", 166 | " n_split = max(n_toks // x_img.size(1), 1)\n", 167 | " n_iter = math.ceil(b / n_split)\n", 168 | " for i in range(n_iter):\n", 169 | " print(\"%05d\" % i, end='\\r')\n", 170 | " x_emb_img_iter = x_emb_img[i*n_split:(i+1)*n_split]\n", 171 | " mask_img_iter = mask_img[i*n_split:(i+1)*n_split]\n", 172 | " b_iter = x_emb_img_iter.size(0)\n", 173 | " \n", 174 | " cache_each = [{\n", 175 | " \"kv\": torch.zeros(b_iter, 1, nh, d, d).to(dtype=torch.float, device=device),\n", 176 | " \"k_cum\": torch.zeros(b_iter, 1, nh, d).to(dtype=torch.float, device=device)\n", 177 | " } for _ in range(self.n_layers)\n", 178 | " ]\n", 179 | " _ = self.enc(x_emb_img_iter, mask_img_iter, cache_each)\n", 180 | " for l in range(self.n_layers):\n", 181 | " cache[l][\"kv\"].append(cache_each[l][\"kv\"].clone())\n", 182 | " cache[l][\"k_cum\"].append(cache_each[l][\"k_cum\"].clone())\n", 183 | " for l in range(self.n_layers):\n", 184 | " cache[l][\"kv\"] = torch.cat(cache[l][\"kv\"], 0)\n", 185 | " cache[l][\"k_cum\"] = torch.cat(cache[l][\"k_cum\"], 0)\n", 186 | "\n", 187 | " pos_enc = get_positional_encoding(\n", 188 | " torch.arange(max_decode_len).view(1,-1,1).to(device=device), \n", 189 | " self.hidden_channels\n", 190 | " )\n", 191 | " \n", 192 | " if beam_size == 1:\n", 193 | " finished = torch.BoolTensor(b,1).to(device=device).fill_(False)\n", 194 | " idx = torch.zeros(b,1).long().to(device=device) + idx_start\n", 195 | " ids = []\n", 196 | " for i in tqdm(range(max_decode_len)):\n", 197 | " x_emb_txt = self.emb_txt.emb(idx) + pos_enc[:,i:i+1]\n", 198 | " x = self.enc(x_emb_txt, None, cache)\n", 199 | " logit_txt = self.proj_txt(x)\n", 200 | " idx = torch.argmax(logit_txt, -1)\n", 201 | " ids.append(idx)\n", 202 | " finished |= torch.eq(idx, idx_end)\n", 203 | " if torch.all(finished):\n", 204 | " break\n", 205 | " return ids\n", 206 | " else:\n", 207 | " def symbols_to_logits_fn(ids, i, cache):\n", 208 | " x_emb_txt = self.emb_txt.emb(ids[:,i:i+1]) + pos_enc[:,i:i+1]\n", 209 | " x = self.enc(x_emb_txt, None, cache)\n", 210 | " logit_txt = self.proj_txt(x)\n", 211 | " return logit_txt, cache\n", 212 | " initial_ids = torch.zeros(b).long().to(device=device) + idx_start\n", 213 | " decoded_ids, scores = beam_search.beam_search(\n", 214 | " symbols_to_logits_fn,\n", 215 | " initial_ids,\n", 216 | " beam_size,\n", 217 | " max_decode_len,\n", 218 | " self.n_vocab,\n", 219 | " alpha,\n", 220 | " states=cache,\n", 221 | " eos_id=idx_end,\n", 222 | " stop_early=(top_beams == 1))\n", 223 | "\n", 224 | " if top_beams == 1:\n", 225 | " decoded_ids = decoded_ids[:, 0, 1:]\n", 226 | " scores = scores[:, 0]\n", 227 | " else:\n", 228 | " decoded_ids = decoded_ids[:, :top_beams, 1:]\n", 229 | " scores = scores[:, :top_beams]\n", 230 | " return decoded_ids, scores" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "retired-prison", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "class HParams(dict):\n", 241 | " def __getattr__(self, name):\n", 242 | " value = self[name]\n", 243 | " if isinstance(value, dict):\n", 244 | " value = HParams(value)\n", 245 | " return value " 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "funded-seminar", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "model_dir = './outputs/base/'\n", 256 | "with open(os.path.join(model_dir, \".hydra/config.yaml\"), \"r\") as f:\n", 257 | " hps = HParams(yaml.full_load(f))" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "binding-causing", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "dataset = ImageLoader(\"/data/private/datasets/pubtabnet/val\", hps.data)\n", 268 | "collate_fn = ImageCollate()\n", 269 | "loader = DataLoader(dataset, num_workers=8, shuffle=False, pin_memory=False,\n", 270 | " collate_fn=collate_fn, batch_size=2**6)\n", 271 | "\n", 272 | "vocab_inv = {v: k for k, v in dataset.vocab.items()}" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "conscious-linux", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "prefix = ''\n", 283 | "postfix = '
'\n", 284 | "html_strings = []\n", 285 | "with torch.no_grad():\n", 286 | " for i, elms in enumerate(loader):\n", 287 | " print(i)\n", 288 | " (img, mask_img, pos_r, pos_c) = elms\n", 289 | " img = img.cuda()\n", 290 | " mask_img = mask_img.cuda()\n", 291 | " pos_r = pos_r.cuda()\n", 292 | " pos_c = pos_c.cuda()\n", 293 | " \n", 294 | " ret, _ = inference(model, img, mask_img, pos_r, pos_c, beam_size=32, alpha=0.6, max_decode_len=min(10000, math.ceil(3 * img.shape[1])))\n", 295 | " ret = ret.cpu().numpy()\n", 296 | " for j, r in enumerate(ret):\n", 297 | " try:\n", 298 | " eos_pos = list(r).index(2)\n", 299 | " r = r[:eos_pos]\n", 300 | " except:\n", 301 | " pass\n", 302 | " html_string = prefix + \"\".join([vocab_inv[x] for x in r]) + postfix\n", 303 | " html_strings.append(html_string)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "id": "spectacular-backup", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "with open(\"/data/private/datasets/pubtabnet/annotations/val.json\", \"r\") as f:\n", 314 | " data = json.load(f)\n", 315 | "\n", 316 | "image_names = [x.split(\"/\")[-1] for x in dataset.image_paths]\n", 317 | "pred = {img: txt for img, txt in zip(image_names, html_strings)}\n", 318 | "true = {x['image_path'].split(\"/\")[-1]: \"\".join(x['text'][1:-1]) for x in data if x['image_path'].split(\"/\")[-1] in pred}\n", 319 | "\n", 320 | "teds = TEDS(n_jobs=14)\n", 321 | "\n", 322 | "html_strings_pred = html_strings\n", 323 | "html_strings_tgt = [prefix + true[k] + postfix for k in image_names]\n", 324 | "\n", 325 | "inputs = [{\"pred\": pred, \"true\": true} for pred, true in zip(html_strings_pred, html_strings_tgt)]\n", 326 | "scores = parallel_process(inputs, teds.evaluate, use_kwargs=True, n_jobs=teds.n_jobs, front_num=1)\n", 327 | "np.mean(scores)" 328 | ] 329 | } 330 | ], 331 | "metadata": { 332 | "kernelspec": { 333 | "display_name": "Python 3", 334 | "language": "python", 335 | "name": "python3" 336 | }, 337 | "language_info": { 338 | "codemirror_mode": { 339 | "name": "ipython", 340 | "version": 3 341 | }, 342 | "file_extension": ".py", 343 | "mimetype": "text/x-python", 344 | "name": "python", 345 | "nbconvert_exporter": "python", 346 | "pygments_lexer": "ipython3", 347 | "version": "3.8.3" 348 | } 349 | }, 350 | "nbformat": 4, 351 | "nbformat_minor": 5 352 | } 353 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import hydra 4 | from omegaconf import OmegaConf 5 | import torch 6 | from torch import nn, optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.multiprocessing as mp 11 | import torch.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.cuda.amp import autocast, GradScaler 14 | 15 | import utils 16 | import commons 17 | from models import TableRecognizer 18 | from data_utils import ( 19 | ImageTextLoader, 20 | ImageTextCollate, 21 | DistributedBucketSampler 22 | ) 23 | from losses import ( 24 | loss_fn_img, 25 | loss_fn_txt 26 | ) 27 | 28 | 29 | global_step = 0 30 | 31 | 32 | @hydra.main(config_path='configs/', config_name='linear_transformer') 33 | def main(hps): 34 | """Assume Single Node Multi GPUs Training Only""" 35 | assert torch.cuda.is_available(), "CPU training is not allowed." 36 | assert OmegaConf.select(hps, "model_dir") != ".hydra", "Please specify model_dir." 37 | print(OmegaConf.to_yaml(hps)) 38 | 39 | n_gpus = torch.cuda.device_count() 40 | os.environ['MASTER_ADDR'] = 'localhost' 41 | os.environ['MASTER_PORT'] = '80000' 42 | 43 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 44 | 45 | 46 | def run(rank, n_gpus, hps): 47 | global global_step 48 | if rank == 0: 49 | writer = SummaryWriter(log_dir='./') 50 | writer_eval = SummaryWriter(log_dir='./eval') 51 | 52 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 53 | torch.manual_seed(hps.train.seed) 54 | torch.cuda.set_device(rank) 55 | 56 | train_dataset = ImageTextLoader(hps.data.training_file_path, hps.data) 57 | train_sampler = DistributedBucketSampler( 58 | train_dataset, 59 | hps.train.num_tokens, 60 | num_replicas=n_gpus, 61 | rank=rank, 62 | shuffle=True) 63 | collate_fn = ImageTextCollate() 64 | train_loader = DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, 65 | collate_fn=collate_fn, batch_sampler=train_sampler) 66 | if rank == 0: 67 | eval_dataset = ImageTextLoader(hps.data.validation_file_path, hps.data) 68 | eval_sampler = DistributedBucketSampler( 69 | eval_dataset, 70 | hps.train.num_tokens, 71 | num_replicas=1, 72 | rank=rank, 73 | shuffle=True) 74 | eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, pin_memory=False, 75 | collate_fn=collate_fn, batch_sampler=eval_sampler) 76 | 77 | model = TableRecognizer( 78 | len(train_dataset.vocab), 79 | 3 * (hps.data.patch_length ** 2), 80 | **hps.model).cuda(rank) 81 | optim = torch.optim.Adam( 82 | model.parameters(), 83 | hps.train.learning_rate, 84 | betas=hps.train.betas, 85 | eps=hps.train.eps) 86 | model = DDP(model, device_ids=[rank]) 87 | 88 | try: 89 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path('./', "model_*.pth"), model, optim) 90 | global_step = (epoch_str - 1) * len(train_loader) 91 | except: 92 | epoch_str = 1 93 | global_step = 0 94 | 95 | scaler = GradScaler(enabled=hps.train.fp16_run) 96 | 97 | for epoch in range(epoch_str, hps.train.epochs + 1): 98 | if rank==0: 99 | train_and_evaluate(rank, epoch, hps, model, optim, scaler, [train_loader, eval_loader], [writer, writer_eval]) 100 | else: 101 | train_and_evaluate(rank, epoch, hps, model, optim, scaler, [train_loader, None], None) 102 | 103 | 104 | def train_and_evaluate(rank, epoch, hps, model, optim, scaler, loaders, writers): 105 | train_loader, eval_loader = loaders 106 | if writers is not None: 107 | writer, writer_eval = writers 108 | 109 | train_loader.batch_sampler.set_epoch(epoch) 110 | global global_step 111 | 112 | model.train() 113 | for batch_idx, (img, txt, mask_img, mask_txt, pos_r, pos_c, pos_t) in enumerate(train_loader): 114 | img = img.cuda(rank, non_blocking=True) 115 | txt = txt.cuda(rank, non_blocking=True) 116 | mask_img = mask_img.cuda(rank, non_blocking=True) 117 | mask_txt = mask_txt.cuda(rank, non_blocking=True) 118 | pos_r = pos_r.cuda(rank, non_blocking=True) 119 | pos_c = pos_c.cuda(rank, non_blocking=True) 120 | pos_t = pos_t.cuda(rank, non_blocking=True) 121 | 122 | img_i = img[:,:-1] 123 | img_o = img[:,1:] 124 | txt_i = txt[:,:-1] 125 | txt_o = txt[:,1:] 126 | mask_img_i = mask_img[:,:-1] 127 | mask_img_o = mask_img[:,1:] 128 | mask_txt_i = mask_txt[:,:-1] 129 | mask_txt_o = mask_txt[:,1:,0] 130 | 131 | with autocast(enabled=hps.train.fp16_run): 132 | logits_img, logits_txt = model(img_i, txt_i, mask_img_i, mask_txt_i, pos_r, pos_c, pos_t) 133 | with autocast(enabled=False): 134 | loss_img = loss_fn_img(logits_img, img_o, mask_img_o) 135 | loss_txt = loss_fn_txt(logits_txt, txt_o, mask_txt_o) 136 | loss_tot = loss_img * hps.train.lamb + loss_txt 137 | optim.zero_grad() 138 | scaler.scale(loss_tot).backward() 139 | scaler.unscale_(optim) 140 | grad_norm = commons.grad_norm(model.parameters()) 141 | scaler.step(optim) 142 | scaler.update() 143 | 144 | if rank==0: 145 | num_tokens = mask_img.sum() + mask_txt.sum() 146 | if global_step % hps.train.log_interval == 0: 147 | lr = optim.param_groups[0]['lr'] 148 | losses = [loss_tot, loss_img, loss_txt] 149 | print('Train Epoch: {} [{:.0f}%]'.format( 150 | epoch, 151 | 100. * batch_idx / len(train_loader))) 152 | print([x.item() for x in losses] + [global_step, lr]) 153 | 154 | scalar_dict = {"loss/total": loss_tot, "loss/img": loss_img, "loss/txt": loss_txt} 155 | scalar_dict.update({"learning_rate": lr, "grad_norm": grad_norm, "num_tokens": num_tokens}) 156 | 157 | utils.summarize( 158 | writer=writer, 159 | global_step=global_step, 160 | scalars=scalar_dict) 161 | 162 | if global_step % hps.train.eval_interval == 0: 163 | print("START: EVAL") 164 | eval_loader.batch_sampler.set_epoch(global_step) 165 | evaluate(hps, model, eval_loader, writer_eval) 166 | utils.save_checkpoint(model, optim, hps.train.learning_rate, epoch, 167 | "model_{}.pth".format(global_step) 168 | ) 169 | print("END: EVAL") 170 | global_step += 1 171 | 172 | if rank == 0: 173 | print('====> Epoch: {}'.format(epoch)) 174 | 175 | 176 | def evaluate(hps, model, eval_loader, writer_eval): 177 | model.eval() 178 | with torch.no_grad(): 179 | for batch_idx, (img, txt, mask_img, mask_txt, pos_r, pos_c, pos_t) in enumerate(eval_loader): 180 | img = img.cuda(0, non_blocking=True) 181 | txt = txt.cuda(0, non_blocking=True) 182 | mask_img = mask_img.cuda(0, non_blocking=True) 183 | mask_txt = mask_txt.cuda(0, non_blocking=True) 184 | pos_r = pos_r.cuda(0, non_blocking=True) 185 | pos_c = pos_c.cuda(0, non_blocking=True) 186 | pos_t = pos_t.cuda(0, non_blocking=True) 187 | 188 | img_i = img[:,:-1] 189 | img_o = img[:,1:] 190 | txt_i = txt[:,:-1] 191 | txt_o = txt[:,1:] 192 | mask_img_i = mask_img[:,:-1] 193 | mask_img_o = mask_img[:,1:] 194 | mask_txt_i = mask_txt[:,:-1] 195 | mask_txt_o = mask_txt[:,1:,0] 196 | 197 | with autocast(enabled=hps.train.fp16_run): 198 | logits_img, logits_txt = model(img_i, txt_i, mask_img_i, mask_txt_i, pos_r, pos_c, pos_t) 199 | with autocast(enabled=False): 200 | loss_img = loss_fn_img(logits_img, img_o, mask_img_o) 201 | loss_txt = loss_fn_txt(logits_txt, txt_o, mask_txt_o) 202 | loss_tot = loss_img * hps.train.lamb + loss_txt 203 | break 204 | 205 | scalar_dict = {"loss/total": loss_tot.item(), "loss/img": loss_img.item(), "loss/txt": loss_txt.item()} 206 | 207 | utils.summarize( 208 | writer=writer_eval, 209 | global_step=global_step, 210 | scalars=scalar_dict, 211 | ) 212 | model.train() 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import subprocess 4 | import torch 5 | 6 | 7 | def load_checkpoint(checkpoint_path, model, optimizer=None): 8 | assert os.path.isfile(checkpoint_path) 9 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 10 | iteration = checkpoint_dict['iteration'] 11 | learning_rate = checkpoint_dict['learning_rate'] 12 | if optimizer is not None: 13 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 14 | saved_state_dict = checkpoint_dict['model'] 15 | if hasattr(model, 'module'): 16 | state_dict = model.module.state_dict() 17 | else: 18 | state_dict = model.state_dict() 19 | new_state_dict= {} 20 | for k, v in state_dict.items(): 21 | try: 22 | new_state_dict[k] = saved_state_dict[k] 23 | except: 24 | print("%s is not in the checkpoint" % k) 25 | new_state_dict[k] = v 26 | if hasattr(model, 'module'): 27 | model.module.load_state_dict(new_state_dict) 28 | else: 29 | model.load_state_dict(new_state_dict) 30 | print("Loaded checkpoint '{}' (iteration {})" .format( 31 | checkpoint_path, iteration)) 32 | return model, optimizer, learning_rate, iteration 33 | 34 | 35 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 36 | print("Saving model and optimizer state at iteration {} to {}".format( 37 | iteration, checkpoint_path)) 38 | if hasattr(model, 'module'): 39 | state_dict = model.module.state_dict() 40 | else: 41 | state_dict = model.state_dict() 42 | torch.save({'model': state_dict, 43 | 'iteration': iteration, 44 | 'optimizer': optimizer.state_dict(), 45 | 'learning_rate': learning_rate}, checkpoint_path) 46 | 47 | 48 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 49 | for k, v in scalars.items(): 50 | writer.add_scalar(k, v, global_step) 51 | for k, v in histograms.items(): 52 | writer.add_histogram(k, v, global_step) 53 | for k, v in images.items(): 54 | writer.add_image(k, v, global_step, dataformats='HWC') 55 | for k, v in audios.items(): 56 | writer.add_audio(k, v, global_step, audio_sampling_rate) 57 | 58 | 59 | def latest_checkpoint_path(dir_path, regex="model_*.pth"): 60 | f_list = glob.glob(os.path.join(dir_path, regex)) 61 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 62 | x = f_list[-1] 63 | print(x) 64 | return x 65 | --------------------------------------------------------------------------------